diff --git a/.gitattributes b/.gitattributes index a6344aac8c09253b3b630fb776ae94478aa0275b..ca6b382d9dbc68686689c5fda8a58348654ed43a 100644 --- a/.gitattributes +++ b/.gitattributes @@ -33,3 +33,121 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text *.zip filter=lfs diff=lfs merge=lfs -text *.zst filter=lfs diff=lfs merge=lfs -text *tfevents* filter=lfs diff=lfs merge=lfs -text +checkpoints/offline/Craftax-Classic-Symbolic-v1-OfflineDiffusion-BC-100M/100000000/default/ocdbt.process_0/d/0cf9a08a9722f9b8a0b7f007da7c1e92 filter=lfs diff=lfs merge=lfs -text +checkpoints/offline/Craftax-Classic-Symbolic-v1-OfflineDiffusion-BC-100M/100000000/default/ocdbt.process_0/d/d32b0ea672fe7a9b86b8e62e7c20dbaf filter=lfs diff=lfs merge=lfs -text +checkpoints/online/Craftax-Classic-Symbolic-v1-OnlineDiffusion-DAgger-50M/50000000/default/ocdbt.process_0/d/fd4cdc0c7be44d4f518c6ccca7ad654a filter=lfs diff=lfs merge=lfs -text +checkpoints/ppo_agents/Craftax-Classic-Symbolic-v1-PPO_RNN-1000M/1000000000/default/ocdbt.process_0/d/369457b7c6608f1adf28eb88024d6b91 filter=lfs diff=lfs merge=lfs -text +checkpoints/ppo_agents/Craftax-Classic-Symbolic-v1-PPO_RNN-1000M/1000000000/default/ocdbt.process_0/d/6bb110a840637eede93e25f5356236f9 filter=lfs diff=lfs merge=lfs -text +checkpoints/ppo_agents/Craftax-Symbolic-v1-PPO_RNN-1000M/1000000000/default/ocdbt.process_0/d/e5a4020f50167115120fe5dac41c20fb filter=lfs diff=lfs merge=lfs -text +experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/achievement_breakdown.png filter=lfs diff=lfs merge=lfs -text +experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/achievement_collapse_action_diversity.png filter=lfs diff=lfs merge=lfs -text +experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/achievement_collapse_advantage_clip.png filter=lfs diff=lfs merge=lfs -text +experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/achievement_collapse_attention_only.png filter=lfs diff=lfs merge=lfs -text +experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/achievement_collapse_baseline_rl.png filter=lfs diff=lfs merge=lfs -text +experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/achievement_collapse_bc_wins.png filter=lfs diff=lfs merge=lfs -text +experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/achievement_collapse_entropy_bonus.png filter=lfs diff=lfs merge=lfs -text +experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/achievement_collapse_ewc.png filter=lfs diff=lfs merge=lfs -text +experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/achievement_collapse_ffn_only.png filter=lfs diff=lfs merge=lfs -text +experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/achievement_collapse_frozen_backbone.png filter=lfs diff=lfs merge=lfs -text +experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/achievement_collapse_gradient_surgery.png filter=lfs diff=lfs merge=lfs -text +experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/achievement_collapse_head_only.png filter=lfs diff=lfs merge=lfs -text +experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/achievement_collapse_kl_penalty.png filter=lfs diff=lfs merge=lfs -text +experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/achievement_collapse_layer_ablation_top1.png filter=lfs diff=lfs merge=lfs -text +experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/achievement_collapse_layer_ablation_top2.png filter=lfs diff=lfs merge=lfs -text +experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/achievement_collapse_layer_ablation_top3.png filter=lfs diff=lfs merge=lfs -text +experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/achievement_collapse_llrd.png filter=lfs diff=lfs merge=lfs -text +experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/achievement_collapse_lora.png filter=lfs diff=lfs merge=lfs -text +experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/achievement_collapse_low_t.png filter=lfs diff=lfs merge=lfs -text +experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/achievement_collapse_mixed_replay.png filter=lfs diff=lfs merge=lfs -text +experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/achievement_collapse_normalized_adv.png filter=lfs diff=lfs merge=lfs -text +experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/achievement_collapse_reward_filtering.png filter=lfs diff=lfs merge=lfs -text +experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/achievement_collapse_reward_model.png filter=lfs diff=lfs merge=lfs -text +experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/achievement_collapse_running_stats.png filter=lfs diff=lfs merge=lfs -text +experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/achievement_collapse_t_curriculum.png filter=lfs diff=lfs merge=lfs -text +experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/achievement_collapse_trust_region_kl.png filter=lfs diff=lfs merge=lfs -text +experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/js_divergence_comparison.png filter=lfs diff=lfs merge=lfs -text +experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/cka_similarity.png filter=lfs diff=lfs merge=lfs -text +experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/curves_action_diversity.png filter=lfs diff=lfs merge=lfs -text +experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/curves_advantage_clip.png filter=lfs diff=lfs merge=lfs -text +experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/curves_attention_only.png filter=lfs diff=lfs merge=lfs -text +experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/curves_baseline_rl.png filter=lfs diff=lfs merge=lfs -text +experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/curves_bc_wins.png filter=lfs diff=lfs merge=lfs -text +experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/curves_entropy_bonus.png filter=lfs diff=lfs merge=lfs -text +experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/curves_ewc.png filter=lfs diff=lfs merge=lfs -text +experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/curves_ffn_only.png filter=lfs diff=lfs merge=lfs -text +experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/curves_frozen_backbone.png filter=lfs diff=lfs merge=lfs -text +experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/curves_gradient_surgery.png filter=lfs diff=lfs merge=lfs -text +experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/curves_head_only.png filter=lfs diff=lfs merge=lfs -text +experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/curves_kl_penalty.png filter=lfs diff=lfs merge=lfs -text +experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/curves_layer_ablation_top1.png filter=lfs diff=lfs merge=lfs -text +experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/curves_layer_ablation_top2.png filter=lfs diff=lfs merge=lfs -text +experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/curves_layer_ablation_top3.png filter=lfs diff=lfs merge=lfs -text +experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/curves_llrd.png filter=lfs diff=lfs merge=lfs -text +experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/curves_lora.png filter=lfs diff=lfs merge=lfs -text +experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/curves_low_t.png filter=lfs diff=lfs merge=lfs -text +experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/curves_mixed_replay.png filter=lfs diff=lfs merge=lfs -text +experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/curves_normalized_adv.png filter=lfs diff=lfs merge=lfs -text +experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/curves_reward_filtering.png filter=lfs diff=lfs merge=lfs -text +experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/curves_reward_model.png filter=lfs diff=lfs merge=lfs -text +experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/curves_running_stats.png filter=lfs diff=lfs merge=lfs -text +experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/curves_t_curriculum.png filter=lfs diff=lfs merge=lfs -text +experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/curves_trust_region_kl.png filter=lfs diff=lfs merge=lfs -text +experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/diagnosis_decision_tree.png filter=lfs diff=lfs merge=lfs -text +experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/eval_scores_over_training.png filter=lfs diff=lfs merge=lfs -text +experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/final_score_comparison.png filter=lfs diff=lfs merge=lfs -text +experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/gradient_alignment.png filter=lfs diff=lfs merge=lfs -text +experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/gradient_conflict_map.png filter=lfs diff=lfs merge=lfs -text +experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/per_layer_grad_heatmap_action_diversity.png filter=lfs diff=lfs merge=lfs -text +experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/per_layer_grad_heatmap_advantage_clip.png filter=lfs diff=lfs merge=lfs -text +experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/per_layer_grad_heatmap_attention_only.png filter=lfs diff=lfs merge=lfs -text +experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/per_layer_grad_heatmap_baseline_rl.png filter=lfs diff=lfs merge=lfs -text +experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/per_layer_grad_heatmap_bc_wins.png filter=lfs diff=lfs merge=lfs -text +experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/per_layer_grad_heatmap_entropy_bonus.png filter=lfs diff=lfs merge=lfs -text +experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/per_layer_grad_heatmap_ewc.png filter=lfs diff=lfs merge=lfs -text +experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/per_layer_grad_heatmap_ffn_only.png filter=lfs diff=lfs merge=lfs -text +experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/per_layer_grad_heatmap_frozen_backbone.png filter=lfs diff=lfs merge=lfs -text +experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/per_layer_grad_heatmap_gradient_surgery.png filter=lfs diff=lfs merge=lfs -text +experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/per_layer_grad_heatmap_head_only.png filter=lfs diff=lfs merge=lfs -text +experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/per_layer_grad_heatmap_kl_penalty.png filter=lfs diff=lfs merge=lfs -text +experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/per_layer_grad_heatmap_layer_ablation_top1.png filter=lfs diff=lfs merge=lfs -text +experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/per_layer_grad_heatmap_layer_ablation_top2.png filter=lfs diff=lfs merge=lfs -text +experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/per_layer_grad_heatmap_layer_ablation_top3.png filter=lfs diff=lfs merge=lfs -text +experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/per_layer_grad_heatmap_llrd.png filter=lfs diff=lfs merge=lfs -text +experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/per_layer_grad_heatmap_lora.png filter=lfs diff=lfs merge=lfs -text +experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/per_layer_grad_heatmap_low_t.png filter=lfs diff=lfs merge=lfs -text +experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/per_layer_grad_heatmap_mixed_replay.png filter=lfs diff=lfs merge=lfs -text +experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/per_layer_grad_heatmap_normalized_adv.png filter=lfs diff=lfs merge=lfs -text +experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/per_layer_grad_heatmap_reward_filtering.png filter=lfs diff=lfs merge=lfs -text +experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/per_layer_grad_heatmap_reward_model.png filter=lfs diff=lfs merge=lfs -text +experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/per_layer_grad_heatmap_running_stats.png filter=lfs diff=lfs merge=lfs -text +experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/per_layer_grad_heatmap_t_curriculum.png filter=lfs diff=lfs merge=lfs -text +experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/per_layer_grad_heatmap_trust_region_kl.png filter=lfs diff=lfs merge=lfs -text +experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/representation_drift.png filter=lfs diff=lfs merge=lfs -text +experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/score_delta_over_baseline_rl.png filter=lfs diff=lfs merge=lfs -text +experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/t_bin_grad_norms_action_diversity.png filter=lfs diff=lfs merge=lfs -text +experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/t_bin_grad_norms_advantage_clip.png filter=lfs diff=lfs merge=lfs -text +experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/t_bin_grad_norms_attention_only.png filter=lfs diff=lfs merge=lfs -text +experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/t_bin_grad_norms_baseline_rl.png filter=lfs diff=lfs merge=lfs -text +experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/t_bin_grad_norms_bc_wins.png filter=lfs diff=lfs merge=lfs -text +experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/t_bin_grad_norms_entropy_bonus.png filter=lfs diff=lfs merge=lfs -text +experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/t_bin_grad_norms_ewc.png filter=lfs diff=lfs merge=lfs -text +experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/t_bin_grad_norms_ffn_only.png filter=lfs diff=lfs merge=lfs -text +experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/t_bin_grad_norms_frozen_backbone.png filter=lfs diff=lfs merge=lfs -text +experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/t_bin_grad_norms_gradient_surgery.png filter=lfs diff=lfs merge=lfs -text +experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/t_bin_grad_norms_head_only.png filter=lfs diff=lfs merge=lfs -text +experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/t_bin_grad_norms_kl_penalty.png filter=lfs diff=lfs merge=lfs -text +experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/t_bin_grad_norms_layer_ablation_top1.png filter=lfs diff=lfs merge=lfs -text +experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/t_bin_grad_norms_layer_ablation_top2.png filter=lfs diff=lfs merge=lfs -text +experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/t_bin_grad_norms_layer_ablation_top3.png filter=lfs diff=lfs merge=lfs -text +experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/t_bin_grad_norms_llrd.png filter=lfs diff=lfs merge=lfs -text +experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/t_bin_grad_norms_low_t.png filter=lfs diff=lfs merge=lfs -text +experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/t_bin_grad_norms_mixed_replay.png filter=lfs diff=lfs merge=lfs -text +experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/t_bin_grad_norms_normalized_adv.png filter=lfs diff=lfs merge=lfs -text +experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/t_bin_grad_norms_reward_filtering.png filter=lfs diff=lfs merge=lfs -text +experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/t_bin_grad_norms_reward_model.png filter=lfs diff=lfs merge=lfs -text +experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/t_bin_grad_norms_running_stats.png filter=lfs diff=lfs merge=lfs -text +experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/t_bin_grad_norms_t_curriculum.png filter=lfs diff=lfs merge=lfs -text +experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/t_bin_grad_norms_trust_region_kl.png filter=lfs diff=lfs merge=lfs -text +experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/t_bin_norms_heatmap.png filter=lfs diff=lfs merge=lfs -text +experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/t_distribution_analysis.png filter=lfs diff=lfs merge=lfs -text +experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/win_rate_and_effective_batch_size.png filter=lfs diff=lfs merge=lfs -text diff --git a/Craftax_Baselines/.gitignore b/Craftax_Baselines/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..dfa0dc993a3761abbcb3138721b366740183e49c --- /dev/null +++ b/Craftax_Baselines/.gitignore @@ -0,0 +1,169 @@ +tmp/ +wandb/ +res/ +runs/ + +play_data + +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# poetry +# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control +#poetry.lock + +# pdm +# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. +#pdm.lock +# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it +# in version control. +# https://pdm.fming.dev/#use-with-ide +.pdm.toml + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +# PyCharm +# JetBrains specific template is maintained in a separate JetBrains.gitignore that can +# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore +# and can be added to the global gitignore or merged into this file. For a more nuclear +# option (not recommended) you can uncomment the following to ignore the entire idea folder. +.idea/ +texture_cache.pbz2 +texture_cache*.pbz2 diff --git a/Craftax_Baselines/.pre-commit-config.yaml b/Craftax_Baselines/.pre-commit-config.yaml new file mode 100644 index 0000000000000000000000000000000000000000..ad94cafd5029ff0caaa470d5a76143bd92d6150e --- /dev/null +++ b/Craftax_Baselines/.pre-commit-config.yaml @@ -0,0 +1,6 @@ +repos: +- repo: https://github.com/psf/black + rev: 22.3.0 + hooks: + - id: black + language_version: python3 \ No newline at end of file diff --git a/Craftax_Baselines/Dockerfile b/Craftax_Baselines/Dockerfile new file mode 100644 index 0000000000000000000000000000000000000000..55ef0a0d7afec9bfafdc33f8510d7e2ff69385a5 --- /dev/null +++ b/Craftax_Baselines/Dockerfile @@ -0,0 +1,41 @@ +FROM nvidia/cuda:12.1.0-cudnn8-devel-ubuntu22.04 + +ENV CUDA_PATH /usr/local/cuda +ENV CUDA_INCLUDE_PATH /usr/local/cuda/include +ENV CUDA_LIBRARY_PATH /usr/local/cuda/lib64 + +# Set timezone +ENV TZ=Europe/London DEBIAN_FRONTEND=noninteractive + +# Add Python 3.8 to Ubuntu 22.04 and install dependencies +RUN apt update +RUN apt install -y software-properties-common && add-apt-repository ppa:deadsnakes/ppa +RUN apt install -y \ + git \ + python3.8 \ + python3-pip \ + python3.8-venv \ + python3-setuptools \ + python3-wheel + +# Create local user +# https://jtreminio.com/blog/running-docker-containers-as-current-host-user/ +ARG UID +ARG GID +RUN if [ ${UID:-0} -ne 0 ] && [ ${GID:-0} -ne 0 ]; then \ + groupadd -g ${GID} duser &&\ + useradd -l -u ${UID} -g duser duser &&\ + install -d -m 0755 -o duser -g duser /home/duser &&\ + chown --changes --silent --no-dereference --recursive ${UID}:${GID} /home/duser \ + ;fi + +USER duser +WORKDIR /home/duser + +# Install Python packages +ENV PATH="/home/duser/.local/bin:$PATH" +RUN python3 -m pip install --upgrade pip +ARG REQS +RUN pip install $REQS -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html + +WORKDIR /home/duser/Craftax diff --git a/Craftax_Baselines/LICENSE b/Craftax_Baselines/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..70b0f5878e5bd2bc15aec41eb21b98f5a4571ebd --- /dev/null +++ b/Craftax_Baselines/LICENSE @@ -0,0 +1,19 @@ +Copyright (c) 2024 Michael Matthews + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. \ No newline at end of file diff --git a/Craftax_Baselines/README.md b/Craftax_Baselines/README.md new file mode 100644 index 0000000000000000000000000000000000000000..0f521ba316623e508c55bb33531db221e7f1a64b --- /dev/null +++ b/Craftax_Baselines/README.md @@ -0,0 +1,46 @@ +

+ +

+ +# Craftax Baselines + +This repository contains the code for running the baselines from the [Craftax paper](https://arxiv.org/abs/2402.16801). +For packaging reasons, this is separate to the [main repository](https://github.com/MichaelTMatthews/Craftax/). + +# Installation +```commandline +git clone https://github.com/MichaelTMatthews/Craftax_Baselines.git +cd Craftax_Baselines +pip install -r requirements.txt -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html +pre-commit install +``` + +# Run Experiments + +### PPO +```commandline +python ppo.py +``` + +### PPO-RNN +```commandline +python ppo_rnn.py +``` + +### ICM +```commandline +python ppo.py --train_icm +``` + +### E3B +```commandline +python ppo.py --train_icm --use_e3b --icm_reward_coeff 0 +``` + +### RND +```commandline +python ppo_rnd.py +``` + +# Visualisation +You can save trained policies with the `--save_policy` flag. These can then be viewed with the `view_ppo_agent` script (pass in the path up to the `files` directory). \ No newline at end of file diff --git a/Craftax_Baselines/analysis/__init__.py b/Craftax_Baselines/analysis/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/Craftax_Baselines/analysis/view_ppo_agent.py b/Craftax_Baselines/analysis/view_ppo_agent.py new file mode 100644 index 0000000000000000000000000000000000000000..c48da547f9b92025515c4c8d96e29881237e9e23 --- /dev/null +++ b/Craftax_Baselines/analysis/view_ppo_agent.py @@ -0,0 +1,151 @@ +import argparse +import os +import sys + +import jax +import jax.numpy as jnp +import numpy as np +import optax +import yaml +from craftax.environment_base.wrappers import AutoResetEnvWrapper +from flax.training.train_state import TrainState +import orbax.checkpoint as ocp + +from ..models.actor_critic import ActorCriticConv, ActorCritic + + +def main(args): + + with open(os.path.join(args.path, "config.yaml")) as f: + raw_config = yaml.load(f, Loader=yaml.Loader) + + config = {} + for key, value in raw_config.items(): + if isinstance(value, dict) and "value" in value: + config[key] = value["value"] + + config["NUM_ENVS"] = 1 + + options = ocp.CheckpointManagerOptions(max_to_keep=1) + checkpoint_manager = ocp.CheckpointManager( + os.path.join(args.path, "policies"), + options=options + ) + + is_classic = False + + if config["ENV_NAME"] == "Craftax-Symbolic-v1": + from craftax.craftax.envs.craftax_symbolic_env import CraftaxSymbolicEnv + from craftax.craftax.constants import Action + + env = CraftaxSymbolicEnv(CraftaxSymbolicEnv.default_static_params()) + network = ActorCritic(len(Action), config["LAYER_SIZE"]) + elif config["ENV_NAME"] == "Craftax-Pixels-v1": + from craftax.craftax.envs.craftax_pixels_env import CraftaxPixelsEnv + from craftax.craftax.constants import Action + + env = CraftaxPixelsEnv(CraftaxPixelsEnv.default_static_params()) + network = ActorCriticConv(len(Action), config["LAYER_SIZE"]) + elif config["ENV_NAME"] == "Craftax-Classic-Symbolic-v1": + from craftax.craftax_classic.envs.craftax_symbolic_env import ( + CraftaxClassicSymbolicEnv, + ) + from craftax.craftax_classic.constants import Action + + env = CraftaxClassicSymbolicEnv( + CraftaxClassicSymbolicEnv.default_static_params() + ) + network = ActorCritic(len(Action), config["LAYER_SIZE"]) + is_classic = True + elif config["ENV_NAME"] == "Craftax-Classic-Pixels-v1": + from craftax.craftax_classic.envs.craftax_pixels_env import ( + CraftaxClassicPixelsEnv, + ) + from craftax.craftax_classic.constants import Action + + env = CraftaxClassicPixelsEnv(CraftaxClassicPixelsEnv.default_static_params()) + network = ActorCriticConv(len(Action), config["LAYER_SIZE"]) + is_classic = True + else: + raise ValueError(f"Unknown env: {config['ENV_NAME']}") + + env = AutoResetEnvWrapper(env) + env_params = env.default_params + + init_x = jnp.zeros((config["NUM_ENVS"], *env.observation_space(env_params).shape)) + + rng = jax.random.PRNGKey(np.random.randint(2**31)) + rng, _rng, __rng = jax.random.split(rng, 3) + network_params = network.init(_rng, init_x) + + tx = optax.chain( + optax.clip_by_global_norm(config["MAX_GRAD_NORM"]), + optax.adam(config["LR"], eps=1e-5), + ) + train_state = TrainState.create( + apply_fn=network.apply, + params=network_params, + tx=tx, + ) + + abstract_train_state = jax.eval_shape(lambda: train_state) + + train_state = checkpoint_manager.restore( + config["TOTAL_TIMESTEPS"], + args=ocp.args.StandardRestore(abstract_train_state) + ) + + obs, env_state = env.reset(key=__rng) + done = 0 + + if is_classic: + from craftax.craftax_classic.play_craftax_classic import CraftaxRenderer + from craftax.craftax_classic.constants import Achievement + else: + from craftax.craftax.play_craftax import CraftaxRenderer + from craftax.craftax.constants import Achievement + + renderer = CraftaxRenderer(env, env_params, pixel_render_size=1) + + while not renderer.is_quit_requested(): + done = np.array([done], dtype=bool) + obs = jnp.expand_dims(obs, axis=0) + + pi, value = network.apply(train_state.params, obs) + rng, _rng = jax.random.split(rng) + action = pi.sample(seed=_rng)[0] + # action = jnp.argmax(pi.probs[0, 0]) + + if action is not None: + rng, _rng = jax.random.split(rng) + old_achievements = env_state.achievements + obs, env_state, reward, done, info = env.step( + _rng, env_state, action, env_params + ) + new_achievements = env_state.achievements + print_new_achievements(Achievement, old_achievements, new_achievements) + if done: + print("\n") + renderer.render(env_state) + + +def print_new_achievements(achievements_cls, old_achievements, new_achievements): + for i in range(len(old_achievements)): + if old_achievements[i] == 0 and new_achievements[i] == 1: + print(f"{achievements_cls(i).name} ({new_achievements.sum()}/{22})") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--path", type=str) + parser.add_argument("--debug", action="store_true") + + args, rest_args = parser.parse_known_args(sys.argv[1:]) + if rest_args: + raise ValueError(f"Unknown args {rest_args}") + + if args.debug: + with jax.disable_jit(): + main(args) + else: + main(args) diff --git a/Craftax_Baselines/build.sh b/Craftax_Baselines/build.sh new file mode 100644 index 0000000000000000000000000000000000000000..ebc6ecce6ac49188a51e9e31ebbc632bec9622a4 --- /dev/null +++ b/Craftax_Baselines/build.sh @@ -0,0 +1,10 @@ +#!/bin/bash + +echo 'Building Dockerfile with image name craftax' +docker build \ + --build-arg UID=$(id -u ${USER}) \ + --build-arg GID=1234 \ + --build-arg REQS="$(cat requirements.txt)" \ + -t craftax_baselines \ + --no-cache \ + . diff --git a/Craftax_Baselines/images/logo.png b/Craftax_Baselines/images/logo.png new file mode 100644 index 0000000000000000000000000000000000000000..baebfeda0d5e6ce832bc0710afd8c1ab14859434 Binary files /dev/null and b/Craftax_Baselines/images/logo.png differ diff --git a/Craftax_Baselines/logz/__init__.py b/Craftax_Baselines/logz/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/Craftax_Baselines/logz/batch_logging.py b/Craftax_Baselines/logz/batch_logging.py new file mode 100644 index 0000000000000000000000000000000000000000..6c5c397478773a0f085abde3a72cc60cf5d217be --- /dev/null +++ b/Craftax_Baselines/logz/batch_logging.py @@ -0,0 +1,115 @@ +import time + +import jax.numpy as jnp +import numpy as np +import wandb + +batch_logs = {} +log_times = [] + + +def create_log_dict(info, config): + to_log = { + "episode_return": info["returned_episode_returns"], + "episode_length": info["returned_episode_lengths"], + } + + diffusion_keys = [ + "loss", "unweighted_loss", "accuracy", "mean_t", + "acc_t_low", "acc_t_mid", "acc_t_high", "grad_norm", + "action_entropy", "action_unique_frac" + ] + for k in diffusion_keys: + if k in info: + to_log[f"diffusion/{k}"] = info[k] + + sum_achievements = 0.0 + sum_val_achievements = 0.0 + has_val = False + + for k, v in info.items(): + if k.startswith("val/"): + has_val = True + to_log[k] = v + if "achievements" in k.lower() and k != "val/achievements": + sum_val_achievements += v / 100.0 + elif "achievements" in k.lower(): + to_log[k] = v + if k != "achievements": + sum_achievements += v / 100.0 + + to_log["achievements"] = sum_achievements + if has_val: + to_log["val/achievements"] = sum_val_achievements + + if config.get("TRAIN_ICM") or config.get("USE_RND"): + to_log["intrinsic_reward"] = info.get("reward_i", 0.0) + to_log["extrinsic_reward"] = info.get("reward_e", 0.0) + + if config.get("TRAIN_ICM"): + to_log["icm_inverse_loss"] = info.get("icm_inverse_loss", 0.0) + to_log["icm_forward_loss"] = info.get("icm_forward_loss", 0.0) + elif config.get("USE_RND"): + to_log["rnd_loss"] = info.get("rnd_loss", 0.0) + + return to_log + + +def batch_log(update_step, log, config): + update_step = int(update_step) + if update_step not in batch_logs: + batch_logs[update_step] = [] + + batch_logs[update_step].append(log) + + if len(batch_logs[update_step]) == config.get("NUM_REPEATS", 1): + agg_logs = {} + for key in batch_logs[update_step][0]: + agg = [] + if key in ["goal_heatmap"]: + agg = [batch_logs[update_step][0][key]] + else: + for i in range(config.get("NUM_REPEATS", 1)): + # Use .get() to prevent KeyErrors if repeats are out of sync + val = batch_logs[update_step][i].get(key, float("nan")) + if not jnp.isnan(val): + agg.append(val) + + if len(agg) > 0: + if key in [ + "episode_length", + "episode_return", + "exploration_bonus", + "e_mean", + "e_std", + "rnd_loss", + "diffusion/loss", + "diffusion/unweighted_loss", + "diffusion/accuracy", + "diffusion/acc_t_low", + "diffusion/acc_t_mid", + "diffusion/acc_t_high", + "diffusion/action_entropy", + "diffusion/grad_norm" + ] or key.startswith("val/") or "achievement" in key.lower(): + agg_logs[key] = np.mean(agg) + else: + agg_logs[key] = np.array(agg) + + log_times.append(time.time()) + + if config.get("DEBUG"): + if len(log_times) == 1: + print("Started logging") + elif len(log_times) > 1: + dt = log_times[-1] - log_times[-2] + steps_between_updates = ( + config["NUM_STEPS"] * config["NUM_ENVS"] * config.get("NUM_REPEATS", 1) + ) + sps = steps_between_updates / dt + agg_logs["sps"] = sps + + wandb.log(agg_logs) + + # Clear buffer to prevent memory leaks + del batch_logs[update_step] \ No newline at end of file diff --git a/Craftax_Baselines/models/__init__.py b/Craftax_Baselines/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/Craftax_Baselines/models/actor_critic.py b/Craftax_Baselines/models/actor_critic.py new file mode 100644 index 0000000000000000000000000000000000000000..0bd3b7a31e81caf0cf6d17f225c8828f9cbe8a10 --- /dev/null +++ b/Craftax_Baselines/models/actor_critic.py @@ -0,0 +1,256 @@ +import jax.numpy as jnp +import flax.linen as nn +import numpy as np +from flax.linen.initializers import constant, orthogonal +from typing import Sequence + +import distrax + + +class ActorCriticConvSymbolicCraftax(nn.Module): + action_dim: int + map_obs_shape: Sequence[int] + layer_width: int + + @nn.compact + def __call__(self, obs): + # Split into map and flat obs + flat_map_obs_shape = ( + self.map_obs_shape[0] * self.map_obs_shape[1] * self.map_obs_shape[2] + ) + image_obs = obs[:, :flat_map_obs_shape] + image_dim = self.map_obs_shape + image_obs = image_obs.reshape((image_obs.shape[0], *image_dim)) + + flat_obs = obs[:, flat_map_obs_shape:] + + # Convolutions on map + image_embedding = nn.Conv(features=32, kernel_size=(2, 2))(image_obs) + image_embedding = nn.relu(image_embedding) + image_embedding = nn.max_pool( + image_embedding, window_shape=(2, 2), strides=(1, 1) + ) + image_embedding = nn.Conv(features=32, kernel_size=(2, 2))(image_embedding) + image_embedding = nn.relu(image_embedding) + image_embedding = nn.max_pool( + image_embedding, window_shape=(2, 2), strides=(1, 1) + ) + image_embedding = image_embedding.reshape(image_embedding.shape[0], -1) + # image_embedding = jnp.concatenate([image_embedding, obs[:, : CraftaxEnv.get_flat_map_obs_shape()]], axis=-1) + + # Combine embeddings + embedding = jnp.concatenate([image_embedding, flat_obs], axis=-1) + embedding = nn.Dense( + self.layer_width, kernel_init=orthogonal(2), bias_init=constant(0.0) + )(embedding) + embedding = nn.relu(embedding) + + actor_mean = nn.Dense( + self.layer_width, kernel_init=orthogonal(2), bias_init=constant(0.0) + )(embedding) + actor_mean = nn.relu(actor_mean) + + actor_mean = nn.Dense( + self.action_dim, kernel_init=orthogonal(0.01), bias_init=constant(0.0) + )(actor_mean) + actor_mean = nn.relu(actor_mean) + + actor_mean = nn.Dense( + self.action_dim, kernel_init=orthogonal(0.01), bias_init=constant(0.0) + )(actor_mean) + + pi = distrax.Categorical(logits=actor_mean) + + critic = nn.Dense( + self.layer_width, kernel_init=orthogonal(2), bias_init=constant(0.0) + )(embedding) + critic = nn.relu(critic) + critic = nn.Dense( + self.layer_width, kernel_init=orthogonal(2), bias_init=constant(0.0) + )(critic) + critic = nn.relu(critic) + critic = nn.Dense( + self.layer_width, kernel_init=orthogonal(2), bias_init=constant(0.0) + )(critic) + critic = nn.relu(critic) + critic = nn.Dense(1, kernel_init=orthogonal(1.0), bias_init=constant(0.0))( + critic + ) + + return pi, jnp.squeeze(critic, axis=-1) + + +class ActorCriticConv(nn.Module): + action_dim: int + layer_width: int + activation: str = "tanh" + + @nn.compact + def __call__(self, obs): + x = nn.Conv(features=32, kernel_size=(5, 5))(obs) + x = nn.relu(x) + x = nn.max_pool(x, window_shape=(3, 3), strides=(3, 3)) + x = nn.Conv(features=32, kernel_size=(5, 5))(x) + x = nn.relu(x) + x = nn.max_pool(x, window_shape=(3, 3), strides=(3, 3)) + x = nn.Conv(features=32, kernel_size=(5, 5))(x) + x = nn.relu(x) + x = nn.max_pool(x, window_shape=(3, 3), strides=(3, 3)) + + embedding = x.reshape(x.shape[0], -1) + + actor_mean = nn.Dense( + self.layer_width, kernel_init=orthogonal(2), bias_init=constant(0.0) + )(embedding) + actor_mean = nn.relu(actor_mean) + + actor_mean = nn.Dense( + self.action_dim, kernel_init=orthogonal(0.01), bias_init=constant(0.0) + )(actor_mean) + actor_mean = nn.relu(actor_mean) + + actor_mean = nn.Dense( + self.action_dim, kernel_init=orthogonal(0.01), bias_init=constant(0.0) + )(actor_mean) + + pi = distrax.Categorical(logits=actor_mean) + + critic = nn.Dense( + self.layer_width, kernel_init=orthogonal(2), bias_init=constant(0.0) + )(embedding) + critic = nn.relu(critic) + critic = nn.Dense(1, kernel_init=orthogonal(1.0), bias_init=constant(0.0))( + critic + ) + + return pi, jnp.squeeze(critic, axis=-1) + + +class ActorCritic(nn.Module): + action_dim: int + layer_width: int + activation: str = "tanh" + + @nn.compact + def __call__(self, x): + if self.activation == "relu": + activation = nn.relu + else: + activation = nn.tanh + + actor_mean = nn.Dense( + self.layer_width, + kernel_init=orthogonal(np.sqrt(2)), + bias_init=constant(0.0), + )(x) + actor_mean = activation(actor_mean) + + actor_mean = nn.Dense( + self.layer_width, + kernel_init=orthogonal(np.sqrt(2)), + bias_init=constant(0.0), + )(actor_mean) + actor_mean = activation(actor_mean) + + actor_mean = nn.Dense( + self.layer_width, + kernel_init=orthogonal(np.sqrt(2)), + bias_init=constant(0.0), + )(actor_mean) + actor_mean = activation(actor_mean) + + actor_mean = nn.Dense( + self.action_dim, kernel_init=orthogonal(0.01), bias_init=constant(0.0) + )(actor_mean) + pi = distrax.Categorical(logits=actor_mean) + + critic = nn.Dense( + self.layer_width, + kernel_init=orthogonal(np.sqrt(2)), + bias_init=constant(0.0), + )(x) + critic = activation(critic) + + critic = nn.Dense( + self.layer_width, + kernel_init=orthogonal(np.sqrt(2)), + bias_init=constant(0.0), + )(critic) + critic = activation(critic) + + critic = nn.Dense( + self.layer_width, + kernel_init=orthogonal(np.sqrt(2)), + bias_init=constant(0.0), + )(critic) + critic = activation(critic) + + critic = nn.Dense(1, kernel_init=orthogonal(1.0), bias_init=constant(0.0))( + critic + ) + + return pi, jnp.squeeze(critic, axis=-1) + + +class ActorCriticWithEmbedding(nn.Module): + action_dim: int + layer_width: int + activation: str = "tanh" + + @nn.compact + def __call__(self, x): + if self.activation == "relu": + activation = nn.relu + else: + activation = nn.tanh + + actor_emb = nn.Dense( + self.layer_width, + kernel_init=orthogonal(np.sqrt(2)), + bias_init=constant(0.0), + )(x) + actor_emb = activation(actor_emb) + + actor_emb = nn.Dense( + self.layer_width, + kernel_init=orthogonal(np.sqrt(2)), + bias_init=constant(0.0), + )(actor_emb) + actor_emb = activation(actor_emb) + + actor_emb = nn.Dense( + 128, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0) + )(actor_emb) + actor_emb = activation(actor_emb) + + actor_mean = nn.Dense( + self.action_dim, kernel_init=orthogonal(0.01), bias_init=constant(0.0) + )(actor_emb) + pi = distrax.Categorical(logits=actor_mean) + + critic = nn.Dense( + self.layer_width, + kernel_init=orthogonal(np.sqrt(2)), + bias_init=constant(0.0), + )(x) + critic = activation(critic) + + critic = nn.Dense( + self.layer_width, + kernel_init=orthogonal(np.sqrt(2)), + bias_init=constant(0.0), + )(critic) + critic = activation(critic) + + critic = nn.Dense( + self.layer_width, + kernel_init=orthogonal(np.sqrt(2)), + bias_init=constant(0.0), + )(critic) + critic = activation(critic) + + critic = nn.Dense(1, kernel_init=orthogonal(1.0), bias_init=constant(0.0))( + critic + ) + + return pi, jnp.squeeze(critic, axis=-1), actor_emb diff --git a/Craftax_Baselines/models/icm.py b/Craftax_Baselines/models/icm.py new file mode 100644 index 0000000000000000000000000000000000000000..2d9154f417276c61c4f94122a94001aa7eecaec3 --- /dev/null +++ b/Craftax_Baselines/models/icm.py @@ -0,0 +1,72 @@ +import jax +import jax.numpy as jnp +import flax.linen as nn + + +class ICMEncoder(nn.Module): + layer_size: int + output_dim: int + num_layers: int + + @nn.compact + def __call__(self, obs): + activation = nn.relu + + # TODO Look at weight inits + + emb = obs + for _ in range(self.num_layers): + emb = nn.Dense( + self.layer_size, + )(emb) + emb = activation(emb) + + emb = nn.Dense(self.output_dim)(emb) + + return emb + + +class ICMForward(nn.Module): + layer_size: int + output_dim: int + num_layers: int + num_actions: int + + @nn.compact + def __call__(self, latent, action): + activation = nn.relu + + action1h = jax.nn.one_hot(action, num_classes=self.num_actions) + emb = jnp.concatenate((latent, action1h), axis=-1) + for _ in range(self.num_layers): + emb = nn.Dense( + self.layer_size, + )(emb) + emb = activation(emb) + + emb = nn.Dense(self.output_dim)(emb) + + return emb + + +class ICMInverse(nn.Module): + layer_size: int + output_dim: int + num_layers: int + + @nn.compact + def __call__(self, latent, next_latent): + activation = nn.relu + + emb = jnp.concatenate((latent, next_latent), axis=-1) + for _ in range(self.num_layers): + emb = nn.Dense( + self.layer_size, + )(emb) + emb = activation(emb) + + action_raw = nn.Dense(self.output_dim)(emb) + + action_logits = jax.nn.log_softmax(action_raw) + + return action_logits diff --git a/Craftax_Baselines/models/rnd.py b/Craftax_Baselines/models/rnd.py new file mode 100644 index 0000000000000000000000000000000000000000..f4e5f11e8c9f565270ba16595058ec719a3016a1 --- /dev/null +++ b/Craftax_Baselines/models/rnd.py @@ -0,0 +1,120 @@ +import jax.numpy as jnp +import flax.linen as nn +import numpy as np +from flax.linen.initializers import constant, orthogonal + +import distrax + + +class RNDNetwork(nn.Module): + layer_size: int + output_dim: int + num_layers: int + + @nn.compact + def __call__(self, x): + activation = nn.relu + + emb = x + for _ in range(self.num_layers): + emb = nn.Dense( + self.layer_size, + )(emb) + emb = activation(emb) + + emb = nn.Dense(self.output_dim)(emb) + + return emb + + +class ActorCriticRND(nn.Module): + action_dim: int + layer_width: int + activation: str = "tanh" + + @nn.compact + def __call__(self, x): + if self.activation == "relu": + activation = nn.relu + else: + activation = nn.tanh + + actor_mean = nn.Dense( + self.layer_width, + kernel_init=orthogonal(np.sqrt(2)), + bias_init=constant(0.0), + )(x) + actor_mean = activation(actor_mean) + + actor_mean = nn.Dense( + self.layer_width, + kernel_init=orthogonal(np.sqrt(2)), + bias_init=constant(0.0), + )(actor_mean) + actor_mean = activation(actor_mean) + + actor_mean = nn.Dense( + self.layer_width, + kernel_init=orthogonal(np.sqrt(2)), + bias_init=constant(0.0), + )(actor_mean) + actor_mean = activation(actor_mean) + + actor_mean = nn.Dense( + self.action_dim, kernel_init=orthogonal(0.01), bias_init=constant(0.0) + )(actor_mean) + pi = distrax.Categorical(logits=actor_mean) + + # Extrinsic reward + critic_e = nn.Dense( + self.layer_width, + kernel_init=orthogonal(np.sqrt(2)), + bias_init=constant(0.0), + )(x) + critic_e = activation(critic_e) + + critic_e = nn.Dense( + self.layer_width, + kernel_init=orthogonal(np.sqrt(2)), + bias_init=constant(0.0), + )(critic_e) + critic_e = activation(critic_e) + + critic_e = nn.Dense( + self.layer_width, + kernel_init=orthogonal(np.sqrt(2)), + bias_init=constant(0.0), + )(critic_e) + critic_e = activation(critic_e) + + critic_e = nn.Dense(1, kernel_init=orthogonal(1.0), bias_init=constant(0.0))( + critic_e + ) + + # Intrinsic reward + critic_i = nn.Dense( + self.layer_width, + kernel_init=orthogonal(np.sqrt(2)), + bias_init=constant(0.0), + )(x) + critic_i = activation(critic_i) + + critic_i = nn.Dense( + self.layer_width, + kernel_init=orthogonal(np.sqrt(2)), + bias_init=constant(0.0), + )(critic_i) + critic_i = activation(critic_i) + + critic_i = nn.Dense( + self.layer_width, + kernel_init=orthogonal(np.sqrt(2)), + bias_init=constant(0.0), + )(critic_i) + critic_i = activation(critic_i) + + critic_i = nn.Dense(1, kernel_init=orthogonal(1.0), bias_init=constant(0.0))( + critic_i + ) + + return pi, jnp.squeeze(critic_e, axis=-1), jnp.squeeze(critic_i, axis=-1) diff --git a/Craftax_Baselines/ppo.py b/Craftax_Baselines/ppo.py new file mode 100644 index 0000000000000000000000000000000000000000..403bcf8d9d989f050584aa996f600213853dad7e --- /dev/null +++ b/Craftax_Baselines/ppo.py @@ -0,0 +1,733 @@ +import argparse +import os +import sys +import time + +import jax +import jax.numpy as jnp +import numpy as np +import optax +from craftax.craftax_env import make_craftax_env_from_name + +import wandb +from typing import NamedTuple + +from flax.training.train_state import TrainState +import orbax.checkpoint as ocp + +from logz.batch_logging import batch_log, create_log_dict +from models.actor_critic import ( + ActorCritic, + ActorCriticConv, +) +from models.icm import ICMEncoder, ICMForward, ICMInverse +from wrappers import ( + LogWrapper, + OptimisticResetVecEnvWrapper, + BatchEnvWrapper, + AutoResetEnvWrapper, +) + +# Code adapted from the original implementation made by Chris Lu +# Original code located at https://github.com/luchris429/purejaxrl + + +class Transition(NamedTuple): + done: jnp.ndarray + action: jnp.ndarray + value: jnp.ndarray + reward_e: jnp.ndarray + reward_i: jnp.ndarray + reward: jnp.ndarray + log_prob: jnp.ndarray + obs: jnp.ndarray + next_obs: jnp.ndarray + info: jnp.ndarray + + +def make_train(config): + config["NUM_UPDATES"] = ( + config["TOTAL_TIMESTEPS"] // config["NUM_STEPS"] // config["NUM_ENVS"] + ) + config["MINIBATCH_SIZE"] = ( + config["NUM_ENVS"] * config["NUM_STEPS"] // config["NUM_MINIBATCHES"] + ) + + env = make_craftax_env_from_name( + config["ENV_NAME"], not config["USE_OPTIMISTIC_RESETS"] + ) + env_params = env.default_params + + env = LogWrapper(env) + if config["USE_OPTIMISTIC_RESETS"]: + env = OptimisticResetVecEnvWrapper( + env, + num_envs=config["NUM_ENVS"], + reset_ratio=min(config["OPTIMISTIC_RESET_RATIO"], config["NUM_ENVS"]), + ) + else: + env = AutoResetEnvWrapper(env) + env = BatchEnvWrapper(env, num_envs=config["NUM_ENVS"]) + + def linear_schedule(count): + frac = ( + 1.0 + - (count // (config["NUM_MINIBATCHES"] * config["UPDATE_EPOCHS"])) + / config["NUM_UPDATES"] + ) + return config["LR"] * frac + + def train(rng): + # INIT NETWORK + if "Symbolic" in config["ENV_NAME"]: + network = ActorCritic(env.action_space(env_params).n, config["LAYER_SIZE"]) + else: + network = ActorCriticConv( + env.action_space(env_params).n, config["LAYER_SIZE"] + ) + + rng, _rng = jax.random.split(rng) + init_x = jnp.zeros((1, *env.observation_space(env_params).shape)) + network_params = network.init(_rng, init_x) + if config["ANNEAL_LR"]: + tx = optax.chain( + optax.clip_by_global_norm(config["MAX_GRAD_NORM"]), + optax.adam(learning_rate=linear_schedule, eps=1e-5), + ) + else: + tx = optax.chain( + optax.clip_by_global_norm(config["MAX_GRAD_NORM"]), + optax.adam(config["LR"], eps=1e-5), + ) + train_state = TrainState.create( + apply_fn=network.apply, + params=network_params, + tx=tx, + ) + + # Exploration state + ex_state = { + "icm_encoder": None, + "icm_forward": None, + "icm_inverse": None, + "e3b_matrix": None, + } + + if config["TRAIN_ICM"]: + obs_shape = env.observation_space(env_params).shape + assert len(obs_shape) == 1, "Only configured for 1D observations" + obs_shape = obs_shape[0] + + # Encoder + icm_encoder_network = ICMEncoder( + num_layers=3, + output_dim=config["ICM_LATENT_SIZE"], + layer_size=config["ICM_LAYER_SIZE"], + ) + rng, _rng = jax.random.split(rng) + icm_encoder_network_params = icm_encoder_network.init( + _rng, jnp.zeros((1, obs_shape)) + ) + tx = optax.chain( + optax.clip_by_global_norm(config["MAX_GRAD_NORM"]), + optax.adam(config["ICM_LR"], eps=1e-5), + ) + ex_state["icm_encoder"] = TrainState.create( + apply_fn=icm_encoder_network.apply, + params=icm_encoder_network_params, + tx=tx, + ) + + # Forward + icm_forward_network = ICMForward( + num_layers=3, + output_dim=config["ICM_LATENT_SIZE"], + layer_size=config["ICM_LAYER_SIZE"], + num_actions=env.num_actions, + ) + rng, _rng = jax.random.split(rng) + icm_forward_network_params = icm_forward_network.init( + _rng, jnp.zeros((1, config["ICM_LATENT_SIZE"])), jnp.zeros((1,)) + ) + tx = optax.chain( + optax.clip_by_global_norm(config["MAX_GRAD_NORM"]), + optax.adam(config["ICM_LR"], eps=1e-5), + ) + ex_state["icm_forward"] = TrainState.create( + apply_fn=icm_forward_network.apply, + params=icm_forward_network_params, + tx=tx, + ) + + # Inverse + icm_inverse_network = ICMInverse( + num_layers=3, + output_dim=env.num_actions, + layer_size=config["ICM_LAYER_SIZE"], + ) + rng, _rng = jax.random.split(rng) + icm_inverse_network_params = icm_inverse_network.init( + _rng, + jnp.zeros((1, config["ICM_LATENT_SIZE"])), + jnp.zeros((1, config["ICM_LATENT_SIZE"])), + ) + tx = optax.chain( + optax.clip_by_global_norm(config["MAX_GRAD_NORM"]), + optax.adam(config["ICM_LR"], eps=1e-5), + ) + ex_state["icm_inverse"] = TrainState.create( + apply_fn=icm_inverse_network.apply, + params=icm_inverse_network_params, + tx=tx, + ) + + if config["USE_E3B"]: + ex_state["e3b_matrix"] = ( + jnp.repeat( + jnp.expand_dims( + jnp.identity(config["ICM_LATENT_SIZE"]), axis=0 + ), + config["NUM_ENVS"], + axis=0, + ) + / config["E3B_LAMBDA"] + ) + + # INIT ENV + rng, _rng = jax.random.split(rng) + obsv, env_state = env.reset(_rng, env_params) + + # TRAIN LOOP + def _update_step(runner_state, unused): + # COLLECT TRAJECTORIES + def _env_step(runner_state, unused): + ( + train_state, + env_state, + last_obs, + ex_state, + rng, + update_step, + ) = runner_state + + # SELECT ACTION + rng, _rng = jax.random.split(rng) + pi, value = network.apply(train_state.params, last_obs) + action = pi.sample(seed=_rng) + log_prob = pi.log_prob(action) + + # STEP ENV + rng, _rng = jax.random.split(rng) + obsv, env_state, reward_e, done, info = env.step( + _rng, env_state, action, env_params + ) + + reward_i = jnp.zeros(config["NUM_ENVS"]) + + if config["TRAIN_ICM"]: + latent_obs = ex_state["icm_encoder"].apply_fn( + ex_state["icm_encoder"].params, last_obs + ) + latent_next_obs = ex_state["icm_encoder"].apply_fn( + ex_state["icm_encoder"].params, obsv + ) + + latent_next_obs_pred = ex_state["icm_forward"].apply_fn( + ex_state["icm_forward"].params, latent_obs, action + ) + error = (latent_next_obs - latent_next_obs_pred) * ( + 1 - done[:, None] + ) + mse = jnp.square(error).mean(axis=-1) + + reward_i = mse * config["ICM_REWARD_COEFF"] + + if config["USE_E3B"]: + # Embedding is (NUM_ENVS, 128) + # e3b_matrix is (NUM_ENVS, 128, 128) + us = jax.vmap(jnp.matmul)(ex_state["e3b_matrix"], latent_obs) + bs = jax.vmap(jnp.dot)(latent_obs, us) + + def update_c(c, b, u): + return c - (1.0 / (1 + b)) * jnp.outer(u, u) + + updated_cs = jax.vmap(update_c)(ex_state["e3b_matrix"], bs, us) + new_cs = ( + jnp.repeat( + jnp.expand_dims( + jnp.identity(config["ICM_LATENT_SIZE"]), axis=0 + ), + config["NUM_ENVS"], + axis=0, + ) + / config["E3B_LAMBDA"] + ) + ex_state["e3b_matrix"] = jnp.where( + done[:, None, None], new_cs, updated_cs + ) + + e3b_bonus = jnp.where( + done, jnp.zeros((config["NUM_ENVS"],)), bs + ) + + reward_i = e3b_bonus * config["E3B_REWARD_COEFF"] + + reward = reward_e + reward_i + + transition = Transition( + done=done, + action=action, + value=value, + reward=reward, + reward_i=reward_i, + reward_e=reward_e, + log_prob=log_prob, + obs=last_obs, + next_obs=obsv, + info=info, + ) + runner_state = ( + train_state, + env_state, + obsv, + ex_state, + rng, + update_step, + ) + return runner_state, transition + + runner_state, traj_batch = jax.lax.scan( + _env_step, runner_state, None, config["NUM_STEPS"] + ) + + # CALCULATE ADVANTAGE + ( + train_state, + env_state, + last_obs, + ex_state, + rng, + update_step, + ) = runner_state + _, last_val = network.apply(train_state.params, last_obs) + + def _calculate_gae(traj_batch, last_val): + def _get_advantages(gae_and_next_value, transition): + gae, next_value = gae_and_next_value + done, value, reward = ( + transition.done, + transition.value, + transition.reward, + ) + delta = reward + config["GAMMA"] * next_value * (1 - done) - value + gae = ( + delta + + config["GAMMA"] * config["GAE_LAMBDA"] * (1 - done) * gae + ) + return (gae, value), gae + + _, advantages = jax.lax.scan( + _get_advantages, + (jnp.zeros_like(last_val), last_val), + traj_batch, + reverse=True, + unroll=16, + ) + return advantages, advantages + traj_batch.value + + advantages, targets = _calculate_gae(traj_batch, last_val) + + # UPDATE NETWORK + def _update_epoch(update_state, unused): + def _update_minbatch(train_state, batch_info): + traj_batch, advantages, targets = batch_info + + # Policy/value network + def _loss_fn(params, traj_batch, gae, targets): + # RERUN NETWORK + pi, value = network.apply(params, traj_batch.obs) + log_prob = pi.log_prob(traj_batch.action) + + # CALCULATE VALUE LOSS + value_pred_clipped = traj_batch.value + ( + value - traj_batch.value + ).clip(-config["CLIP_EPS"], config["CLIP_EPS"]) + value_losses = jnp.square(value - targets) + value_losses_clipped = jnp.square(value_pred_clipped - targets) + value_loss = ( + 0.5 * jnp.maximum(value_losses, value_losses_clipped).mean() + ) + + # CALCULATE ACTOR LOSS + ratio = jnp.exp(log_prob - traj_batch.log_prob) + gae = (gae - gae.mean()) / (gae.std() + 1e-8) + loss_actor1 = ratio * gae + loss_actor2 = ( + jnp.clip( + ratio, + 1.0 - config["CLIP_EPS"], + 1.0 + config["CLIP_EPS"], + ) + * gae + ) + loss_actor = -jnp.minimum(loss_actor1, loss_actor2) + loss_actor = loss_actor.mean() + entropy = pi.entropy().mean() + + total_loss = ( + loss_actor + + config["VF_COEF"] * value_loss + - config["ENT_COEF"] * entropy + ) + return total_loss, (value_loss, loss_actor, entropy) + + grad_fn = jax.value_and_grad(_loss_fn, has_aux=True) + total_loss, grads = grad_fn( + train_state.params, traj_batch, advantages, targets + ) + train_state = train_state.apply_gradients(grads=grads) + + losses = (total_loss, 0) + return train_state, losses + + ( + train_state, + traj_batch, + advantages, + targets, + rng, + ) = update_state + rng, _rng = jax.random.split(rng) + batch_size = config["MINIBATCH_SIZE"] * config["NUM_MINIBATCHES"] + assert ( + batch_size == config["NUM_STEPS"] * config["NUM_ENVS"] + ), "batch size must be equal to number of steps * number of envs" + permutation = jax.random.permutation(_rng, batch_size) + batch = (traj_batch, advantages, targets) + batch = jax.tree.map( + lambda x: x.reshape((batch_size,) + x.shape[2:]), batch + ) + shuffled_batch = jax.tree.map( + lambda x: jnp.take(x, permutation, axis=0), batch + ) + minibatches = jax.tree.map( + lambda x: jnp.reshape( + x, [config["NUM_MINIBATCHES"], -1] + list(x.shape[1:]) + ), + shuffled_batch, + ) + train_state, losses = jax.lax.scan( + _update_minbatch, train_state, minibatches + ) + update_state = ( + train_state, + traj_batch, + advantages, + targets, + rng, + ) + return update_state, losses + + update_state = ( + train_state, + traj_batch, + advantages, + targets, + rng, + ) + update_state, loss_info = jax.lax.scan( + _update_epoch, update_state, None, config["UPDATE_EPOCHS"] + ) + + train_state = update_state[0] + metric = jax.tree.map( + lambda x: (x * traj_batch.info["returned_episode"]).sum() + / traj_batch.info["returned_episode"].sum(), + traj_batch.info, + ) + + rng = update_state[-1] + + # UPDATE EXPLORATION STATE + def _update_ex_epoch(update_state, unused): + def _update_ex_minbatch(ex_state, traj_batch): + def _inverse_loss_fn( + icm_encoder_params, icm_inverse_params, traj_batch + ): + latent_obs = ex_state["icm_encoder"].apply_fn( + icm_encoder_params, traj_batch.obs + ) + latent_next_obs = ex_state["icm_encoder"].apply_fn( + icm_encoder_params, traj_batch.next_obs + ) + + action_pred_logits = ex_state["icm_inverse"].apply_fn( + icm_inverse_params, latent_obs, latent_next_obs + ) + true_action = jax.nn.one_hot( + traj_batch.action, num_classes=action_pred_logits.shape[-1] + ) + + bce = -jnp.mean( + jnp.sum( + action_pred_logits + * true_action + * (1 - traj_batch.done[:, None]), + axis=1, + ) + ) + + return bce * config["ICM_INVERSE_LOSS_COEF"] + + inverse_grad_fn = jax.value_and_grad( + _inverse_loss_fn, + has_aux=False, + argnums=( + 0, + 1, + ), + ) + inverse_loss, grads = inverse_grad_fn( + ex_state["icm_encoder"].params, + ex_state["icm_inverse"].params, + traj_batch, + ) + icm_encoder_grad, icm_inverse_grad = grads + ex_state["icm_encoder"] = ex_state["icm_encoder"].apply_gradients( + grads=icm_encoder_grad + ) + ex_state["icm_inverse"] = ex_state["icm_inverse"].apply_gradients( + grads=icm_inverse_grad + ) + + def _forward_loss_fn(icm_forward_params, traj_batch): + latent_obs = ex_state["icm_encoder"].apply_fn( + ex_state["icm_encoder"].params, traj_batch.obs + ) + latent_next_obs = ex_state["icm_encoder"].apply_fn( + ex_state["icm_encoder"].params, traj_batch.next_obs + ) + + latent_next_obs_pred = ex_state["icm_forward"].apply_fn( + icm_forward_params, latent_obs, traj_batch.action + ) + + error = (latent_next_obs - latent_next_obs_pred) * ( + 1 - traj_batch.done[:, None] + ) + return ( + jnp.square(error).mean() * config["ICM_FORWARD_LOSS_COEF"] + ) + + forward_grad_fn = jax.value_and_grad( + _forward_loss_fn, has_aux=False + ) + forward_loss, icm_forward_grad = forward_grad_fn( + ex_state["icm_forward"].params, traj_batch + ) + ex_state["icm_forward"] = ex_state["icm_forward"].apply_gradients( + grads=icm_forward_grad + ) + + losses = (inverse_loss, forward_loss) + return ex_state, losses + + (ex_state, traj_batch, rng) = update_state + rng, _rng = jax.random.split(rng) + batch_size = config["MINIBATCH_SIZE"] * config["NUM_MINIBATCHES"] + assert ( + batch_size == config["NUM_STEPS"] * config["NUM_ENVS"] + ), "batch size must be equal to number of steps * number of envs" + permutation = jax.random.permutation(_rng, batch_size) + batch = jax.tree.map( + lambda x: x.reshape((batch_size,) + x.shape[2:]), traj_batch + ) + shuffled_batch = jax.tree.map( + lambda x: jnp.take(x, permutation, axis=0), batch + ) + minibatches = jax.tree.map( + lambda x: jnp.reshape( + x, [config["NUM_MINIBATCHES"], -1] + list(x.shape[1:]) + ), + shuffled_batch, + ) + ex_state, losses = jax.lax.scan( + _update_ex_minbatch, ex_state, minibatches + ) + update_state = (ex_state, traj_batch, rng) + return update_state, losses + + if config["TRAIN_ICM"]: + ex_update_state = (ex_state, traj_batch, rng) + ex_update_state, ex_loss = jax.lax.scan( + _update_ex_epoch, + ex_update_state, + None, + config["EXPLORATION_UPDATE_EPOCHS"], + ) + metric["icm_inverse_loss"] = ex_loss[0].mean() + metric["icm_forward_loss"] = ex_loss[1].mean() + metric["reward_i"] = traj_batch.reward_i.mean() + metric["reward_e"] = traj_batch.reward_e.mean() + + ex_state = ex_update_state[0] + rng = ex_update_state[-1] + + # wandb logging + if config["DEBUG"] and config["USE_WANDB"]: + + def callback(metric, update_step): + to_log = create_log_dict(metric, config) + batch_log(update_step, to_log, config) + + jax.debug.callback( + callback, + metric, + update_step, + ) + + runner_state = ( + train_state, + env_state, + last_obs, + ex_state, + rng, + update_step + 1, + ) + return runner_state, metric + + rng, _rng = jax.random.split(rng) + runner_state = ( + train_state, + env_state, + obsv, + ex_state, + _rng, + 0, + ) + runner_state, metric = jax.lax.scan( + _update_step, runner_state, None, config["NUM_UPDATES"] + ) + return {"runner_state": runner_state} # , "info": metric} + + return train + + +def run_ppo(config): + config = {k.upper(): v for k, v in config.__dict__.items()} + + if config["USE_WANDB"]: + wandb.init( + project=config["WANDB_PROJECT"], + entity=config["WANDB_ENTITY"], + config=config, + name=config["ENV_NAME"] + + "-" + + str(int(config["TOTAL_TIMESTEPS"] // 1e6)) + + "M", + ) + + rng = jax.random.PRNGKey(config["SEED"]) + rngs = jax.random.split(rng, config["NUM_REPEATS"]) + + train_jit = jax.jit(make_train(config)) + train_vmap = jax.vmap(train_jit) + + t0 = time.time() + out = train_vmap(rngs) + t1 = time.time() + print("Time to run experiment", t1 - t0) + print("SPS: ", config["TOTAL_TIMESTEPS"] / (t1 - t0)) + + if config["USE_WANDB"]: + + def _save_network(rs_index, dir_name): + train_states = out["runner_state"][rs_index] + train_state = jax.tree.map(lambda x: x[0], train_states) + + path = os.path.join(wandb.run.dir, dir_name) + options = ocp.CheckpointManagerOptions(max_to_keep=1) + + with ocp.CheckpointManager(path, options=options) as checkpoint_manager: + checkpoint_manager.save( + int(config["TOTAL_TIMESTEPS"]), + args=ocp.args.StandardSave(train_state) + ) + + print(f"saved runner state to {path}") + + if config["SAVE_POLICY"]: + _save_network(0, "policies") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--env_name", type=str, default="Craftax-Symbolic-v1") + parser.add_argument( + "--num_envs", + type=int, + default=1024, + ) + parser.add_argument( + "--total_timesteps", type=lambda x: int(float(x)), default=1e9 + ) # Allow scientific notation + parser.add_argument("--lr", type=float, default=2e-4) + parser.add_argument("--num_steps", type=int, default=64) + parser.add_argument("--update_epochs", type=int, default=4) + parser.add_argument("--num_minibatches", type=int, default=8) + parser.add_argument("--gamma", type=float, default=0.99) + parser.add_argument("--gae_lambda", type=float, default=0.8) + parser.add_argument("--clip_eps", type=float, default=0.2) + parser.add_argument("--ent_coef", type=float, default=0.01) + parser.add_argument("--vf_coef", type=float, default=0.5) + parser.add_argument("--max_grad_norm", type=float, default=1.0) + parser.add_argument("--activation", type=str, default="tanh") + parser.add_argument( + "--anneal_lr", action=argparse.BooleanOptionalAction, default=True + ) + parser.add_argument("--debug", action=argparse.BooleanOptionalAction, default=True) + parser.add_argument("--jit", action=argparse.BooleanOptionalAction, default=True) + parser.add_argument("--seed", type=int) + parser.add_argument( + "--use_wandb", action=argparse.BooleanOptionalAction, default=True + ) + parser.add_argument("--save_policy", action="store_true") + parser.add_argument("--num_repeats", type=int, default=1) + parser.add_argument("--layer_size", type=int, default=512) + parser.add_argument("--wandb_project", type=str) + parser.add_argument("--wandb_entity", type=str) + parser.add_argument( + "--use_optimistic_resets", action=argparse.BooleanOptionalAction, default=True + ) + parser.add_argument("--optimistic_reset_ratio", type=int, default=16) + + # EXPLORATION + parser.add_argument("--exploration_update_epochs", type=int, default=4) + # ICM + parser.add_argument("--icm_reward_coeff", type=float, default=1.0) + parser.add_argument("--train_icm", action="store_true") + parser.add_argument("--icm_lr", type=float, default=3e-4) + parser.add_argument("--icm_forward_loss_coef", type=float, default=1.0) + parser.add_argument("--icm_inverse_loss_coef", type=float, default=1.0) + parser.add_argument("--icm_layer_size", type=int, default=256) + parser.add_argument("--icm_latent_size", type=int, default=32) + # E3B + parser.add_argument("--e3b_reward_coeff", type=float, default=1.0) + parser.add_argument("--use_e3b", action="store_true") + parser.add_argument("--e3b_lambda", type=float, default=0.1) + + args, rest_args = parser.parse_known_args(sys.argv[1:]) + if rest_args: + raise ValueError(f"Unknown args {rest_args}") + + if args.use_e3b: + assert args.train_icm + assert args.icm_reward_coeff == 0 + if args.seed is None: + args.seed = np.random.randint(2**31) + + if args.jit: + run_ppo(args) + else: + with jax.disable_jit(): + run_ppo(args) diff --git a/Craftax_Baselines/ppo_rnd.py b/Craftax_Baselines/ppo_rnd.py new file mode 100644 index 0000000000000000000000000000000000000000..7256b5ec716def468d05599a4e7ae68755765944 --- /dev/null +++ b/Craftax_Baselines/ppo_rnd.py @@ -0,0 +1,680 @@ +import argparse +import os +import sys +import time + +import jax +import jax.numpy as jnp +import numpy as np +import optax +from craftax.craftax_env import make_craftax_env_from_name + +import wandb +from typing import NamedTuple + +from flax.training.train_state import TrainState +import orbax.checkpoint as ocp + +from logz.batch_logging import batch_log, create_log_dict +from wrappers import ( + LogWrapper, + OptimisticResetVecEnvWrapper, + AutoResetEnvWrapper, + BatchEnvWrapper, +) +from models.rnd import RNDNetwork, ActorCriticRND + +# Code adapted from the original implementation made by Chris Lu +# Original code located at https://github.com/luchris429/purejaxrl + + +class Transition(NamedTuple): + done: jnp.ndarray + action: jnp.ndarray + value_e: jnp.ndarray + value_i: jnp.ndarray + reward_e: jnp.ndarray + reward_i: jnp.ndarray + reward: jnp.ndarray + log_prob: jnp.ndarray + obs: jnp.ndarray + next_obs: jnp.ndarray + info: jnp.ndarray + + +def make_train(config): + config["NUM_UPDATES"] = ( + config["TOTAL_TIMESTEPS"] // config["NUM_STEPS"] // config["NUM_ENVS"] + ) + config["MINIBATCH_SIZE"] = ( + config["NUM_ENVS"] * config["NUM_STEPS"] // config["NUM_MINIBATCHES"] + ) + + env = make_craftax_env_from_name( + config["ENV_NAME"], not config["USE_OPTIMISTIC_RESETS"] + ) + env_params = env.default_params + + env = LogWrapper(env) + if config["USE_OPTIMISTIC_RESETS"]: + env = OptimisticResetVecEnvWrapper( + env, + num_envs=config["NUM_ENVS"], + reset_ratio=min(config["OPTIMISTIC_RESET_RATIO"], config["NUM_ENVS"]), + ) + else: + env = AutoResetEnvWrapper(env) + env = BatchEnvWrapper(env, num_envs=config["NUM_ENVS"]) + + def linear_schedule(count): + frac = ( + 1.0 + - (count // (config["NUM_MINIBATCHES"] * config["UPDATE_EPOCHS"])) + / config["NUM_UPDATES"] + ) + return config["LR"] * frac + + def train(rng): + # INIT NETWORK + if "Symbolic" in config["ENV_NAME"]: + network = ActorCriticRND( + env.action_space(env_params).n, config["LAYER_SIZE"] + ) + else: + raise ValueError + # network = ActorCriticConv( + # env.action_space(env_params).n, config["LAYER_SIZE"] + # ) + + rng, _rng = jax.random.split(rng) + init_x = jnp.zeros((1, *env.observation_space(env_params).shape)) + network_params = network.init(_rng, init_x) + if config["ANNEAL_LR"]: + tx = optax.chain( + optax.clip_by_global_norm(config["MAX_GRAD_NORM"]), + optax.adam(learning_rate=linear_schedule, eps=1e-5), + ) + else: + tx = optax.chain( + optax.clip_by_global_norm(config["MAX_GRAD_NORM"]), + optax.adam(config["LR"], eps=1e-5), + ) + train_state = TrainState.create( + apply_fn=network.apply, + params=network_params, + tx=tx, + ) + + # Exploration state + ex_state = { + "rnd_model": None, + } + + if config["USE_RND"]: + obs_shape = env.observation_space(env_params).shape + assert len(obs_shape) == 1, "Only configured for 1D observations" + obs_shape = obs_shape[0] + + # Random network + rnd_random_network = RNDNetwork( + num_layers=3, + output_dim=config["RND_OUTPUT_SIZE"], + layer_size=config["RND_LAYER_SIZE"], + ) + rng, _rng = jax.random.split(rng) + rnd_random_network_params = rnd_random_network.init( + _rng, jnp.zeros((1, obs_shape)) + ) + + # Distillation Network + rnd_distillation_network = RNDNetwork( + num_layers=3, + output_dim=config["RND_OUTPUT_SIZE"], + layer_size=config["RND_LAYER_SIZE"], + ) + rng, _rng = jax.random.split(rng) + rnd_distillation_network_params = rnd_distillation_network.init( + _rng, jnp.zeros((1, obs_shape)) + ) + tx = optax.chain( + optax.clip_by_global_norm(config["MAX_GRAD_NORM"]), + optax.adam(config["RND_LR"], eps=1e-5), + ) + ex_state["rnd_distillation_network"] = TrainState.create( + apply_fn=rnd_distillation_network.apply, + params=rnd_distillation_network_params, + tx=tx, + ) + + # INIT ENV + rng, _rng = jax.random.split(rng) + obsv, env_state = env.reset(_rng, env_params) + + # TRAIN LOOP + def _update_step(runner_state, unused): + # COLLECT TRAJECTORIES + def _env_step(runner_state, unused): + ( + train_state, + env_state, + last_obs, + ex_state, + rng, + update_step, + ) = runner_state + + # SELECT ACTION + rng, _rng = jax.random.split(rng) + pi, value_e, value_i = network.apply(train_state.params, last_obs) + action = pi.sample(seed=_rng) + log_prob = pi.log_prob(action) + + # STEP ENV + rng, _rng = jax.random.split(rng) + obsv, env_state, reward_e, done, info = env.step( + _rng, env_state, action, env_params + ) + + reward_i = jnp.zeros(config["NUM_ENVS"]) + + if config["USE_RND"]: + random_pred = rnd_random_network.apply( + rnd_random_network_params, obsv + ) + + distill_pred = ex_state["rnd_distillation_network"].apply_fn( + ex_state["rnd_distillation_network"].params, obsv + ) + error = (random_pred - distill_pred) * (1 - done[:, None]) + mse = jnp.square(error).mean(axis=-1) + + reward_i = mse * config["RND_REWARD_COEFF"] + + reward = reward_e + reward_i + + transition = Transition( + done=done, + action=action, + value_e=value_e, + value_i=value_i, + reward=reward, + reward_i=reward_i, + reward_e=reward_e, + log_prob=log_prob, + obs=last_obs, + next_obs=obsv, + info=info, + ) + runner_state = ( + train_state, + env_state, + obsv, + ex_state, + rng, + update_step, + ) + return runner_state, transition + + runner_state, traj_batch = jax.lax.scan( + _env_step, runner_state, None, config["NUM_STEPS"] + ) + + # CALCULATE ADVANTAGE + ( + train_state, + env_state, + last_obs, + ex_state, + rng, + update_step, + ) = runner_state + _, last_val_e, last_val_i = network.apply(train_state.params, last_obs) + + def _calculate_gae(traj_batch, last_val, is_extrinsic): + def _get_advantages(gae_and_next_value, transition): + gae, next_value, is_extrinsic = gae_and_next_value + done, value, reward = ( + transition.done, + jax.lax.select( + is_extrinsic, transition.value_e, transition.value_i + ), + jax.lax.select( + is_extrinsic, transition.reward_e, transition.reward_i + ), + ) + done = jnp.logical_and( + done, jnp.logical_or(config["RND_IS_EPISODIC"], is_extrinsic) + ) + + delta = reward + config["GAMMA"] * next_value * (1 - done) - value + gae = ( + delta + + config["GAMMA"] * config["GAE_LAMBDA"] * (1 - done) * gae + ) + return (gae, value, is_extrinsic), gae + + _, advantages = jax.lax.scan( + _get_advantages, + (jnp.zeros_like(last_val), last_val, is_extrinsic), + traj_batch, + reverse=True, + unroll=16, + ) + return advantages, advantages + jax.lax.select( + is_extrinsic, traj_batch.value_e, traj_batch.value_i + ) + + advantages_e, targets_e = _calculate_gae(traj_batch, last_val_e, True) + advantages_i, targets_i = _calculate_gae(traj_batch, last_val_i, False) + + # UPDATE NETWORK + def _update_epoch(update_state, unused): + def _update_minbatch(train_state, batch_info): + ( + traj_batch, + advantages_e, + targets_e, + advantages_i, + targets_i, + ) = batch_info + + # Policy/value network + def _loss_fn( + params, traj_batch, gae_e, targets_e, gae_i, targets_i + ): + # RERUN NETWORK + pi, value_e, value_i = network.apply(params, traj_batch.obs) + log_prob = pi.log_prob(traj_batch.action) + + # CALCULATE EXTRINSIC VALUE LOSS + value_pred_clipped_e = traj_batch.value_e + ( + value_e - traj_batch.value_e + ).clip(-config["CLIP_EPS"], config["CLIP_EPS"]) + value_losses_e = jnp.square(value_e - targets_e) + value_losses_clipped_e = jnp.square( + value_pred_clipped_e - targets_e + ) + value_loss_e = ( + 0.5 + * jnp.maximum(value_losses_e, value_losses_clipped_e).mean() + ) + + # CALCULATE INTRINSIC VALUE LOSS + value_pred_clipped_i = traj_batch.value_i + ( + value_i - traj_batch.value_i + ).clip(-config["CLIP_EPS"], config["CLIP_EPS"]) + value_losses_i = jnp.square(value_i - targets_i) + value_losses_clipped_i = jnp.square( + value_pred_clipped_i - targets_i + ) + value_loss_i = ( + 0.5 + * jnp.maximum(value_losses_i, value_losses_clipped_i).mean() + ) + + # CALCULATE ACTOR LOSS + gae = gae_e + if config["USE_RND"]: + gae += gae_i * config["RND_GAE_COEFF"] + ratio = jnp.exp(log_prob - traj_batch.log_prob) + gae = (gae - gae.mean()) / (gae.std() + 1e-8) + loss_actor1 = ratio * gae + loss_actor2 = ( + jnp.clip( + ratio, + 1.0 - config["CLIP_EPS"], + 1.0 + config["CLIP_EPS"], + ) + * gae + ) + loss_actor = -jnp.minimum(loss_actor1, loss_actor2) + loss_actor = loss_actor.mean() + entropy = pi.entropy().mean() + + value_loss = value_loss_e + if config["USE_RND"]: + value_loss += value_loss_i + + total_loss = ( + loss_actor + + config["VF_COEF"] * value_loss + - config["ENT_COEF"] * entropy + ) + return total_loss, ( + value_loss_e, + value_loss_i, + loss_actor, + entropy, + ) + + grad_fn = jax.value_and_grad(_loss_fn, has_aux=True) + total_loss, grads = grad_fn( + train_state.params, + traj_batch, + advantages_e, + targets_e, + advantages_i, + targets_i, + ) + train_state = train_state.apply_gradients(grads=grads) + + losses = (total_loss, 0) + return train_state, losses + + ( + train_state, + traj_batch, + advantages_e, + targets_e, + advantages_i, + targets_i, + rng, + ) = update_state + rng, _rng = jax.random.split(rng) + batch_size = config["MINIBATCH_SIZE"] * config["NUM_MINIBATCHES"] + assert ( + batch_size == config["NUM_STEPS"] * config["NUM_ENVS"] + ), "batch size must be equal to number of steps * number of envs" + permutation = jax.random.permutation(_rng, batch_size) + batch = ( + traj_batch, + advantages_e, + targets_e, + advantages_i, + targets_i, + ) + batch = jax.tree.map( + lambda x: x.reshape((batch_size,) + x.shape[2:]), batch + ) + shuffled_batch = jax.tree.map( + lambda x: jnp.take(x, permutation, axis=0), batch + ) + minibatches = jax.tree.map( + lambda x: jnp.reshape( + x, [config["NUM_MINIBATCHES"], -1] + list(x.shape[1:]) + ), + shuffled_batch, + ) + train_state, losses = jax.lax.scan( + _update_minbatch, train_state, minibatches + ) + update_state = ( + train_state, + traj_batch, + advantages_e, + targets_e, + advantages_i, + targets_i, + rng, + ) + return update_state, losses + + update_state = ( + train_state, + traj_batch, + advantages_e, + targets_e, + advantages_i, + targets_i, + rng, + ) + update_state, loss_info = jax.lax.scan( + _update_epoch, update_state, None, config["UPDATE_EPOCHS"] + ) + + train_state = update_state[0] + metric = jax.tree.map( + lambda x: (x * traj_batch.info["returned_episode"]).sum() + / traj_batch.info["returned_episode"].sum(), + traj_batch.info, + ) + + rng = update_state[-1] + + # UPDATE EXPLORATION STATE + def _update_ex_epoch(update_state, unused): + def _update_ex_minbatch(ex_state, traj_batch): + rnd_loss = 0 + + if config["USE_RND"]: + + def _rnd_loss_fn(rnd_distillation_params, traj_batch): + random_network_out = rnd_random_network.apply( + rnd_random_network_params, traj_batch.next_obs + ) + + distillation_network_out = ex_state[ + "rnd_distillation_network" + ].apply_fn(rnd_distillation_params, traj_batch.next_obs) + + error = (random_network_out - distillation_network_out) * ( + 1 - traj_batch.done[:, None] + ) + return jnp.square(error).mean() * config["RND_LOSS_COEFF"] + + rnd_grad_fn = jax.value_and_grad(_rnd_loss_fn, has_aux=False) + rnd_loss, rnd_grad = rnd_grad_fn( + ex_state["rnd_distillation_network"].params, traj_batch + ) + ex_state["rnd_distillation_network"] = ex_state[ + "rnd_distillation_network" + ].apply_gradients(grads=rnd_grad) + + losses = (rnd_loss,) + return ex_state, losses + + (ex_state, traj_batch, rng) = update_state + rng, _rng = jax.random.split(rng) + batch_size = config["MINIBATCH_SIZE"] * config["NUM_MINIBATCHES"] + assert ( + batch_size == config["NUM_STEPS"] * config["NUM_ENVS"] + ), "batch size must be equal to number of steps * number of envs" + permutation = jax.random.permutation(_rng, batch_size) + batch = jax.tree.map( + lambda x: x.reshape((batch_size,) + x.shape[2:]), traj_batch + ) + shuffled_batch = jax.tree.map( + lambda x: jnp.take(x, permutation, axis=0), batch + ) + minibatches = jax.tree.map( + lambda x: jnp.reshape( + x, [config["NUM_MINIBATCHES"], -1] + list(x.shape[1:]) + ), + shuffled_batch, + ) + ex_state, losses = jax.lax.scan( + _update_ex_minbatch, ex_state, minibatches + ) + update_state = (ex_state, traj_batch, rng) + return update_state, losses + + if config["USE_RND"]: + ex_update_state = (ex_state, traj_batch, rng) + ex_update_state, ex_loss = jax.lax.scan( + _update_ex_epoch, + ex_update_state, + None, + config["EXPLORATION_UPDATE_EPOCHS"], + ) + metric["rnd_loss"] = ex_loss[0].mean() + metric["reward_i"] = traj_batch.reward_i.mean() + metric["reward_e"] = traj_batch.reward_e.mean() + + ex_state = ex_update_state[0] + rng = ex_update_state[-1] + + # wandb logging + if config["DEBUG"] and config["USE_WANDB"]: + + def callback( + metric, update_step + ): # , loss_info, traj_batch, ex_state, advantages_i, targets_i): + to_log = create_log_dict(metric, config) + batch_log(update_step, to_log, config) + + jax.debug.callback( + callback, + metric, + update_step, + # loss_info, traj_batch, ex_state, advantages_i, targets_i + ) + + runner_state = ( + train_state, + env_state, + last_obs, + ex_state, + rng, + update_step + 1, + ) + return runner_state, metric + + rng, _rng = jax.random.split(rng) + runner_state = ( + train_state, + env_state, + obsv, + ex_state, + _rng, + 0, + ) + runner_state, metric = jax.lax.scan( + _update_step, runner_state, None, config["NUM_UPDATES"] + ) + return {"runner_state": runner_state} # , "info": metric} + + return train + + +def run_ppo(config): + config = {k.upper(): v for k, v in config.__dict__.items()} + + if config["USE_WANDB"]: + wandb.init( + project=config["WANDB_PROJECT"], + entity=config["WANDB_ENTITY"], + config=config, + name=config["ENV_NAME"] + + "-PPO_RND-" + + str(int(config["TOTAL_TIMESTEPS"] // 1e6)) + + "M", + ) + + rng = jax.random.PRNGKey(config["SEED"]) + rngs = jax.random.split(rng, config["NUM_REPEATS"]) + + train_jit = jax.jit(make_train(config)) + train_vmap = jax.vmap(train_jit) + + t0 = time.time() + out = train_vmap(rngs) + t1 = time.time() + print("Time to run experiment", t1 - t0) + print("SPS: ", config["TOTAL_TIMESTEPS"] / (t1 - t0)) + # t1 = time.time() + # out = train_vmap(rngs) + # t2 = time.time() + # print("t2", t2 - t1) + # print("SPS2: ", config["TOTAL_TIMESTEPS"] / (t2 - t1)) + + if config["USE_WANDB"]: + # if config["DEBUG"] == "end": + # info = out["info"] + # for update in range(info["timestep"].shape[1]): + # if update % 10 == 0: + # for repeat in range(info["timestep"].shape[0]): + # update_info = jax.tree.map(lambda x: x[repeat, update], info) + # to_log = create_log_dict(update_info) + # batch_log(update, to_log, config) + # + # t2 = time.time() + # print("Time to log to wandb", t2 - t1) + + def _save_network(rs_index, dir_name): + train_states = out["runner_state"][rs_index] + train_state = jax.tree.map(lambda x: x[0], train_states) + + path = os.path.join(wandb.run.dir, dir_name) + options = ocp.CheckpointManagerOptions(max_to_keep=1) + + with ocp.CheckpointManager(path, options=options) as checkpoint_manager: + checkpoint_manager.save( + int(config["TOTAL_TIMESTEPS"]), + args=ocp.args.StandardSave(train_state) + ) + + print(f"saved runner state to {path}") + + if config["SAVE_POLICY"]: + _save_network(0, "policies") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--env_name", type=str, default="Craftax-Symbolic-v1") + parser.add_argument( + "--num_envs", + type=int, + default=1024, + ) + parser.add_argument( + "--total_timesteps", type=lambda x: int(float(x)), default=1e9 + ) # Allow scientific notation + parser.add_argument("--lr", type=float, default=2e-4) + parser.add_argument("--num_steps", type=int, default=64) + parser.add_argument("--update_epochs", type=int, default=4) + parser.add_argument("--num_minibatches", type=int, default=8) + parser.add_argument("--gamma", type=float, default=0.99) + parser.add_argument("--gae_lambda", type=float, default=0.8) + parser.add_argument("--clip_eps", type=float, default=0.2) + parser.add_argument("--ent_coef", type=float, default=0.01) + parser.add_argument("--vf_coef", type=float, default=0.5) + parser.add_argument("--max_grad_norm", type=float, default=1.0) + parser.add_argument("--activation", type=str, default="tanh") + parser.add_argument( + "--anneal_lr", action=argparse.BooleanOptionalAction, default=True + ) + parser.add_argument("--debug", action=argparse.BooleanOptionalAction, default=True) + parser.add_argument("--jit", action=argparse.BooleanOptionalAction, default=True) + parser.add_argument("--seed", type=int) + parser.add_argument( + "--use_wandb", action=argparse.BooleanOptionalAction, default=True + ) + parser.add_argument("--save_policy", action="store_true") + parser.add_argument("--num_repeats", type=int, default=1) + parser.add_argument("--layer_size", type=int, default=512) + parser.add_argument("--wandb_project", type=str) + parser.add_argument("--wandb_entity", type=str) + parser.add_argument( + "--use_optimistic_resets", action=argparse.BooleanOptionalAction, default=True + ) + parser.add_argument("--optimistic_reset_ratio", type=int, default=16) + + # EXPLORATION + parser.add_argument("--exploration_update_epochs", type=int, default=1) + # RND + parser.add_argument( + "--use_rnd", action=argparse.BooleanOptionalAction, default=True + ) + parser.add_argument("--rnd_layer_size", type=int, default=256) + parser.add_argument("--rnd_output_size", type=int, default=512) + parser.add_argument("--rnd_lr", type=float, default=3e-4) + parser.add_argument("--rnd_reward_coeff", type=float, default=1.0) + parser.add_argument("--rnd_loss_coeff", type=float, default=0.01) + parser.add_argument("--rnd_gae_coeff", type=float, default=0.01) + parser.add_argument( + "--rnd_is_episodic", action=argparse.BooleanOptionalAction, default=False + ) + + args, rest_args = parser.parse_known_args(sys.argv[1:]) + if rest_args: + raise ValueError(f"Unknown args {rest_args}") + + if args.seed is None: + args.seed = np.random.randint(2**31) + + if args.jit: + run_ppo(args) + else: + with jax.disable_jit(): + run_ppo(args) diff --git a/Craftax_Baselines/ppo_rnn.py b/Craftax_Baselines/ppo_rnn.py new file mode 100644 index 0000000000000000000000000000000000000000..7868e559e14563d05e1b5c7b9429bd38ad1fbedc --- /dev/null +++ b/Craftax_Baselines/ppo_rnn.py @@ -0,0 +1,542 @@ +import argparse +import os +import sys + +import jax +import jax.numpy as jnp +import flax.linen as nn +import numpy as np +import optax +import time + +import orbax.checkpoint as ocp + +import wandb +from flax.linen.initializers import constant, orthogonal +from typing import NamedTuple, Dict +from flax.training.train_state import TrainState +import distrax +import functools + +from wrappers import ( + LogWrapper, + OptimisticResetVecEnvWrapper, + BatchEnvWrapper, + AutoResetEnvWrapper, +) +from logz.batch_logging import create_log_dict, batch_log + +from craftax.craftax_env import make_craftax_env_from_name + +# Code adapted from the original implementation made by Chris Lu +# Original code located at https://github.com/luchris429/purejaxrl + + +class ScannedRNN(nn.Module): + @functools.partial( + nn.scan, + variable_broadcast="params", + in_axes=0, + out_axes=0, + split_rngs={"params": False}, + ) + @nn.compact + def __call__(self, carry, x): + """Applies the module.""" + rnn_state = carry + ins, resets = x + rnn_state = jnp.where( + resets[:, np.newaxis], + self.initialize_carry(ins.shape[0], ins.shape[1]), + rnn_state, + ) + new_rnn_state, y = nn.GRUCell(features=ins.shape[1])(rnn_state, ins) + return new_rnn_state, y + + @staticmethod + def initialize_carry(batch_size, hidden_size): + # Use a dummy key since the default state init fn is just zeros. + cell = nn.GRUCell(features=hidden_size) + return cell.initialize_carry(jax.random.PRNGKey(0), (batch_size, hidden_size)) + + +class ActorCriticRNN(nn.Module): + action_dim: int + config: Dict + + @nn.compact + def __call__(self, hidden, x): + obs, dones = x + embedding = nn.Dense( + self.config["LAYER_SIZE"], + kernel_init=orthogonal(np.sqrt(2)), + bias_init=constant(0.0), + )(obs) + embedding = nn.relu(embedding) + + rnn_in = (embedding, dones) + hidden, embedding = ScannedRNN()(hidden, rnn_in) + + actor_mean = nn.Dense( + self.config["LAYER_SIZE"], + kernel_init=orthogonal(2), + bias_init=constant(0.0), + )(embedding) + actor_mean = nn.relu(actor_mean) + actor_mean = nn.Dense( + self.config["LAYER_SIZE"], + kernel_init=orthogonal(2), + bias_init=constant(0.0), + )(actor_mean) + actor_mean = nn.relu(actor_mean) + actor_mean = nn.Dense( + self.action_dim, kernel_init=orthogonal(0.01), bias_init=constant(0.0) + )(actor_mean) + + pi = distrax.Categorical(logits=actor_mean) + + critic = nn.Dense( + self.config["LAYER_SIZE"], + kernel_init=orthogonal(2), + bias_init=constant(0.0), + )(embedding) + critic = nn.relu(critic) + critic = nn.Dense( + self.config["LAYER_SIZE"], + kernel_init=orthogonal(2), + bias_init=constant(0.0), + )(critic) + critic = nn.relu(critic) + critic = nn.Dense(1, kernel_init=orthogonal(1.0), bias_init=constant(0.0))( + critic + ) + + return hidden, pi, jnp.squeeze(critic, axis=-1) + + +class Transition(NamedTuple): + done: jnp.ndarray + action: jnp.ndarray + value: jnp.ndarray + reward: jnp.ndarray + log_prob: jnp.ndarray + obs: jnp.ndarray + info: jnp.ndarray + + +def make_train(config): + config["NUM_UPDATES"] = ( + config["TOTAL_TIMESTEPS"] // config["NUM_STEPS"] // config["NUM_ENVS"] + ) + config["MINIBATCH_SIZE"] = ( + config["NUM_ENVS"] * config["NUM_STEPS"] // config["NUM_MINIBATCHES"] + ) + + # Create environment + env = make_craftax_env_from_name( + config["ENV_NAME"], not config["USE_OPTIMISTIC_RESETS"] + ) + env_params = env.default_params + + # Wrap with some extra logging + env = LogWrapper(env) + + # Wrap with a batcher, maybe using optimistic resets + if config["USE_OPTIMISTIC_RESETS"]: + env = OptimisticResetVecEnvWrapper( + env, + num_envs=config["NUM_ENVS"], + reset_ratio=min(config["OPTIMISTIC_RESET_RATIO"], config["NUM_ENVS"]), + ) + else: + env = AutoResetEnvWrapper(env) + env = BatchEnvWrapper(env, num_envs=config["NUM_ENVS"]) + + def linear_schedule(count): + frac = ( + 1.0 + - (count // (config["NUM_MINIBATCHES"] * config["UPDATE_EPOCHS"])) + / config["NUM_UPDATES"] + ) + return config["LR"] * frac + + def train(rng): + # INIT NETWORK + network = ActorCriticRNN(env.action_space(env_params).n, config=config) + rng, _rng = jax.random.split(rng) + init_x = ( + jnp.zeros( + (1, config["NUM_ENVS"], *env.observation_space(env_params).shape) + ), + jnp.zeros((1, config["NUM_ENVS"])), + ) + init_hstate = ScannedRNN.initialize_carry( + config["NUM_ENVS"], config["LAYER_SIZE"] + ) + network_params = network.init(_rng, init_hstate, init_x) + if config["ANNEAL_LR"]: + tx = optax.chain( + optax.clip_by_global_norm(config["MAX_GRAD_NORM"]), + optax.adam(learning_rate=linear_schedule, eps=1e-5), + ) + else: + tx = optax.chain( + optax.clip_by_global_norm(config["MAX_GRAD_NORM"]), + optax.adam(config["LR"], eps=1e-5), + ) + train_state = TrainState.create( + apply_fn=network.apply, + params=network_params, + tx=tx, + ) + + # INIT ENV + rng, _rng = jax.random.split(rng) + obsv, env_state = env.reset(_rng, env_params) + init_hstate = ScannedRNN.initialize_carry( + config["NUM_ENVS"], config["LAYER_SIZE"] + ) + + # TRAIN LOOP + def _update_step(runner_state, unused): + # COLLECT TRAJECTORIES + def _env_step(runner_state, unused): + ( + train_state, + env_state, + last_obs, + last_done, + hstate, + rng, + update_step, + ) = runner_state + rng, _rng = jax.random.split(rng) + + # SELECT ACTION + ac_in = (last_obs[np.newaxis, :], last_done[np.newaxis, :]) + hstate, pi, value = network.apply(train_state.params, hstate, ac_in) + action = pi.sample(seed=_rng) + log_prob = pi.log_prob(action) + value, action, log_prob = ( + value.squeeze(0), + action.squeeze(0), + log_prob.squeeze(0), + ) + + # STEP ENV + rng, _rng = jax.random.split(rng) + obsv, env_state, reward, done, info = env.step( + _rng, env_state, action, env_params + ) + transition = Transition( + last_done, action, value, reward, log_prob, last_obs, info + ) + runner_state = ( + train_state, + env_state, + obsv, + done, + hstate, + rng, + update_step, + ) + return runner_state, transition + + initial_hstate = runner_state[-3] + runner_state, traj_batch = jax.lax.scan( + _env_step, runner_state, None, config["NUM_STEPS"] + ) + + # CALCULATE ADVANTAGE + ( + train_state, + env_state, + last_obs, + last_done, + hstate, + rng, + update_step, + ) = runner_state + ac_in = (last_obs[np.newaxis, :], last_done[np.newaxis, :]) + _, _, last_val = network.apply(train_state.params, hstate, ac_in) + last_val = last_val.squeeze(0) + + def _calculate_gae(traj_batch, last_val, last_done): + def _get_advantages(carry, transition): + gae, next_value, next_done = carry + done, value, reward = ( + transition.done, + transition.value, + transition.reward, + ) + delta = ( + reward + config["GAMMA"] * next_value * (1 - next_done) - value + ) + gae = ( + delta + + config["GAMMA"] * config["GAE_LAMBDA"] * (1 - next_done) * gae + ) + return (gae, value, done), gae + + _, advantages = jax.lax.scan( + _get_advantages, + (jnp.zeros_like(last_val), last_val, last_done), + traj_batch, + reverse=True, + unroll=16, + ) + return advantages, advantages + traj_batch.value + + advantages, targets = _calculate_gae(traj_batch, last_val, last_done) + + # UPDATE NETWORK + def _update_epoch(update_state, unused): + def _update_minbatch(train_state, batch_info): + init_hstate, traj_batch, advantages, targets = batch_info + + def _loss_fn(params, init_hstate, traj_batch, gae, targets): + # RERUN NETWORK + _, pi, value = network.apply( + params, init_hstate[0], (traj_batch.obs, traj_batch.done) + ) + log_prob = pi.log_prob(traj_batch.action) + + # CALCULATE VALUE LOSS + value_pred_clipped = traj_batch.value + ( + value - traj_batch.value + ).clip(-config["CLIP_EPS"], config["CLIP_EPS"]) + value_losses = jnp.square(value - targets) + value_losses_clipped = jnp.square(value_pred_clipped - targets) + value_loss = ( + 0.5 * jnp.maximum(value_losses, value_losses_clipped).mean() + ) + + # CALCULATE ACTOR LOSS + ratio = jnp.exp(log_prob - traj_batch.log_prob) + gae = (gae - gae.mean()) / (gae.std() + 1e-8) + loss_actor1 = ratio * gae + loss_actor2 = ( + jnp.clip( + ratio, + 1.0 - config["CLIP_EPS"], + 1.0 + config["CLIP_EPS"], + ) + * gae + ) + loss_actor = -jnp.minimum(loss_actor1, loss_actor2) + loss_actor = loss_actor.mean() + entropy = pi.entropy().mean() + + total_loss = ( + loss_actor + + config["VF_COEF"] * value_loss + - config["ENT_COEF"] * entropy + ) + return total_loss, (value_loss, loss_actor, entropy) + + grad_fn = jax.value_and_grad(_loss_fn, has_aux=True) + total_loss, grads = grad_fn( + train_state.params, init_hstate, traj_batch, advantages, targets + ) + train_state = train_state.apply_gradients(grads=grads) + return train_state, total_loss + + ( + train_state, + init_hstate, + traj_batch, + advantages, + targets, + rng, + ) = update_state + + rng, _rng = jax.random.split(rng) + permutation = jax.random.permutation(_rng, config["NUM_ENVS"]) + batch = (init_hstate, traj_batch, advantages, targets) + + shuffled_batch = jax.tree.map( + lambda x: jnp.take(x, permutation, axis=1), batch + ) + + minibatches = jax.tree.map( + lambda x: jnp.swapaxes( + jnp.reshape( + x, + [x.shape[0], config["NUM_MINIBATCHES"], -1] + + list(x.shape[2:]), + ), + 1, + 0, + ), + shuffled_batch, + ) + + train_state, total_loss = jax.lax.scan( + _update_minbatch, train_state, minibatches + ) + update_state = ( + train_state, + init_hstate, + traj_batch, + advantages, + targets, + rng, + ) + return update_state, total_loss + + init_hstate = initial_hstate[None, :] # TBH + update_state = ( + train_state, + init_hstate, + traj_batch, + advantages, + targets, + rng, + ) + update_state, loss_info = jax.lax.scan( + _update_epoch, update_state, None, config["UPDATE_EPOCHS"] + ) + train_state = update_state[0] + metric = jax.tree.map( + lambda x: (x * traj_batch.info["returned_episode"]).sum() + / traj_batch.info["returned_episode"].sum(), + traj_batch.info, + ) + rng = update_state[-1] + if config["DEBUG"] and config["USE_WANDB"]: + + def callback(metric, update_step): + to_log = create_log_dict(metric, config) + batch_log(update_step, to_log, config) + + jax.debug.callback(callback, metric, update_step) + + runner_state = ( + train_state, + env_state, + last_obs, + last_done, + hstate, + rng, + update_step + 1, + ) + return runner_state, metric + + rng, _rng = jax.random.split(rng) + runner_state = ( + train_state, + env_state, + obsv, + jnp.zeros((config["NUM_ENVS"]), dtype=bool), + init_hstate, + _rng, + 0, + ) + runner_state, metric = jax.lax.scan( + _update_step, runner_state, None, config["NUM_UPDATES"] + ) + return {"runner_state": runner_state, "metric": metric} + + return train + + +def run_ppo(config): + config = {k.upper(): v for k, v in config.__dict__.items()} + + if config["USE_WANDB"]: + wandb.init( + project=config["WANDB_PROJECT"], + entity=config["WANDB_ENTITY"], + config=config, + name=config["ENV_NAME"] + + "-PPO_RNN-" + + str(int(config["TOTAL_TIMESTEPS"] // 1e6)) + + "M", + ) + + rng = jax.random.PRNGKey(config["SEED"]) + rngs = jax.random.split(rng, config["NUM_REPEATS"]) + + train_jit = jax.jit(make_train(config)) + train_vmap = jax.vmap(train_jit) + + t0 = time.time() + out = train_vmap(rngs) + t1 = time.time() + print("Time to run experiment", t1 - t0) + print("SPS: ", config["TOTAL_TIMESTEPS"] / (t1 - t0)) + + if config["USE_WANDB"]: + + def _save_network(rs_index, dir_name): + train_states = out["runner_state"][rs_index] + train_state = jax.tree.map(lambda x: x[0], train_states) + + path = os.path.join(wandb.run.dir, dir_name) + options = ocp.CheckpointManagerOptions(max_to_keep=1) + + with ocp.CheckpointManager(path, options=options) as checkpoint_manager: + checkpoint_manager.save( + int(config["TOTAL_TIMESTEPS"]), + args=ocp.args.StandardSave(train_state) + ) + + print(f"saved runner state to {path}") + + if config["SAVE_POLICY"]: + _save_network(0, "policies") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--env_name", type=str, default="Craftax-Symbolic-v1") + parser.add_argument( + "--num_envs", + type=int, + default=1024, + ) + parser.add_argument("--total_timesteps", type=lambda x: int(float(x)), default=1e9) + parser.add_argument("--lr", type=float, default=2e-4) + parser.add_argument("--num_steps", type=int, default=64) + parser.add_argument("--update_epochs", type=int, default=4) + parser.add_argument("--num_minibatches", type=int, default=8) + parser.add_argument("--gamma", type=float, default=0.99) + parser.add_argument("--gae_lambda", type=float, default=0.8) + parser.add_argument("--clip_eps", type=float, default=0.2) + parser.add_argument("--ent_coef", type=float, default=0.01) + parser.add_argument("--vf_coef", type=float, default=0.5) + parser.add_argument("--max_grad_norm", type=float, default=1.0) + parser.add_argument("--activation", type=str, default="tanh") + parser.add_argument( + "--anneal_lr", action=argparse.BooleanOptionalAction, default=True + ) + parser.add_argument("--debug", action=argparse.BooleanOptionalAction, default=True) + parser.add_argument("--jit", action=argparse.BooleanOptionalAction, default=True) + parser.add_argument("--seed", type=int, default=np.random.randint(2**31)) + parser.add_argument( + "--use_wandb", action=argparse.BooleanOptionalAction, default=True + ) + parser.add_argument( + "--save_policy", action=argparse.BooleanOptionalAction, default=False + ) + parser.add_argument("--num_repeats", type=int, default=1) + parser.add_argument("--layer_size", type=int, default=512) + parser.add_argument("--wandb_project", type=str) + parser.add_argument("--wandb_entity", type=str) + parser.add_argument( + "--use_optimistic_resets", action=argparse.BooleanOptionalAction, default=True + ) + parser.add_argument("--optimistic_reset_ratio", type=int, default=16) + + args, rest_args = parser.parse_known_args(sys.argv[1:]) + if rest_args: + raise ValueError(f"Unknown args {rest_args}") + + if args.seed is None: + args.seed = np.random.randint(2**31) + + if args.jit: + run_ppo(args) + else: + with jax.disable_jit(): + run_ppo(args) diff --git a/Craftax_Baselines/requirements.txt b/Craftax_Baselines/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..ed907da69758eef3abbe6c9b9fcca670e9794633 --- /dev/null +++ b/Craftax_Baselines/requirements.txt @@ -0,0 +1,16 @@ +jax[cuda12_pip] +distrax +optax +flax +numpy +black +pre-commit +argparse +wandb +orbax-checkpoint==0.5.0 +pygame +gymnax +chex +matplotlib +imageio +craftax diff --git a/Craftax_Baselines/run_docker.sh b/Craftax_Baselines/run_docker.sh new file mode 100644 index 0000000000000000000000000000000000000000..f22e47879f9809ecbfa3f055e7ff4726ea06c613 --- /dev/null +++ b/Craftax_Baselines/run_docker.sh @@ -0,0 +1,24 @@ +#!/bin/bash +WANDB_API_KEY=$(cat ./wandb_key) +# git pull + +script_and_args="${@:2}" +if [ $1 == "all" ]; then + gpus="0 1 2 3 4 5 6 7" +else + gpus=$1 +fi + +for gpu in $gpus; do + echo "Launching container craftax_$gpu on GPU $gpu" + docker run \ + --gpus device=$gpu \ + -e WANDB_API_KEY=$WANDB_API_KEY \ + -v $(pwd):/home/duser/Craftax \ + --name craftax_$gpu \ + --user $(id -u) \ + --rm \ + -d \ + -t craftax_baselines \ + /bin/bash -c "$script_and_args" +done diff --git a/Craftax_Baselines/wrappers.py b/Craftax_Baselines/wrappers.py new file mode 100644 index 0000000000000000000000000000000000000000..43d601e184f2abd2b6b4c0a1ddc6839460d7db63 --- /dev/null +++ b/Craftax_Baselines/wrappers.py @@ -0,0 +1,200 @@ +import jax +import jax.numpy as jnp +import chex +import numpy as np +from flax import struct +from functools import partial +from typing import Optional, Tuple, Union, Any + + +class GymnaxWrapper(object): + """Base class for Gymnax wrappers.""" + + def __init__(self, env): + self._env = env + + # provide proxy access to regular attributes of wrapped object + def __getattr__(self, name): + return getattr(self._env, name) + + +class BatchEnvWrapper(GymnaxWrapper): + """Batches reset and step functions""" + + def __init__(self, env, num_envs: int): + super().__init__(env) + + self.num_envs = num_envs + + self.reset_fn = jax.vmap(self._env.reset, in_axes=(0, None)) + self.step_fn = jax.vmap(self._env.step, in_axes=(0, 0, 0, None)) + + @partial(jax.jit, static_argnums=(0, 2)) + def reset(self, rng, params=None): + rng, _rng = jax.random.split(rng) + rngs = jax.random.split(_rng, self.num_envs) + obs, env_state = self.reset_fn(rngs, params) + return obs, env_state + + @partial(jax.jit, static_argnums=(0, 4)) + def step(self, rng, state, action, params=None): + rng, _rng = jax.random.split(rng) + rngs = jax.random.split(_rng, self.num_envs) + obs, state, reward, done, info = self.step_fn(rngs, state, action, params) + + return obs, state, reward, done, info + + +class AutoResetEnvWrapper(GymnaxWrapper): + """Provides standard auto-reset functionality, providing the same behaviour as Gymnax-default.""" + + def __init__(self, env): + super().__init__(env) + + @partial(jax.jit, static_argnums=(0, 2)) + def reset(self, key, params=None): + return self._env.reset(key, params) + + @partial(jax.jit, static_argnums=(0, 4)) + def step(self, rng, state, action, params=None): + + rng, _rng = jax.random.split(rng) + obs_st, state_st, reward, done, info = self._env.step( + _rng, state, action, params + ) + + rng, _rng = jax.random.split(rng) + obs_re, state_re = self._env.reset(_rng, params) + + # Auto-reset environment based on termination + def auto_reset(done, state_re, state_st, obs_re, obs_st): + state = jax.tree.map( + lambda x, y: jax.lax.select(done, x, y), state_re, state_st + ) + obs = jax.lax.select(done, obs_re, obs_st) + + return obs, state + + obs, state = auto_reset(done, state_re, state_st, obs_re, obs_st) + + return obs, state, reward, done, info + + +class OptimisticResetVecEnvWrapper(GymnaxWrapper): + """ + Provides efficient 'optimistic' resets. + The wrapper also necessarily handles the batching of environment steps and resetting. + reset_ratio: the number of environment workers per environment reset. Higher means more efficient but a higher + chance of duplicate resets. + """ + + def __init__(self, env, num_envs: int, reset_ratio: int): + super().__init__(env) + + self.num_envs = num_envs + self.reset_ratio = reset_ratio + assert ( + num_envs % reset_ratio == 0 + ), "Reset ratio must perfectly divide num envs." + self.num_resets = self.num_envs // reset_ratio + + self.reset_fn = jax.vmap(self._env.reset, in_axes=(0, None)) + self.step_fn = jax.vmap(self._env.step, in_axes=(0, 0, 0, None)) + + @partial(jax.jit, static_argnums=(0, 2)) + def reset(self, rng, params=None): + rng, _rng = jax.random.split(rng) + rngs = jax.random.split(_rng, self.num_envs) + obs, env_state = self.reset_fn(rngs, params) + return obs, env_state + + @partial(jax.jit, static_argnums=(0, 4)) + def step(self, rng, state, action, params=None): + + rng, _rng = jax.random.split(rng) + rngs = jax.random.split(_rng, self.num_envs) + obs_st, state_st, reward, done, info = self.step_fn(rngs, state, action, params) + + rng, _rng = jax.random.split(rng) + rngs = jax.random.split(_rng, self.num_resets) + obs_re, state_re = self.reset_fn(rngs, params) + + rng, _rng = jax.random.split(rng) + reset_indexes = jnp.arange(self.num_resets).repeat(self.reset_ratio) + + being_reset = jax.random.choice( + _rng, + jnp.arange(self.num_envs), + shape=(self.num_resets,), + p=done, + replace=False, + ) + reset_indexes = reset_indexes.at[being_reset].set(jnp.arange(self.num_resets)) + + obs_re = obs_re[reset_indexes] + state_re = jax.tree.map(lambda x: x[reset_indexes], state_re) + + # Auto-reset environment based on termination + def auto_reset(done, state_re, state_st, obs_re, obs_st): + state = jax.tree.map( + lambda x, y: jax.lax.select(done, x, y), state_re, state_st + ) + obs = jax.lax.select(done, obs_re, obs_st) + + return state, obs + + state, obs = jax.vmap(auto_reset)(done, state_re, state_st, obs_re, obs_st) + + return obs, state, reward, done, info + + +@struct.dataclass +class LogEnvState: + env_state: Any + episode_returns: float + episode_lengths: int + returned_episode_returns: float + returned_episode_lengths: int + timestep: int + + +class LogWrapper(GymnaxWrapper): + """Log the episode returns and lengths.""" + + def __init__(self, env): + super().__init__(env) + + @partial(jax.jit, static_argnums=(0, 2)) + def reset(self, key: chex.PRNGKey, params=None): + obs, env_state = self._env.reset(key, params) + state = LogEnvState(env_state, 0.0, 0, 0.0, 0, 0) + return obs, state + + @partial(jax.jit, static_argnums=(0, 4)) + def step( + self, + key: chex.PRNGKey, + state, + action: Union[int, float], + params=None, + ): + obs, env_state, reward, done, info = self._env.step( + key, state.env_state, action, params + ) + new_episode_return = state.episode_returns + reward + new_episode_length = state.episode_lengths + 1 + state = LogEnvState( + env_state=env_state, + episode_returns=new_episode_return * (1 - done), + episode_lengths=new_episode_length * (1 - done), + returned_episode_returns=state.returned_episode_returns * (1 - done) + + new_episode_return * done, + returned_episode_lengths=state.returned_episode_lengths * (1 - done) + + new_episode_length * done, + timestep=state.timestep + 1, + ) + info["returned_episode_returns"] = state.returned_episode_returns + info["returned_episode_lengths"] = state.returned_episode_lengths + info["timestep"] = state.timestep + info["returned_episode"] = done + return obs, state, reward, done, info diff --git a/README.md b/README.md new file mode 100644 index 0000000000000000000000000000000000000000..b43bf2c29f63c5bd4fff8ace05c72e9ace33f8e5 --- /dev/null +++ b/README.md @@ -0,0 +1,547 @@ +# ReMDM Planner — Discrete Diffusion Planning on Craftax + +A JAX implementation of **ReMDM** (Remasking Discrete Diffusion Model) for action-sequence planning in the [Craftax](https://github.com/MichaelTMatthews/Craftax) environment. A bidirectional transformer learns to generate action plans by iteratively denoising masked token sequences, conditioned on the current environment observation. + +--- + +## Description + +The planner starts from a fully-masked action sequence and iteratively unmasks tokens over `T` denoising steps, producing a `plan_horizon`-length plan. The ReMDM framework extends standard Masked Discrete Language Modelling (MDLM) with remasking strategies that allow committed tokens to be re-predicted, improving plan coherence. + +Two independent training pipelines are available — **Offline BC** and **Online DAgger** — both supervised by a pre-trained PPO expert but otherwise separate. Neither depends on the other; the paper compares them head-to-head. + +``` +[Shared] Train PPO agent Craftax_Baselines/ppo_rnn.py | ppo_rnd.py + | + v checkpoint + ┌───────┴────────┐ + │ │ + [Offline BC] [Online DAgger] + main.py main.py + --mode offline --mode online + (train on live (train from scratch; + PPO rollouts) mixed policy + expert + │ labels into replay buffer) + v v + checkpoint checkpoint + │ │ + └───────┬────────┘ + v +[Evaluate] main.py --mode inference --checkpoint_path ... + +Optional: an offline BC checkpoint can warm-start DAgger +via --offline_checkpoint_path (not used in the paper). + + [Offline BC] ──checkpoint──> [Online DAgger] +``` + +**Optional utility modes:** +``` +[Collect] Save PPO rollouts to disk main.py --mode collect +[Smoke test] Quick end-to-end check main.py --mode smoke +``` + +--- + +## Installation + +### Prerequisites (system-level) + +`uv` manages Python packages only. The following must be installed at the OS level before +running on a GPU node — they are **not** in `pyproject.toml`: + +- **CUDA 13** driver and toolkit (`libcuda.so`, `libcudnn`) + +On HPC clusters these are typically loaded via `module load cuda/13.x`. + +### 1. Create the virtual environment + +```bash +# CPU-only (local development / macOS) +uv sync + +# NVIDIA CUDA 13 (GPU node — Linux only) +uv sync --extra cuda + +# Activate +source .venv/bin/activate +``` + +`uv sync` reads `pyproject.toml`, resolves a fully-reproducible lockfile (`uv.lock`), +and installs into `.venv/`. Commit `uv.lock` to pin the exact dependency graph. + +### 2. Initialise the submodule + +```bash +git submodule update --init --recursive +``` + +--- + +## Dependencies + +| Package | Version | Role | +|---------|---------|------| +| `jax` | >=0.9.2 | JIT compilation and functional arrays | +| `flax` | >=0.12.6 | Neural network definitions | +| `optax` | >=0.2.8 | Adam optimiser and gradient clipping | +| `craftax` | >=1.5.0 | Procedurally-generated Minecraft-like environment | +| `chex` | >=0.1.91 | JAX testing and assertion utilities | +| `distrax` | >=0.1.7 | Probability distributions | +| `orbax` | >=0.1.9 | Model checkpointing | +| `wandb` | >=0.25.1 | Experiment logging | +| `numpy` | >=2.4.4 | Array operations | +| `matplotlib` | >=3.10.8 | Plotting | +| `polars` | >=1.39.3 | DataFrame analysis | +| `orjson` | >=3.11.8 | Fast JSON serialisation | +| `pyyaml` | >=6.0.3 | Config file parsing | + +Full specification in `pyproject.toml`. Exact transitive pins are in `uv.lock`. + +--- + +## Usage + +All modes share the same entry point. Defaults are loaded from `configs/defaults.yaml`; any value can be overridden on the command line. + +```bash +python main.py --mode [--config PATH] [OVERRIDES...] +``` + +Pass `--no-jit` to disable JIT compilation (useful for debugging): + +```bash +python main.py --mode offline --no-jit --num_envs 4 +``` + +### Stage 1 — Train a PPO agent + +PPO training is handled by the `Craftax_Baselines` submodule and produces the checkpoint consumed by all downstream stages. + +```bash +cd Craftax_Baselines + +# PPO with GRU hidden state (recommended) +python ppo_rnn.py \ + --env_name Craftax-Classic-Symbolic-v1 \ + --total_timesteps 500000000 \ + --save_policy --use_wandb + +# PPO with Random Network Distillation +python ppo_rnd.py \ + --env_name Craftax-Classic-Symbolic-v1 \ + --total_timesteps 500000000 \ + --save_policy --use_wandb + +cd .. +``` + +### Stage 2a — Collect trajectories to disk + +Roll out the PPO checkpoint and save `(obs, actions, rewards, dones)` as a `.npz` file for reuse across multiple diffusion training runs. + +```bash +python main.py --mode collect \ + --ppo_checkpoint_path /path/to/ppo_checkpoint \ + --offline_data_path data/trajectories.npz \ + --collect_num_steps 1000000 \ + --collect_num_envs 128 +``` + +The file stores arrays shaped `[num_envs, num_iters, ...]`, preserving per-environment contiguity so episode boundaries are respected during window sampling. + +### Stage 2b — Train offline from live PPO rollouts + +Roll out the PPO agent live at each update step and train the diffusion model on the collected windows. Windows that cross episode boundaries are masked out; windows with higher cumulative reward receive proportionally larger gradient contributions (clipped to `[0.1, return_weight_cap]`). + +```bash +python main.py --mode offline \ + --ppo_checkpoint_path /path/to/ppo_checkpoint \ + --offline_total_timesteps 100000000 \ + --save_policy +``` + +### Online DAgger Training + +The diffusion model is trained **from scratch** via DAgger (Dataset Aggregation). At each iteration a mixed policy blends the PPO expert and the diffusion learner (controlled by an exponentially decaying `beta`). The mixed policy rolls out trajectories; the expert labels every visited state with the action it would take. These `(obs, expert_plan)` pairs are appended to a growing circular replay buffer, and the diffusion model is trained on the full buffer with the standard MDLM ELBO loss (pure behavioural cloning — no reward weighting). + +```bash +# From scratch (requires PPO expert checkpoint) +python main.py --mode online \ + --ppo_checkpoint_path /path/to/ppo_checkpoint \ + --online_num_updates 1000 \ + --save_policy + +# Optional: warm-start from a pre-trained offline checkpoint +# (not used in the paper — both methods are compared independently) +python main.py --mode online \ + --ppo_checkpoint_path /path/to/ppo_checkpoint \ + --offline_checkpoint_path /path/to/offline_checkpoint \ + --online_num_updates 1000 \ + --save_policy +``` + +When `save_policy=true`, online training uploads **two** W&B artifacts: `{env_name}-policy` (final weights) and `{env_name}-policy-best` (weights from the validation iteration with the highest return). Either artifact can be consumed downstream by `--checkpoint_path wandb:…`. + +### Stage 4 — Evaluate + +```bash +python main.py --mode inference \ + --checkpoint_path /path/to/checkpoint \ + --eval_steps 10000 \ + --eval_num_envs 32 +``` + +Prints mean episode return, completed episodes, steps per second, and per-achievement unlock counts. Uses historical inpainting: the first `hist_len` plan positions are locked to observed history. + +### Loading checkpoints from W&B artifacts + +Any checkpoint path argument (`--checkpoint_path`, `--offline_checkpoint_path`, `--ppo_checkpoint_path`) accepts a W&B artifact reference prefixed with `wandb:`. The artifact is downloaded automatically before training or evaluation begins. + +```bash +# Fully qualified: entity/project/artifact_name:version_or_alias +python main.py --mode inference \ + --checkpoint_path wandb:my-team/remdm-craftax/Craftax-Classic-Symbolic-v1-policy:latest + +# Online fine-tuning from a W&B offline checkpoint +python main.py --mode online \ + --offline_checkpoint_path wandb:my-team/remdm-craftax/Craftax-Classic-Symbolic-v1-policy:v3 + +# PPO checkpoint from W&B +python main.py --mode offline \ + --ppo_checkpoint_path wandb:my-team/ppo-craftax/ppo-rnn-policy:best +``` + +Control the download location with `--wandb_download_dir` (defaults to `./artifacts/`). + +### Resuming a Training Run + +A completed training checkpoint can be used as the starting point for a new run that continues where the previous one left off. This is useful when extending the training budget or when a preempted job needs to be restarted. + +**Offline resume:** + +```bash +# Auto-detect step and wandb run ID from checkpoint metadata +python main.py --mode offline \ + --ppo_checkpoint_path /path/to/ppo_checkpoint \ + --resume_checkpoint_path /path/to/completed_offline_checkpoint \ + --offline_total_timesteps 200000000 \ + --save_policy + +# Explicit step and wandb run ID override +python main.py --mode offline \ + --ppo_checkpoint_path /path/to/ppo_checkpoint \ + --resume_checkpoint_path /path/to/completed_offline_checkpoint \ + --resume_step 1525 \ + --resume_wandb_run_id abc123xyz \ + --offline_total_timesteps 200000000 \ + --save_policy + +# Resume from a W&B artifact +python main.py --mode offline \ + --ppo_checkpoint_path /path/to/ppo_checkpoint \ + --resume_checkpoint_path wandb:my-team/remdm-craftax/policy:latest \ + --offline_total_timesteps 200000000 \ + --save_policy +``` + +**Online resume:** + +```bash +python main.py --mode online \ + --ppo_checkpoint_path /path/to/ppo_checkpoint \ + --resume_checkpoint_path /path/to/completed_online_checkpoint \ + --online_num_updates 2000 \ + --save_policy +``` + +**Notes:** +- The DAgger replay buffer is **not** persisted across resumes. It starts empty and refills within the first few iterations. +- JIT compilation is fully preserved. Resume only affects initialisation outside `jax.jit` (loading checkpoint, setting the optimizer step counter, adjusting scan length). +- The cosine LR schedule is constructed for the full `num_updates` range. The optimizer step counter is set to the resume offset so the learning rate picks up exactly where the previous run stopped. +- When `resume_checkpoint_path` points to a checkpoint with a metadata sidecar, `resume_step` and `resume_wandb_run_id` are auto-detected. Explicit CLI flags override the metadata values. +- Checkpoints without a metadata sidecar (created before this feature) still load; provide `--resume_step` explicitly. + + +--- + +## Configuration + +All hyperparameters are in `configs/defaults.yaml`. Override any value on the command line: + +```bash +python main.py --mode offline --lr 1e-4 --plan_horizon 64 --num_minibatches 16 +``` + +Point to a custom config file: + +```bash +python main.py --mode online --config configs/my_experiment.yaml +``` + +Preset configs for larger runs are provided in `configs/`: + +| File | Purpose | +|------|---------| +| `configs/defaults.yaml` | Base defaults for all modes | +| `configs/classic_exp_a_beta_fix.yaml` | Craftax Classic DAgger — beta decay fix only (isolates data quality) | +| `configs/classic_exp_b_beta_big_model.yaml` | Craftax Classic DAgger — beta fix + 3.5× larger transformer | +| `configs/classic_exp_c_full_recipe.yaml` | Craftax Classic DAgger — beta + big model + training dynamics | +| `configs/craftax_exp_a_beta_fix.yaml` | Full Craftax DAgger — beta decay fix only | +| `configs/craftax_exp_b_beta_big_model.yaml` | Full Craftax DAgger — beta fix + larger transformer | +| `configs/craftax_exp_c_full_recipe.yaml` | Full Craftax DAgger — full recipe | +| `configs/final_classic_ucl.yaml` | Final Craftax Classic DAgger — UCL 3090 Ti, seed 42 (produces the Classic checkpoint consumed by the ablation suite) | +| `configs/final_classic_qmul.yaml` | Env-frame-matched second seed of `final_classic_ucl.yaml` (QMUL H200, seed 43) | +| `configs/final_craftax_ucl.yaml` | Final Full Craftax DAgger — UCL 4090, seed 42 (produces the Full Craftax checkpoint consumed by the ablation suite) | +| `configs/final_craftax_qmul.yaml` | Env-frame-matched second seed of `final_craftax_ucl.yaml` (QMUL H200, seed 43) | + +RL fine-tuning ablation hyperparameters live under `experiments/rl_finetuning/configs/` and are loaded by `run_ablations.py`, not by `main.py`. See `experiments/README.md`. + +The `final_*_qmul.yaml` presets differ from their UCL counterparts only in `num_envs` (smaller partition) and `seed`. All fairness-critical hyperparameters are denominated in env frames or update cycles and automatically rescaled by `resolve_scaled_hyperparams()` at load time, so no manual derivation is needed when moving between hardware tiers. + +### Key hyperparameters + +**Environment** + +| Parameter | Default | Description | +|-----------|---------|-------------| +| `env_name` | `Craftax-Classic-Symbolic-v1` | Craftax environment ID. Use `Craftax-Symbolic-v1` for Full Craftax. | +| `use_optimistic_resets` | `false` | Use `OptimisticResetVecEnvWrapper` instead of `AutoResetEnvWrapper` | +| `optimistic_reset_ratio` | 16 | Fraction of envs reset per step when optimistic resets are enabled | + +**Diffusion model** + +| Parameter | Default | Description | +|-----------|---------|-------------| +| `plan_horizon` | 32 | Action plan length H | +| `diffusion_steps` | 15 | Denoising steps T at inference | +| `diffusion_schedule` | `cosine` | Noise schedule: `cosine` or `linear` | +| `remask_strategy` | `rescale` | Remasking strategy: `rescale`, `cap`, or `conf` | +| `train_sigma` | 0.0 | Per-token remasking correction during training (0 = standard MDLM) | +| `label_smoothing` | 0.0 | Cross-entropy label smoothing epsilon (0 = exact ELBO) | +| `eta` | 0.5 | Remasking strength | +| `use_loop` | `true` | Three-phase loop remasking (Algorithm 3) | +| `t_on` / `t_off` | 0.7 / 0.3 | Time window boundaries for loop remasking | +| `temperature` | 0.5 | Softmax temperature for token sampling | +| `top_p` | 0.95 | Nucleus sampling threshold | + +**Transformer architecture** + +| Parameter | Default | Description | +|-----------|---------|-------------| +| `d_model` | 256 | Hidden dimension | +| `n_heads` | 4 | Attention heads | +| `n_layers` | 4 | Transformer blocks | +| `d_ff` | 512 | FFN inner dimension | +| `obs_encoder_layers` | 2 | MLP layers in the observation encoder | +| `obs_encoder_width` | 512 | Observation encoder hidden width | +| `dropout_rate` | 0.1 | Dropout rate (disabled at inference) | + +**Offline training** + +| Parameter | Default | Description | +|-----------|---------|-------------| +| `offline_total_timesteps` | 1e8 | **PRIMARY** env-frame budget for live-PPO data collection. Derives `num_updates` as `offline_total_timesteps // (num_envs * num_steps)`, making the run hardware-portable across `num_envs` changes. | +| `offline_num_updates` | `null` | **LEGACY** outer update count; used only when `offline_total_timesteps` is unset. | +| `num_envs` | 1024 | Parallel environments | +| `num_steps` | 64 | Environment steps collected per update | +| `num_minibatches` | 8 | Gradient minibatches per epoch | +| `update_epochs` | 4 | SGD epochs per update step | +| `num_repeats` | 1 | Independent training seeds (vmapped) | +| `lr` | 3e-4 | Adam learning rate (cosine-decayed to 10% over all gradient steps) | +| `lr_warmup_frames` | `null` | **PRIMARY** env-frame warm-up budget. Derives `lr_warmup_steps` as `lr_warmup_frames // (num_envs * num_steps)`. | +| `lr_warmup_steps` | 0 | **LEGACY** linear warm-up steps before cosine decay (used when `lr_warmup_frames` is unset; 0 = disabled). | +| `max_grad_norm` | 1.0 | Global gradient clipping norm | +| `return_weight_cap` | 5.0 | Clip ceiling for per-window return weights (lower clip is fixed at 0.1) | +| `collect_temperature` | 1.0 | Softmax temperature on PPO logits during live data collection | +| `val_interval_frames` | `null` | **PRIMARY** env-frames between validation rollouts. Overrides `val_interval` via `val_interval = val_interval_frames // (num_envs * num_steps)`. | +| `val_interval` | 50 | **LEGACY** validation frequency in update steps (used when `val_interval_frames` is unset). | +| `val_diffusion_steps` | 50 | Denoising steps used during validation rollouts | +| `val_replan_every` | 4 | Environment steps executed per diffusion plan during validation | +| `val_steps` | 128 | Total environment steps per validation rollout | + +**Online DAgger training** + +| Parameter | Default | Description | +|-----------|---------|-------------| +| `online_total_timesteps` | `null` | **PRIMARY** env-frame budget for online DAgger (hardware-portable). Derives `num_updates` as `online_total_timesteps // (num_envs * num_steps)`. | +| `online_num_updates` | 1000 | **LEGACY** outer DAgger iterations (used when `online_total_timesteps` is unset). | +| `dagger_beta_init` | 1.0 | Initial expert mixing probability `beta_1` (1.0 = pure expert on the first iteration). | +| `dagger_beta_final` | `null` | **PRIMARY** target mixing ratio at the end of training. Overrides `dagger_beta_decay` via `decay = (beta_final / beta_init) ** (1 / num_updates)`. | +| `dagger_beta_decay` | 0.95 | **LEGACY** per-update decay: `beta_i = beta_init * decay^i` (used when `dagger_beta_final` is unset). | +| `dagger_buffer_cycles` | `null` | **PRIMARY** buffer capacity denominated in update cycles of history (1 cycle = `num_envs * num_steps` frames). Overrides `dagger_buffer_max` via `buffer_max = cycles * (num_envs * num_steps)`. | +| `dagger_buffer_max` | 100000 | **LEGACY** max samples in the DAgger replay buffer (circular eviction when full). | +| `dagger_train_passes` | `null` | Passes per update over the aggregated buffer. `null` = 1 pass (matches offline BC per-update gradient work exactly for fair compute comparison). Raise to >1 to trade BC fairness for wider per-update buffer coverage. | +| `dagger_expert_deterministic` | `true` | If `true`, the PPO expert takes the argmax action (fixed `s → a*` map); if `false`, it samples categorically. Deterministic removes label noise from the aggregated dataset. | + +**Data collection** + +| Parameter | Default | Description | +|-----------|---------|-------------| +| `collect_num_steps` | 10000000 | Total environment steps to collect | +| `collect_num_envs` | 128 | Parallel environments during collection | +| `ppo_model_type` | `ppo_rnn` | PPO architecture: `ppo`, `ppo_rnn`, or `ppo_rnd` | +| `layer_size` | 512 | PPO actor-critic hidden layer width | + +**Inference** + +| Parameter | Default | Description | +|-----------|---------|-------------| +| `eval_steps` | 10000 | Environment steps for evaluation | +| `eval_num_envs` | 32 | Parallel agents during evaluation (independent of `num_envs`) | +| `diffusion_steps_eval` | 10 | Denoising steps T used at evaluation time | + +**Checkpointing** + +| Parameter | Default | Description | +|-----------|---------|-------------| +| `save_policy` | `true` | Save final checkpoint at end of training and upload it as a W&B artifact | + +**Resume** + +| Parameter | Default | Description | +|-----------|---------|-------------| +| `resume_checkpoint_path` | `null` | Path to a completed checkpoint to resume from (accepts `wandb:` refs) | +| `resume_wandb_run_id` | `null` | W&B run ID to resume logging into (auto-read from checkpoint metadata) | +| `resume_step` | `null` | Update step the checkpoint was saved at (auto-read from checkpoint metadata) | + +**Logging** + +| Parameter | Default | Description | +|-----------|---------|-------------| +| `use_wandb` | `true` | Enable Weights & Biases logging | +| `wandb_project` | `remdm-craftax` | W&B project name | +| `wandb_entity` | `"mathis-weil-university-college-london-ucl-"` | W&B entity (team or username) | +| `wandb_download_dir` | `null` | Download directory for W&B artifacts; null = `./artifacts/` | +| `seed` | `null` | RNG seed (random if null) | + +--- + +## Remasking Strategies + +Controlled by `--remask_strategy`. All strategies operate on top of the three-phase loop controlled by `--use_loop`, `--t_on`, and `--t_off`. + +| Strategy | Formula | Description | +|----------|---------|-------------| +| `rescale` | `sigma = eta * sigma_max` | Scales maximum remasking probability proportionally | +| `cap` | `sigma = min(eta, sigma_max)` | Caps remasking at a fixed rate | +| `conf` | `sigma = eta * sigma_max * (1 - confidence)` | High-confidence tokens are remasked less | + +--- + +## Environment Wrappers + +**From `Craftax_Baselines/wrappers.py`** (submodule): + +| Wrapper | Purpose | +|---------|---------| +| `LogWrapper` | Tracks episode returns and lengths; adds stats to the info dict | +| `AutoResetEnvWrapper` | Automatically resets episodes on `done` | +| `BatchEnvWrapper` | Vmaps `reset` and `step` over `num_envs` environments | +| `OptimisticResetVecEnvWrapper` | Batched resets with reduced overhead; enable via `--use_optimistic_resets` | + +**From `src/envs/wrappers.py`**: + +| Wrapper | Purpose | +|---------|---------| +| `SequenceHistoryWrapper` | Maintains a sliding window of past observations and actions in the env state | +| `DiscreteTokenizationWrapper` | Quantizes continuous observations into discrete token indices | +| `PlannerWrapper` | Manages the plan/replan cycle for the diffusion planner | +| `OfflineTrajectoryWrapper` | Accumulates transitions into a fixed-size circular replay buffer | + +**Wrapper stacks:** + +``` +Training: env -> LogWrapper -> AutoResetEnvWrapper -> BatchEnvWrapper +Inference: env -> LogWrapper -> AutoResetEnvWrapper -> BatchEnvWrapper +``` + +--- + +## Project Structure + +``` +craftax-ReMDM-planner/ +├── Craftax_Baselines/ # Git submodule — PPO agents and standard wrappers +│ ├── wrappers.py # LogWrapper, BatchEnvWrapper, AutoResetEnvWrapper, etc. +│ ├── ppo_rnn.py # PPO-RNN training script +│ ├── ppo_rnd.py # PPO-RND training script +│ ├── ppo.py # PPO model definitions +│ └── models/ +│ ├── actor_critic.py # ActorCritic variants +│ ├── rnd.py # RND network +│ └── icm.py # ICM encoder, forward, and inverse networks +├── configs/ +│ ├── defaults.yaml # Base hyperparameters (CLI-overridable) +│ ├── classic_exp_a_beta_fix.yaml # Classic DAgger — beta decay fix only +│ ├── classic_exp_b_beta_big_model.yaml # Classic DAgger — beta fix + big model +│ ├── classic_exp_c_full_recipe.yaml # Classic DAgger — full recipe +│ ├── craftax_exp_a_beta_fix.yaml # Full Craftax DAgger — beta fix +│ ├── craftax_exp_b_beta_big_model.yaml # Full Craftax DAgger — beta + big model +│ ├── craftax_exp_c_full_recipe.yaml # Full Craftax DAgger — full recipe +│ ├── final_classic_ucl.yaml # Classic DAgger — UCL 3090 Ti, seed 42 +│ ├── final_classic_qmul.yaml # Classic DAgger — QMUL H200, seed 43 +│ ├── final_craftax_ucl.yaml # Full Craftax DAgger — UCL 4090, seed 42 +│ └── final_craftax_qmul.yaml # Full Craftax DAgger — QMUL H200, seed 43 +├── src/ +│ ├── diffusion/ +│ │ ├── forward.py # Forward masking process q(z_t | x_0) +│ │ ├── loss.py # Continuous-time MDLM ELBO loss +│ │ ├── sampling.py # Reverse diffusion with ReMDM remasking +│ │ └── schedules.py # Linear and cosine noise schedules +│ ├── models/ +│ │ └── denoiser.py # DenoisingTransformer (obs encoder + transformer) +│ ├── envs/ +│ │ └── wrappers.py # Sequence, tokenization, planner, and trajectory wrappers +│ └── planners/ +│ ├── collect.py # --mode collect: PPO rollouts -> .npz +│ ├── common.py # Shared utilities +│ ├── env.py # Environment construction +│ ├── inference.py # --mode inference: MPC evaluation with inpainting +│ ├── logging.py # Centralised W&B logging utilities +│ ├── model.py # Diffusion model lifecycle +│ ├── offline.py # --mode offline: make_train (live PPO rollouts) +│ ├── online.py # --mode online: DAgger fine-tuning +│ └── ppo.py # PPO agent adapter and checkpoint loading utilities +├── experiments/ +│ └── rl_finetuning/ # RL fine-tuning ablation suite (see experiments/README.md) +│ ├── run_ablations.py # CLI entry point +│ ├── ablations/ # Loss, optimizer, registry, and training modules +│ ├── diagnostics/ # Gradient, representation, and timestep diagnostics +│ ├── analysis/ # Plots, tables, and report generation +│ └── configs/ # ablations_default.yaml, ablations_fast.yaml, +│ # ablations_final_{classic,craftax}_{ucl,qmul}.yaml +├── main.py # CLI entry point +├── pyproject.toml # uv project — direct deps + tool config +└── uv.lock # Reproducible lockfile (commit this) +``` + +--- + +## Implementation Notes + +**JAX functional purity**: training closures (`make_train`, `make_train_dagger`) are fully JIT-compatible. Environment construction and checkpoint I/O happen outside `jax.jit`. + +**Offline training**: `--mode offline` rolls out the PPO agent live at each update step via `make_train`. Use `--mode collect` to save a trajectory `.npz` for inspection or analysis; re-feeding it to `--mode offline` is not supported — pass `--ppo_checkpoint_path` instead. + +**Episode-boundary masking**: the offline sampler pre-computes a validity mask over all `(env, time)` positions. A window at `(e, t)` is valid only if `dones[e, t+1:t+H-1]` are all `False`. + +**Return weighting**: valid windows are weighted by their cumulative reward, normalised by the batch mean and clipped to `[0.1, RETURN_WEIGHT_CAP]`. Weights are passed as per-sample multipliers into the MDLM loss before reduction, so they correctly scale each sample's gradient contribution. + +**LR schedule**: cosine decay from `lr` to `lr * 0.1` over all gradient steps. Set `lr_warmup_frames > 0` (env-frame-invariant, PRIMARY) or `lr_warmup_steps > 0` (LEGACY) to prepend a linear warm-up phase. + +**Env-frame-invariant hyperparameters**: the PRIMARY keys `offline_total_timesteps`, `online_total_timesteps`, `lr_warmup_frames`, `val_interval_frames`, `dagger_beta_final`, and `dagger_buffer_cycles` are denominated in env frames (or update cycles). At config load time, `resolve_scaled_hyperparams()` in `src/planners/common.py` converts them to the equivalent update-step-denominated quantities (`num_updates`, `lr_warmup_steps`, `val_interval`, `dagger_beta_decay`, `dagger_buffer_max`) using the current `num_envs * num_steps` frames-per-update. This lets the same config run on different hardware tiers without re-tuning. + +**Loss weight clipping**: the MDLM SUBS weight `-alpha'(t) / (1 - alpha_t)` is clipped to 1000 to prevent numerical instability when `alpha_t ≈ 1`. + +**Validation rollouts**: during offline training, a held-out rollout runs every `val_interval` steps. It uses the same sampling parameters as inference (`remask_strategy`, `eta`, `use_loop`, `t_on`, `t_off`, `temperature`, `top_p`) with `val_diffusion_steps` denoising steps and `val_replan_every` env steps per plan, for a total of `val_steps` environment steps. + +**W&B logging**: all metric aggregation is centralised in `src/planners/logging.py`. Metric namespaces: `diffusion/` (loss, accuracy), `train/` (data quality, throughput), `env/` (episode returns, achievements), `val/` (validation rollouts, emitted every `val_interval` steps), `dagger/` (online DAgger training: beta, buffer fill, reward mean, valid fraction). `train/sps` (environment frames/sec) is only logged in modes that perform live environment interaction. + +**DAgger dataset aggregation**: online training (`--mode online`) implements DAgger (Ross et al., 2011). A circular replay buffer accumulates `(obs, expert_plan)` pairs across all iterations. Each update samples uniformly from the full buffer, not just the latest batch. Training samples that cross episode boundaries (any `done` within the plan-horizon window) are marked invalid. The expert (PPO agent) receives correct `done` flags so its RNN hidden state resets on episode boundaries. Windows are extracted with a sliding stride (one per env-time position) rather than stepping the buffer in plan-horizon chunks, so every visited state contributes a label. + +**Best-checkpoint tracking**: during online training, the parameters from the validation iteration with the highest validation return are preserved alongside the current live parameters. The final checkpoint and the best-validation checkpoint are both uploaded as separate W&B artifacts (`{env_name}-policy` and `{env_name}-policy-best`). + +**Denoising step indexing**: the reverse scan runs from `step_idx = 0` to `T-1`, mapping to diffusion time `t = (T - step_idx) / T` (high noise to low noise). + +**Submodule PPO agents**: PPO training lives entirely in `Craftax_Baselines/`. Planner scripts only consume pre-trained checkpoints via `--ppo_checkpoint_path`. diff --git a/checkpoints/offline/Craftax-Classic-Symbolic-v1-OfflineDiffusion-BC-100M/100000000/_CHECKPOINT_METADATA b/checkpoints/offline/Craftax-Classic-Symbolic-v1-OfflineDiffusion-BC-100M/100000000/_CHECKPOINT_METADATA new file mode 100644 index 0000000000000000000000000000000000000000..af9deeec510ac5142d3a4f9cdbae10bb10eb761a --- /dev/null +++ b/checkpoints/offline/Craftax-Classic-Symbolic-v1-OfflineDiffusion-BC-100M/100000000/_CHECKPOINT_METADATA @@ -0,0 +1 @@ +{"item_handlers": {"default": "orbax.checkpoint._src.handlers.standard_checkpoint_handler.StandardCheckpointHandler"}, "metrics": {}, "performance_metrics": {}, "init_timestamp_nsecs": 1775663434533263974, "commit_timestamp_nsecs": 1775663435644779625, "custom_metadata": {}} \ No newline at end of file diff --git a/checkpoints/offline/Craftax-Classic-Symbolic-v1-OfflineDiffusion-BC-100M/100000000/default/_METADATA b/checkpoints/offline/Craftax-Classic-Symbolic-v1-OfflineDiffusion-BC-100M/100000000/default/_METADATA new file mode 100644 index 0000000000000000000000000000000000000000..8d64d23ddc877e9c3a263d5a2d60556a82bc8ff2 --- /dev/null +++ b/checkpoints/offline/Craftax-Classic-Symbolic-v1-OfflineDiffusion-BC-100M/100000000/default/_METADATA @@ -0,0 +1 @@ +{"tree_metadata": {"('step',)": {"key_metadata": [{"key": "step", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": []}}, "('params', 'params', 'Dense_0', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [768]}}, "('params', 'params', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1345, 768]}}, "('params', 'params', 'Dense_1', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [768]}}, "('params', 'params', 'Dense_1', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [768, 768]}}, "('params', 'params', 'Dense_2', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('params', 'params', 'Dense_2', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_2", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [768, 384]}}, "('params', 'params', 'Dense_3', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_3", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('params', 'params', 'Dense_3', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_3", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 384]}}, "('params', 'params', 'Dense_4', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_4", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('params', 'params', 'Dense_4', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_4", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 384]}}, "('params', 'params', 'Dense_5', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_5", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [17]}}, "('params', 'params', 'Dense_5', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_5", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 17]}}, "('params', 'params', 'Embed_0', 'embedding')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Embed_0", "key_type": 2}, {"key": "embedding", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [18, 384]}}, "('params', 'params', 'LayerNorm_0', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "LayerNorm_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [768]}}, "('params', 'params', 'LayerNorm_0', 'scale')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "LayerNorm_0", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [768]}}, "('params', 'params', 'LayerNorm_1', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "LayerNorm_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('params', 'params', 'LayerNorm_1', 'scale')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "LayerNorm_1", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('params', 'params', 'TransformerBlock_0', 'Dense_0', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_0", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [768]}}, "('params', 'params', 'TransformerBlock_0', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_0", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 768]}}, "('params', 'params', 'TransformerBlock_0', 'Dense_1', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_0", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('params', 'params', 'TransformerBlock_0', 'Dense_1', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_0", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [768, 384]}}, "('params', 'params', 'TransformerBlock_0', 'LayerNorm_0', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_0", "key_type": 2}, {"key": "LayerNorm_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('params', 'params', 'TransformerBlock_0', 'LayerNorm_0', 'scale')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_0", "key_type": 2}, {"key": "LayerNorm_0", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('params', 'params', 'TransformerBlock_0', 'LayerNorm_1', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_0", "key_type": 2}, {"key": "LayerNorm_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('params', 'params', 'TransformerBlock_0', 'LayerNorm_1', 'scale')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_0", "key_type": 2}, {"key": "LayerNorm_1", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('params', 'params', 'TransformerBlock_0', 'MultiHeadDotProductAttention_0', 'key', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_0", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "key", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 48]}}, "('params', 'params', 'TransformerBlock_0', 'MultiHeadDotProductAttention_0', 'key', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_0", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "key", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 8, 48]}}, "('params', 'params', 'TransformerBlock_0', 'MultiHeadDotProductAttention_0', 'out', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_0", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "out", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('params', 'params', 'TransformerBlock_0', 'MultiHeadDotProductAttention_0', 'out', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_0", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "out", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 48, 384]}}, "('params', 'params', 'TransformerBlock_0', 'MultiHeadDotProductAttention_0', 'query', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_0", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "query", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 48]}}, "('params', 'params', 'TransformerBlock_0', 'MultiHeadDotProductAttention_0', 'query', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_0", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "query", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 8, 48]}}, "('params', 'params', 'TransformerBlock_0', 'MultiHeadDotProductAttention_0', 'value', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_0", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "value", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 48]}}, "('params', 'params', 'TransformerBlock_0', 'MultiHeadDotProductAttention_0', 'value', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_0", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "value", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 8, 48]}}, "('params', 'params', 'TransformerBlock_1', 'Dense_0', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_1", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [768]}}, "('params', 'params', 'TransformerBlock_1', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_1", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 768]}}, "('params', 'params', 'TransformerBlock_1', 'Dense_1', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_1", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('params', 'params', 'TransformerBlock_1', 'Dense_1', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_1", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [768, 384]}}, "('params', 'params', 'TransformerBlock_1', 'LayerNorm_0', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_1", "key_type": 2}, {"key": "LayerNorm_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('params', 'params', 'TransformerBlock_1', 'LayerNorm_0', 'scale')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_1", "key_type": 2}, {"key": "LayerNorm_0", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('params', 'params', 'TransformerBlock_1', 'LayerNorm_1', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_1", "key_type": 2}, {"key": "LayerNorm_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('params', 'params', 'TransformerBlock_1', 'LayerNorm_1', 'scale')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_1", "key_type": 2}, {"key": "LayerNorm_1", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('params', 'params', 'TransformerBlock_1', 'MultiHeadDotProductAttention_0', 'key', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_1", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "key", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 48]}}, "('params', 'params', 'TransformerBlock_1', 'MultiHeadDotProductAttention_0', 'key', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_1", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "key", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 8, 48]}}, "('params', 'params', 'TransformerBlock_1', 'MultiHeadDotProductAttention_0', 'out', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_1", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "out", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('params', 'params', 'TransformerBlock_1', 'MultiHeadDotProductAttention_0', 'out', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_1", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "out", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 48, 384]}}, "('params', 'params', 'TransformerBlock_1', 'MultiHeadDotProductAttention_0', 'query', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_1", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "query", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 48]}}, "('params', 'params', 'TransformerBlock_1', 'MultiHeadDotProductAttention_0', 'query', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_1", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "query", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 8, 48]}}, "('params', 'params', 'TransformerBlock_1', 'MultiHeadDotProductAttention_0', 'value', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_1", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "value", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 48]}}, "('params', 'params', 'TransformerBlock_1', 'MultiHeadDotProductAttention_0', 'value', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_1", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "value", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 8, 48]}}, "('params', 'params', 'TransformerBlock_2', 'Dense_0', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_2", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [768]}}, "('params', 'params', 'TransformerBlock_2', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_2", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 768]}}, "('params', 'params', 'TransformerBlock_2', 'Dense_1', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_2", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('params', 'params', 'TransformerBlock_2', 'Dense_1', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_2", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [768, 384]}}, "('params', 'params', 'TransformerBlock_2', 'LayerNorm_0', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_2", "key_type": 2}, {"key": "LayerNorm_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('params', 'params', 'TransformerBlock_2', 'LayerNorm_0', 'scale')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_2", "key_type": 2}, {"key": "LayerNorm_0", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('params', 'params', 'TransformerBlock_2', 'LayerNorm_1', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_2", "key_type": 2}, {"key": "LayerNorm_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('params', 'params', 'TransformerBlock_2', 'LayerNorm_1', 'scale')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_2", "key_type": 2}, {"key": "LayerNorm_1", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('params', 'params', 'TransformerBlock_2', 'MultiHeadDotProductAttention_0', 'key', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_2", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "key", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 48]}}, "('params', 'params', 'TransformerBlock_2', 'MultiHeadDotProductAttention_0', 'key', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_2", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "key", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 8, 48]}}, "('params', 'params', 'TransformerBlock_2', 'MultiHeadDotProductAttention_0', 'out', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_2", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "out", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('params', 'params', 'TransformerBlock_2', 'MultiHeadDotProductAttention_0', 'out', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_2", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "out", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 48, 384]}}, "('params', 'params', 'TransformerBlock_2', 'MultiHeadDotProductAttention_0', 'query', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_2", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "query", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 48]}}, "('params', 'params', 'TransformerBlock_2', 'MultiHeadDotProductAttention_0', 'query', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_2", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "query", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 8, 48]}}, "('params', 'params', 'TransformerBlock_2', 'MultiHeadDotProductAttention_0', 'value', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_2", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "value", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 48]}}, "('params', 'params', 'TransformerBlock_2', 'MultiHeadDotProductAttention_0', 'value', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_2", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "value", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 8, 48]}}, "('params', 'params', 'TransformerBlock_3', 'Dense_0', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_3", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [768]}}, "('params', 'params', 'TransformerBlock_3', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_3", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 768]}}, "('params', 'params', 'TransformerBlock_3', 'Dense_1', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_3", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('params', 'params', 'TransformerBlock_3', 'Dense_1', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_3", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [768, 384]}}, "('params', 'params', 'TransformerBlock_3', 'LayerNorm_0', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_3", "key_type": 2}, {"key": "LayerNorm_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('params', 'params', 'TransformerBlock_3', 'LayerNorm_0', 'scale')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_3", "key_type": 2}, {"key": "LayerNorm_0", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('params', 'params', 'TransformerBlock_3', 'LayerNorm_1', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_3", "key_type": 2}, {"key": "LayerNorm_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('params', 'params', 'TransformerBlock_3', 'LayerNorm_1', 'scale')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_3", "key_type": 2}, {"key": "LayerNorm_1", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('params', 'params', 'TransformerBlock_3', 'MultiHeadDotProductAttention_0', 'key', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_3", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "key", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 48]}}, "('params', 'params', 'TransformerBlock_3', 'MultiHeadDotProductAttention_0', 'key', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_3", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "key", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 8, 48]}}, "('params', 'params', 'TransformerBlock_3', 'MultiHeadDotProductAttention_0', 'out', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_3", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "out", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('params', 'params', 'TransformerBlock_3', 'MultiHeadDotProductAttention_0', 'out', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_3", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "out", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 48, 384]}}, "('params', 'params', 'TransformerBlock_3', 'MultiHeadDotProductAttention_0', 'query', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_3", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "query", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 48]}}, "('params', 'params', 'TransformerBlock_3', 'MultiHeadDotProductAttention_0', 'query', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_3", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "query", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 8, 48]}}, "('params', 'params', 'TransformerBlock_3', 'MultiHeadDotProductAttention_0', 'value', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_3", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "value", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 48]}}, "('params', 'params', 'TransformerBlock_3', 'MultiHeadDotProductAttention_0', 'value', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_3", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "value", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 8, 48]}}, "('params', 'params', 'TransformerBlock_4', 'Dense_0', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_4", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [768]}}, "('params', 'params', 'TransformerBlock_4', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_4", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 768]}}, "('params', 'params', 'TransformerBlock_4', 'Dense_1', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_4", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('params', 'params', 'TransformerBlock_4', 'Dense_1', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_4", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [768, 384]}}, "('params', 'params', 'TransformerBlock_4', 'LayerNorm_0', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_4", "key_type": 2}, {"key": "LayerNorm_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('params', 'params', 'TransformerBlock_4', 'LayerNorm_0', 'scale')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_4", "key_type": 2}, {"key": "LayerNorm_0", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('params', 'params', 'TransformerBlock_4', 'LayerNorm_1', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_4", "key_type": 2}, {"key": "LayerNorm_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('params', 'params', 'TransformerBlock_4', 'LayerNorm_1', 'scale')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_4", "key_type": 2}, {"key": "LayerNorm_1", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('params', 'params', 'TransformerBlock_4', 'MultiHeadDotProductAttention_0', 'key', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_4", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "key", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 48]}}, "('params', 'params', 'TransformerBlock_4', 'MultiHeadDotProductAttention_0', 'key', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_4", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "key", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 8, 48]}}, "('params', 'params', 'TransformerBlock_4', 'MultiHeadDotProductAttention_0', 'out', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_4", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "out", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('params', 'params', 'TransformerBlock_4', 'MultiHeadDotProductAttention_0', 'out', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_4", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "out", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 48, 384]}}, "('params', 'params', 'TransformerBlock_4', 'MultiHeadDotProductAttention_0', 'query', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_4", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "query", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 48]}}, "('params', 'params', 'TransformerBlock_4', 'MultiHeadDotProductAttention_0', 'query', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_4", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "query", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 8, 48]}}, "('params', 'params', 'TransformerBlock_4', 'MultiHeadDotProductAttention_0', 'value', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_4", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "value", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 48]}}, "('params', 'params', 'TransformerBlock_4', 'MultiHeadDotProductAttention_0', 'value', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_4", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "value", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 8, 48]}}, "('params', 'params', 'TransformerBlock_5', 'Dense_0', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_5", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [768]}}, "('params', 'params', 'TransformerBlock_5', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_5", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 768]}}, "('params', 'params', 'TransformerBlock_5', 'Dense_1', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_5", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('params', 'params', 'TransformerBlock_5', 'Dense_1', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_5", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [768, 384]}}, "('params', 'params', 'TransformerBlock_5', 'LayerNorm_0', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_5", "key_type": 2}, {"key": "LayerNorm_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('params', 'params', 'TransformerBlock_5', 'LayerNorm_0', 'scale')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_5", "key_type": 2}, {"key": "LayerNorm_0", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('params', 'params', 'TransformerBlock_5', 'LayerNorm_1', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_5", "key_type": 2}, {"key": "LayerNorm_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('params', 'params', 'TransformerBlock_5', 'LayerNorm_1', 'scale')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_5", "key_type": 2}, {"key": "LayerNorm_1", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('params', 'params', 'TransformerBlock_5', 'MultiHeadDotProductAttention_0', 'key', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_5", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "key", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 48]}}, "('params', 'params', 'TransformerBlock_5', 'MultiHeadDotProductAttention_0', 'key', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_5", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "key", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 8, 48]}}, "('params', 'params', 'TransformerBlock_5', 'MultiHeadDotProductAttention_0', 'out', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_5", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "out", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('params', 'params', 'TransformerBlock_5', 'MultiHeadDotProductAttention_0', 'out', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_5", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "out", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 48, 384]}}, "('params', 'params', 'TransformerBlock_5', 'MultiHeadDotProductAttention_0', 'query', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_5", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "query", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 48]}}, "('params', 'params', 'TransformerBlock_5', 'MultiHeadDotProductAttention_0', 'query', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_5", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "query", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 8, 48]}}, "('params', 'params', 'TransformerBlock_5', 'MultiHeadDotProductAttention_0', 'value', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_5", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "value", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 48]}}, "('params', 'params', 'TransformerBlock_5', 'MultiHeadDotProductAttention_0', 'value', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_5", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "value", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 8, 48]}}, "('opt_state', '0')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "0", "key_type": 1}], "value_metadata": {"value_type": "None", "skip_deserialize": true}}, "('opt_state', '1', '0', 'count')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "count", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": []}}, "('opt_state', '1', '0', 'mu', 'params', 'Dense_0', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [768]}}, "('opt_state', '1', '0', 'mu', 'params', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1345, 768]}}, "('opt_state', '1', '0', 'mu', 'params', 'Dense_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [768]}}, "('opt_state', '1', '0', 'mu', 'params', 'Dense_1', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [768, 768]}}, "('opt_state', '1', '0', 'mu', 'params', 'Dense_2', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'mu', 'params', 'Dense_2', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_2", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [768, 384]}}, "('opt_state', '1', '0', 'mu', 'params', 'Dense_3', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_3", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'mu', 'params', 'Dense_3', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_3", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 384]}}, "('opt_state', '1', '0', 'mu', 'params', 'Dense_4', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_4", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'mu', 'params', 'Dense_4', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_4", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 384]}}, "('opt_state', '1', '0', 'mu', 'params', 'Dense_5', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_5", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [17]}}, "('opt_state', '1', '0', 'mu', 'params', 'Dense_5', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_5", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 17]}}, "('opt_state', '1', '0', 'mu', 'params', 'Embed_0', 'embedding')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Embed_0", "key_type": 2}, {"key": "embedding", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [18, 384]}}, "('opt_state', '1', '0', 'mu', 'params', 'LayerNorm_0', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "LayerNorm_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [768]}}, "('opt_state', '1', '0', 'mu', 'params', 'LayerNorm_0', 'scale')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "LayerNorm_0", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [768]}}, "('opt_state', '1', '0', 'mu', 'params', 'LayerNorm_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "LayerNorm_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'mu', 'params', 'LayerNorm_1', 'scale')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "LayerNorm_1", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_0', 'Dense_0', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_0", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [768]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_0', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_0", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 768]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_0', 'Dense_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_0", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_0', 'Dense_1', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_0", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [768, 384]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_0', 'LayerNorm_0', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_0", "key_type": 2}, {"key": "LayerNorm_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_0', 'LayerNorm_0', 'scale')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_0", "key_type": 2}, {"key": "LayerNorm_0", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_0', 'LayerNorm_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_0", "key_type": 2}, {"key": "LayerNorm_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_0', 'LayerNorm_1', 'scale')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_0", "key_type": 2}, {"key": "LayerNorm_1", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_0', 'MultiHeadDotProductAttention_0', 'key', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_0", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "key", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 48]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_0', 'MultiHeadDotProductAttention_0', 'key', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_0", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "key", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 8, 48]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_0', 'MultiHeadDotProductAttention_0', 'out', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_0", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "out", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_0', 'MultiHeadDotProductAttention_0', 'out', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_0", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "out", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 48, 384]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_0', 'MultiHeadDotProductAttention_0', 'query', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_0", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "query", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 48]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_0', 'MultiHeadDotProductAttention_0', 'query', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_0", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "query", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 8, 48]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_0', 'MultiHeadDotProductAttention_0', 'value', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_0", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "value", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 48]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_0', 'MultiHeadDotProductAttention_0', 'value', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_0", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "value", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 8, 48]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_1', 'Dense_0', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_1", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [768]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_1', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_1", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 768]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_1', 'Dense_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_1", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_1', 'Dense_1', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_1", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [768, 384]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_1', 'LayerNorm_0', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_1", "key_type": 2}, {"key": "LayerNorm_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_1', 'LayerNorm_0', 'scale')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_1", "key_type": 2}, {"key": "LayerNorm_0", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_1', 'LayerNorm_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_1", "key_type": 2}, {"key": "LayerNorm_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_1', 'LayerNorm_1', 'scale')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_1", "key_type": 2}, {"key": "LayerNorm_1", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_1', 'MultiHeadDotProductAttention_0', 'key', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_1", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "key", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 48]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_1', 'MultiHeadDotProductAttention_0', 'key', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_1", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "key", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 8, 48]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_1', 'MultiHeadDotProductAttention_0', 'out', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_1", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "out", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_1', 'MultiHeadDotProductAttention_0', 'out', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_1", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "out", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 48, 384]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_1', 'MultiHeadDotProductAttention_0', 'query', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_1", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "query", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 48]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_1', 'MultiHeadDotProductAttention_0', 'query', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_1", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "query", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 8, 48]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_1', 'MultiHeadDotProductAttention_0', 'value', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_1", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "value", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 48]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_1', 'MultiHeadDotProductAttention_0', 'value', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_1", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "value", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 8, 48]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_2', 'Dense_0', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_2", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [768]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_2', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_2", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 768]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_2', 'Dense_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_2", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_2', 'Dense_1', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_2", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [768, 384]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_2', 'LayerNorm_0', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_2", "key_type": 2}, {"key": "LayerNorm_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_2', 'LayerNorm_0', 'scale')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_2", "key_type": 2}, {"key": "LayerNorm_0", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_2', 'LayerNorm_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_2", "key_type": 2}, {"key": "LayerNorm_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_2', 'LayerNorm_1', 'scale')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_2", "key_type": 2}, {"key": "LayerNorm_1", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_2', 'MultiHeadDotProductAttention_0', 'key', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_2", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "key", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 48]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_2', 'MultiHeadDotProductAttention_0', 'key', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_2", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "key", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 8, 48]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_2', 'MultiHeadDotProductAttention_0', 'out', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_2", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "out", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_2', 'MultiHeadDotProductAttention_0', 'out', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_2", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "out", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 48, 384]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_2', 'MultiHeadDotProductAttention_0', 'query', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_2", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "query", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 48]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_2', 'MultiHeadDotProductAttention_0', 'query', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_2", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "query", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 8, 48]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_2', 'MultiHeadDotProductAttention_0', 'value', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_2", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "value", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 48]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_2', 'MultiHeadDotProductAttention_0', 'value', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_2", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "value", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 8, 48]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_3', 'Dense_0', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_3", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [768]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_3', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_3", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 768]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_3', 'Dense_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_3", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_3', 'Dense_1', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_3", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [768, 384]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_3', 'LayerNorm_0', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_3", "key_type": 2}, {"key": "LayerNorm_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_3', 'LayerNorm_0', 'scale')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_3", "key_type": 2}, {"key": "LayerNorm_0", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_3', 'LayerNorm_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_3", "key_type": 2}, {"key": "LayerNorm_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_3', 'LayerNorm_1', 'scale')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_3", "key_type": 2}, {"key": "LayerNorm_1", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_3', 'MultiHeadDotProductAttention_0', 'key', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_3", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "key", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 48]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_3', 'MultiHeadDotProductAttention_0', 'key', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_3", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "key", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 8, 48]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_3', 'MultiHeadDotProductAttention_0', 'out', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_3", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "out", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_3', 'MultiHeadDotProductAttention_0', 'out', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_3", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "out", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 48, 384]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_3', 'MultiHeadDotProductAttention_0', 'query', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_3", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "query", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 48]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_3', 'MultiHeadDotProductAttention_0', 'query', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_3", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "query", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 8, 48]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_3', 'MultiHeadDotProductAttention_0', 'value', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_3", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "value", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 48]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_3', 'MultiHeadDotProductAttention_0', 'value', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_3", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "value", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 8, 48]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_4', 'Dense_0', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_4", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [768]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_4', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_4", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 768]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_4', 'Dense_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_4", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_4', 'Dense_1', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_4", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [768, 384]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_4', 'LayerNorm_0', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_4", "key_type": 2}, {"key": "LayerNorm_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_4', 'LayerNorm_0', 'scale')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_4", "key_type": 2}, {"key": "LayerNorm_0", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_4', 'LayerNorm_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_4", "key_type": 2}, {"key": "LayerNorm_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_4', 'LayerNorm_1', 'scale')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_4", "key_type": 2}, {"key": "LayerNorm_1", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_4', 'MultiHeadDotProductAttention_0', 'key', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_4", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "key", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 48]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_4', 'MultiHeadDotProductAttention_0', 'key', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_4", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "key", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 8, 48]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_4', 'MultiHeadDotProductAttention_0', 'out', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_4", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "out", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_4', 'MultiHeadDotProductAttention_0', 'out', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_4", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "out", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 48, 384]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_4', 'MultiHeadDotProductAttention_0', 'query', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_4", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "query", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 48]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_4', 'MultiHeadDotProductAttention_0', 'query', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_4", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "query", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 8, 48]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_4', 'MultiHeadDotProductAttention_0', 'value', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_4", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "value", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 48]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_4', 'MultiHeadDotProductAttention_0', 'value', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_4", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "value", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 8, 48]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_5', 'Dense_0', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_5", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [768]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_5', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_5", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 768]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_5', 'Dense_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_5", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_5', 'Dense_1', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_5", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [768, 384]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_5', 'LayerNorm_0', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_5", "key_type": 2}, {"key": "LayerNorm_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_5', 'LayerNorm_0', 'scale')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_5", "key_type": 2}, {"key": "LayerNorm_0", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_5', 'LayerNorm_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_5", "key_type": 2}, {"key": "LayerNorm_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_5', 'LayerNorm_1', 'scale')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_5", "key_type": 2}, {"key": "LayerNorm_1", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_5', 'MultiHeadDotProductAttention_0', 'key', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_5", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "key", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 48]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_5', 'MultiHeadDotProductAttention_0', 'key', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_5", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "key", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 8, 48]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_5', 'MultiHeadDotProductAttention_0', 'out', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_5", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "out", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_5', 'MultiHeadDotProductAttention_0', 'out', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_5", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "out", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 48, 384]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_5', 'MultiHeadDotProductAttention_0', 'query', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_5", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "query", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 48]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_5', 'MultiHeadDotProductAttention_0', 'query', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_5", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "query", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 8, 48]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_5', 'MultiHeadDotProductAttention_0', 'value', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_5", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "value", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 48]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_5', 'MultiHeadDotProductAttention_0', 'value', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_5", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "value", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 8, 48]}}, "('opt_state', '1', '0', 'nu', 'params', 'Dense_0', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [768]}}, "('opt_state', '1', '0', 'nu', 'params', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1345, 768]}}, "('opt_state', '1', '0', 'nu', 'params', 'Dense_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [768]}}, "('opt_state', '1', '0', 'nu', 'params', 'Dense_1', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [768, 768]}}, "('opt_state', '1', '0', 'nu', 'params', 'Dense_2', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'nu', 'params', 'Dense_2', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_2", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [768, 384]}}, "('opt_state', '1', '0', 'nu', 'params', 'Dense_3', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_3", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'nu', 'params', 'Dense_3', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_3", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 384]}}, "('opt_state', '1', '0', 'nu', 'params', 'Dense_4', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_4", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'nu', 'params', 'Dense_4', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_4", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 384]}}, "('opt_state', '1', '0', 'nu', 'params', 'Dense_5', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_5", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [17]}}, "('opt_state', '1', '0', 'nu', 'params', 'Dense_5', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_5", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 17]}}, "('opt_state', '1', '0', 'nu', 'params', 'Embed_0', 'embedding')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Embed_0", "key_type": 2}, {"key": "embedding", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [18, 384]}}, "('opt_state', '1', '0', 'nu', 'params', 'LayerNorm_0', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "LayerNorm_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [768]}}, "('opt_state', '1', '0', 'nu', 'params', 'LayerNorm_0', 'scale')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "LayerNorm_0", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [768]}}, "('opt_state', '1', '0', 'nu', 'params', 'LayerNorm_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "LayerNorm_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'nu', 'params', 'LayerNorm_1', 'scale')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "LayerNorm_1", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_0', 'Dense_0', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_0", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [768]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_0', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_0", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 768]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_0', 'Dense_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_0", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_0', 'Dense_1', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_0", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [768, 384]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_0', 'LayerNorm_0', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_0", "key_type": 2}, {"key": "LayerNorm_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_0', 'LayerNorm_0', 'scale')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_0", "key_type": 2}, {"key": "LayerNorm_0", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_0', 'LayerNorm_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_0", "key_type": 2}, {"key": "LayerNorm_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_0', 'LayerNorm_1', 'scale')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_0", "key_type": 2}, {"key": "LayerNorm_1", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_0', 'MultiHeadDotProductAttention_0', 'key', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_0", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "key", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 48]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_0', 'MultiHeadDotProductAttention_0', 'key', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_0", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "key", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 8, 48]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_0', 'MultiHeadDotProductAttention_0', 'out', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_0", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "out", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_0', 'MultiHeadDotProductAttention_0', 'out', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_0", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "out", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 48, 384]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_0', 'MultiHeadDotProductAttention_0', 'query', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_0", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "query", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 48]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_0', 'MultiHeadDotProductAttention_0', 'query', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_0", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "query", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 8, 48]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_0', 'MultiHeadDotProductAttention_0', 'value', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_0", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "value", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 48]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_0', 'MultiHeadDotProductAttention_0', 'value', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_0", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "value", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 8, 48]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_1', 'Dense_0', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_1", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [768]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_1', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_1", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 768]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_1', 'Dense_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_1", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_1', 'Dense_1', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_1", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [768, 384]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_1', 'LayerNorm_0', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_1", "key_type": 2}, {"key": "LayerNorm_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_1', 'LayerNorm_0', 'scale')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_1", "key_type": 2}, {"key": "LayerNorm_0", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_1', 'LayerNorm_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_1", "key_type": 2}, {"key": "LayerNorm_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_1', 'LayerNorm_1', 'scale')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_1", "key_type": 2}, {"key": "LayerNorm_1", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_1', 'MultiHeadDotProductAttention_0', 'key', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_1", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "key", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 48]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_1', 'MultiHeadDotProductAttention_0', 'key', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_1", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "key", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 8, 48]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_1', 'MultiHeadDotProductAttention_0', 'out', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_1", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "out", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_1', 'MultiHeadDotProductAttention_0', 'out', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_1", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "out", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 48, 384]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_1', 'MultiHeadDotProductAttention_0', 'query', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_1", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "query", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 48]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_1', 'MultiHeadDotProductAttention_0', 'query', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_1", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "query", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 8, 48]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_1', 'MultiHeadDotProductAttention_0', 'value', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_1", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "value", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 48]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_1', 'MultiHeadDotProductAttention_0', 'value', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_1", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "value", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 8, 48]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_2', 'Dense_0', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_2", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [768]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_2', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_2", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 768]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_2', 'Dense_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_2", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_2', 'Dense_1', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_2", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [768, 384]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_2', 'LayerNorm_0', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_2", "key_type": 2}, {"key": "LayerNorm_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_2', 'LayerNorm_0', 'scale')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_2", "key_type": 2}, {"key": "LayerNorm_0", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_2', 'LayerNorm_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_2", "key_type": 2}, {"key": "LayerNorm_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_2', 'LayerNorm_1', 'scale')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_2", "key_type": 2}, {"key": "LayerNorm_1", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_2', 'MultiHeadDotProductAttention_0', 'key', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_2", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "key", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 48]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_2', 'MultiHeadDotProductAttention_0', 'key', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_2", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "key", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 8, 48]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_2', 'MultiHeadDotProductAttention_0', 'out', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_2", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "out", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_2', 'MultiHeadDotProductAttention_0', 'out', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_2", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "out", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 48, 384]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_2', 'MultiHeadDotProductAttention_0', 'query', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_2", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "query", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 48]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_2', 'MultiHeadDotProductAttention_0', 'query', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_2", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "query", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 8, 48]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_2', 'MultiHeadDotProductAttention_0', 'value', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_2", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "value", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 48]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_2', 'MultiHeadDotProductAttention_0', 'value', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_2", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "value", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 8, 48]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_3', 'Dense_0', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_3", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [768]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_3', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_3", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 768]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_3', 'Dense_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_3", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_3', 'Dense_1', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_3", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [768, 384]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_3', 'LayerNorm_0', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_3", "key_type": 2}, {"key": "LayerNorm_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_3', 'LayerNorm_0', 'scale')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_3", "key_type": 2}, {"key": "LayerNorm_0", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_3', 'LayerNorm_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_3", "key_type": 2}, {"key": "LayerNorm_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_3', 'LayerNorm_1', 'scale')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_3", "key_type": 2}, {"key": "LayerNorm_1", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_3', 'MultiHeadDotProductAttention_0', 'key', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_3", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "key", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 48]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_3', 'MultiHeadDotProductAttention_0', 'key', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_3", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "key", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 8, 48]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_3', 'MultiHeadDotProductAttention_0', 'out', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_3", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "out", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_3', 'MultiHeadDotProductAttention_0', 'out', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_3", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "out", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 48, 384]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_3', 'MultiHeadDotProductAttention_0', 'query', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_3", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "query", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 48]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_3', 'MultiHeadDotProductAttention_0', 'query', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_3", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "query", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 8, 48]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_3', 'MultiHeadDotProductAttention_0', 'value', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_3", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "value", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 48]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_3', 'MultiHeadDotProductAttention_0', 'value', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_3", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "value", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 8, 48]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_4', 'Dense_0', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_4", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [768]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_4', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_4", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 768]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_4', 'Dense_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_4", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_4', 'Dense_1', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_4", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [768, 384]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_4', 'LayerNorm_0', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_4", "key_type": 2}, {"key": "LayerNorm_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_4', 'LayerNorm_0', 'scale')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_4", "key_type": 2}, {"key": "LayerNorm_0", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_4', 'LayerNorm_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_4", "key_type": 2}, {"key": "LayerNorm_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_4', 'LayerNorm_1', 'scale')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_4", "key_type": 2}, {"key": "LayerNorm_1", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_4', 'MultiHeadDotProductAttention_0', 'key', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_4", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "key", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 48]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_4', 'MultiHeadDotProductAttention_0', 'key', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_4", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "key", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 8, 48]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_4', 'MultiHeadDotProductAttention_0', 'out', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_4", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "out", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_4', 'MultiHeadDotProductAttention_0', 'out', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_4", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "out", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 48, 384]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_4', 'MultiHeadDotProductAttention_0', 'query', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_4", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "query", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 48]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_4', 'MultiHeadDotProductAttention_0', 'query', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_4", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "query", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 8, 48]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_4', 'MultiHeadDotProductAttention_0', 'value', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_4", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "value", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 48]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_4', 'MultiHeadDotProductAttention_0', 'value', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_4", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "value", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 8, 48]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_5', 'Dense_0', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_5", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [768]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_5', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_5", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 768]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_5', 'Dense_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_5", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_5', 'Dense_1', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_5", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [768, 384]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_5', 'LayerNorm_0', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_5", "key_type": 2}, {"key": "LayerNorm_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_5', 'LayerNorm_0', 'scale')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_5", "key_type": 2}, {"key": "LayerNorm_0", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_5', 'LayerNorm_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_5", "key_type": 2}, {"key": "LayerNorm_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_5', 'LayerNorm_1', 'scale')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_5", "key_type": 2}, {"key": "LayerNorm_1", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_5', 'MultiHeadDotProductAttention_0', 'key', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_5", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "key", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 48]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_5', 'MultiHeadDotProductAttention_0', 'key', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_5", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "key", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 8, 48]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_5', 'MultiHeadDotProductAttention_0', 'out', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_5", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "out", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_5', 'MultiHeadDotProductAttention_0', 'out', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_5", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "out", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 48, 384]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_5', 'MultiHeadDotProductAttention_0', 'query', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_5", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "query", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 48]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_5', 'MultiHeadDotProductAttention_0', 'query', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_5", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "query", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 8, 48]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_5', 'MultiHeadDotProductAttention_0', 'value', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_5", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "value", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 48]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_5', 'MultiHeadDotProductAttention_0', 'value', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_5", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "value", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 8, 48]}}, "('opt_state', '1', '1', 'count')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "1", "key_type": 1}, {"key": "count", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": []}}}, "use_ocdbt": true, "use_zarr3": false, "store_array_data_equal_to_fill_value": true, "custom_metadata": null} \ No newline at end of file diff --git a/checkpoints/offline/Craftax-Classic-Symbolic-v1-OfflineDiffusion-BC-100M/100000000/default/_sharding b/checkpoints/offline/Craftax-Classic-Symbolic-v1-OfflineDiffusion-BC-100M/100000000/default/_sharding new file mode 100644 index 0000000000000000000000000000000000000000..c3e2898e48ce0d01971669ae6a0dc6bd5228e39f --- /dev/null +++ b/checkpoints/offline/Craftax-Classic-Symbolic-v1-OfflineDiffusion-BC-100M/100000000/default/_sharding @@ -0,0 +1 @@ +{"b3B0X3N0YXRlLjEuMC5jb3VudA==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuRGVuc2VfMC5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuRGVuc2VfMC5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuRGVuc2VfMS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuRGVuc2VfMS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuRGVuc2VfMi5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuRGVuc2VfMi5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuRGVuc2VfMy5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuRGVuc2VfMy5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuRGVuc2VfNC5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuRGVuc2VfNC5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuRGVuc2VfNS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuRGVuc2VfNS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuRW1iZWRfMC5lbWJlZGRpbmc=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuTGF5ZXJOb3JtXzAuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuTGF5ZXJOb3JtXzAuc2NhbGU=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuTGF5ZXJOb3JtXzEuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuTGF5ZXJOb3JtXzEuc2NhbGU=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja180Lk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC52YWx1ZS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja180Lk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC52YWx1ZS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja180Lk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5rZXkuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja180Lk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5rZXkua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja180Lk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5vdXQuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja180Lk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5vdXQua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja180Lk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5xdWVyeS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja180Lk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5xdWVyeS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja180LkRlbnNlXzAuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja180LkRlbnNlXzAua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja180LkRlbnNlXzEuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja180LkRlbnNlXzEua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja180LkxheWVyTm9ybV8wLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja180LkxheWVyTm9ybV8wLnNjYWxl":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja180LkxheWVyTm9ybV8xLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja180LkxheWVyTm9ybV8xLnNjYWxl":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja181Lk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC52YWx1ZS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja181Lk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC52YWx1ZS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja181Lk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5rZXkuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja181Lk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5rZXkua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja181Lk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5vdXQuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja181Lk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5vdXQua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja181Lk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5xdWVyeS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja181Lk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5xdWVyeS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja181LkRlbnNlXzAuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja181LkRlbnNlXzAua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja181LkRlbnNlXzEuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja181LkRlbnNlXzEua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja181LkxheWVyTm9ybV8wLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja181LkxheWVyTm9ybV8wLnNjYWxl":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja181LkxheWVyTm9ybV8xLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja181LkxheWVyTm9ybV8xLnNjYWxl":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18wLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC52YWx1ZS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18wLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC52YWx1ZS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18wLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5rZXkuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18wLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5rZXkua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18wLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5vdXQuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18wLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5vdXQua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18wLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5xdWVyeS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18wLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5xdWVyeS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18wLkRlbnNlXzAuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18wLkRlbnNlXzAua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18wLkRlbnNlXzEuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18wLkRlbnNlXzEua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18wLkxheWVyTm9ybV8wLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18wLkxheWVyTm9ybV8wLnNjYWxl":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18wLkxheWVyTm9ybV8xLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18wLkxheWVyTm9ybV8xLnNjYWxl":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18xLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC52YWx1ZS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18xLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC52YWx1ZS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18xLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5rZXkuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18xLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5rZXkua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18xLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5vdXQuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18xLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5vdXQua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18xLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5xdWVyeS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18xLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5xdWVyeS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18xLkRlbnNlXzAuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18xLkRlbnNlXzAua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18xLkRlbnNlXzEuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18xLkRlbnNlXzEua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18xLkxheWVyTm9ybV8wLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18xLkxheWVyTm9ybV8wLnNjYWxl":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18xLkxheWVyTm9ybV8xLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18xLkxheWVyTm9ybV8xLnNjYWxl":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18yLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC52YWx1ZS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18yLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC52YWx1ZS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18yLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5rZXkuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18yLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5rZXkua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18yLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5vdXQuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18yLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5vdXQua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18yLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5xdWVyeS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18yLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5xdWVyeS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18yLkRlbnNlXzAuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18yLkRlbnNlXzAua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18yLkRlbnNlXzEuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18yLkRlbnNlXzEua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18yLkxheWVyTm9ybV8wLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18yLkxheWVyTm9ybV8wLnNjYWxl":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18yLkxheWVyTm9ybV8xLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18yLkxheWVyTm9ybV8xLnNjYWxl":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18zLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC52YWx1ZS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18zLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC52YWx1ZS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18zLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5rZXkuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18zLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5rZXkua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18zLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5vdXQuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18zLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5vdXQua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18zLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5xdWVyeS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18zLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5xdWVyeS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18zLkRlbnNlXzAuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18zLkRlbnNlXzAua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18zLkRlbnNlXzEuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18zLkRlbnNlXzEua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18zLkxheWVyTm9ybV8wLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18zLkxheWVyTm9ybV8wLnNjYWxl":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18zLkxheWVyTm9ybV8xLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18zLkxheWVyTm9ybV8xLnNjYWxl":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuRGVuc2VfMC5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuRGVuc2VfMC5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuRGVuc2VfMS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuRGVuc2VfMS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuRGVuc2VfMi5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuRGVuc2VfMi5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuRGVuc2VfMy5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuRGVuc2VfMy5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuRGVuc2VfNC5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuRGVuc2VfNC5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuRGVuc2VfNS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuRGVuc2VfNS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuRW1iZWRfMC5lbWJlZGRpbmc=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuTGF5ZXJOb3JtXzAuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuTGF5ZXJOb3JtXzAuc2NhbGU=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuTGF5ZXJOb3JtXzEuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuTGF5ZXJOb3JtXzEuc2NhbGU=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja180Lk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC52YWx1ZS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja180Lk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC52YWx1ZS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja180Lk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5rZXkuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja180Lk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5rZXkua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja180Lk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5vdXQuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja180Lk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5vdXQua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja180Lk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5xdWVyeS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja180Lk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5xdWVyeS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja180LkRlbnNlXzAuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja180LkRlbnNlXzAua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja180LkRlbnNlXzEuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja180LkRlbnNlXzEua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja180LkxheWVyTm9ybV8wLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja180LkxheWVyTm9ybV8wLnNjYWxl":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja180LkxheWVyTm9ybV8xLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja180LkxheWVyTm9ybV8xLnNjYWxl":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja181Lk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC52YWx1ZS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja181Lk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC52YWx1ZS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja181Lk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5rZXkuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja181Lk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5rZXkua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja181Lk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5vdXQuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja181Lk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5vdXQua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja181Lk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5xdWVyeS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja181Lk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5xdWVyeS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja181LkRlbnNlXzAuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja181LkRlbnNlXzAua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja181LkRlbnNlXzEuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja181LkRlbnNlXzEua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja181LkxheWVyTm9ybV8wLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja181LkxheWVyTm9ybV8wLnNjYWxl":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja181LkxheWVyTm9ybV8xLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja181LkxheWVyTm9ybV8xLnNjYWxl":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18wLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC52YWx1ZS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18wLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC52YWx1ZS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18wLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5rZXkuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18wLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5rZXkua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18wLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5vdXQuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18wLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5vdXQua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18wLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5xdWVyeS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18wLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5xdWVyeS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18wLkRlbnNlXzAuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18wLkRlbnNlXzAua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18wLkRlbnNlXzEuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18wLkRlbnNlXzEua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18wLkxheWVyTm9ybV8wLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18wLkxheWVyTm9ybV8wLnNjYWxl":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18wLkxheWVyTm9ybV8xLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18wLkxheWVyTm9ybV8xLnNjYWxl":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18xLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC52YWx1ZS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18xLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC52YWx1ZS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18xLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5rZXkuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18xLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5rZXkua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18xLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5vdXQuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18xLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5vdXQua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18xLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5xdWVyeS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18xLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5xdWVyeS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18xLkRlbnNlXzAuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18xLkRlbnNlXzAua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18xLkRlbnNlXzEuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18xLkRlbnNlXzEua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18xLkxheWVyTm9ybV8wLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18xLkxheWVyTm9ybV8wLnNjYWxl":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18xLkxheWVyTm9ybV8xLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18xLkxheWVyTm9ybV8xLnNjYWxl":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18yLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC52YWx1ZS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18yLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC52YWx1ZS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18yLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5rZXkuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18yLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5rZXkua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18yLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5vdXQuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18yLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5vdXQua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18yLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5xdWVyeS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18yLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5xdWVyeS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18yLkRlbnNlXzAuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18yLkRlbnNlXzAua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18yLkRlbnNlXzEuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18yLkRlbnNlXzEua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18yLkxheWVyTm9ybV8wLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18yLkxheWVyTm9ybV8wLnNjYWxl":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18yLkxheWVyTm9ybV8xLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18yLkxheWVyTm9ybV8xLnNjYWxl":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18zLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC52YWx1ZS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18zLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC52YWx1ZS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18zLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5rZXkuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18zLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5rZXkua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18zLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5vdXQuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18zLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5vdXQua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18zLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5xdWVyeS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18zLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5xdWVyeS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18zLkRlbnNlXzAuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18zLkRlbnNlXzAua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18zLkRlbnNlXzEuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18zLkRlbnNlXzEua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18zLkxheWVyTm9ybV8wLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18zLkxheWVyTm9ybV8wLnNjYWxl":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18zLkxheWVyTm9ybV8xLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18zLkxheWVyTm9ybV8xLnNjYWxl":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMS5jb3VudA==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","c3RlcA==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5EZW5zZV80LmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5EZW5zZV80Lmtlcm5lbA==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5EZW5zZV81LmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5EZW5zZV81Lmtlcm5lbA==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5EZW5zZV8wLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5EZW5zZV8wLmtlcm5lbA==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5EZW5zZV8xLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5EZW5zZV8xLmtlcm5lbA==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5EZW5zZV8yLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5EZW5zZV8yLmtlcm5lbA==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5EZW5zZV8zLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5EZW5zZV8zLmtlcm5lbA==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5FbWJlZF8wLmVtYmVkZGluZw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5MYXllck5vcm1fMC5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5MYXllck5vcm1fMC5zY2FsZQ==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5MYXllck5vcm1fMS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5MYXllck5vcm1fMS5zY2FsZQ==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzAuRGVuc2VfMC5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzAuRGVuc2VfMC5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzAuRGVuc2VfMS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzAuRGVuc2VfMS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzAuTGF5ZXJOb3JtXzAuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzAuTGF5ZXJOb3JtXzAuc2NhbGU=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzAuTGF5ZXJOb3JtXzEuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzAuTGF5ZXJOb3JtXzEuc2NhbGU=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzAuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLm91dC5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzAuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLm91dC5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzAuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLmtleS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzAuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLmtleS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzAuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLnF1ZXJ5LmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzAuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLnF1ZXJ5Lmtlcm5lbA==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzAuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLnZhbHVlLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzAuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLnZhbHVlLmtlcm5lbA==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzEuRGVuc2VfMC5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzEuRGVuc2VfMC5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzEuRGVuc2VfMS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzEuRGVuc2VfMS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzEuTGF5ZXJOb3JtXzAuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzEuTGF5ZXJOb3JtXzAuc2NhbGU=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzEuTGF5ZXJOb3JtXzEuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzEuTGF5ZXJOb3JtXzEuc2NhbGU=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzEuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLm91dC5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzEuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLm91dC5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzEuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLmtleS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzEuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLmtleS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzEuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLnF1ZXJ5LmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzEuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLnF1ZXJ5Lmtlcm5lbA==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzEuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLnZhbHVlLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzEuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLnZhbHVlLmtlcm5lbA==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzIuRGVuc2VfMC5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzIuRGVuc2VfMC5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzIuRGVuc2VfMS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzIuRGVuc2VfMS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzIuTGF5ZXJOb3JtXzAuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzIuTGF5ZXJOb3JtXzAuc2NhbGU=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzIuTGF5ZXJOb3JtXzEuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzIuTGF5ZXJOb3JtXzEuc2NhbGU=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzIuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLm91dC5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzIuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLm91dC5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzIuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLmtleS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzIuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLmtleS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzIuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLnF1ZXJ5LmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzIuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLnF1ZXJ5Lmtlcm5lbA==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzIuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLnZhbHVlLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzIuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLnZhbHVlLmtlcm5lbA==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzMuRGVuc2VfMC5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzMuRGVuc2VfMC5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzMuRGVuc2VfMS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzMuRGVuc2VfMS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzMuTGF5ZXJOb3JtXzAuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzMuTGF5ZXJOb3JtXzAuc2NhbGU=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzMuTGF5ZXJOb3JtXzEuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzMuTGF5ZXJOb3JtXzEuc2NhbGU=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzMuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLm91dC5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzMuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLm91dC5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzMuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLmtleS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzMuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLmtleS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzMuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLnF1ZXJ5LmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzMuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLnF1ZXJ5Lmtlcm5lbA==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzMuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLnZhbHVlLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzMuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLnZhbHVlLmtlcm5lbA==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzQuRGVuc2VfMC5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzQuRGVuc2VfMC5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzQuRGVuc2VfMS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzQuRGVuc2VfMS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzQuTGF5ZXJOb3JtXzAuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzQuTGF5ZXJOb3JtXzAuc2NhbGU=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzQuTGF5ZXJOb3JtXzEuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzQuTGF5ZXJOb3JtXzEuc2NhbGU=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzQuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLm91dC5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzQuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLm91dC5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzQuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLmtleS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzQuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLmtleS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzQuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLnF1ZXJ5LmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzQuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLnF1ZXJ5Lmtlcm5lbA==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzQuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLnZhbHVlLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzQuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLnZhbHVlLmtlcm5lbA==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzUuRGVuc2VfMC5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzUuRGVuc2VfMC5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzUuRGVuc2VfMS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzUuRGVuc2VfMS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzUuTGF5ZXJOb3JtXzAuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzUuTGF5ZXJOb3JtXzAuc2NhbGU=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzUuTGF5ZXJOb3JtXzEuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzUuTGF5ZXJOb3JtXzEuc2NhbGU=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzUuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLm91dC5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzUuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLm91dC5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzUuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLmtleS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzUuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLmtleS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzUuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLnF1ZXJ5LmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzUuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLnF1ZXJ5Lmtlcm5lbA==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzUuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLnZhbHVlLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzUuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLnZhbHVlLmtlcm5lbA==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}"} \ No newline at end of file diff --git a/checkpoints/offline/Craftax-Classic-Symbolic-v1-OfflineDiffusion-BC-100M/100000000/default/array_metadatas/process_0 b/checkpoints/offline/Craftax-Classic-Symbolic-v1-OfflineDiffusion-BC-100M/100000000/default/array_metadatas/process_0 new file mode 100644 index 0000000000000000000000000000000000000000..af0fb84132e9bc0aa2fed6958e07a83773a12ae5 --- /dev/null +++ b/checkpoints/offline/Craftax-Classic-Symbolic-v1-OfflineDiffusion-BC-100M/100000000/default/array_metadatas/process_0 @@ -0,0 +1 @@ +{"array_metadatas": [{"array_metadata": {"param_name": "step", "write_shape": [], "chunk_shape": [], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.Dense_0.bias", "write_shape": [768], "chunk_shape": [768], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.Dense_0.kernel", "write_shape": [1345, 768], "chunk_shape": [1345, 768], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.Dense_1.bias", "write_shape": [768], "chunk_shape": [768], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.Dense_1.kernel", "write_shape": [768, 768], "chunk_shape": [768, 768], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.Dense_2.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.Dense_2.kernel", "write_shape": [768, 384], "chunk_shape": [768, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.Dense_3.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.Dense_3.kernel", "write_shape": [384, 384], "chunk_shape": [384, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.Dense_4.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.Dense_4.kernel", "write_shape": [384, 384], "chunk_shape": [384, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.Dense_5.bias", "write_shape": [17], "chunk_shape": [17], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.Dense_5.kernel", "write_shape": [384, 17], "chunk_shape": [384, 17], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.Embed_0.embedding", "write_shape": [18, 384], "chunk_shape": [18, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.LayerNorm_0.bias", "write_shape": [768], "chunk_shape": [768], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.LayerNorm_0.scale", "write_shape": [768], "chunk_shape": [768], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.LayerNorm_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.LayerNorm_1.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_0.Dense_0.bias", "write_shape": [768], "chunk_shape": [768], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_0.Dense_0.kernel", "write_shape": [384, 768], "chunk_shape": [384, 768], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_0.Dense_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_0.Dense_1.kernel", "write_shape": [768, 384], "chunk_shape": [768, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_0.LayerNorm_0.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_0.LayerNorm_0.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_0.LayerNorm_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_0.LayerNorm_1.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_0.MultiHeadDotProductAttention_0.key.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_0.MultiHeadDotProductAttention_0.key.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_0.MultiHeadDotProductAttention_0.out.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_0.MultiHeadDotProductAttention_0.out.kernel", "write_shape": [8, 48, 384], "chunk_shape": [8, 48, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_0.MultiHeadDotProductAttention_0.query.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_0.MultiHeadDotProductAttention_0.query.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_0.MultiHeadDotProductAttention_0.value.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_0.MultiHeadDotProductAttention_0.value.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_1.Dense_0.bias", "write_shape": [768], "chunk_shape": [768], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_1.Dense_0.kernel", "write_shape": [384, 768], "chunk_shape": [384, 768], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_1.Dense_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_1.Dense_1.kernel", "write_shape": [768, 384], "chunk_shape": [768, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_1.LayerNorm_0.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_1.LayerNorm_0.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_1.LayerNorm_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_1.LayerNorm_1.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_1.MultiHeadDotProductAttention_0.key.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_1.MultiHeadDotProductAttention_0.key.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_1.MultiHeadDotProductAttention_0.out.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_1.MultiHeadDotProductAttention_0.out.kernel", "write_shape": [8, 48, 384], "chunk_shape": [8, 48, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_1.MultiHeadDotProductAttention_0.query.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_1.MultiHeadDotProductAttention_0.query.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_1.MultiHeadDotProductAttention_0.value.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_1.MultiHeadDotProductAttention_0.value.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_2.Dense_0.bias", "write_shape": [768], "chunk_shape": [768], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_2.Dense_0.kernel", "write_shape": [384, 768], "chunk_shape": [384, 768], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_2.Dense_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_2.Dense_1.kernel", "write_shape": [768, 384], "chunk_shape": [768, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_2.LayerNorm_0.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_2.LayerNorm_0.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_2.LayerNorm_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_2.LayerNorm_1.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_2.MultiHeadDotProductAttention_0.key.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_2.MultiHeadDotProductAttention_0.key.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_2.MultiHeadDotProductAttention_0.out.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_2.MultiHeadDotProductAttention_0.out.kernel", "write_shape": [8, 48, 384], "chunk_shape": [8, 48, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_2.MultiHeadDotProductAttention_0.query.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_2.MultiHeadDotProductAttention_0.query.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_2.MultiHeadDotProductAttention_0.value.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_2.MultiHeadDotProductAttention_0.value.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_3.Dense_0.bias", "write_shape": [768], "chunk_shape": [768], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_3.Dense_0.kernel", "write_shape": [384, 768], "chunk_shape": [384, 768], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_3.Dense_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_3.Dense_1.kernel", "write_shape": [768, 384], "chunk_shape": [768, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_3.LayerNorm_0.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_3.LayerNorm_0.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_3.LayerNorm_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_3.LayerNorm_1.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_3.MultiHeadDotProductAttention_0.key.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_3.MultiHeadDotProductAttention_0.key.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_3.MultiHeadDotProductAttention_0.out.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_3.MultiHeadDotProductAttention_0.out.kernel", "write_shape": [8, 48, 384], "chunk_shape": [8, 48, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_3.MultiHeadDotProductAttention_0.query.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_3.MultiHeadDotProductAttention_0.query.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_3.MultiHeadDotProductAttention_0.value.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_3.MultiHeadDotProductAttention_0.value.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_4.Dense_0.bias", "write_shape": [768], "chunk_shape": [768], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_4.Dense_0.kernel", "write_shape": [384, 768], "chunk_shape": [384, 768], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_4.Dense_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_4.Dense_1.kernel", "write_shape": [768, 384], "chunk_shape": [768, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_4.LayerNorm_0.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_4.LayerNorm_0.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_4.LayerNorm_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_4.LayerNorm_1.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_4.MultiHeadDotProductAttention_0.key.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_4.MultiHeadDotProductAttention_0.key.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_4.MultiHeadDotProductAttention_0.out.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_4.MultiHeadDotProductAttention_0.out.kernel", "write_shape": [8, 48, 384], "chunk_shape": [8, 48, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_4.MultiHeadDotProductAttention_0.query.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_4.MultiHeadDotProductAttention_0.query.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_4.MultiHeadDotProductAttention_0.value.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_4.MultiHeadDotProductAttention_0.value.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_5.Dense_0.bias", "write_shape": [768], "chunk_shape": [768], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_5.Dense_0.kernel", "write_shape": [384, 768], "chunk_shape": [384, 768], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_5.Dense_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_5.Dense_1.kernel", "write_shape": [768, 384], "chunk_shape": [768, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_5.LayerNorm_0.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_5.LayerNorm_0.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_5.LayerNorm_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_5.LayerNorm_1.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_5.MultiHeadDotProductAttention_0.key.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_5.MultiHeadDotProductAttention_0.key.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_5.MultiHeadDotProductAttention_0.out.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_5.MultiHeadDotProductAttention_0.out.kernel", "write_shape": [8, 48, 384], "chunk_shape": [8, 48, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_5.MultiHeadDotProductAttention_0.query.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_5.MultiHeadDotProductAttention_0.query.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_5.MultiHeadDotProductAttention_0.value.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_5.MultiHeadDotProductAttention_0.value.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.count", "write_shape": [], "chunk_shape": [], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.Dense_0.bias", "write_shape": [768], "chunk_shape": [768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.Dense_0.kernel", "write_shape": [1345, 768], "chunk_shape": [1345, 768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.Dense_1.bias", "write_shape": [768], "chunk_shape": [768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.Dense_1.kernel", "write_shape": [768, 768], "chunk_shape": [768, 768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.Dense_2.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.Dense_2.kernel", "write_shape": [768, 384], "chunk_shape": [768, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.Dense_3.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.Dense_3.kernel", "write_shape": [384, 384], "chunk_shape": [384, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.Dense_4.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.Dense_4.kernel", "write_shape": [384, 384], "chunk_shape": [384, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.Dense_5.bias", "write_shape": [17], "chunk_shape": [17], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.Dense_5.kernel", "write_shape": [384, 17], "chunk_shape": [384, 17], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.Embed_0.embedding", "write_shape": [18, 384], "chunk_shape": [18, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.LayerNorm_0.bias", "write_shape": [768], "chunk_shape": [768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.LayerNorm_0.scale", "write_shape": [768], "chunk_shape": [768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.LayerNorm_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.LayerNorm_1.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_0.Dense_0.bias", "write_shape": [768], "chunk_shape": [768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_0.Dense_0.kernel", "write_shape": [384, 768], "chunk_shape": [384, 768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_0.Dense_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_0.Dense_1.kernel", "write_shape": [768, 384], "chunk_shape": [768, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_0.LayerNorm_0.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_0.LayerNorm_0.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_0.LayerNorm_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_0.LayerNorm_1.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_0.MultiHeadDotProductAttention_0.key.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_0.MultiHeadDotProductAttention_0.key.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_0.MultiHeadDotProductAttention_0.out.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_0.MultiHeadDotProductAttention_0.out.kernel", "write_shape": [8, 48, 384], "chunk_shape": [8, 48, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_0.MultiHeadDotProductAttention_0.query.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_0.MultiHeadDotProductAttention_0.query.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_0.MultiHeadDotProductAttention_0.value.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_0.MultiHeadDotProductAttention_0.value.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_1.Dense_0.bias", "write_shape": [768], "chunk_shape": [768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_1.Dense_0.kernel", "write_shape": [384, 768], "chunk_shape": [384, 768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_1.Dense_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_1.Dense_1.kernel", "write_shape": [768, 384], "chunk_shape": [768, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_1.LayerNorm_0.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_1.LayerNorm_0.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_1.LayerNorm_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_1.LayerNorm_1.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_1.MultiHeadDotProductAttention_0.key.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_1.MultiHeadDotProductAttention_0.key.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_1.MultiHeadDotProductAttention_0.out.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_1.MultiHeadDotProductAttention_0.out.kernel", "write_shape": [8, 48, 384], "chunk_shape": [8, 48, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_1.MultiHeadDotProductAttention_0.query.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_1.MultiHeadDotProductAttention_0.query.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_1.MultiHeadDotProductAttention_0.value.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_1.MultiHeadDotProductAttention_0.value.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_2.Dense_0.bias", "write_shape": [768], "chunk_shape": [768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_2.Dense_0.kernel", "write_shape": [384, 768], "chunk_shape": [384, 768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_2.Dense_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_2.Dense_1.kernel", "write_shape": [768, 384], "chunk_shape": [768, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_2.LayerNorm_0.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_2.LayerNorm_0.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_2.LayerNorm_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_2.LayerNorm_1.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_2.MultiHeadDotProductAttention_0.key.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_2.MultiHeadDotProductAttention_0.key.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_2.MultiHeadDotProductAttention_0.out.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_2.MultiHeadDotProductAttention_0.out.kernel", "write_shape": [8, 48, 384], "chunk_shape": [8, 48, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_2.MultiHeadDotProductAttention_0.query.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_2.MultiHeadDotProductAttention_0.query.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_2.MultiHeadDotProductAttention_0.value.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_2.MultiHeadDotProductAttention_0.value.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_3.Dense_0.bias", "write_shape": [768], "chunk_shape": [768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_3.Dense_0.kernel", "write_shape": [384, 768], "chunk_shape": [384, 768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_3.Dense_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_3.Dense_1.kernel", "write_shape": [768, 384], "chunk_shape": [768, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_3.LayerNorm_0.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_3.LayerNorm_0.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_3.LayerNorm_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_3.LayerNorm_1.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_3.MultiHeadDotProductAttention_0.key.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_3.MultiHeadDotProductAttention_0.key.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_3.MultiHeadDotProductAttention_0.out.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_3.MultiHeadDotProductAttention_0.out.kernel", "write_shape": [8, 48, 384], "chunk_shape": [8, 48, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_3.MultiHeadDotProductAttention_0.query.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_3.MultiHeadDotProductAttention_0.query.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_3.MultiHeadDotProductAttention_0.value.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_3.MultiHeadDotProductAttention_0.value.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_4.Dense_0.bias", "write_shape": [768], "chunk_shape": [768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_4.Dense_0.kernel", "write_shape": [384, 768], "chunk_shape": [384, 768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_4.Dense_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_4.Dense_1.kernel", "write_shape": [768, 384], "chunk_shape": [768, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_4.LayerNorm_0.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_4.LayerNorm_0.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_4.LayerNorm_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_4.LayerNorm_1.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_4.MultiHeadDotProductAttention_0.key.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_4.MultiHeadDotProductAttention_0.key.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_4.MultiHeadDotProductAttention_0.out.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_4.MultiHeadDotProductAttention_0.out.kernel", "write_shape": [8, 48, 384], "chunk_shape": [8, 48, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_4.MultiHeadDotProductAttention_0.query.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_4.MultiHeadDotProductAttention_0.query.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_4.MultiHeadDotProductAttention_0.value.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_4.MultiHeadDotProductAttention_0.value.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_5.Dense_0.bias", "write_shape": [768], "chunk_shape": [768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_5.Dense_0.kernel", "write_shape": [384, 768], "chunk_shape": [384, 768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_5.Dense_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_5.Dense_1.kernel", "write_shape": [768, 384], "chunk_shape": [768, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_5.LayerNorm_0.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_5.LayerNorm_0.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_5.LayerNorm_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_5.LayerNorm_1.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_5.MultiHeadDotProductAttention_0.key.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_5.MultiHeadDotProductAttention_0.key.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_5.MultiHeadDotProductAttention_0.out.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_5.MultiHeadDotProductAttention_0.out.kernel", "write_shape": [8, 48, 384], "chunk_shape": [8, 48, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_5.MultiHeadDotProductAttention_0.query.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_5.MultiHeadDotProductAttention_0.query.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_5.MultiHeadDotProductAttention_0.value.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_5.MultiHeadDotProductAttention_0.value.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.Dense_0.bias", "write_shape": [768], "chunk_shape": [768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.Dense_0.kernel", "write_shape": [1345, 768], "chunk_shape": [1345, 768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.Dense_1.bias", "write_shape": [768], "chunk_shape": [768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.Dense_1.kernel", "write_shape": [768, 768], "chunk_shape": [768, 768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.Dense_2.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.Dense_2.kernel", "write_shape": [768, 384], "chunk_shape": [768, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.Dense_3.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.Dense_3.kernel", "write_shape": [384, 384], "chunk_shape": [384, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.Dense_4.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.Dense_4.kernel", "write_shape": [384, 384], "chunk_shape": [384, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.Dense_5.bias", "write_shape": [17], "chunk_shape": [17], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.Dense_5.kernel", "write_shape": [384, 17], "chunk_shape": [384, 17], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.Embed_0.embedding", "write_shape": [18, 384], "chunk_shape": [18, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.LayerNorm_0.bias", "write_shape": [768], "chunk_shape": [768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.LayerNorm_0.scale", "write_shape": [768], "chunk_shape": [768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.LayerNorm_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.LayerNorm_1.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_0.Dense_0.bias", "write_shape": [768], "chunk_shape": [768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_0.Dense_0.kernel", "write_shape": [384, 768], "chunk_shape": [384, 768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_0.Dense_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_0.Dense_1.kernel", "write_shape": [768, 384], "chunk_shape": [768, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_0.LayerNorm_0.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_0.LayerNorm_0.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_0.LayerNorm_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_0.LayerNorm_1.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_0.MultiHeadDotProductAttention_0.key.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_0.MultiHeadDotProductAttention_0.key.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_0.MultiHeadDotProductAttention_0.out.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_0.MultiHeadDotProductAttention_0.out.kernel", "write_shape": [8, 48, 384], "chunk_shape": [8, 48, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_0.MultiHeadDotProductAttention_0.query.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_0.MultiHeadDotProductAttention_0.query.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_0.MultiHeadDotProductAttention_0.value.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_0.MultiHeadDotProductAttention_0.value.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_1.Dense_0.bias", "write_shape": [768], "chunk_shape": [768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_1.Dense_0.kernel", "write_shape": [384, 768], "chunk_shape": [384, 768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_1.Dense_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_1.Dense_1.kernel", "write_shape": [768, 384], "chunk_shape": [768, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_1.LayerNorm_0.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_1.LayerNorm_0.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_1.LayerNorm_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_1.LayerNorm_1.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_1.MultiHeadDotProductAttention_0.key.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_1.MultiHeadDotProductAttention_0.key.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_1.MultiHeadDotProductAttention_0.out.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_1.MultiHeadDotProductAttention_0.out.kernel", "write_shape": [8, 48, 384], "chunk_shape": [8, 48, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_1.MultiHeadDotProductAttention_0.query.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_1.MultiHeadDotProductAttention_0.query.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_1.MultiHeadDotProductAttention_0.value.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_1.MultiHeadDotProductAttention_0.value.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_2.Dense_0.bias", "write_shape": [768], "chunk_shape": [768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_2.Dense_0.kernel", "write_shape": [384, 768], "chunk_shape": [384, 768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_2.Dense_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_2.Dense_1.kernel", "write_shape": [768, 384], "chunk_shape": [768, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_2.LayerNorm_0.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_2.LayerNorm_0.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_2.LayerNorm_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_2.LayerNorm_1.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_2.MultiHeadDotProductAttention_0.key.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_2.MultiHeadDotProductAttention_0.key.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_2.MultiHeadDotProductAttention_0.out.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_2.MultiHeadDotProductAttention_0.out.kernel", "write_shape": [8, 48, 384], "chunk_shape": [8, 48, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_2.MultiHeadDotProductAttention_0.query.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_2.MultiHeadDotProductAttention_0.query.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_2.MultiHeadDotProductAttention_0.value.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_2.MultiHeadDotProductAttention_0.value.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_3.Dense_0.bias", "write_shape": [768], "chunk_shape": [768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_3.Dense_0.kernel", "write_shape": [384, 768], "chunk_shape": [384, 768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_3.Dense_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_3.Dense_1.kernel", "write_shape": [768, 384], "chunk_shape": [768, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_3.LayerNorm_0.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_3.LayerNorm_0.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_3.LayerNorm_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_3.LayerNorm_1.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_3.MultiHeadDotProductAttention_0.key.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_3.MultiHeadDotProductAttention_0.key.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_3.MultiHeadDotProductAttention_0.out.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_3.MultiHeadDotProductAttention_0.out.kernel", "write_shape": [8, 48, 384], "chunk_shape": [8, 48, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_3.MultiHeadDotProductAttention_0.query.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_3.MultiHeadDotProductAttention_0.query.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_3.MultiHeadDotProductAttention_0.value.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_3.MultiHeadDotProductAttention_0.value.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_4.Dense_0.bias", "write_shape": [768], "chunk_shape": [768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_4.Dense_0.kernel", "write_shape": [384, 768], "chunk_shape": [384, 768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_4.Dense_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_4.Dense_1.kernel", "write_shape": [768, 384], "chunk_shape": [768, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_4.LayerNorm_0.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_4.LayerNorm_0.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_4.LayerNorm_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_4.LayerNorm_1.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_4.MultiHeadDotProductAttention_0.key.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_4.MultiHeadDotProductAttention_0.key.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_4.MultiHeadDotProductAttention_0.out.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_4.MultiHeadDotProductAttention_0.out.kernel", "write_shape": [8, 48, 384], "chunk_shape": [8, 48, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_4.MultiHeadDotProductAttention_0.query.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_4.MultiHeadDotProductAttention_0.query.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_4.MultiHeadDotProductAttention_0.value.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_4.MultiHeadDotProductAttention_0.value.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_5.Dense_0.bias", "write_shape": [768], "chunk_shape": [768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_5.Dense_0.kernel", "write_shape": [384, 768], "chunk_shape": [384, 768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_5.Dense_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_5.Dense_1.kernel", "write_shape": [768, 384], "chunk_shape": [768, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_5.LayerNorm_0.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_5.LayerNorm_0.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_5.LayerNorm_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_5.LayerNorm_1.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_5.MultiHeadDotProductAttention_0.key.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_5.MultiHeadDotProductAttention_0.key.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_5.MultiHeadDotProductAttention_0.out.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_5.MultiHeadDotProductAttention_0.out.kernel", "write_shape": [8, 48, 384], "chunk_shape": [8, 48, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_5.MultiHeadDotProductAttention_0.query.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_5.MultiHeadDotProductAttention_0.query.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_5.MultiHeadDotProductAttention_0.value.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_5.MultiHeadDotProductAttention_0.value.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.1.count", "write_shape": [], "chunk_shape": [], "ext_metadata": null}}]} \ No newline at end of file diff --git a/checkpoints/offline/Craftax-Classic-Symbolic-v1-OfflineDiffusion-BC-100M/100000000/default/d/021af9ba431a3072f4819480f91b83af b/checkpoints/offline/Craftax-Classic-Symbolic-v1-OfflineDiffusion-BC-100M/100000000/default/d/021af9ba431a3072f4819480f91b83af new file mode 100644 index 0000000000000000000000000000000000000000..40063e0bb0125db03d6fa58bf65a3cb8328f1805 Binary files /dev/null and b/checkpoints/offline/Craftax-Classic-Symbolic-v1-OfflineDiffusion-BC-100M/100000000/default/d/021af9ba431a3072f4819480f91b83af differ diff --git a/checkpoints/offline/Craftax-Classic-Symbolic-v1-OfflineDiffusion-BC-100M/100000000/default/manifest.ocdbt b/checkpoints/offline/Craftax-Classic-Symbolic-v1-OfflineDiffusion-BC-100M/100000000/default/manifest.ocdbt new file mode 100644 index 0000000000000000000000000000000000000000..b12a2bab77e77df915993ace51a110585d5eeb2c Binary files /dev/null and b/checkpoints/offline/Craftax-Classic-Symbolic-v1-OfflineDiffusion-BC-100M/100000000/default/manifest.ocdbt differ diff --git a/checkpoints/offline/Craftax-Classic-Symbolic-v1-OfflineDiffusion-BC-100M/100000000/default/ocdbt.process_0/d/0cf9a08a9722f9b8a0b7f007da7c1e92 b/checkpoints/offline/Craftax-Classic-Symbolic-v1-OfflineDiffusion-BC-100M/100000000/default/ocdbt.process_0/d/0cf9a08a9722f9b8a0b7f007da7c1e92 new file mode 100644 index 0000000000000000000000000000000000000000..a930731853d772682631bb44ded1cbd4e369f83f --- /dev/null +++ b/checkpoints/offline/Craftax-Classic-Symbolic-v1-OfflineDiffusion-BC-100M/100000000/default/ocdbt.process_0/d/0cf9a08a9722f9b8a0b7f007da7c1e92 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:bbff61a18e9475d72fae302d4748615daf5fc6b87cc0e0a338c96b8a781d6c0f +size 101199872 diff --git a/checkpoints/offline/Craftax-Classic-Symbolic-v1-OfflineDiffusion-BC-100M/100000000/default/ocdbt.process_0/d/1968eb861d84503c0e805cffdd77528a b/checkpoints/offline/Craftax-Classic-Symbolic-v1-OfflineDiffusion-BC-100M/100000000/default/ocdbt.process_0/d/1968eb861d84503c0e805cffdd77528a new file mode 100644 index 0000000000000000000000000000000000000000..04a1c2a5e61500bb83ea8fee15a98bdda281a09b Binary files /dev/null and b/checkpoints/offline/Craftax-Classic-Symbolic-v1-OfflineDiffusion-BC-100M/100000000/default/ocdbt.process_0/d/1968eb861d84503c0e805cffdd77528a differ diff --git a/checkpoints/offline/Craftax-Classic-Symbolic-v1-OfflineDiffusion-BC-100M/100000000/default/ocdbt.process_0/d/aec20934c03229d1bd9651c955e59d84 b/checkpoints/offline/Craftax-Classic-Symbolic-v1-OfflineDiffusion-BC-100M/100000000/default/ocdbt.process_0/d/aec20934c03229d1bd9651c955e59d84 new file mode 100644 index 0000000000000000000000000000000000000000..32aeb3b90d60b0a7055cd0d594f6225851c8357d Binary files /dev/null and b/checkpoints/offline/Craftax-Classic-Symbolic-v1-OfflineDiffusion-BC-100M/100000000/default/ocdbt.process_0/d/aec20934c03229d1bd9651c955e59d84 differ diff --git a/checkpoints/offline/Craftax-Classic-Symbolic-v1-OfflineDiffusion-BC-100M/100000000/default/ocdbt.process_0/d/d32b0ea672fe7a9b86b8e62e7c20dbaf b/checkpoints/offline/Craftax-Classic-Symbolic-v1-OfflineDiffusion-BC-100M/100000000/default/ocdbt.process_0/d/d32b0ea672fe7a9b86b8e62e7c20dbaf new file mode 100644 index 0000000000000000000000000000000000000000..4ee8641d9ce591743e3d806da9a6038ed872704c --- /dev/null +++ b/checkpoints/offline/Craftax-Classic-Symbolic-v1-OfflineDiffusion-BC-100M/100000000/default/ocdbt.process_0/d/d32b0ea672fe7a9b86b8e62e7c20dbaf @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:66e7df58a5ad39030e5631943ffa5d45164b91f283a2b7b34d4265c6bbf08be4 +size 448037 diff --git a/checkpoints/offline/Craftax-Classic-Symbolic-v1-OfflineDiffusion-BC-100M/100000000/default/ocdbt.process_0/manifest.ocdbt b/checkpoints/offline/Craftax-Classic-Symbolic-v1-OfflineDiffusion-BC-100M/100000000/default/ocdbt.process_0/manifest.ocdbt new file mode 100644 index 0000000000000000000000000000000000000000..de1c9f551f4ad5ea78d886518e8ad52d1e0d0997 Binary files /dev/null and b/checkpoints/offline/Craftax-Classic-Symbolic-v1-OfflineDiffusion-BC-100M/100000000/default/ocdbt.process_0/manifest.ocdbt differ diff --git a/checkpoints/offline/Craftax-Classic-Symbolic-v1-OfflineDiffusion-BC-100M/resume_metadata.json b/checkpoints/offline/Craftax-Classic-Symbolic-v1-OfflineDiffusion-BC-100M/resume_metadata.json new file mode 100644 index 0000000000000000000000000000000000000000..4938279b4e3dbb56ee812a59ca9f0c415381ee7b --- /dev/null +++ b/checkpoints/offline/Craftax-Classic-Symbolic-v1-OfflineDiffusion-BC-100M/resume_metadata.json @@ -0,0 +1,68 @@ +{ + "mode": "offline", + "update_step": 1525, + "total_gradient_steps_completed": 97600, + "wandb_run_id": "6opvce2t", + "config_snapshot": { + "ENV_NAME": "Craftax-Classic-Symbolic-v1", + "USE_OPTIMISTIC_RESETS": false, + "OPTIMISTIC_RESET_RATIO": 16, + "D_MODEL": 384, + "N_HEADS": 8, + "N_LAYERS": 6, + "D_FF": 768, + "OBS_ENCODER_LAYERS": 2, + "OBS_ENCODER_WIDTH": 768, + "DROPOUT_RATE": 0.1, + "PLAN_HORIZON": 32, + "DIFFUSION_SCHEDULE": "cosine", + "TRAIN_SIGMA": 0.0, + "LABEL_SMOOTHING": 0.0, + "DIFFUSION_STEPS": 15, + "DIFFUSION_STEPS_EVAL": 10, + "REMASK_STRATEGY": "rescale", + "ETA": 0.5, + "USE_LOOP": true, + "T_ON": 0.7, + "T_OFF": 0.3, + "TEMPERATURE": 0.5, + "TOP_P": 0.95, + "LR": 0.0003, + "MAX_GRAD_NORM": 1.0, + "LR_WARMUP_FRAMES": "1.048576e8", + "NUM_ENVS": 512, + "NUM_STEPS": 128, + "NUM_MINIBATCHES": 8, + "UPDATE_EPOCHS": 8, + "NUM_REPEATS": 1, + "OFFLINE_TOTAL_TIMESTEPS": 99942400, + "COLLECT_TEMPERATURE": 1.0, + "RETURN_WEIGHT_CAP": 5.0, + "ONLINE_TOTAL_TIMESTEPS": 100000000.0, + "DAGGER_BETA_INIT": 1.0, + "DAGGER_BETA_FINAL": 0.344, + "DAGGER_BUFFER_CYCLES": 1.90735, + "VAL_INTERVAL_FRAMES": 1000000.0, + "VAL_DIFFUSION_STEPS": 50, + "VAL_REPLAN_EVERY": 4, + "VAL_STEPS": 256, + "COLLECT_NUM_STEPS": 10000000, + "COLLECT_NUM_ENVS": 128, + "PPO_MODEL_TYPE": "ppo_rnn", + "LAYER_SIZE": 512, + "EVAL_STEPS": 10000, + "EVAL_NUM_ENVS": 32, + "SAVE_POLICY": true, + "SEED": 42, + "USE_WANDB": true, + "WANDB_PROJECT": "remdm-craftax", + "WANDB_ENTITY": "mathis-weil-university-college-london-ucl-", + "MODE": "offline", + "JIT": true, + "PPO_CHECKPOINT_PATH": "checkpoints/ppo_agents/policies/Craftax-Classic-Symbolic-v1-PPO_RNN-1000M", + "NUM_UPDATES": 1525, + "LR_WARMUP_STEPS": 1600, + "VAL_INTERVAL": 15, + "MINIBATCH_SIZE": 6208 + } +} \ No newline at end of file diff --git a/checkpoints/online/Craftax-Classic-Symbolic-v1-OnlineDiffusion-DAgger-50M/50000000/_CHECKPOINT_METADATA b/checkpoints/online/Craftax-Classic-Symbolic-v1-OnlineDiffusion-DAgger-50M/50000000/_CHECKPOINT_METADATA new file mode 100644 index 0000000000000000000000000000000000000000..0986df91a3ab4b9643bdb9f4279a648481c330bf --- /dev/null +++ b/checkpoints/online/Craftax-Classic-Symbolic-v1-OnlineDiffusion-DAgger-50M/50000000/_CHECKPOINT_METADATA @@ -0,0 +1 @@ +{"item_handlers": {"default": "orbax.checkpoint._src.handlers.standard_checkpoint_handler.StandardCheckpointHandler"}, "metrics": {}, "performance_metrics": {}, "init_timestamp_nsecs": 1775623858059636986, "commit_timestamp_nsecs": 1775623858516125466, "custom_metadata": {}} \ No newline at end of file diff --git a/checkpoints/online/Craftax-Classic-Symbolic-v1-OnlineDiffusion-DAgger-50M/50000000/default/_METADATA b/checkpoints/online/Craftax-Classic-Symbolic-v1-OnlineDiffusion-DAgger-50M/50000000/default/_METADATA new file mode 100644 index 0000000000000000000000000000000000000000..de3d5139d201d649880d3082c460bbb0f7b57ca0 --- /dev/null +++ b/checkpoints/online/Craftax-Classic-Symbolic-v1-OnlineDiffusion-DAgger-50M/50000000/default/_METADATA @@ -0,0 +1 @@ +{"tree_metadata": {"('step',)": {"key_metadata": [{"key": "step", "key_type": 2}], "value_metadata": {"value_type": "scalar", "skip_deserialize": false}}, "('params', 'params', 'Dense_0', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [768]}}, "('params', 'params', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1345, 768]}}, "('params', 'params', 'Dense_1', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [768]}}, "('params', 'params', 'Dense_1', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [768, 768]}}, "('params', 'params', 'Dense_2', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('params', 'params', 'Dense_2', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_2", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [768, 384]}}, "('params', 'params', 'Dense_3', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_3", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('params', 'params', 'Dense_3', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_3", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 384]}}, "('params', 'params', 'Dense_4', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_4", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('params', 'params', 'Dense_4', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_4", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 384]}}, "('params', 'params', 'Dense_5', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_5", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [17]}}, "('params', 'params', 'Dense_5', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_5", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 17]}}, "('params', 'params', 'Embed_0', 'embedding')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Embed_0", "key_type": 2}, {"key": "embedding", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [18, 384]}}, "('params', 'params', 'LayerNorm_0', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "LayerNorm_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [768]}}, "('params', 'params', 'LayerNorm_0', 'scale')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "LayerNorm_0", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [768]}}, "('params', 'params', 'LayerNorm_1', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "LayerNorm_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('params', 'params', 'LayerNorm_1', 'scale')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "LayerNorm_1", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('params', 'params', 'TransformerBlock_0', 'Dense_0', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_0", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [768]}}, "('params', 'params', 'TransformerBlock_0', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_0", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 768]}}, "('params', 'params', 'TransformerBlock_0', 'Dense_1', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_0", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('params', 'params', 'TransformerBlock_0', 'Dense_1', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_0", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [768, 384]}}, "('params', 'params', 'TransformerBlock_0', 'LayerNorm_0', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_0", "key_type": 2}, {"key": "LayerNorm_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('params', 'params', 'TransformerBlock_0', 'LayerNorm_0', 'scale')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_0", "key_type": 2}, {"key": "LayerNorm_0", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('params', 'params', 'TransformerBlock_0', 'LayerNorm_1', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_0", "key_type": 2}, {"key": "LayerNorm_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('params', 'params', 'TransformerBlock_0', 'LayerNorm_1', 'scale')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_0", "key_type": 2}, {"key": "LayerNorm_1", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('params', 'params', 'TransformerBlock_0', 'MultiHeadDotProductAttention_0', 'key', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_0", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "key", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 48]}}, "('params', 'params', 'TransformerBlock_0', 'MultiHeadDotProductAttention_0', 'key', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_0", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "key", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 8, 48]}}, "('params', 'params', 'TransformerBlock_0', 'MultiHeadDotProductAttention_0', 'out', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_0", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "out", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('params', 'params', 'TransformerBlock_0', 'MultiHeadDotProductAttention_0', 'out', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_0", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "out", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 48, 384]}}, "('params', 'params', 'TransformerBlock_0', 'MultiHeadDotProductAttention_0', 'query', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_0", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "query", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 48]}}, "('params', 'params', 'TransformerBlock_0', 'MultiHeadDotProductAttention_0', 'query', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_0", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "query", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 8, 48]}}, "('params', 'params', 'TransformerBlock_0', 'MultiHeadDotProductAttention_0', 'value', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_0", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "value", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 48]}}, "('params', 'params', 'TransformerBlock_0', 'MultiHeadDotProductAttention_0', 'value', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_0", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "value", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 8, 48]}}, "('params', 'params', 'TransformerBlock_1', 'Dense_0', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_1", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [768]}}, "('params', 'params', 'TransformerBlock_1', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_1", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 768]}}, "('params', 'params', 'TransformerBlock_1', 'Dense_1', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_1", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('params', 'params', 'TransformerBlock_1', 'Dense_1', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_1", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [768, 384]}}, "('params', 'params', 'TransformerBlock_1', 'LayerNorm_0', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_1", "key_type": 2}, {"key": "LayerNorm_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('params', 'params', 'TransformerBlock_1', 'LayerNorm_0', 'scale')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_1", "key_type": 2}, {"key": "LayerNorm_0", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('params', 'params', 'TransformerBlock_1', 'LayerNorm_1', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_1", "key_type": 2}, {"key": "LayerNorm_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('params', 'params', 'TransformerBlock_1', 'LayerNorm_1', 'scale')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_1", "key_type": 2}, {"key": "LayerNorm_1", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('params', 'params', 'TransformerBlock_1', 'MultiHeadDotProductAttention_0', 'key', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_1", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "key", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 48]}}, "('params', 'params', 'TransformerBlock_1', 'MultiHeadDotProductAttention_0', 'key', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_1", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "key", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 8, 48]}}, "('params', 'params', 'TransformerBlock_1', 'MultiHeadDotProductAttention_0', 'out', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_1", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "out", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('params', 'params', 'TransformerBlock_1', 'MultiHeadDotProductAttention_0', 'out', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_1", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "out", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 48, 384]}}, "('params', 'params', 'TransformerBlock_1', 'MultiHeadDotProductAttention_0', 'query', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_1", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "query", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 48]}}, "('params', 'params', 'TransformerBlock_1', 'MultiHeadDotProductAttention_0', 'query', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_1", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "query", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 8, 48]}}, "('params', 'params', 'TransformerBlock_1', 'MultiHeadDotProductAttention_0', 'value', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_1", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "value", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 48]}}, "('params', 'params', 'TransformerBlock_1', 'MultiHeadDotProductAttention_0', 'value', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_1", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "value", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 8, 48]}}, "('params', 'params', 'TransformerBlock_2', 'Dense_0', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_2", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [768]}}, "('params', 'params', 'TransformerBlock_2', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_2", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 768]}}, "('params', 'params', 'TransformerBlock_2', 'Dense_1', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_2", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('params', 'params', 'TransformerBlock_2', 'Dense_1', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_2", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [768, 384]}}, "('params', 'params', 'TransformerBlock_2', 'LayerNorm_0', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_2", "key_type": 2}, {"key": "LayerNorm_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('params', 'params', 'TransformerBlock_2', 'LayerNorm_0', 'scale')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_2", "key_type": 2}, {"key": "LayerNorm_0", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('params', 'params', 'TransformerBlock_2', 'LayerNorm_1', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_2", "key_type": 2}, {"key": "LayerNorm_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('params', 'params', 'TransformerBlock_2', 'LayerNorm_1', 'scale')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_2", "key_type": 2}, {"key": "LayerNorm_1", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('params', 'params', 'TransformerBlock_2', 'MultiHeadDotProductAttention_0', 'key', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_2", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "key", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 48]}}, "('params', 'params', 'TransformerBlock_2', 'MultiHeadDotProductAttention_0', 'key', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_2", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "key", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 8, 48]}}, "('params', 'params', 'TransformerBlock_2', 'MultiHeadDotProductAttention_0', 'out', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_2", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "out", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('params', 'params', 'TransformerBlock_2', 'MultiHeadDotProductAttention_0', 'out', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_2", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "out", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 48, 384]}}, "('params', 'params', 'TransformerBlock_2', 'MultiHeadDotProductAttention_0', 'query', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_2", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "query", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 48]}}, "('params', 'params', 'TransformerBlock_2', 'MultiHeadDotProductAttention_0', 'query', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_2", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "query", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 8, 48]}}, "('params', 'params', 'TransformerBlock_2', 'MultiHeadDotProductAttention_0', 'value', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_2", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "value", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 48]}}, "('params', 'params', 'TransformerBlock_2', 'MultiHeadDotProductAttention_0', 'value', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_2", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "value", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 8, 48]}}, "('params', 'params', 'TransformerBlock_3', 'Dense_0', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_3", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [768]}}, "('params', 'params', 'TransformerBlock_3', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_3", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 768]}}, "('params', 'params', 'TransformerBlock_3', 'Dense_1', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_3", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('params', 'params', 'TransformerBlock_3', 'Dense_1', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_3", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [768, 384]}}, "('params', 'params', 'TransformerBlock_3', 'LayerNorm_0', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_3", "key_type": 2}, {"key": "LayerNorm_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('params', 'params', 'TransformerBlock_3', 'LayerNorm_0', 'scale')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_3", "key_type": 2}, {"key": "LayerNorm_0", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('params', 'params', 'TransformerBlock_3', 'LayerNorm_1', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_3", "key_type": 2}, {"key": "LayerNorm_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('params', 'params', 'TransformerBlock_3', 'LayerNorm_1', 'scale')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_3", "key_type": 2}, {"key": "LayerNorm_1", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('params', 'params', 'TransformerBlock_3', 'MultiHeadDotProductAttention_0', 'key', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_3", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "key", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 48]}}, "('params', 'params', 'TransformerBlock_3', 'MultiHeadDotProductAttention_0', 'key', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_3", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "key", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 8, 48]}}, "('params', 'params', 'TransformerBlock_3', 'MultiHeadDotProductAttention_0', 'out', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_3", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "out", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('params', 'params', 'TransformerBlock_3', 'MultiHeadDotProductAttention_0', 'out', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_3", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "out", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 48, 384]}}, "('params', 'params', 'TransformerBlock_3', 'MultiHeadDotProductAttention_0', 'query', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_3", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "query", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 48]}}, "('params', 'params', 'TransformerBlock_3', 'MultiHeadDotProductAttention_0', 'query', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_3", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "query", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 8, 48]}}, "('params', 'params', 'TransformerBlock_3', 'MultiHeadDotProductAttention_0', 'value', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_3", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "value", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 48]}}, "('params', 'params', 'TransformerBlock_3', 'MultiHeadDotProductAttention_0', 'value', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_3", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "value", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 8, 48]}}, "('params', 'params', 'TransformerBlock_4', 'Dense_0', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_4", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [768]}}, "('params', 'params', 'TransformerBlock_4', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_4", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 768]}}, "('params', 'params', 'TransformerBlock_4', 'Dense_1', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_4", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('params', 'params', 'TransformerBlock_4', 'Dense_1', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_4", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [768, 384]}}, "('params', 'params', 'TransformerBlock_4', 'LayerNorm_0', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_4", "key_type": 2}, {"key": "LayerNorm_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('params', 'params', 'TransformerBlock_4', 'LayerNorm_0', 'scale')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_4", "key_type": 2}, {"key": "LayerNorm_0", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('params', 'params', 'TransformerBlock_4', 'LayerNorm_1', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_4", "key_type": 2}, {"key": "LayerNorm_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('params', 'params', 'TransformerBlock_4', 'LayerNorm_1', 'scale')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_4", "key_type": 2}, {"key": "LayerNorm_1", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('params', 'params', 'TransformerBlock_4', 'MultiHeadDotProductAttention_0', 'key', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_4", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "key", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 48]}}, "('params', 'params', 'TransformerBlock_4', 'MultiHeadDotProductAttention_0', 'key', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_4", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "key", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 8, 48]}}, "('params', 'params', 'TransformerBlock_4', 'MultiHeadDotProductAttention_0', 'out', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_4", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "out", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('params', 'params', 'TransformerBlock_4', 'MultiHeadDotProductAttention_0', 'out', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_4", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "out", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 48, 384]}}, "('params', 'params', 'TransformerBlock_4', 'MultiHeadDotProductAttention_0', 'query', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_4", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "query", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 48]}}, "('params', 'params', 'TransformerBlock_4', 'MultiHeadDotProductAttention_0', 'query', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_4", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "query", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 8, 48]}}, "('params', 'params', 'TransformerBlock_4', 'MultiHeadDotProductAttention_0', 'value', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_4", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "value", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 48]}}, "('params', 'params', 'TransformerBlock_4', 'MultiHeadDotProductAttention_0', 'value', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_4", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "value", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 8, 48]}}, "('params', 'params', 'TransformerBlock_5', 'Dense_0', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_5", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [768]}}, "('params', 'params', 'TransformerBlock_5', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_5", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 768]}}, "('params', 'params', 'TransformerBlock_5', 'Dense_1', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_5", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('params', 'params', 'TransformerBlock_5', 'Dense_1', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_5", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [768, 384]}}, "('params', 'params', 'TransformerBlock_5', 'LayerNorm_0', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_5", "key_type": 2}, {"key": "LayerNorm_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('params', 'params', 'TransformerBlock_5', 'LayerNorm_0', 'scale')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_5", "key_type": 2}, {"key": "LayerNorm_0", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('params', 'params', 'TransformerBlock_5', 'LayerNorm_1', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_5", "key_type": 2}, {"key": "LayerNorm_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('params', 'params', 'TransformerBlock_5', 'LayerNorm_1', 'scale')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_5", "key_type": 2}, {"key": "LayerNorm_1", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('params', 'params', 'TransformerBlock_5', 'MultiHeadDotProductAttention_0', 'key', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_5", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "key", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 48]}}, "('params', 'params', 'TransformerBlock_5', 'MultiHeadDotProductAttention_0', 'key', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_5", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "key", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 8, 48]}}, "('params', 'params', 'TransformerBlock_5', 'MultiHeadDotProductAttention_0', 'out', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_5", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "out", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('params', 'params', 'TransformerBlock_5', 'MultiHeadDotProductAttention_0', 'out', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_5", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "out", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 48, 384]}}, "('params', 'params', 'TransformerBlock_5', 'MultiHeadDotProductAttention_0', 'query', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_5", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "query", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 48]}}, "('params', 'params', 'TransformerBlock_5', 'MultiHeadDotProductAttention_0', 'query', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_5", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "query", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 8, 48]}}, "('params', 'params', 'TransformerBlock_5', 'MultiHeadDotProductAttention_0', 'value', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_5", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "value", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 48]}}, "('params', 'params', 'TransformerBlock_5', 'MultiHeadDotProductAttention_0', 'value', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_5", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "value", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 8, 48]}}, "('opt_state', '0')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "0", "key_type": 1}], "value_metadata": {"value_type": "None", "skip_deserialize": true}}, "('opt_state', '1', '0', 'count')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "count", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": []}}, "('opt_state', '1', '0', 'mu', 'params', 'Dense_0', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [768]}}, "('opt_state', '1', '0', 'mu', 'params', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1345, 768]}}, "('opt_state', '1', '0', 'mu', 'params', 'Dense_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [768]}}, "('opt_state', '1', '0', 'mu', 'params', 'Dense_1', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [768, 768]}}, "('opt_state', '1', '0', 'mu', 'params', 'Dense_2', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'mu', 'params', 'Dense_2', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_2", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [768, 384]}}, "('opt_state', '1', '0', 'mu', 'params', 'Dense_3', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_3", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'mu', 'params', 'Dense_3', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_3", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 384]}}, "('opt_state', '1', '0', 'mu', 'params', 'Dense_4', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_4", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'mu', 'params', 'Dense_4', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_4", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 384]}}, "('opt_state', '1', '0', 'mu', 'params', 'Dense_5', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_5", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [17]}}, "('opt_state', '1', '0', 'mu', 'params', 'Dense_5', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_5", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 17]}}, "('opt_state', '1', '0', 'mu', 'params', 'Embed_0', 'embedding')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Embed_0", "key_type": 2}, {"key": "embedding", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [18, 384]}}, "('opt_state', '1', '0', 'mu', 'params', 'LayerNorm_0', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "LayerNorm_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [768]}}, "('opt_state', '1', '0', 'mu', 'params', 'LayerNorm_0', 'scale')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "LayerNorm_0", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [768]}}, "('opt_state', '1', '0', 'mu', 'params', 'LayerNorm_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "LayerNorm_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'mu', 'params', 'LayerNorm_1', 'scale')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "LayerNorm_1", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_0', 'Dense_0', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_0", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [768]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_0', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_0", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 768]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_0', 'Dense_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_0", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_0', 'Dense_1', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_0", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [768, 384]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_0', 'LayerNorm_0', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_0", "key_type": 2}, {"key": "LayerNorm_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_0', 'LayerNorm_0', 'scale')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_0", "key_type": 2}, {"key": "LayerNorm_0", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_0', 'LayerNorm_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_0", "key_type": 2}, {"key": "LayerNorm_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_0', 'LayerNorm_1', 'scale')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_0", "key_type": 2}, {"key": "LayerNorm_1", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_0', 'MultiHeadDotProductAttention_0', 'key', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_0", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "key", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 48]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_0', 'MultiHeadDotProductAttention_0', 'key', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_0", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "key", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 8, 48]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_0', 'MultiHeadDotProductAttention_0', 'out', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_0", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "out", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_0', 'MultiHeadDotProductAttention_0', 'out', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_0", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "out", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 48, 384]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_0', 'MultiHeadDotProductAttention_0', 'query', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_0", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "query", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 48]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_0', 'MultiHeadDotProductAttention_0', 'query', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_0", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "query", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 8, 48]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_0', 'MultiHeadDotProductAttention_0', 'value', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_0", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "value", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 48]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_0', 'MultiHeadDotProductAttention_0', 'value', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_0", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "value", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 8, 48]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_1', 'Dense_0', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_1", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [768]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_1', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_1", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 768]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_1', 'Dense_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_1", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_1', 'Dense_1', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_1", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [768, 384]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_1', 'LayerNorm_0', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_1", "key_type": 2}, {"key": "LayerNorm_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_1', 'LayerNorm_0', 'scale')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_1", "key_type": 2}, {"key": "LayerNorm_0", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_1', 'LayerNorm_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_1", "key_type": 2}, {"key": "LayerNorm_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_1', 'LayerNorm_1', 'scale')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_1", "key_type": 2}, {"key": "LayerNorm_1", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_1', 'MultiHeadDotProductAttention_0', 'key', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_1", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "key", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 48]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_1', 'MultiHeadDotProductAttention_0', 'key', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_1", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "key", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 8, 48]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_1', 'MultiHeadDotProductAttention_0', 'out', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_1", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "out", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_1', 'MultiHeadDotProductAttention_0', 'out', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_1", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "out", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 48, 384]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_1', 'MultiHeadDotProductAttention_0', 'query', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_1", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "query", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 48]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_1', 'MultiHeadDotProductAttention_0', 'query', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_1", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "query", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 8, 48]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_1', 'MultiHeadDotProductAttention_0', 'value', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_1", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "value", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 48]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_1', 'MultiHeadDotProductAttention_0', 'value', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_1", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "value", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 8, 48]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_2', 'Dense_0', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_2", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [768]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_2', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_2", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 768]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_2', 'Dense_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_2", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_2', 'Dense_1', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_2", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [768, 384]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_2', 'LayerNorm_0', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_2", "key_type": 2}, {"key": "LayerNorm_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_2', 'LayerNorm_0', 'scale')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_2", "key_type": 2}, {"key": "LayerNorm_0", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_2', 'LayerNorm_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_2", "key_type": 2}, {"key": "LayerNorm_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_2', 'LayerNorm_1', 'scale')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_2", "key_type": 2}, {"key": "LayerNorm_1", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_2', 'MultiHeadDotProductAttention_0', 'key', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_2", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "key", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 48]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_2', 'MultiHeadDotProductAttention_0', 'key', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_2", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "key", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 8, 48]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_2', 'MultiHeadDotProductAttention_0', 'out', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_2", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "out", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_2', 'MultiHeadDotProductAttention_0', 'out', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_2", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "out", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 48, 384]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_2', 'MultiHeadDotProductAttention_0', 'query', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_2", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "query", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 48]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_2', 'MultiHeadDotProductAttention_0', 'query', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_2", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "query", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 8, 48]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_2', 'MultiHeadDotProductAttention_0', 'value', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_2", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "value", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 48]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_2', 'MultiHeadDotProductAttention_0', 'value', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_2", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "value", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 8, 48]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_3', 'Dense_0', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_3", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [768]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_3', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_3", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 768]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_3', 'Dense_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_3", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_3', 'Dense_1', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_3", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [768, 384]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_3', 'LayerNorm_0', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_3", "key_type": 2}, {"key": "LayerNorm_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_3', 'LayerNorm_0', 'scale')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_3", "key_type": 2}, {"key": "LayerNorm_0", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_3', 'LayerNorm_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_3", "key_type": 2}, {"key": "LayerNorm_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_3', 'LayerNorm_1', 'scale')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_3", "key_type": 2}, {"key": "LayerNorm_1", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_3', 'MultiHeadDotProductAttention_0', 'key', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_3", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "key", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 48]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_3', 'MultiHeadDotProductAttention_0', 'key', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_3", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "key", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 8, 48]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_3', 'MultiHeadDotProductAttention_0', 'out', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_3", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "out", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_3', 'MultiHeadDotProductAttention_0', 'out', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_3", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "out", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 48, 384]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_3', 'MultiHeadDotProductAttention_0', 'query', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_3", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "query", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 48]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_3', 'MultiHeadDotProductAttention_0', 'query', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_3", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "query", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 8, 48]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_3', 'MultiHeadDotProductAttention_0', 'value', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_3", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "value", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 48]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_3', 'MultiHeadDotProductAttention_0', 'value', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_3", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "value", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 8, 48]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_4', 'Dense_0', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_4", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [768]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_4', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_4", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 768]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_4', 'Dense_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_4", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_4', 'Dense_1', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_4", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [768, 384]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_4', 'LayerNorm_0', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_4", "key_type": 2}, {"key": "LayerNorm_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_4', 'LayerNorm_0', 'scale')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_4", "key_type": 2}, {"key": "LayerNorm_0", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_4', 'LayerNorm_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_4", "key_type": 2}, {"key": "LayerNorm_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_4', 'LayerNorm_1', 'scale')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_4", "key_type": 2}, {"key": "LayerNorm_1", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_4', 'MultiHeadDotProductAttention_0', 'key', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_4", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "key", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 48]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_4', 'MultiHeadDotProductAttention_0', 'key', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_4", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "key", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 8, 48]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_4', 'MultiHeadDotProductAttention_0', 'out', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_4", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "out", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_4', 'MultiHeadDotProductAttention_0', 'out', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_4", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "out", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 48, 384]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_4', 'MultiHeadDotProductAttention_0', 'query', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_4", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "query", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 48]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_4', 'MultiHeadDotProductAttention_0', 'query', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_4", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "query", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 8, 48]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_4', 'MultiHeadDotProductAttention_0', 'value', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_4", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "value", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 48]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_4', 'MultiHeadDotProductAttention_0', 'value', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_4", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "value", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 8, 48]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_5', 'Dense_0', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_5", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [768]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_5', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_5", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 768]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_5', 'Dense_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_5", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_5', 'Dense_1', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_5", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [768, 384]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_5', 'LayerNorm_0', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_5", "key_type": 2}, {"key": "LayerNorm_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_5', 'LayerNorm_0', 'scale')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_5", "key_type": 2}, {"key": "LayerNorm_0", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_5', 'LayerNorm_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_5", "key_type": 2}, {"key": "LayerNorm_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_5', 'LayerNorm_1', 'scale')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_5", "key_type": 2}, {"key": "LayerNorm_1", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_5', 'MultiHeadDotProductAttention_0', 'key', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_5", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "key", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 48]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_5', 'MultiHeadDotProductAttention_0', 'key', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_5", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "key", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 8, 48]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_5', 'MultiHeadDotProductAttention_0', 'out', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_5", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "out", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_5', 'MultiHeadDotProductAttention_0', 'out', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_5", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "out", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 48, 384]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_5', 'MultiHeadDotProductAttention_0', 'query', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_5", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "query", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 48]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_5', 'MultiHeadDotProductAttention_0', 'query', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_5", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "query", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 8, 48]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_5', 'MultiHeadDotProductAttention_0', 'value', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_5", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "value", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 48]}}, "('opt_state', '1', '0', 'mu', 'params', 'TransformerBlock_5', 'MultiHeadDotProductAttention_0', 'value', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_5", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "value", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 8, 48]}}, "('opt_state', '1', '0', 'nu', 'params', 'Dense_0', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [768]}}, "('opt_state', '1', '0', 'nu', 'params', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1345, 768]}}, "('opt_state', '1', '0', 'nu', 'params', 'Dense_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [768]}}, "('opt_state', '1', '0', 'nu', 'params', 'Dense_1', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [768, 768]}}, "('opt_state', '1', '0', 'nu', 'params', 'Dense_2', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'nu', 'params', 'Dense_2', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_2", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [768, 384]}}, "('opt_state', '1', '0', 'nu', 'params', 'Dense_3', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_3", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'nu', 'params', 'Dense_3', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_3", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 384]}}, "('opt_state', '1', '0', 'nu', 'params', 'Dense_4', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_4", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'nu', 'params', 'Dense_4', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_4", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 384]}}, "('opt_state', '1', '0', 'nu', 'params', 'Dense_5', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_5", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [17]}}, "('opt_state', '1', '0', 'nu', 'params', 'Dense_5', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_5", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 17]}}, "('opt_state', '1', '0', 'nu', 'params', 'Embed_0', 'embedding')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Embed_0", "key_type": 2}, {"key": "embedding", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [18, 384]}}, "('opt_state', '1', '0', 'nu', 'params', 'LayerNorm_0', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "LayerNorm_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [768]}}, "('opt_state', '1', '0', 'nu', 'params', 'LayerNorm_0', 'scale')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "LayerNorm_0", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [768]}}, "('opt_state', '1', '0', 'nu', 'params', 'LayerNorm_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "LayerNorm_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'nu', 'params', 'LayerNorm_1', 'scale')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "LayerNorm_1", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_0', 'Dense_0', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_0", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [768]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_0', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_0", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 768]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_0', 'Dense_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_0", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_0', 'Dense_1', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_0", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [768, 384]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_0', 'LayerNorm_0', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_0", "key_type": 2}, {"key": "LayerNorm_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_0', 'LayerNorm_0', 'scale')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_0", "key_type": 2}, {"key": "LayerNorm_0", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_0', 'LayerNorm_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_0", "key_type": 2}, {"key": "LayerNorm_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_0', 'LayerNorm_1', 'scale')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_0", "key_type": 2}, {"key": "LayerNorm_1", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_0', 'MultiHeadDotProductAttention_0', 'key', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_0", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "key", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 48]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_0', 'MultiHeadDotProductAttention_0', 'key', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_0", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "key", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 8, 48]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_0', 'MultiHeadDotProductAttention_0', 'out', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_0", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "out", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_0', 'MultiHeadDotProductAttention_0', 'out', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_0", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "out", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 48, 384]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_0', 'MultiHeadDotProductAttention_0', 'query', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_0", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "query", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 48]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_0', 'MultiHeadDotProductAttention_0', 'query', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_0", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "query", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 8, 48]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_0', 'MultiHeadDotProductAttention_0', 'value', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_0", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "value", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 48]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_0', 'MultiHeadDotProductAttention_0', 'value', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_0", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "value", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 8, 48]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_1', 'Dense_0', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_1", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [768]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_1', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_1", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 768]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_1', 'Dense_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_1", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_1', 'Dense_1', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_1", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [768, 384]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_1', 'LayerNorm_0', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_1", "key_type": 2}, {"key": "LayerNorm_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_1', 'LayerNorm_0', 'scale')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_1", "key_type": 2}, {"key": "LayerNorm_0", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_1', 'LayerNorm_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_1", "key_type": 2}, {"key": "LayerNorm_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_1', 'LayerNorm_1', 'scale')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_1", "key_type": 2}, {"key": "LayerNorm_1", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_1', 'MultiHeadDotProductAttention_0', 'key', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_1", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "key", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 48]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_1', 'MultiHeadDotProductAttention_0', 'key', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_1", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "key", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 8, 48]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_1', 'MultiHeadDotProductAttention_0', 'out', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_1", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "out", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_1', 'MultiHeadDotProductAttention_0', 'out', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_1", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "out", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 48, 384]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_1', 'MultiHeadDotProductAttention_0', 'query', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_1", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "query", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 48]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_1', 'MultiHeadDotProductAttention_0', 'query', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_1", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "query", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 8, 48]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_1', 'MultiHeadDotProductAttention_0', 'value', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_1", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "value", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 48]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_1', 'MultiHeadDotProductAttention_0', 'value', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_1", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "value", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 8, 48]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_2', 'Dense_0', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_2", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [768]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_2', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_2", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 768]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_2', 'Dense_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_2", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_2', 'Dense_1', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_2", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [768, 384]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_2', 'LayerNorm_0', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_2", "key_type": 2}, {"key": "LayerNorm_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_2', 'LayerNorm_0', 'scale')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_2", "key_type": 2}, {"key": "LayerNorm_0", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_2', 'LayerNorm_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_2", "key_type": 2}, {"key": "LayerNorm_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_2', 'LayerNorm_1', 'scale')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_2", "key_type": 2}, {"key": "LayerNorm_1", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_2', 'MultiHeadDotProductAttention_0', 'key', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_2", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "key", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 48]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_2', 'MultiHeadDotProductAttention_0', 'key', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_2", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "key", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 8, 48]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_2', 'MultiHeadDotProductAttention_0', 'out', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_2", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "out", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_2', 'MultiHeadDotProductAttention_0', 'out', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_2", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "out", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 48, 384]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_2', 'MultiHeadDotProductAttention_0', 'query', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_2", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "query", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 48]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_2', 'MultiHeadDotProductAttention_0', 'query', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_2", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "query", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 8, 48]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_2', 'MultiHeadDotProductAttention_0', 'value', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_2", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "value", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 48]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_2', 'MultiHeadDotProductAttention_0', 'value', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_2", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "value", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 8, 48]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_3', 'Dense_0', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_3", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [768]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_3', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_3", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 768]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_3', 'Dense_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_3", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_3', 'Dense_1', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_3", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [768, 384]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_3', 'LayerNorm_0', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_3", "key_type": 2}, {"key": "LayerNorm_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_3', 'LayerNorm_0', 'scale')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_3", "key_type": 2}, {"key": "LayerNorm_0", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_3', 'LayerNorm_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_3", "key_type": 2}, {"key": "LayerNorm_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_3', 'LayerNorm_1', 'scale')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_3", "key_type": 2}, {"key": "LayerNorm_1", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_3', 'MultiHeadDotProductAttention_0', 'key', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_3", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "key", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 48]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_3', 'MultiHeadDotProductAttention_0', 'key', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_3", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "key", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 8, 48]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_3', 'MultiHeadDotProductAttention_0', 'out', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_3", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "out", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_3', 'MultiHeadDotProductAttention_0', 'out', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_3", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "out", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 48, 384]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_3', 'MultiHeadDotProductAttention_0', 'query', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_3", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "query", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 48]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_3', 'MultiHeadDotProductAttention_0', 'query', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_3", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "query", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 8, 48]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_3', 'MultiHeadDotProductAttention_0', 'value', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_3", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "value", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 48]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_3', 'MultiHeadDotProductAttention_0', 'value', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_3", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "value", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 8, 48]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_4', 'Dense_0', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_4", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [768]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_4', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_4", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 768]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_4', 'Dense_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_4", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_4', 'Dense_1', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_4", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [768, 384]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_4', 'LayerNorm_0', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_4", "key_type": 2}, {"key": "LayerNorm_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_4', 'LayerNorm_0', 'scale')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_4", "key_type": 2}, {"key": "LayerNorm_0", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_4', 'LayerNorm_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_4", "key_type": 2}, {"key": "LayerNorm_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_4', 'LayerNorm_1', 'scale')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_4", "key_type": 2}, {"key": "LayerNorm_1", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_4', 'MultiHeadDotProductAttention_0', 'key', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_4", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "key", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 48]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_4', 'MultiHeadDotProductAttention_0', 'key', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_4", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "key", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 8, 48]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_4', 'MultiHeadDotProductAttention_0', 'out', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_4", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "out", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_4', 'MultiHeadDotProductAttention_0', 'out', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_4", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "out", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 48, 384]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_4', 'MultiHeadDotProductAttention_0', 'query', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_4", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "query", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 48]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_4', 'MultiHeadDotProductAttention_0', 'query', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_4", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "query", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 8, 48]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_4', 'MultiHeadDotProductAttention_0', 'value', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_4", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "value", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 48]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_4', 'MultiHeadDotProductAttention_0', 'value', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_4", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "value", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 8, 48]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_5', 'Dense_0', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_5", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [768]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_5', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_5", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 768]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_5', 'Dense_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_5", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_5', 'Dense_1', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_5", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [768, 384]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_5', 'LayerNorm_0', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_5", "key_type": 2}, {"key": "LayerNorm_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_5', 'LayerNorm_0', 'scale')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_5", "key_type": 2}, {"key": "LayerNorm_0", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_5', 'LayerNorm_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_5", "key_type": 2}, {"key": "LayerNorm_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_5', 'LayerNorm_1', 'scale')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_5", "key_type": 2}, {"key": "LayerNorm_1", "key_type": 2}, {"key": "scale", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_5', 'MultiHeadDotProductAttention_0', 'key', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_5", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "key", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 48]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_5', 'MultiHeadDotProductAttention_0', 'key', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_5", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "key", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 8, 48]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_5', 'MultiHeadDotProductAttention_0', 'out', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_5", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "out", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_5', 'MultiHeadDotProductAttention_0', 'out', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_5", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "out", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 48, 384]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_5', 'MultiHeadDotProductAttention_0', 'query', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_5", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "query", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 48]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_5', 'MultiHeadDotProductAttention_0', 'query', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_5", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "query", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 8, 48]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_5', 'MultiHeadDotProductAttention_0', 'value', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_5", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "value", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8, 48]}}, "('opt_state', '1', '0', 'nu', 'params', 'TransformerBlock_5', 'MultiHeadDotProductAttention_0', 'value', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "TransformerBlock_5", "key_type": 2}, {"key": "MultiHeadDotProductAttention_0", "key_type": 2}, {"key": "value", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [384, 8, 48]}}, "('opt_state', '1', '1')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "1", "key_type": 1}], "value_metadata": {"value_type": "None", "skip_deserialize": true}}}, "use_ocdbt": true, "use_zarr3": false, "store_array_data_equal_to_fill_value": true, "custom_metadata": null} \ No newline at end of file diff --git a/checkpoints/online/Craftax-Classic-Symbolic-v1-OnlineDiffusion-DAgger-50M/50000000/default/_sharding b/checkpoints/online/Craftax-Classic-Symbolic-v1-OnlineDiffusion-DAgger-50M/50000000/default/_sharding new file mode 100644 index 0000000000000000000000000000000000000000..0018d0fe5602c2ae1f48d1ba9055fa7e1936ae8c --- /dev/null +++ b/checkpoints/online/Craftax-Classic-Symbolic-v1-OnlineDiffusion-DAgger-50M/50000000/default/_sharding @@ -0,0 +1 @@ +{"b3B0X3N0YXRlLjEuMC5jb3VudA==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuRGVuc2VfMC5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuRGVuc2VfMC5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuRGVuc2VfMS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuRGVuc2VfMS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuRGVuc2VfMi5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuRGVuc2VfMi5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuRGVuc2VfMy5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuRGVuc2VfMy5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuRGVuc2VfNC5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuRGVuc2VfNC5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuRGVuc2VfNS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuRGVuc2VfNS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuRW1iZWRfMC5lbWJlZGRpbmc=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuTGF5ZXJOb3JtXzAuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuTGF5ZXJOb3JtXzAuc2NhbGU=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuTGF5ZXJOb3JtXzEuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuTGF5ZXJOb3JtXzEuc2NhbGU=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja180Lk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC52YWx1ZS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja180Lk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC52YWx1ZS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja180Lk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5rZXkuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja180Lk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5rZXkua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja180Lk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5vdXQuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja180Lk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5vdXQua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja180Lk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5xdWVyeS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja180Lk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5xdWVyeS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja180LkRlbnNlXzAuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja180LkRlbnNlXzAua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja180LkRlbnNlXzEuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja180LkRlbnNlXzEua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja180LkxheWVyTm9ybV8wLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja180LkxheWVyTm9ybV8wLnNjYWxl":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja180LkxheWVyTm9ybV8xLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja180LkxheWVyTm9ybV8xLnNjYWxl":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja181Lk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC52YWx1ZS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja181Lk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC52YWx1ZS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja181Lk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5rZXkuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja181Lk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5rZXkua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja181Lk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5vdXQuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja181Lk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5vdXQua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja181Lk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5xdWVyeS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja181Lk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5xdWVyeS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja181LkRlbnNlXzAuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja181LkRlbnNlXzAua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja181LkRlbnNlXzEuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja181LkRlbnNlXzEua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja181LkxheWVyTm9ybV8wLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja181LkxheWVyTm9ybV8wLnNjYWxl":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja181LkxheWVyTm9ybV8xLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja181LkxheWVyTm9ybV8xLnNjYWxl":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18wLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC52YWx1ZS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18wLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC52YWx1ZS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18wLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5rZXkuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18wLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5rZXkua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18wLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5vdXQuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18wLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5vdXQua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18wLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5xdWVyeS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18wLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5xdWVyeS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18wLkRlbnNlXzAuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18wLkRlbnNlXzAua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18wLkRlbnNlXzEuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18wLkRlbnNlXzEua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18wLkxheWVyTm9ybV8wLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18wLkxheWVyTm9ybV8wLnNjYWxl":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18wLkxheWVyTm9ybV8xLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18wLkxheWVyTm9ybV8xLnNjYWxl":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18xLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC52YWx1ZS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18xLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC52YWx1ZS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18xLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5rZXkuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18xLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5rZXkua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18xLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5vdXQuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18xLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5vdXQua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18xLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5xdWVyeS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18xLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5xdWVyeS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18xLkRlbnNlXzAuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18xLkRlbnNlXzAua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18xLkRlbnNlXzEuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18xLkRlbnNlXzEua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18xLkxheWVyTm9ybV8wLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18xLkxheWVyTm9ybV8wLnNjYWxl":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18xLkxheWVyTm9ybV8xLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18xLkxheWVyTm9ybV8xLnNjYWxl":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18yLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC52YWx1ZS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18yLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC52YWx1ZS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18yLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5rZXkuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18yLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5rZXkua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18yLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5vdXQuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18yLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5vdXQua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18yLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5xdWVyeS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18yLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5xdWVyeS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18yLkRlbnNlXzAuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18yLkRlbnNlXzAua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18yLkRlbnNlXzEuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18yLkRlbnNlXzEua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18yLkxheWVyTm9ybV8wLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18yLkxheWVyTm9ybV8wLnNjYWxl":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18yLkxheWVyTm9ybV8xLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18yLkxheWVyTm9ybV8xLnNjYWxl":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18zLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC52YWx1ZS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18zLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC52YWx1ZS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18zLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5rZXkuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18zLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5rZXkua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18zLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5vdXQuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18zLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5vdXQua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18zLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5xdWVyeS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18zLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5xdWVyeS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18zLkRlbnNlXzAuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18zLkRlbnNlXzAua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18zLkRlbnNlXzEuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18zLkRlbnNlXzEua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18zLkxheWVyTm9ybV8wLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18zLkxheWVyTm9ybV8wLnNjYWxl":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18zLkxheWVyTm9ybV8xLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18zLkxheWVyTm9ybV8xLnNjYWxl":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuRGVuc2VfMC5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuRGVuc2VfMC5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuRGVuc2VfMS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuRGVuc2VfMS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuRGVuc2VfMi5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuRGVuc2VfMi5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuRGVuc2VfMy5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuRGVuc2VfMy5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuRGVuc2VfNC5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuRGVuc2VfNC5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuRGVuc2VfNS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuRGVuc2VfNS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuRW1iZWRfMC5lbWJlZGRpbmc=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuTGF5ZXJOb3JtXzAuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuTGF5ZXJOb3JtXzAuc2NhbGU=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuTGF5ZXJOb3JtXzEuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuTGF5ZXJOb3JtXzEuc2NhbGU=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja180Lk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC52YWx1ZS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja180Lk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC52YWx1ZS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja180Lk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5rZXkuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja180Lk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5rZXkua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja180Lk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5vdXQuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja180Lk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5vdXQua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja180Lk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5xdWVyeS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja180Lk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5xdWVyeS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja180LkRlbnNlXzAuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja180LkRlbnNlXzAua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja180LkRlbnNlXzEuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja180LkRlbnNlXzEua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja180LkxheWVyTm9ybV8wLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja180LkxheWVyTm9ybV8wLnNjYWxl":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja180LkxheWVyTm9ybV8xLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja180LkxheWVyTm9ybV8xLnNjYWxl":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja181Lk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC52YWx1ZS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja181Lk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC52YWx1ZS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja181Lk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5rZXkuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja181Lk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5rZXkua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja181Lk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5vdXQuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja181Lk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5vdXQua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja181Lk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5xdWVyeS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja181Lk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5xdWVyeS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja181LkRlbnNlXzAuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja181LkRlbnNlXzAua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja181LkRlbnNlXzEuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja181LkRlbnNlXzEua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja181LkxheWVyTm9ybV8wLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja181LkxheWVyTm9ybV8wLnNjYWxl":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja181LkxheWVyTm9ybV8xLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja181LkxheWVyTm9ybV8xLnNjYWxl":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18wLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC52YWx1ZS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18wLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC52YWx1ZS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18wLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5rZXkuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18wLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5rZXkua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18wLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5vdXQuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18wLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5vdXQua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18wLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5xdWVyeS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18wLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5xdWVyeS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18wLkRlbnNlXzAuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18wLkRlbnNlXzAua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18wLkRlbnNlXzEuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18wLkRlbnNlXzEua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18wLkxheWVyTm9ybV8wLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18wLkxheWVyTm9ybV8wLnNjYWxl":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18wLkxheWVyTm9ybV8xLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18wLkxheWVyTm9ybV8xLnNjYWxl":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18xLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC52YWx1ZS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18xLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC52YWx1ZS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18xLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5rZXkuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18xLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5rZXkua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18xLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5vdXQuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18xLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5vdXQua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18xLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5xdWVyeS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18xLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5xdWVyeS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18xLkRlbnNlXzAuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18xLkRlbnNlXzAua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18xLkRlbnNlXzEuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18xLkRlbnNlXzEua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18xLkxheWVyTm9ybV8wLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18xLkxheWVyTm9ybV8wLnNjYWxl":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18xLkxheWVyTm9ybV8xLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18xLkxheWVyTm9ybV8xLnNjYWxl":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18yLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC52YWx1ZS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18yLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC52YWx1ZS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18yLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5rZXkuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18yLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5rZXkua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18yLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5vdXQuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18yLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5vdXQua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18yLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5xdWVyeS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18yLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5xdWVyeS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18yLkRlbnNlXzAuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18yLkRlbnNlXzAua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18yLkRlbnNlXzEuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18yLkRlbnNlXzEua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18yLkxheWVyTm9ybV8wLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18yLkxheWVyTm9ybV8wLnNjYWxl":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18yLkxheWVyTm9ybV8xLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18yLkxheWVyTm9ybV8xLnNjYWxl":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18zLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC52YWx1ZS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18zLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC52YWx1ZS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18zLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5rZXkuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18zLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5rZXkua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18zLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5vdXQuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18zLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5vdXQua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18zLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5xdWVyeS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18zLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5xdWVyeS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18zLkRlbnNlXzAuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18zLkRlbnNlXzAua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18zLkRlbnNlXzEuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18zLkRlbnNlXzEua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18zLkxheWVyTm9ybV8wLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18zLkxheWVyTm9ybV8wLnNjYWxl":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18zLkxheWVyTm9ybV8xLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18zLkxheWVyTm9ybV8xLnNjYWxl":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5EZW5zZV80LmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5EZW5zZV80Lmtlcm5lbA==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5EZW5zZV81LmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5EZW5zZV81Lmtlcm5lbA==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5EZW5zZV8wLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5EZW5zZV8wLmtlcm5lbA==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5EZW5zZV8xLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5EZW5zZV8xLmtlcm5lbA==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5EZW5zZV8yLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5EZW5zZV8yLmtlcm5lbA==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5EZW5zZV8zLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5EZW5zZV8zLmtlcm5lbA==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5FbWJlZF8wLmVtYmVkZGluZw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5MYXllck5vcm1fMC5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5MYXllck5vcm1fMC5zY2FsZQ==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5MYXllck5vcm1fMS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5MYXllck5vcm1fMS5zY2FsZQ==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzAuRGVuc2VfMC5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzAuRGVuc2VfMC5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzAuRGVuc2VfMS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzAuRGVuc2VfMS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzAuTGF5ZXJOb3JtXzAuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzAuTGF5ZXJOb3JtXzAuc2NhbGU=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzAuTGF5ZXJOb3JtXzEuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzAuTGF5ZXJOb3JtXzEuc2NhbGU=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzAuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLm91dC5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzAuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLm91dC5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzAuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLmtleS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzAuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLmtleS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzAuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLnF1ZXJ5LmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzAuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLnF1ZXJ5Lmtlcm5lbA==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzAuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLnZhbHVlLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzAuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLnZhbHVlLmtlcm5lbA==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzEuRGVuc2VfMC5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzEuRGVuc2VfMC5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzEuRGVuc2VfMS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzEuRGVuc2VfMS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzEuTGF5ZXJOb3JtXzAuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzEuTGF5ZXJOb3JtXzAuc2NhbGU=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzEuTGF5ZXJOb3JtXzEuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzEuTGF5ZXJOb3JtXzEuc2NhbGU=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzEuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLm91dC5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzEuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLm91dC5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzEuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLmtleS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzEuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLmtleS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzEuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLnF1ZXJ5LmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzEuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLnF1ZXJ5Lmtlcm5lbA==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzEuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLnZhbHVlLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzEuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLnZhbHVlLmtlcm5lbA==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzIuRGVuc2VfMC5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzIuRGVuc2VfMC5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzIuRGVuc2VfMS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzIuRGVuc2VfMS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzIuTGF5ZXJOb3JtXzAuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzIuTGF5ZXJOb3JtXzAuc2NhbGU=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzIuTGF5ZXJOb3JtXzEuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzIuTGF5ZXJOb3JtXzEuc2NhbGU=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzIuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLm91dC5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzIuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLm91dC5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzIuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLmtleS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzIuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLmtleS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzIuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLnF1ZXJ5LmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzIuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLnF1ZXJ5Lmtlcm5lbA==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzIuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLnZhbHVlLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzIuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLnZhbHVlLmtlcm5lbA==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzMuRGVuc2VfMC5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzMuRGVuc2VfMC5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzMuRGVuc2VfMS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzMuRGVuc2VfMS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzMuTGF5ZXJOb3JtXzAuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzMuTGF5ZXJOb3JtXzAuc2NhbGU=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzMuTGF5ZXJOb3JtXzEuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzMuTGF5ZXJOb3JtXzEuc2NhbGU=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzMuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLm91dC5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzMuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLm91dC5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzMuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLmtleS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzMuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLmtleS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzMuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLnF1ZXJ5LmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzMuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLnF1ZXJ5Lmtlcm5lbA==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzMuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLnZhbHVlLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzMuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLnZhbHVlLmtlcm5lbA==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzQuRGVuc2VfMC5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzQuRGVuc2VfMC5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzQuRGVuc2VfMS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzQuRGVuc2VfMS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzQuTGF5ZXJOb3JtXzAuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzQuTGF5ZXJOb3JtXzAuc2NhbGU=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzQuTGF5ZXJOb3JtXzEuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzQuTGF5ZXJOb3JtXzEuc2NhbGU=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzQuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLm91dC5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzQuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLm91dC5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzQuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLmtleS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzQuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLmtleS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzQuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLnF1ZXJ5LmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzQuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLnF1ZXJ5Lmtlcm5lbA==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzQuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLnZhbHVlLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzQuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLnZhbHVlLmtlcm5lbA==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzUuRGVuc2VfMC5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzUuRGVuc2VfMC5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzUuRGVuc2VfMS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzUuRGVuc2VfMS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzUuTGF5ZXJOb3JtXzAuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzUuTGF5ZXJOb3JtXzAuc2NhbGU=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzUuTGF5ZXJOb3JtXzEuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzUuTGF5ZXJOb3JtXzEuc2NhbGU=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzUuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLm91dC5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzUuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLm91dC5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzUuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLmtleS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzUuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLmtleS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzUuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLnF1ZXJ5LmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzUuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLnF1ZXJ5Lmtlcm5lbA==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzUuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLnZhbHVlLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzUuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLnZhbHVlLmtlcm5lbA==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}"} \ No newline at end of file diff --git a/checkpoints/online/Craftax-Classic-Symbolic-v1-OnlineDiffusion-DAgger-50M/50000000/default/array_metadatas/process_0 b/checkpoints/online/Craftax-Classic-Symbolic-v1-OnlineDiffusion-DAgger-50M/50000000/default/array_metadatas/process_0 new file mode 100644 index 0000000000000000000000000000000000000000..6dc09f7d71ab2572709c42a52d44ad5b59e2d1ea --- /dev/null +++ b/checkpoints/online/Craftax-Classic-Symbolic-v1-OnlineDiffusion-DAgger-50M/50000000/default/array_metadatas/process_0 @@ -0,0 +1 @@ +{"array_metadatas": [{"array_metadata": {"param_name": "params.params.Dense_0.bias", "write_shape": [768], "chunk_shape": [768], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.Dense_0.kernel", "write_shape": [1345, 768], "chunk_shape": [1345, 768], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.Dense_1.bias", "write_shape": [768], "chunk_shape": [768], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.Dense_1.kernel", "write_shape": [768, 768], "chunk_shape": [768, 768], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.Dense_2.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.Dense_2.kernel", "write_shape": [768, 384], "chunk_shape": [768, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.Dense_3.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.Dense_3.kernel", "write_shape": [384, 384], "chunk_shape": [384, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.Dense_4.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.Dense_4.kernel", "write_shape": [384, 384], "chunk_shape": [384, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.Dense_5.bias", "write_shape": [17], "chunk_shape": [17], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.Dense_5.kernel", "write_shape": [384, 17], "chunk_shape": [384, 17], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.Embed_0.embedding", "write_shape": [18, 384], "chunk_shape": [18, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.LayerNorm_0.bias", "write_shape": [768], "chunk_shape": [768], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.LayerNorm_0.scale", "write_shape": [768], "chunk_shape": [768], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.LayerNorm_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.LayerNorm_1.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_0.Dense_0.bias", "write_shape": [768], "chunk_shape": [768], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_0.Dense_0.kernel", "write_shape": [384, 768], "chunk_shape": [384, 768], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_0.Dense_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_0.Dense_1.kernel", "write_shape": [768, 384], "chunk_shape": [768, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_0.LayerNorm_0.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_0.LayerNorm_0.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_0.LayerNorm_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_0.LayerNorm_1.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_0.MultiHeadDotProductAttention_0.key.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_0.MultiHeadDotProductAttention_0.key.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_0.MultiHeadDotProductAttention_0.out.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_0.MultiHeadDotProductAttention_0.out.kernel", "write_shape": [8, 48, 384], "chunk_shape": [8, 48, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_0.MultiHeadDotProductAttention_0.query.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_0.MultiHeadDotProductAttention_0.query.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_0.MultiHeadDotProductAttention_0.value.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_0.MultiHeadDotProductAttention_0.value.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_1.Dense_0.bias", "write_shape": [768], "chunk_shape": [768], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_1.Dense_0.kernel", "write_shape": [384, 768], "chunk_shape": [384, 768], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_1.Dense_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_1.Dense_1.kernel", "write_shape": [768, 384], "chunk_shape": [768, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_1.LayerNorm_0.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_1.LayerNorm_0.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_1.LayerNorm_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_1.LayerNorm_1.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_1.MultiHeadDotProductAttention_0.key.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_1.MultiHeadDotProductAttention_0.key.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_1.MultiHeadDotProductAttention_0.out.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_1.MultiHeadDotProductAttention_0.out.kernel", "write_shape": [8, 48, 384], "chunk_shape": [8, 48, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_1.MultiHeadDotProductAttention_0.query.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_1.MultiHeadDotProductAttention_0.query.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_1.MultiHeadDotProductAttention_0.value.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_1.MultiHeadDotProductAttention_0.value.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_2.Dense_0.bias", "write_shape": [768], "chunk_shape": [768], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_2.Dense_0.kernel", "write_shape": [384, 768], "chunk_shape": [384, 768], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_2.Dense_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_2.Dense_1.kernel", "write_shape": [768, 384], "chunk_shape": [768, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_2.LayerNorm_0.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_2.LayerNorm_0.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_2.LayerNorm_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_2.LayerNorm_1.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_2.MultiHeadDotProductAttention_0.key.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_2.MultiHeadDotProductAttention_0.key.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_2.MultiHeadDotProductAttention_0.out.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_2.MultiHeadDotProductAttention_0.out.kernel", "write_shape": [8, 48, 384], "chunk_shape": [8, 48, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_2.MultiHeadDotProductAttention_0.query.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_2.MultiHeadDotProductAttention_0.query.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_2.MultiHeadDotProductAttention_0.value.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_2.MultiHeadDotProductAttention_0.value.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_3.Dense_0.bias", "write_shape": [768], "chunk_shape": [768], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_3.Dense_0.kernel", "write_shape": [384, 768], "chunk_shape": [384, 768], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_3.Dense_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_3.Dense_1.kernel", "write_shape": [768, 384], "chunk_shape": [768, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_3.LayerNorm_0.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_3.LayerNorm_0.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_3.LayerNorm_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_3.LayerNorm_1.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_3.MultiHeadDotProductAttention_0.key.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_3.MultiHeadDotProductAttention_0.key.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_3.MultiHeadDotProductAttention_0.out.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_3.MultiHeadDotProductAttention_0.out.kernel", "write_shape": [8, 48, 384], "chunk_shape": [8, 48, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_3.MultiHeadDotProductAttention_0.query.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_3.MultiHeadDotProductAttention_0.query.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_3.MultiHeadDotProductAttention_0.value.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_3.MultiHeadDotProductAttention_0.value.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_4.Dense_0.bias", "write_shape": [768], "chunk_shape": [768], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_4.Dense_0.kernel", "write_shape": [384, 768], "chunk_shape": [384, 768], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_4.Dense_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_4.Dense_1.kernel", "write_shape": [768, 384], "chunk_shape": [768, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_4.LayerNorm_0.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_4.LayerNorm_0.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_4.LayerNorm_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_4.LayerNorm_1.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_4.MultiHeadDotProductAttention_0.key.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_4.MultiHeadDotProductAttention_0.key.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_4.MultiHeadDotProductAttention_0.out.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_4.MultiHeadDotProductAttention_0.out.kernel", "write_shape": [8, 48, 384], "chunk_shape": [8, 48, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_4.MultiHeadDotProductAttention_0.query.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_4.MultiHeadDotProductAttention_0.query.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_4.MultiHeadDotProductAttention_0.value.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_4.MultiHeadDotProductAttention_0.value.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_5.Dense_0.bias", "write_shape": [768], "chunk_shape": [768], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_5.Dense_0.kernel", "write_shape": [384, 768], "chunk_shape": [384, 768], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_5.Dense_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_5.Dense_1.kernel", "write_shape": [768, 384], "chunk_shape": [768, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_5.LayerNorm_0.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_5.LayerNorm_0.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_5.LayerNorm_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_5.LayerNorm_1.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_5.MultiHeadDotProductAttention_0.key.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_5.MultiHeadDotProductAttention_0.key.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_5.MultiHeadDotProductAttention_0.out.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_5.MultiHeadDotProductAttention_0.out.kernel", "write_shape": [8, 48, 384], "chunk_shape": [8, 48, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_5.MultiHeadDotProductAttention_0.query.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_5.MultiHeadDotProductAttention_0.query.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_5.MultiHeadDotProductAttention_0.value.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_5.MultiHeadDotProductAttention_0.value.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.count", "write_shape": [], "chunk_shape": [], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.Dense_0.bias", "write_shape": [768], "chunk_shape": [768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.Dense_0.kernel", "write_shape": [1345, 768], "chunk_shape": [1345, 768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.Dense_1.bias", "write_shape": [768], "chunk_shape": [768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.Dense_1.kernel", "write_shape": [768, 768], "chunk_shape": [768, 768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.Dense_2.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.Dense_2.kernel", "write_shape": [768, 384], "chunk_shape": [768, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.Dense_3.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.Dense_3.kernel", "write_shape": [384, 384], "chunk_shape": [384, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.Dense_4.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.Dense_4.kernel", "write_shape": [384, 384], "chunk_shape": [384, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.Dense_5.bias", "write_shape": [17], "chunk_shape": [17], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.Dense_5.kernel", "write_shape": [384, 17], "chunk_shape": [384, 17], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.Embed_0.embedding", "write_shape": [18, 384], "chunk_shape": [18, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.LayerNorm_0.bias", "write_shape": [768], "chunk_shape": [768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.LayerNorm_0.scale", "write_shape": [768], "chunk_shape": [768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.LayerNorm_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.LayerNorm_1.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_0.Dense_0.bias", "write_shape": [768], "chunk_shape": [768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_0.Dense_0.kernel", "write_shape": [384, 768], "chunk_shape": [384, 768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_0.Dense_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_0.Dense_1.kernel", "write_shape": [768, 384], "chunk_shape": [768, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_0.LayerNorm_0.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_0.LayerNorm_0.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_0.LayerNorm_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_0.LayerNorm_1.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_0.MultiHeadDotProductAttention_0.key.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_0.MultiHeadDotProductAttention_0.key.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_0.MultiHeadDotProductAttention_0.out.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_0.MultiHeadDotProductAttention_0.out.kernel", "write_shape": [8, 48, 384], "chunk_shape": [8, 48, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_0.MultiHeadDotProductAttention_0.query.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_0.MultiHeadDotProductAttention_0.query.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_0.MultiHeadDotProductAttention_0.value.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_0.MultiHeadDotProductAttention_0.value.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_1.Dense_0.bias", "write_shape": [768], "chunk_shape": [768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_1.Dense_0.kernel", "write_shape": [384, 768], "chunk_shape": [384, 768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_1.Dense_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_1.Dense_1.kernel", "write_shape": [768, 384], "chunk_shape": [768, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_1.LayerNorm_0.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_1.LayerNorm_0.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_1.LayerNorm_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_1.LayerNorm_1.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_1.MultiHeadDotProductAttention_0.key.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_1.MultiHeadDotProductAttention_0.key.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_1.MultiHeadDotProductAttention_0.out.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_1.MultiHeadDotProductAttention_0.out.kernel", "write_shape": [8, 48, 384], "chunk_shape": [8, 48, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_1.MultiHeadDotProductAttention_0.query.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_1.MultiHeadDotProductAttention_0.query.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_1.MultiHeadDotProductAttention_0.value.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_1.MultiHeadDotProductAttention_0.value.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_2.Dense_0.bias", "write_shape": [768], "chunk_shape": [768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_2.Dense_0.kernel", "write_shape": [384, 768], "chunk_shape": [384, 768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_2.Dense_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_2.Dense_1.kernel", "write_shape": [768, 384], "chunk_shape": [768, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_2.LayerNorm_0.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_2.LayerNorm_0.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_2.LayerNorm_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_2.LayerNorm_1.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_2.MultiHeadDotProductAttention_0.key.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_2.MultiHeadDotProductAttention_0.key.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_2.MultiHeadDotProductAttention_0.out.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_2.MultiHeadDotProductAttention_0.out.kernel", "write_shape": [8, 48, 384], "chunk_shape": [8, 48, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_2.MultiHeadDotProductAttention_0.query.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_2.MultiHeadDotProductAttention_0.query.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_2.MultiHeadDotProductAttention_0.value.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_2.MultiHeadDotProductAttention_0.value.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_3.Dense_0.bias", "write_shape": [768], "chunk_shape": [768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_3.Dense_0.kernel", "write_shape": [384, 768], "chunk_shape": [384, 768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_3.Dense_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_3.Dense_1.kernel", "write_shape": [768, 384], "chunk_shape": [768, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_3.LayerNorm_0.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_3.LayerNorm_0.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_3.LayerNorm_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_3.LayerNorm_1.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_3.MultiHeadDotProductAttention_0.key.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_3.MultiHeadDotProductAttention_0.key.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_3.MultiHeadDotProductAttention_0.out.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_3.MultiHeadDotProductAttention_0.out.kernel", "write_shape": [8, 48, 384], "chunk_shape": [8, 48, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_3.MultiHeadDotProductAttention_0.query.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_3.MultiHeadDotProductAttention_0.query.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_3.MultiHeadDotProductAttention_0.value.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_3.MultiHeadDotProductAttention_0.value.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_4.Dense_0.bias", "write_shape": [768], "chunk_shape": [768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_4.Dense_0.kernel", "write_shape": [384, 768], "chunk_shape": [384, 768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_4.Dense_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_4.Dense_1.kernel", "write_shape": [768, 384], "chunk_shape": [768, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_4.LayerNorm_0.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_4.LayerNorm_0.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_4.LayerNorm_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_4.LayerNorm_1.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_4.MultiHeadDotProductAttention_0.key.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_4.MultiHeadDotProductAttention_0.key.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_4.MultiHeadDotProductAttention_0.out.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_4.MultiHeadDotProductAttention_0.out.kernel", "write_shape": [8, 48, 384], "chunk_shape": [8, 48, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_4.MultiHeadDotProductAttention_0.query.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_4.MultiHeadDotProductAttention_0.query.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_4.MultiHeadDotProductAttention_0.value.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_4.MultiHeadDotProductAttention_0.value.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_5.Dense_0.bias", "write_shape": [768], "chunk_shape": [768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_5.Dense_0.kernel", "write_shape": [384, 768], "chunk_shape": [384, 768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_5.Dense_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_5.Dense_1.kernel", "write_shape": [768, 384], "chunk_shape": [768, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_5.LayerNorm_0.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_5.LayerNorm_0.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_5.LayerNorm_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_5.LayerNorm_1.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_5.MultiHeadDotProductAttention_0.key.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_5.MultiHeadDotProductAttention_0.key.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_5.MultiHeadDotProductAttention_0.out.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_5.MultiHeadDotProductAttention_0.out.kernel", "write_shape": [8, 48, 384], "chunk_shape": [8, 48, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_5.MultiHeadDotProductAttention_0.query.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_5.MultiHeadDotProductAttention_0.query.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_5.MultiHeadDotProductAttention_0.value.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_5.MultiHeadDotProductAttention_0.value.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.Dense_0.bias", "write_shape": [768], "chunk_shape": [768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.Dense_0.kernel", "write_shape": [1345, 768], "chunk_shape": [1345, 768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.Dense_1.bias", "write_shape": [768], "chunk_shape": [768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.Dense_1.kernel", "write_shape": [768, 768], "chunk_shape": [768, 768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.Dense_2.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.Dense_2.kernel", "write_shape": [768, 384], "chunk_shape": [768, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.Dense_3.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.Dense_3.kernel", "write_shape": [384, 384], "chunk_shape": [384, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.Dense_4.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.Dense_4.kernel", "write_shape": [384, 384], "chunk_shape": [384, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.Dense_5.bias", "write_shape": [17], "chunk_shape": [17], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.Dense_5.kernel", "write_shape": [384, 17], "chunk_shape": [384, 17], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.Embed_0.embedding", "write_shape": [18, 384], "chunk_shape": [18, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.LayerNorm_0.bias", "write_shape": [768], "chunk_shape": [768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.LayerNorm_0.scale", "write_shape": [768], "chunk_shape": [768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.LayerNorm_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.LayerNorm_1.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_0.Dense_0.bias", "write_shape": [768], "chunk_shape": [768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_0.Dense_0.kernel", "write_shape": [384, 768], "chunk_shape": [384, 768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_0.Dense_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_0.Dense_1.kernel", "write_shape": [768, 384], "chunk_shape": [768, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_0.LayerNorm_0.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_0.LayerNorm_0.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_0.LayerNorm_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_0.LayerNorm_1.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_0.MultiHeadDotProductAttention_0.key.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_0.MultiHeadDotProductAttention_0.key.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_0.MultiHeadDotProductAttention_0.out.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_0.MultiHeadDotProductAttention_0.out.kernel", "write_shape": [8, 48, 384], "chunk_shape": [8, 48, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_0.MultiHeadDotProductAttention_0.query.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_0.MultiHeadDotProductAttention_0.query.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_0.MultiHeadDotProductAttention_0.value.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_0.MultiHeadDotProductAttention_0.value.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_1.Dense_0.bias", "write_shape": [768], "chunk_shape": [768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_1.Dense_0.kernel", "write_shape": [384, 768], "chunk_shape": [384, 768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_1.Dense_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_1.Dense_1.kernel", "write_shape": [768, 384], "chunk_shape": [768, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_1.LayerNorm_0.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_1.LayerNorm_0.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_1.LayerNorm_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_1.LayerNorm_1.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_1.MultiHeadDotProductAttention_0.key.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_1.MultiHeadDotProductAttention_0.key.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_1.MultiHeadDotProductAttention_0.out.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_1.MultiHeadDotProductAttention_0.out.kernel", "write_shape": [8, 48, 384], "chunk_shape": [8, 48, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_1.MultiHeadDotProductAttention_0.query.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_1.MultiHeadDotProductAttention_0.query.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_1.MultiHeadDotProductAttention_0.value.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_1.MultiHeadDotProductAttention_0.value.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_2.Dense_0.bias", "write_shape": [768], "chunk_shape": [768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_2.Dense_0.kernel", "write_shape": [384, 768], "chunk_shape": [384, 768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_2.Dense_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_2.Dense_1.kernel", "write_shape": [768, 384], "chunk_shape": [768, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_2.LayerNorm_0.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_2.LayerNorm_0.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_2.LayerNorm_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_2.LayerNorm_1.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_2.MultiHeadDotProductAttention_0.key.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_2.MultiHeadDotProductAttention_0.key.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_2.MultiHeadDotProductAttention_0.out.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_2.MultiHeadDotProductAttention_0.out.kernel", "write_shape": [8, 48, 384], "chunk_shape": [8, 48, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_2.MultiHeadDotProductAttention_0.query.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_2.MultiHeadDotProductAttention_0.query.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_2.MultiHeadDotProductAttention_0.value.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_2.MultiHeadDotProductAttention_0.value.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_3.Dense_0.bias", "write_shape": [768], "chunk_shape": [768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_3.Dense_0.kernel", "write_shape": [384, 768], "chunk_shape": [384, 768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_3.Dense_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_3.Dense_1.kernel", "write_shape": [768, 384], "chunk_shape": [768, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_3.LayerNorm_0.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_3.LayerNorm_0.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_3.LayerNorm_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_3.LayerNorm_1.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_3.MultiHeadDotProductAttention_0.key.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_3.MultiHeadDotProductAttention_0.key.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_3.MultiHeadDotProductAttention_0.out.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_3.MultiHeadDotProductAttention_0.out.kernel", "write_shape": [8, 48, 384], "chunk_shape": [8, 48, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_3.MultiHeadDotProductAttention_0.query.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_3.MultiHeadDotProductAttention_0.query.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_3.MultiHeadDotProductAttention_0.value.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_3.MultiHeadDotProductAttention_0.value.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_4.Dense_0.bias", "write_shape": [768], "chunk_shape": [768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_4.Dense_0.kernel", "write_shape": [384, 768], "chunk_shape": [384, 768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_4.Dense_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_4.Dense_1.kernel", "write_shape": [768, 384], "chunk_shape": [768, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_4.LayerNorm_0.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_4.LayerNorm_0.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_4.LayerNorm_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_4.LayerNorm_1.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_4.MultiHeadDotProductAttention_0.key.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_4.MultiHeadDotProductAttention_0.key.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_4.MultiHeadDotProductAttention_0.out.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_4.MultiHeadDotProductAttention_0.out.kernel", "write_shape": [8, 48, 384], "chunk_shape": [8, 48, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_4.MultiHeadDotProductAttention_0.query.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_4.MultiHeadDotProductAttention_0.query.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_4.MultiHeadDotProductAttention_0.value.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_4.MultiHeadDotProductAttention_0.value.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_5.Dense_0.bias", "write_shape": [768], "chunk_shape": [768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_5.Dense_0.kernel", "write_shape": [384, 768], "chunk_shape": [384, 768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_5.Dense_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_5.Dense_1.kernel", "write_shape": [768, 384], "chunk_shape": [768, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_5.LayerNorm_0.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_5.LayerNorm_0.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_5.LayerNorm_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_5.LayerNorm_1.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_5.MultiHeadDotProductAttention_0.key.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_5.MultiHeadDotProductAttention_0.key.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_5.MultiHeadDotProductAttention_0.out.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_5.MultiHeadDotProductAttention_0.out.kernel", "write_shape": [8, 48, 384], "chunk_shape": [8, 48, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_5.MultiHeadDotProductAttention_0.query.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_5.MultiHeadDotProductAttention_0.query.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_5.MultiHeadDotProductAttention_0.value.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_5.MultiHeadDotProductAttention_0.value.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}]} \ No newline at end of file diff --git a/checkpoints/online/Craftax-Classic-Symbolic-v1-OnlineDiffusion-DAgger-50M/50000000/default/d/63ff4b6b75238977cfc360704c224d86 b/checkpoints/online/Craftax-Classic-Symbolic-v1-OnlineDiffusion-DAgger-50M/50000000/default/d/63ff4b6b75238977cfc360704c224d86 new file mode 100644 index 0000000000000000000000000000000000000000..fae357a26ebb3db90b38a3447d1d9da8e86e4e15 Binary files /dev/null and b/checkpoints/online/Craftax-Classic-Symbolic-v1-OnlineDiffusion-DAgger-50M/50000000/default/d/63ff4b6b75238977cfc360704c224d86 differ diff --git a/checkpoints/online/Craftax-Classic-Symbolic-v1-OnlineDiffusion-DAgger-50M/50000000/default/manifest.ocdbt b/checkpoints/online/Craftax-Classic-Symbolic-v1-OnlineDiffusion-DAgger-50M/50000000/default/manifest.ocdbt new file mode 100644 index 0000000000000000000000000000000000000000..22b3d5b45e54da7718c063cb92310f13a16d915e Binary files /dev/null and b/checkpoints/online/Craftax-Classic-Symbolic-v1-OnlineDiffusion-DAgger-50M/50000000/default/manifest.ocdbt differ diff --git a/checkpoints/online/Craftax-Classic-Symbolic-v1-OnlineDiffusion-DAgger-50M/50000000/default/ocdbt.process_0/d/63a3ac9c870d5f7eb2b74967758ee043 b/checkpoints/online/Craftax-Classic-Symbolic-v1-OnlineDiffusion-DAgger-50M/50000000/default/ocdbt.process_0/d/63a3ac9c870d5f7eb2b74967758ee043 new file mode 100644 index 0000000000000000000000000000000000000000..804a604796c2a9823bc945bd872bcbe7cbab66a3 Binary files /dev/null and b/checkpoints/online/Craftax-Classic-Symbolic-v1-OnlineDiffusion-DAgger-50M/50000000/default/ocdbt.process_0/d/63a3ac9c870d5f7eb2b74967758ee043 differ diff --git a/checkpoints/online/Craftax-Classic-Symbolic-v1-OnlineDiffusion-DAgger-50M/50000000/default/ocdbt.process_0/d/c3b086748e0ea04233c3638a3994fa30 b/checkpoints/online/Craftax-Classic-Symbolic-v1-OnlineDiffusion-DAgger-50M/50000000/default/ocdbt.process_0/d/c3b086748e0ea04233c3638a3994fa30 new file mode 100644 index 0000000000000000000000000000000000000000..74c3bcb593785b1fb3634f7100051a528ef9a785 Binary files /dev/null and b/checkpoints/online/Craftax-Classic-Symbolic-v1-OnlineDiffusion-DAgger-50M/50000000/default/ocdbt.process_0/d/c3b086748e0ea04233c3638a3994fa30 differ diff --git a/checkpoints/online/Craftax-Classic-Symbolic-v1-OnlineDiffusion-DAgger-50M/50000000/default/ocdbt.process_0/d/db1925df3ee2d3c92bea0a9878efa5fc b/checkpoints/online/Craftax-Classic-Symbolic-v1-OnlineDiffusion-DAgger-50M/50000000/default/ocdbt.process_0/d/db1925df3ee2d3c92bea0a9878efa5fc new file mode 100644 index 0000000000000000000000000000000000000000..c67c899fef975fe77b32faa9470bd7f540359e9e Binary files /dev/null and b/checkpoints/online/Craftax-Classic-Symbolic-v1-OnlineDiffusion-DAgger-50M/50000000/default/ocdbt.process_0/d/db1925df3ee2d3c92bea0a9878efa5fc differ diff --git a/checkpoints/online/Craftax-Classic-Symbolic-v1-OnlineDiffusion-DAgger-50M/50000000/default/ocdbt.process_0/d/e25dea5d414404e637e55db20175c620 b/checkpoints/online/Craftax-Classic-Symbolic-v1-OnlineDiffusion-DAgger-50M/50000000/default/ocdbt.process_0/d/e25dea5d414404e637e55db20175c620 new file mode 100644 index 0000000000000000000000000000000000000000..bffb28ab5a843d5164c98e7877cadcfab401f9a4 Binary files /dev/null and b/checkpoints/online/Craftax-Classic-Symbolic-v1-OnlineDiffusion-DAgger-50M/50000000/default/ocdbt.process_0/d/e25dea5d414404e637e55db20175c620 differ diff --git a/checkpoints/online/Craftax-Classic-Symbolic-v1-OnlineDiffusion-DAgger-50M/50000000/default/ocdbt.process_0/d/fd4cdc0c7be44d4f518c6ccca7ad654a b/checkpoints/online/Craftax-Classic-Symbolic-v1-OnlineDiffusion-DAgger-50M/50000000/default/ocdbt.process_0/d/fd4cdc0c7be44d4f518c6ccca7ad654a new file mode 100644 index 0000000000000000000000000000000000000000..bf7a6fd370cd12e6dbe30df5d9edbff2e79691eb --- /dev/null +++ b/checkpoints/online/Craftax-Classic-Symbolic-v1-OnlineDiffusion-DAgger-50M/50000000/default/ocdbt.process_0/d/fd4cdc0c7be44d4f518c6ccca7ad654a @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:1c27dc63cbdd625c2b62fac311fd37e14406b411ac848847ca4bd4e99f333419 +size 34631680 diff --git a/checkpoints/online/Craftax-Classic-Symbolic-v1-OnlineDiffusion-DAgger-50M/50000000/default/ocdbt.process_0/manifest.ocdbt b/checkpoints/online/Craftax-Classic-Symbolic-v1-OnlineDiffusion-DAgger-50M/50000000/default/ocdbt.process_0/manifest.ocdbt new file mode 100644 index 0000000000000000000000000000000000000000..fd80a355e5475f60ba25b97606b58246ac1d3956 Binary files /dev/null and b/checkpoints/online/Craftax-Classic-Symbolic-v1-OnlineDiffusion-DAgger-50M/50000000/default/ocdbt.process_0/manifest.ocdbt differ diff --git a/checkpoints/ppo_agents/Craftax-Classic-Symbolic-v1-PPO_RNN-1000M/1000000000/_CHECKPOINT_METADATA b/checkpoints/ppo_agents/Craftax-Classic-Symbolic-v1-PPO_RNN-1000M/1000000000/_CHECKPOINT_METADATA new file mode 100644 index 0000000000000000000000000000000000000000..7e036dbbbbb68b3f50bcad514a9ffea587b5f20b --- /dev/null +++ b/checkpoints/ppo_agents/Craftax-Classic-Symbolic-v1-PPO_RNN-1000M/1000000000/_CHECKPOINT_METADATA @@ -0,0 +1 @@ +{"item_handlers": {"default": "orbax.checkpoint._src.handlers.standard_checkpoint_handler.StandardCheckpointHandler"}, "metrics": {}, "performance_metrics": {}, "init_timestamp_nsecs": 1773173340517772966, "commit_timestamp_nsecs": 1773173340998852009, "custom_metadata": {}} \ No newline at end of file diff --git a/checkpoints/ppo_agents/Craftax-Classic-Symbolic-v1-PPO_RNN-1000M/1000000000/default/_METADATA b/checkpoints/ppo_agents/Craftax-Classic-Symbolic-v1-PPO_RNN-1000M/1000000000/default/_METADATA new file mode 100644 index 0000000000000000000000000000000000000000..4e46b06ae0446645451fbf801d54c3264da0ff25 --- /dev/null +++ b/checkpoints/ppo_agents/Craftax-Classic-Symbolic-v1-PPO_RNN-1000M/1000000000/default/_METADATA @@ -0,0 +1 @@ +{"tree_metadata": {"('step',)": {"key_metadata": [{"key": "step", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": []}}, "('params', 'params', 'Dense_0', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('params', 'params', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1345, 512]}}, "('params', 'params', 'Dense_1', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('params', 'params', 'Dense_1', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512, 512]}}, "('params', 'params', 'Dense_2', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('params', 'params', 'Dense_2', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_2", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512, 512]}}, "('params', 'params', 'Dense_3', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_3", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [17]}}, "('params', 'params', 'Dense_3', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_3", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512, 17]}}, "('params', 'params', 'Dense_4', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_4", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('params', 'params', 'Dense_4', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_4", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512, 512]}}, "('params', 'params', 'Dense_5', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_5", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('params', 'params', 'Dense_5', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_5", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512, 512]}}, "('params', 'params', 'Dense_6', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_6", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1]}}, "('params', 'params', 'Dense_6', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_6", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512, 1]}}, "('params', 'params', 'ScannedRNN_0', 'GRUCell_1', 'hn', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "ScannedRNN_0", "key_type": 2}, {"key": "GRUCell_1", "key_type": 2}, {"key": "hn", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('params', 'params', 'ScannedRNN_0', 'GRUCell_1', 'hn', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "ScannedRNN_0", "key_type": 2}, {"key": "GRUCell_1", "key_type": 2}, {"key": "hn", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512, 512]}}, "('params', 'params', 'ScannedRNN_0', 'GRUCell_1', 'hr', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "ScannedRNN_0", "key_type": 2}, {"key": "GRUCell_1", "key_type": 2}, {"key": "hr", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512, 512]}}, "('params', 'params', 'ScannedRNN_0', 'GRUCell_1', 'hz', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "ScannedRNN_0", "key_type": 2}, {"key": "GRUCell_1", "key_type": 2}, {"key": "hz", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512, 512]}}, "('params', 'params', 'ScannedRNN_0', 'GRUCell_1', 'in', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "ScannedRNN_0", "key_type": 2}, {"key": "GRUCell_1", "key_type": 2}, {"key": "in", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('params', 'params', 'ScannedRNN_0', 'GRUCell_1', 'in', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "ScannedRNN_0", "key_type": 2}, {"key": "GRUCell_1", "key_type": 2}, {"key": "in", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512, 512]}}, "('params', 'params', 'ScannedRNN_0', 'GRUCell_1', 'ir', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "ScannedRNN_0", "key_type": 2}, {"key": "GRUCell_1", "key_type": 2}, {"key": "ir", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('params', 'params', 'ScannedRNN_0', 'GRUCell_1', 'ir', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "ScannedRNN_0", "key_type": 2}, {"key": "GRUCell_1", "key_type": 2}, {"key": "ir", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512, 512]}}, "('params', 'params', 'ScannedRNN_0', 'GRUCell_1', 'iz', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "ScannedRNN_0", "key_type": 2}, {"key": "GRUCell_1", "key_type": 2}, {"key": "iz", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('params', 'params', 'ScannedRNN_0', 'GRUCell_1', 'iz', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "ScannedRNN_0", "key_type": 2}, {"key": "GRUCell_1", "key_type": 2}, {"key": "iz", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512, 512]}}, "('opt_state', '0')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "0", "key_type": 1}], "value_metadata": {"value_type": "None", "skip_deserialize": true}}, "('opt_state', '1', '0', 'count')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "count", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": []}}, "('opt_state', '1', '0', 'mu', 'params', 'Dense_0', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('opt_state', '1', '0', 'mu', 'params', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1345, 512]}}, "('opt_state', '1', '0', 'mu', 'params', 'Dense_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('opt_state', '1', '0', 'mu', 'params', 'Dense_1', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512, 512]}}, "('opt_state', '1', '0', 'mu', 'params', 'Dense_2', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('opt_state', '1', '0', 'mu', 'params', 'Dense_2', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_2", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512, 512]}}, "('opt_state', '1', '0', 'mu', 'params', 'Dense_3', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_3", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [17]}}, "('opt_state', '1', '0', 'mu', 'params', 'Dense_3', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_3", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512, 17]}}, "('opt_state', '1', '0', 'mu', 'params', 'Dense_4', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_4", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('opt_state', '1', '0', 'mu', 'params', 'Dense_4', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_4", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512, 512]}}, "('opt_state', '1', '0', 'mu', 'params', 'Dense_5', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_5", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('opt_state', '1', '0', 'mu', 'params', 'Dense_5', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_5", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512, 512]}}, "('opt_state', '1', '0', 'mu', 'params', 'Dense_6', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_6", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1]}}, "('opt_state', '1', '0', 'mu', 'params', 'Dense_6', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_6", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512, 1]}}, "('opt_state', '1', '0', 'mu', 'params', 'ScannedRNN_0', 'GRUCell_1', 'hn', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "ScannedRNN_0", "key_type": 2}, {"key": "GRUCell_1", "key_type": 2}, {"key": "hn", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('opt_state', '1', '0', 'mu', 'params', 'ScannedRNN_0', 'GRUCell_1', 'hn', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "ScannedRNN_0", "key_type": 2}, {"key": "GRUCell_1", "key_type": 2}, {"key": "hn", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512, 512]}}, "('opt_state', '1', '0', 'mu', 'params', 'ScannedRNN_0', 'GRUCell_1', 'hr', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "ScannedRNN_0", "key_type": 2}, {"key": "GRUCell_1", "key_type": 2}, {"key": "hr", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512, 512]}}, "('opt_state', '1', '0', 'mu', 'params', 'ScannedRNN_0', 'GRUCell_1', 'hz', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "ScannedRNN_0", "key_type": 2}, {"key": "GRUCell_1", "key_type": 2}, {"key": "hz", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512, 512]}}, "('opt_state', '1', '0', 'mu', 'params', 'ScannedRNN_0', 'GRUCell_1', 'in', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "ScannedRNN_0", "key_type": 2}, {"key": "GRUCell_1", "key_type": 2}, {"key": "in", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('opt_state', '1', '0', 'mu', 'params', 'ScannedRNN_0', 'GRUCell_1', 'in', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "ScannedRNN_0", "key_type": 2}, {"key": "GRUCell_1", "key_type": 2}, {"key": "in", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512, 512]}}, "('opt_state', '1', '0', 'mu', 'params', 'ScannedRNN_0', 'GRUCell_1', 'ir', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "ScannedRNN_0", "key_type": 2}, {"key": "GRUCell_1", "key_type": 2}, {"key": "ir", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('opt_state', '1', '0', 'mu', 'params', 'ScannedRNN_0', 'GRUCell_1', 'ir', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "ScannedRNN_0", "key_type": 2}, {"key": "GRUCell_1", "key_type": 2}, {"key": "ir", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512, 512]}}, "('opt_state', '1', '0', 'mu', 'params', 'ScannedRNN_0', 'GRUCell_1', 'iz', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "ScannedRNN_0", "key_type": 2}, {"key": "GRUCell_1", "key_type": 2}, {"key": "iz", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('opt_state', '1', '0', 'mu', 'params', 'ScannedRNN_0', 'GRUCell_1', 'iz', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "ScannedRNN_0", "key_type": 2}, {"key": "GRUCell_1", "key_type": 2}, {"key": "iz", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512, 512]}}, "('opt_state', '1', '0', 'nu', 'params', 'Dense_0', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('opt_state', '1', '0', 'nu', 'params', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1345, 512]}}, "('opt_state', '1', '0', 'nu', 'params', 'Dense_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('opt_state', '1', '0', 'nu', 'params', 'Dense_1', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512, 512]}}, "('opt_state', '1', '0', 'nu', 'params', 'Dense_2', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('opt_state', '1', '0', 'nu', 'params', 'Dense_2', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_2", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512, 512]}}, "('opt_state', '1', '0', 'nu', 'params', 'Dense_3', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_3", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [17]}}, "('opt_state', '1', '0', 'nu', 'params', 'Dense_3', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_3", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512, 17]}}, "('opt_state', '1', '0', 'nu', 'params', 'Dense_4', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_4", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('opt_state', '1', '0', 'nu', 'params', 'Dense_4', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_4", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512, 512]}}, "('opt_state', '1', '0', 'nu', 'params', 'Dense_5', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_5", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('opt_state', '1', '0', 'nu', 'params', 'Dense_5', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_5", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512, 512]}}, "('opt_state', '1', '0', 'nu', 'params', 'Dense_6', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_6", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1]}}, "('opt_state', '1', '0', 'nu', 'params', 'Dense_6', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_6", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512, 1]}}, "('opt_state', '1', '0', 'nu', 'params', 'ScannedRNN_0', 'GRUCell_1', 'hn', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "ScannedRNN_0", "key_type": 2}, {"key": "GRUCell_1", "key_type": 2}, {"key": "hn", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('opt_state', '1', '0', 'nu', 'params', 'ScannedRNN_0', 'GRUCell_1', 'hn', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "ScannedRNN_0", "key_type": 2}, {"key": "GRUCell_1", "key_type": 2}, {"key": "hn", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512, 512]}}, "('opt_state', '1', '0', 'nu', 'params', 'ScannedRNN_0', 'GRUCell_1', 'hr', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "ScannedRNN_0", "key_type": 2}, {"key": "GRUCell_1", "key_type": 2}, {"key": "hr", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512, 512]}}, "('opt_state', '1', '0', 'nu', 'params', 'ScannedRNN_0', 'GRUCell_1', 'hz', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "ScannedRNN_0", "key_type": 2}, {"key": "GRUCell_1", "key_type": 2}, {"key": "hz", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512, 512]}}, "('opt_state', '1', '0', 'nu', 'params', 'ScannedRNN_0', 'GRUCell_1', 'in', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "ScannedRNN_0", "key_type": 2}, {"key": "GRUCell_1", "key_type": 2}, {"key": "in", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('opt_state', '1', '0', 'nu', 'params', 'ScannedRNN_0', 'GRUCell_1', 'in', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "ScannedRNN_0", "key_type": 2}, {"key": "GRUCell_1", "key_type": 2}, {"key": "in", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512, 512]}}, "('opt_state', '1', '0', 'nu', 'params', 'ScannedRNN_0', 'GRUCell_1', 'ir', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "ScannedRNN_0", "key_type": 2}, {"key": "GRUCell_1", "key_type": 2}, {"key": "ir", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('opt_state', '1', '0', 'nu', 'params', 'ScannedRNN_0', 'GRUCell_1', 'ir', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "ScannedRNN_0", "key_type": 2}, {"key": "GRUCell_1", "key_type": 2}, {"key": "ir", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512, 512]}}, "('opt_state', '1', '0', 'nu', 'params', 'ScannedRNN_0', 'GRUCell_1', 'iz', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "ScannedRNN_0", "key_type": 2}, {"key": "GRUCell_1", "key_type": 2}, {"key": "iz", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('opt_state', '1', '0', 'nu', 'params', 'ScannedRNN_0', 'GRUCell_1', 'iz', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "ScannedRNN_0", "key_type": 2}, {"key": "GRUCell_1", "key_type": 2}, {"key": "iz", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512, 512]}}, "('opt_state', '1', '1', 'count')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "1", "key_type": 1}, {"key": "count", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": []}}}, "use_ocdbt": true, "use_zarr3": false, "store_array_data_equal_to_fill_value": true, "custom_metadata": null} \ No newline at end of file diff --git a/checkpoints/ppo_agents/Craftax-Classic-Symbolic-v1-PPO_RNN-1000M/1000000000/default/_sharding b/checkpoints/ppo_agents/Craftax-Classic-Symbolic-v1-PPO_RNN-1000M/1000000000/default/_sharding new file mode 100644 index 0000000000000000000000000000000000000000..429946c94c750adf4803fc2959c02aa1f8cf68a1 --- /dev/null +++ b/checkpoints/ppo_agents/Craftax-Classic-Symbolic-v1-PPO_RNN-1000M/1000000000/default/_sharding @@ -0,0 +1 @@ +{"b3B0X3N0YXRlLjEuMC5jb3VudA==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuRGVuc2VfMC5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuRGVuc2VfMC5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuRGVuc2VfMS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuRGVuc2VfMS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuRGVuc2VfMi5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuRGVuc2VfMi5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuRGVuc2VfMy5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuRGVuc2VfMy5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuRGVuc2VfNC5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuRGVuc2VfNC5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuRGVuc2VfNS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuRGVuc2VfNS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuRGVuc2VfNi5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuRGVuc2VfNi5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuU2Nhbm5lZFJOTl8wLkdSVUNlbGxfMS5obi5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuU2Nhbm5lZFJOTl8wLkdSVUNlbGxfMS5obi5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuU2Nhbm5lZFJOTl8wLkdSVUNlbGxfMS5oci5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuU2Nhbm5lZFJOTl8wLkdSVUNlbGxfMS5oei5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuU2Nhbm5lZFJOTl8wLkdSVUNlbGxfMS5pbi5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuU2Nhbm5lZFJOTl8wLkdSVUNlbGxfMS5pbi5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuU2Nhbm5lZFJOTl8wLkdSVUNlbGxfMS5pci5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuU2Nhbm5lZFJOTl8wLkdSVUNlbGxfMS5pci5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuU2Nhbm5lZFJOTl8wLkdSVUNlbGxfMS5pei5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuU2Nhbm5lZFJOTl8wLkdSVUNlbGxfMS5pei5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuRGVuc2VfMC5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuRGVuc2VfMC5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuRGVuc2VfMS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuRGVuc2VfMS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuRGVuc2VfMi5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuRGVuc2VfMi5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuRGVuc2VfMy5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuRGVuc2VfMy5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuRGVuc2VfNC5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuRGVuc2VfNC5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuRGVuc2VfNS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuRGVuc2VfNS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuRGVuc2VfNi5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuRGVuc2VfNi5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuU2Nhbm5lZFJOTl8wLkdSVUNlbGxfMS5obi5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuU2Nhbm5lZFJOTl8wLkdSVUNlbGxfMS5obi5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuU2Nhbm5lZFJOTl8wLkdSVUNlbGxfMS5oci5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuU2Nhbm5lZFJOTl8wLkdSVUNlbGxfMS5oei5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuU2Nhbm5lZFJOTl8wLkdSVUNlbGxfMS5pbi5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuU2Nhbm5lZFJOTl8wLkdSVUNlbGxfMS5pbi5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuU2Nhbm5lZFJOTl8wLkdSVUNlbGxfMS5pci5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuU2Nhbm5lZFJOTl8wLkdSVUNlbGxfMS5pci5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuU2Nhbm5lZFJOTl8wLkdSVUNlbGxfMS5pei5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuU2Nhbm5lZFJOTl8wLkdSVUNlbGxfMS5pei5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMS5jb3VudA==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","c3RlcA==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5EZW5zZV80LmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5EZW5zZV80Lmtlcm5lbA==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5EZW5zZV81LmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5EZW5zZV81Lmtlcm5lbA==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5EZW5zZV82LmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5EZW5zZV82Lmtlcm5lbA==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5EZW5zZV8wLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5EZW5zZV8wLmtlcm5lbA==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5EZW5zZV8xLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5EZW5zZV8xLmtlcm5lbA==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5EZW5zZV8yLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5EZW5zZV8yLmtlcm5lbA==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5EZW5zZV8zLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5EZW5zZV8zLmtlcm5lbA==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5TY2FubmVkUk5OXzAuR1JVQ2VsbF8xLmh6Lmtlcm5lbA==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5TY2FubmVkUk5OXzAuR1JVQ2VsbF8xLmhuLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5TY2FubmVkUk5OXzAuR1JVQ2VsbF8xLmhuLmtlcm5lbA==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5TY2FubmVkUk5OXzAuR1JVQ2VsbF8xLmhyLmtlcm5lbA==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5TY2FubmVkUk5OXzAuR1JVQ2VsbF8xLml6LmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5TY2FubmVkUk5OXzAuR1JVQ2VsbF8xLml6Lmtlcm5lbA==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5TY2FubmVkUk5OXzAuR1JVQ2VsbF8xLmluLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5TY2FubmVkUk5OXzAuR1JVQ2VsbF8xLmluLmtlcm5lbA==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5TY2FubmVkUk5OXzAuR1JVQ2VsbF8xLmlyLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5TY2FubmVkUk5OXzAuR1JVQ2VsbF8xLmlyLmtlcm5lbA==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}"} \ No newline at end of file diff --git a/checkpoints/ppo_agents/Craftax-Classic-Symbolic-v1-PPO_RNN-1000M/1000000000/default/array_metadatas/process_0 b/checkpoints/ppo_agents/Craftax-Classic-Symbolic-v1-PPO_RNN-1000M/1000000000/default/array_metadatas/process_0 new file mode 100644 index 0000000000000000000000000000000000000000..8791b9fe04d1a3996ac3d3d5d86a46145f675956 --- /dev/null +++ b/checkpoints/ppo_agents/Craftax-Classic-Symbolic-v1-PPO_RNN-1000M/1000000000/default/array_metadatas/process_0 @@ -0,0 +1 @@ +{"array_metadatas": [{"array_metadata": {"param_name": "step", "write_shape": [], "chunk_shape": [], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.Dense_0.bias", "write_shape": [512], "chunk_shape": [512], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.Dense_0.kernel", "write_shape": [1345, 512], "chunk_shape": [1345, 512], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.Dense_1.bias", "write_shape": [512], "chunk_shape": [512], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.Dense_1.kernel", "write_shape": [512, 512], "chunk_shape": [512, 512], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.Dense_2.bias", "write_shape": [512], "chunk_shape": [512], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.Dense_2.kernel", "write_shape": [512, 512], "chunk_shape": [512, 512], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.Dense_3.bias", "write_shape": [17], "chunk_shape": [17], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.Dense_3.kernel", "write_shape": [512, 17], "chunk_shape": [512, 17], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.Dense_4.bias", "write_shape": [512], "chunk_shape": [512], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.Dense_4.kernel", "write_shape": [512, 512], "chunk_shape": [512, 512], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.Dense_5.bias", "write_shape": [512], "chunk_shape": [512], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.Dense_5.kernel", "write_shape": [512, 512], "chunk_shape": [512, 512], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.Dense_6.bias", "write_shape": [1], "chunk_shape": [1], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.Dense_6.kernel", "write_shape": [512, 1], "chunk_shape": [512, 1], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.ScannedRNN_0.GRUCell_1.hn.bias", "write_shape": [512], "chunk_shape": [512], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.ScannedRNN_0.GRUCell_1.hn.kernel", "write_shape": [512, 512], "chunk_shape": [512, 512], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.ScannedRNN_0.GRUCell_1.hr.kernel", "write_shape": [512, 512], "chunk_shape": [512, 512], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.ScannedRNN_0.GRUCell_1.hz.kernel", "write_shape": [512, 512], "chunk_shape": [512, 512], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.ScannedRNN_0.GRUCell_1.in.bias", "write_shape": [512], "chunk_shape": [512], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.ScannedRNN_0.GRUCell_1.in.kernel", "write_shape": [512, 512], "chunk_shape": [512, 512], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.ScannedRNN_0.GRUCell_1.ir.bias", "write_shape": [512], "chunk_shape": [512], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.ScannedRNN_0.GRUCell_1.ir.kernel", "write_shape": [512, 512], "chunk_shape": [512, 512], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.ScannedRNN_0.GRUCell_1.iz.bias", "write_shape": [512], "chunk_shape": [512], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.ScannedRNN_0.GRUCell_1.iz.kernel", "write_shape": [512, 512], "chunk_shape": [512, 512], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.count", "write_shape": [], "chunk_shape": [], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.Dense_0.bias", "write_shape": [512], "chunk_shape": [512], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.Dense_0.kernel", "write_shape": [1345, 512], "chunk_shape": [1345, 512], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.Dense_1.bias", "write_shape": [512], "chunk_shape": [512], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.Dense_1.kernel", "write_shape": [512, 512], "chunk_shape": [512, 512], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.Dense_2.bias", "write_shape": [512], "chunk_shape": [512], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.Dense_2.kernel", "write_shape": [512, 512], "chunk_shape": [512, 512], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.Dense_3.bias", "write_shape": [17], "chunk_shape": [17], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.Dense_3.kernel", "write_shape": [512, 17], "chunk_shape": [512, 17], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.Dense_4.bias", "write_shape": [512], "chunk_shape": [512], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.Dense_4.kernel", "write_shape": [512, 512], "chunk_shape": [512, 512], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.Dense_5.bias", "write_shape": [512], "chunk_shape": [512], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.Dense_5.kernel", "write_shape": [512, 512], "chunk_shape": [512, 512], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.Dense_6.bias", "write_shape": [1], "chunk_shape": [1], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.Dense_6.kernel", "write_shape": [512, 1], "chunk_shape": [512, 1], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.ScannedRNN_0.GRUCell_1.hn.bias", "write_shape": [512], "chunk_shape": [512], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.ScannedRNN_0.GRUCell_1.hn.kernel", "write_shape": [512, 512], "chunk_shape": [512, 512], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.ScannedRNN_0.GRUCell_1.hr.kernel", "write_shape": [512, 512], "chunk_shape": [512, 512], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.ScannedRNN_0.GRUCell_1.hz.kernel", "write_shape": [512, 512], "chunk_shape": [512, 512], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.ScannedRNN_0.GRUCell_1.in.bias", "write_shape": [512], "chunk_shape": [512], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.ScannedRNN_0.GRUCell_1.in.kernel", "write_shape": [512, 512], "chunk_shape": [512, 512], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.ScannedRNN_0.GRUCell_1.ir.bias", "write_shape": [512], "chunk_shape": [512], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.ScannedRNN_0.GRUCell_1.ir.kernel", "write_shape": [512, 512], "chunk_shape": [512, 512], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.ScannedRNN_0.GRUCell_1.iz.bias", "write_shape": [512], "chunk_shape": [512], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.ScannedRNN_0.GRUCell_1.iz.kernel", "write_shape": [512, 512], "chunk_shape": [512, 512], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.Dense_0.bias", "write_shape": [512], "chunk_shape": [512], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.Dense_0.kernel", "write_shape": [1345, 512], "chunk_shape": [1345, 512], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.Dense_1.bias", "write_shape": [512], "chunk_shape": [512], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.Dense_1.kernel", "write_shape": [512, 512], "chunk_shape": [512, 512], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.Dense_2.bias", "write_shape": [512], "chunk_shape": [512], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.Dense_2.kernel", "write_shape": [512, 512], "chunk_shape": [512, 512], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.Dense_3.bias", "write_shape": [17], "chunk_shape": [17], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.Dense_3.kernel", "write_shape": [512, 17], "chunk_shape": [512, 17], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.Dense_4.bias", "write_shape": [512], "chunk_shape": [512], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.Dense_4.kernel", "write_shape": [512, 512], "chunk_shape": [512, 512], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.Dense_5.bias", "write_shape": [512], "chunk_shape": [512], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.Dense_5.kernel", "write_shape": [512, 512], "chunk_shape": [512, 512], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.Dense_6.bias", "write_shape": [1], "chunk_shape": [1], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.Dense_6.kernel", "write_shape": [512, 1], "chunk_shape": [512, 1], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.ScannedRNN_0.GRUCell_1.hn.bias", "write_shape": [512], "chunk_shape": [512], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.ScannedRNN_0.GRUCell_1.hn.kernel", "write_shape": [512, 512], "chunk_shape": [512, 512], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.ScannedRNN_0.GRUCell_1.hr.kernel", "write_shape": [512, 512], "chunk_shape": [512, 512], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.ScannedRNN_0.GRUCell_1.hz.kernel", "write_shape": [512, 512], "chunk_shape": [512, 512], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.ScannedRNN_0.GRUCell_1.in.bias", "write_shape": [512], "chunk_shape": [512], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.ScannedRNN_0.GRUCell_1.in.kernel", "write_shape": [512, 512], "chunk_shape": [512, 512], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.ScannedRNN_0.GRUCell_1.ir.bias", "write_shape": [512], "chunk_shape": [512], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.ScannedRNN_0.GRUCell_1.ir.kernel", "write_shape": [512, 512], "chunk_shape": [512, 512], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.ScannedRNN_0.GRUCell_1.iz.bias", "write_shape": [512], "chunk_shape": [512], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.ScannedRNN_0.GRUCell_1.iz.kernel", "write_shape": [512, 512], "chunk_shape": [512, 512], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.1.count", "write_shape": [], "chunk_shape": [], "ext_metadata": null}}]} \ No newline at end of file diff --git a/checkpoints/ppo_agents/Craftax-Classic-Symbolic-v1-PPO_RNN-1000M/1000000000/default/d/3d817e301205eacc425259e9b57de121 b/checkpoints/ppo_agents/Craftax-Classic-Symbolic-v1-PPO_RNN-1000M/1000000000/default/d/3d817e301205eacc425259e9b57de121 new file mode 100644 index 0000000000000000000000000000000000000000..82f4fad7fcb7fb60d275db8afe17c2b599c4127a Binary files /dev/null and b/checkpoints/ppo_agents/Craftax-Classic-Symbolic-v1-PPO_RNN-1000M/1000000000/default/d/3d817e301205eacc425259e9b57de121 differ diff --git a/checkpoints/ppo_agents/Craftax-Classic-Symbolic-v1-PPO_RNN-1000M/1000000000/default/manifest.ocdbt b/checkpoints/ppo_agents/Craftax-Classic-Symbolic-v1-PPO_RNN-1000M/1000000000/default/manifest.ocdbt new file mode 100644 index 0000000000000000000000000000000000000000..2100c67f1592ac1d222a76a33db3d076ffb9485d Binary files /dev/null and b/checkpoints/ppo_agents/Craftax-Classic-Symbolic-v1-PPO_RNN-1000M/1000000000/default/manifest.ocdbt differ diff --git a/checkpoints/ppo_agents/Craftax-Classic-Symbolic-v1-PPO_RNN-1000M/1000000000/default/ocdbt.process_0/d/369457b7c6608f1adf28eb88024d6b91 b/checkpoints/ppo_agents/Craftax-Classic-Symbolic-v1-PPO_RNN-1000M/1000000000/default/ocdbt.process_0/d/369457b7c6608f1adf28eb88024d6b91 new file mode 100644 index 0000000000000000000000000000000000000000..923b34d85e6d1c9fa2ee825866041e90b6720a54 --- /dev/null +++ b/checkpoints/ppo_agents/Craftax-Classic-Symbolic-v1-PPO_RNN-1000M/1000000000/default/ocdbt.process_0/d/369457b7c6608f1adf28eb88024d6b91 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:25bead7246dcfff317ab909012529b50f77ce42d2b976cd3d66baf4e3013ef6d +size 31166464 diff --git a/checkpoints/ppo_agents/Craftax-Classic-Symbolic-v1-PPO_RNN-1000M/1000000000/default/ocdbt.process_0/d/553952417e25cc3b880b7c458a1b4fa6 b/checkpoints/ppo_agents/Craftax-Classic-Symbolic-v1-PPO_RNN-1000M/1000000000/default/ocdbt.process_0/d/553952417e25cc3b880b7c458a1b4fa6 new file mode 100644 index 0000000000000000000000000000000000000000..8d0bbc1b3c5cb6e70fe084effae122d5917185a5 Binary files /dev/null and b/checkpoints/ppo_agents/Craftax-Classic-Symbolic-v1-PPO_RNN-1000M/1000000000/default/ocdbt.process_0/d/553952417e25cc3b880b7c458a1b4fa6 differ diff --git a/checkpoints/ppo_agents/Craftax-Classic-Symbolic-v1-PPO_RNN-1000M/1000000000/default/ocdbt.process_0/d/592d44df588422e75968df016db43e91 b/checkpoints/ppo_agents/Craftax-Classic-Symbolic-v1-PPO_RNN-1000M/1000000000/default/ocdbt.process_0/d/592d44df588422e75968df016db43e91 new file mode 100644 index 0000000000000000000000000000000000000000..7d05015519f121430dced55f611012e5d6052ad7 Binary files /dev/null and b/checkpoints/ppo_agents/Craftax-Classic-Symbolic-v1-PPO_RNN-1000M/1000000000/default/ocdbt.process_0/d/592d44df588422e75968df016db43e91 differ diff --git a/checkpoints/ppo_agents/Craftax-Classic-Symbolic-v1-PPO_RNN-1000M/1000000000/default/ocdbt.process_0/d/6bb110a840637eede93e25f5356236f9 b/checkpoints/ppo_agents/Craftax-Classic-Symbolic-v1-PPO_RNN-1000M/1000000000/default/ocdbt.process_0/d/6bb110a840637eede93e25f5356236f9 new file mode 100644 index 0000000000000000000000000000000000000000..5320f50b3e7cb2eaeb7867db10bd5e1f09ab96c1 --- /dev/null +++ b/checkpoints/ppo_agents/Craftax-Classic-Symbolic-v1-PPO_RNN-1000M/1000000000/default/ocdbt.process_0/d/6bb110a840637eede93e25f5356236f9 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f2e86dd7454e40b6286af72bd0c2c1a1df1914c39f1fef20c0a37f32accff16b +size 5246976 diff --git a/checkpoints/ppo_agents/Craftax-Classic-Symbolic-v1-PPO_RNN-1000M/1000000000/default/ocdbt.process_0/d/ca2a63995e2d2ca4dffb0ca4171ab0ee b/checkpoints/ppo_agents/Craftax-Classic-Symbolic-v1-PPO_RNN-1000M/1000000000/default/ocdbt.process_0/d/ca2a63995e2d2ca4dffb0ca4171ab0ee new file mode 100644 index 0000000000000000000000000000000000000000..32aeb3b90d60b0a7055cd0d594f6225851c8357d Binary files /dev/null and b/checkpoints/ppo_agents/Craftax-Classic-Symbolic-v1-PPO_RNN-1000M/1000000000/default/ocdbt.process_0/d/ca2a63995e2d2ca4dffb0ca4171ab0ee differ diff --git a/checkpoints/ppo_agents/Craftax-Classic-Symbolic-v1-PPO_RNN-1000M/1000000000/default/ocdbt.process_0/manifest.ocdbt b/checkpoints/ppo_agents/Craftax-Classic-Symbolic-v1-PPO_RNN-1000M/1000000000/default/ocdbt.process_0/manifest.ocdbt new file mode 100644 index 0000000000000000000000000000000000000000..d342ef752f388a0dc1c6ddfe900594f94a6f61b0 Binary files /dev/null and b/checkpoints/ppo_agents/Craftax-Classic-Symbolic-v1-PPO_RNN-1000M/1000000000/default/ocdbt.process_0/manifest.ocdbt differ diff --git a/checkpoints/ppo_agents/Craftax-Symbolic-v1-PPO_RNN-1000M/1000000000/_CHECKPOINT_METADATA b/checkpoints/ppo_agents/Craftax-Symbolic-v1-PPO_RNN-1000M/1000000000/_CHECKPOINT_METADATA new file mode 100644 index 0000000000000000000000000000000000000000..1fc6ed7f6c28d1a129092315199a11b4692de31d --- /dev/null +++ b/checkpoints/ppo_agents/Craftax-Symbolic-v1-PPO_RNN-1000M/1000000000/_CHECKPOINT_METADATA @@ -0,0 +1 @@ +{"item_handlers": {"default": "orbax.checkpoint._src.handlers.standard_checkpoint_handler.StandardCheckpointHandler"}, "metrics": {}, "performance_metrics": {}, "init_timestamp_nsecs": 1774217044031399688, "commit_timestamp_nsecs": 1774217044369893130, "custom_metadata": {}} \ No newline at end of file diff --git a/checkpoints/ppo_agents/Craftax-Symbolic-v1-PPO_RNN-1000M/1000000000/default/_METADATA b/checkpoints/ppo_agents/Craftax-Symbolic-v1-PPO_RNN-1000M/1000000000/default/_METADATA new file mode 100644 index 0000000000000000000000000000000000000000..fdb5a1b7aa2860843779d07bdcb8229313c52d71 --- /dev/null +++ b/checkpoints/ppo_agents/Craftax-Symbolic-v1-PPO_RNN-1000M/1000000000/default/_METADATA @@ -0,0 +1 @@ +{"tree_metadata": {"('step',)": {"key_metadata": [{"key": "step", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": []}}, "('params', 'params', 'Dense_0', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('params', 'params', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8268, 512]}}, "('params', 'params', 'Dense_1', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('params', 'params', 'Dense_1', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512, 512]}}, "('params', 'params', 'Dense_2', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('params', 'params', 'Dense_2', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_2", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512, 512]}}, "('params', 'params', 'Dense_3', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_3", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [43]}}, "('params', 'params', 'Dense_3', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_3", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512, 43]}}, "('params', 'params', 'Dense_4', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_4", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('params', 'params', 'Dense_4', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_4", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512, 512]}}, "('params', 'params', 'Dense_5', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_5", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('params', 'params', 'Dense_5', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_5", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512, 512]}}, "('params', 'params', 'Dense_6', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_6", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1]}}, "('params', 'params', 'Dense_6', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_6", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512, 1]}}, "('params', 'params', 'ScannedRNN_0', 'GRUCell_1', 'hn', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "ScannedRNN_0", "key_type": 2}, {"key": "GRUCell_1", "key_type": 2}, {"key": "hn", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('params', 'params', 'ScannedRNN_0', 'GRUCell_1', 'hn', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "ScannedRNN_0", "key_type": 2}, {"key": "GRUCell_1", "key_type": 2}, {"key": "hn", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512, 512]}}, "('params', 'params', 'ScannedRNN_0', 'GRUCell_1', 'hr', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "ScannedRNN_0", "key_type": 2}, {"key": "GRUCell_1", "key_type": 2}, {"key": "hr", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512, 512]}}, "('params', 'params', 'ScannedRNN_0', 'GRUCell_1', 'hz', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "ScannedRNN_0", "key_type": 2}, {"key": "GRUCell_1", "key_type": 2}, {"key": "hz", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512, 512]}}, "('params', 'params', 'ScannedRNN_0', 'GRUCell_1', 'in', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "ScannedRNN_0", "key_type": 2}, {"key": "GRUCell_1", "key_type": 2}, {"key": "in", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('params', 'params', 'ScannedRNN_0', 'GRUCell_1', 'in', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "ScannedRNN_0", "key_type": 2}, {"key": "GRUCell_1", "key_type": 2}, {"key": "in", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512, 512]}}, "('params', 'params', 'ScannedRNN_0', 'GRUCell_1', 'ir', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "ScannedRNN_0", "key_type": 2}, {"key": "GRUCell_1", "key_type": 2}, {"key": "ir", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('params', 'params', 'ScannedRNN_0', 'GRUCell_1', 'ir', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "ScannedRNN_0", "key_type": 2}, {"key": "GRUCell_1", "key_type": 2}, {"key": "ir", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512, 512]}}, "('params', 'params', 'ScannedRNN_0', 'GRUCell_1', 'iz', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "ScannedRNN_0", "key_type": 2}, {"key": "GRUCell_1", "key_type": 2}, {"key": "iz", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('params', 'params', 'ScannedRNN_0', 'GRUCell_1', 'iz', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "ScannedRNN_0", "key_type": 2}, {"key": "GRUCell_1", "key_type": 2}, {"key": "iz", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512, 512]}}, "('opt_state', '0')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "0", "key_type": 1}], "value_metadata": {"value_type": "None", "skip_deserialize": true}}, "('opt_state', '1', '0', 'count')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "count", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": []}}, "('opt_state', '1', '0', 'mu', 'params', 'Dense_0', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('opt_state', '1', '0', 'mu', 'params', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8268, 512]}}, "('opt_state', '1', '0', 'mu', 'params', 'Dense_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('opt_state', '1', '0', 'mu', 'params', 'Dense_1', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512, 512]}}, "('opt_state', '1', '0', 'mu', 'params', 'Dense_2', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('opt_state', '1', '0', 'mu', 'params', 'Dense_2', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_2", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512, 512]}}, "('opt_state', '1', '0', 'mu', 'params', 'Dense_3', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_3", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [43]}}, "('opt_state', '1', '0', 'mu', 'params', 'Dense_3', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_3", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512, 43]}}, "('opt_state', '1', '0', 'mu', 'params', 'Dense_4', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_4", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('opt_state', '1', '0', 'mu', 'params', 'Dense_4', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_4", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512, 512]}}, "('opt_state', '1', '0', 'mu', 'params', 'Dense_5', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_5", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('opt_state', '1', '0', 'mu', 'params', 'Dense_5', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_5", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512, 512]}}, "('opt_state', '1', '0', 'mu', 'params', 'Dense_6', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_6", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1]}}, "('opt_state', '1', '0', 'mu', 'params', 'Dense_6', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_6", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512, 1]}}, "('opt_state', '1', '0', 'mu', 'params', 'ScannedRNN_0', 'GRUCell_1', 'hn', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "ScannedRNN_0", "key_type": 2}, {"key": "GRUCell_1", "key_type": 2}, {"key": "hn", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('opt_state', '1', '0', 'mu', 'params', 'ScannedRNN_0', 'GRUCell_1', 'hn', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "ScannedRNN_0", "key_type": 2}, {"key": "GRUCell_1", "key_type": 2}, {"key": "hn", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512, 512]}}, "('opt_state', '1', '0', 'mu', 'params', 'ScannedRNN_0', 'GRUCell_1', 'hr', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "ScannedRNN_0", "key_type": 2}, {"key": "GRUCell_1", "key_type": 2}, {"key": "hr", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512, 512]}}, "('opt_state', '1', '0', 'mu', 'params', 'ScannedRNN_0', 'GRUCell_1', 'hz', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "ScannedRNN_0", "key_type": 2}, {"key": "GRUCell_1", "key_type": 2}, {"key": "hz", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512, 512]}}, "('opt_state', '1', '0', 'mu', 'params', 'ScannedRNN_0', 'GRUCell_1', 'in', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "ScannedRNN_0", "key_type": 2}, {"key": "GRUCell_1", "key_type": 2}, {"key": "in", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('opt_state', '1', '0', 'mu', 'params', 'ScannedRNN_0', 'GRUCell_1', 'in', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "ScannedRNN_0", "key_type": 2}, {"key": "GRUCell_1", "key_type": 2}, {"key": "in", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512, 512]}}, "('opt_state', '1', '0', 'mu', 'params', 'ScannedRNN_0', 'GRUCell_1', 'ir', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "ScannedRNN_0", "key_type": 2}, {"key": "GRUCell_1", "key_type": 2}, {"key": "ir", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('opt_state', '1', '0', 'mu', 'params', 'ScannedRNN_0', 'GRUCell_1', 'ir', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "ScannedRNN_0", "key_type": 2}, {"key": "GRUCell_1", "key_type": 2}, {"key": "ir", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512, 512]}}, "('opt_state', '1', '0', 'mu', 'params', 'ScannedRNN_0', 'GRUCell_1', 'iz', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "ScannedRNN_0", "key_type": 2}, {"key": "GRUCell_1", "key_type": 2}, {"key": "iz", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('opt_state', '1', '0', 'mu', 'params', 'ScannedRNN_0', 'GRUCell_1', 'iz', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "ScannedRNN_0", "key_type": 2}, {"key": "GRUCell_1", "key_type": 2}, {"key": "iz", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512, 512]}}, "('opt_state', '1', '0', 'nu', 'params', 'Dense_0', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('opt_state', '1', '0', 'nu', 'params', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [8268, 512]}}, "('opt_state', '1', '0', 'nu', 'params', 'Dense_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('opt_state', '1', '0', 'nu', 'params', 'Dense_1', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512, 512]}}, "('opt_state', '1', '0', 'nu', 'params', 'Dense_2', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('opt_state', '1', '0', 'nu', 'params', 'Dense_2', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_2", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512, 512]}}, "('opt_state', '1', '0', 'nu', 'params', 'Dense_3', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_3", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [43]}}, "('opt_state', '1', '0', 'nu', 'params', 'Dense_3', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_3", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512, 43]}}, "('opt_state', '1', '0', 'nu', 'params', 'Dense_4', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_4", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('opt_state', '1', '0', 'nu', 'params', 'Dense_4', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_4", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512, 512]}}, "('opt_state', '1', '0', 'nu', 'params', 'Dense_5', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_5", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('opt_state', '1', '0', 'nu', 'params', 'Dense_5', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_5", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512, 512]}}, "('opt_state', '1', '0', 'nu', 'params', 'Dense_6', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_6", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1]}}, "('opt_state', '1', '0', 'nu', 'params', 'Dense_6', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_6", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512, 1]}}, "('opt_state', '1', '0', 'nu', 'params', 'ScannedRNN_0', 'GRUCell_1', 'hn', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "ScannedRNN_0", "key_type": 2}, {"key": "GRUCell_1", "key_type": 2}, {"key": "hn", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('opt_state', '1', '0', 'nu', 'params', 'ScannedRNN_0', 'GRUCell_1', 'hn', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "ScannedRNN_0", "key_type": 2}, {"key": "GRUCell_1", "key_type": 2}, {"key": "hn", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512, 512]}}, "('opt_state', '1', '0', 'nu', 'params', 'ScannedRNN_0', 'GRUCell_1', 'hr', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "ScannedRNN_0", "key_type": 2}, {"key": "GRUCell_1", "key_type": 2}, {"key": "hr", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512, 512]}}, "('opt_state', '1', '0', 'nu', 'params', 'ScannedRNN_0', 'GRUCell_1', 'hz', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "ScannedRNN_0", "key_type": 2}, {"key": "GRUCell_1", "key_type": 2}, {"key": "hz", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512, 512]}}, "('opt_state', '1', '0', 'nu', 'params', 'ScannedRNN_0', 'GRUCell_1', 'in', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "ScannedRNN_0", "key_type": 2}, {"key": "GRUCell_1", "key_type": 2}, {"key": "in", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('opt_state', '1', '0', 'nu', 'params', 'ScannedRNN_0', 'GRUCell_1', 'in', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "ScannedRNN_0", "key_type": 2}, {"key": "GRUCell_1", "key_type": 2}, {"key": "in", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512, 512]}}, "('opt_state', '1', '0', 'nu', 'params', 'ScannedRNN_0', 'GRUCell_1', 'ir', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "ScannedRNN_0", "key_type": 2}, {"key": "GRUCell_1", "key_type": 2}, {"key": "ir", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('opt_state', '1', '0', 'nu', 'params', 'ScannedRNN_0', 'GRUCell_1', 'ir', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "ScannedRNN_0", "key_type": 2}, {"key": "GRUCell_1", "key_type": 2}, {"key": "ir", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512, 512]}}, "('opt_state', '1', '0', 'nu', 'params', 'ScannedRNN_0', 'GRUCell_1', 'iz', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "ScannedRNN_0", "key_type": 2}, {"key": "GRUCell_1", "key_type": 2}, {"key": "iz", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('opt_state', '1', '0', 'nu', 'params', 'ScannedRNN_0', 'GRUCell_1', 'iz', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "ScannedRNN_0", "key_type": 2}, {"key": "GRUCell_1", "key_type": 2}, {"key": "iz", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512, 512]}}, "('opt_state', '1', '1', 'count')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "1", "key_type": 1}, {"key": "count", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": []}}}, "use_ocdbt": true, "use_zarr3": false, "store_array_data_equal_to_fill_value": true, "custom_metadata": null} \ No newline at end of file diff --git a/checkpoints/ppo_agents/Craftax-Symbolic-v1-PPO_RNN-1000M/1000000000/default/_sharding b/checkpoints/ppo_agents/Craftax-Symbolic-v1-PPO_RNN-1000M/1000000000/default/_sharding new file mode 100644 index 0000000000000000000000000000000000000000..429946c94c750adf4803fc2959c02aa1f8cf68a1 --- /dev/null +++ b/checkpoints/ppo_agents/Craftax-Symbolic-v1-PPO_RNN-1000M/1000000000/default/_sharding @@ -0,0 +1 @@ +{"b3B0X3N0YXRlLjEuMC5jb3VudA==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuRGVuc2VfMC5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuRGVuc2VfMC5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuRGVuc2VfMS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuRGVuc2VfMS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuRGVuc2VfMi5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuRGVuc2VfMi5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuRGVuc2VfMy5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuRGVuc2VfMy5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuRGVuc2VfNC5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuRGVuc2VfNC5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuRGVuc2VfNS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuRGVuc2VfNS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuRGVuc2VfNi5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuRGVuc2VfNi5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuU2Nhbm5lZFJOTl8wLkdSVUNlbGxfMS5obi5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuU2Nhbm5lZFJOTl8wLkdSVUNlbGxfMS5obi5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuU2Nhbm5lZFJOTl8wLkdSVUNlbGxfMS5oci5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuU2Nhbm5lZFJOTl8wLkdSVUNlbGxfMS5oei5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuU2Nhbm5lZFJOTl8wLkdSVUNlbGxfMS5pbi5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuU2Nhbm5lZFJOTl8wLkdSVUNlbGxfMS5pbi5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuU2Nhbm5lZFJOTl8wLkdSVUNlbGxfMS5pci5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuU2Nhbm5lZFJOTl8wLkdSVUNlbGxfMS5pci5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuU2Nhbm5lZFJOTl8wLkdSVUNlbGxfMS5pei5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuU2Nhbm5lZFJOTl8wLkdSVUNlbGxfMS5pei5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuRGVuc2VfMC5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuRGVuc2VfMC5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuRGVuc2VfMS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuRGVuc2VfMS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuRGVuc2VfMi5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuRGVuc2VfMi5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuRGVuc2VfMy5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuRGVuc2VfMy5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuRGVuc2VfNC5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuRGVuc2VfNC5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuRGVuc2VfNS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuRGVuc2VfNS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuRGVuc2VfNi5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuRGVuc2VfNi5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuU2Nhbm5lZFJOTl8wLkdSVUNlbGxfMS5obi5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuU2Nhbm5lZFJOTl8wLkdSVUNlbGxfMS5obi5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuU2Nhbm5lZFJOTl8wLkdSVUNlbGxfMS5oci5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuU2Nhbm5lZFJOTl8wLkdSVUNlbGxfMS5oei5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuU2Nhbm5lZFJOTl8wLkdSVUNlbGxfMS5pbi5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuU2Nhbm5lZFJOTl8wLkdSVUNlbGxfMS5pbi5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuU2Nhbm5lZFJOTl8wLkdSVUNlbGxfMS5pci5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuU2Nhbm5lZFJOTl8wLkdSVUNlbGxfMS5pci5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuU2Nhbm5lZFJOTl8wLkdSVUNlbGxfMS5pei5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuU2Nhbm5lZFJOTl8wLkdSVUNlbGxfMS5pei5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMS5jb3VudA==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","c3RlcA==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5EZW5zZV80LmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5EZW5zZV80Lmtlcm5lbA==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5EZW5zZV81LmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5EZW5zZV81Lmtlcm5lbA==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5EZW5zZV82LmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5EZW5zZV82Lmtlcm5lbA==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5EZW5zZV8wLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5EZW5zZV8wLmtlcm5lbA==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5EZW5zZV8xLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5EZW5zZV8xLmtlcm5lbA==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5EZW5zZV8yLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5EZW5zZV8yLmtlcm5lbA==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5EZW5zZV8zLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5EZW5zZV8zLmtlcm5lbA==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5TY2FubmVkUk5OXzAuR1JVQ2VsbF8xLmh6Lmtlcm5lbA==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5TY2FubmVkUk5OXzAuR1JVQ2VsbF8xLmhuLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5TY2FubmVkUk5OXzAuR1JVQ2VsbF8xLmhuLmtlcm5lbA==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5TY2FubmVkUk5OXzAuR1JVQ2VsbF8xLmhyLmtlcm5lbA==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5TY2FubmVkUk5OXzAuR1JVQ2VsbF8xLml6LmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5TY2FubmVkUk5OXzAuR1JVQ2VsbF8xLml6Lmtlcm5lbA==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5TY2FubmVkUk5OXzAuR1JVQ2VsbF8xLmluLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5TY2FubmVkUk5OXzAuR1JVQ2VsbF8xLmluLmtlcm5lbA==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5TY2FubmVkUk5OXzAuR1JVQ2VsbF8xLmlyLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5TY2FubmVkUk5OXzAuR1JVQ2VsbF8xLmlyLmtlcm5lbA==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}"} \ No newline at end of file diff --git a/checkpoints/ppo_agents/Craftax-Symbolic-v1-PPO_RNN-1000M/1000000000/default/array_metadatas/process_0 b/checkpoints/ppo_agents/Craftax-Symbolic-v1-PPO_RNN-1000M/1000000000/default/array_metadatas/process_0 new file mode 100644 index 0000000000000000000000000000000000000000..d9161e15d6b3e9297e6700b5e515328ea8a34296 --- /dev/null +++ b/checkpoints/ppo_agents/Craftax-Symbolic-v1-PPO_RNN-1000M/1000000000/default/array_metadatas/process_0 @@ -0,0 +1 @@ +{"array_metadatas": [{"array_metadata": {"param_name": "step", "write_shape": [], "chunk_shape": [], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.Dense_0.bias", "write_shape": [512], "chunk_shape": [512], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.Dense_0.kernel", "write_shape": [8268, 512], "chunk_shape": [8268, 512], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.Dense_1.bias", "write_shape": [512], "chunk_shape": [512], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.Dense_1.kernel", "write_shape": [512, 512], "chunk_shape": [512, 512], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.Dense_2.bias", "write_shape": [512], "chunk_shape": [512], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.Dense_2.kernel", "write_shape": [512, 512], "chunk_shape": [512, 512], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.Dense_3.bias", "write_shape": [43], "chunk_shape": [43], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.Dense_3.kernel", "write_shape": [512, 43], "chunk_shape": [512, 43], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.Dense_4.bias", "write_shape": [512], "chunk_shape": [512], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.Dense_4.kernel", "write_shape": [512, 512], "chunk_shape": [512, 512], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.Dense_5.bias", "write_shape": [512], "chunk_shape": [512], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.Dense_5.kernel", "write_shape": [512, 512], "chunk_shape": [512, 512], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.Dense_6.bias", "write_shape": [1], "chunk_shape": [1], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.Dense_6.kernel", "write_shape": [512, 1], "chunk_shape": [512, 1], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.ScannedRNN_0.GRUCell_1.hn.bias", "write_shape": [512], "chunk_shape": [512], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.ScannedRNN_0.GRUCell_1.hn.kernel", "write_shape": [512, 512], "chunk_shape": [512, 512], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.ScannedRNN_0.GRUCell_1.hr.kernel", "write_shape": [512, 512], "chunk_shape": [512, 512], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.ScannedRNN_0.GRUCell_1.hz.kernel", "write_shape": [512, 512], "chunk_shape": [512, 512], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.ScannedRNN_0.GRUCell_1.in.bias", "write_shape": [512], "chunk_shape": [512], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.ScannedRNN_0.GRUCell_1.in.kernel", "write_shape": [512, 512], "chunk_shape": [512, 512], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.ScannedRNN_0.GRUCell_1.ir.bias", "write_shape": [512], "chunk_shape": [512], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.ScannedRNN_0.GRUCell_1.ir.kernel", "write_shape": [512, 512], "chunk_shape": [512, 512], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.ScannedRNN_0.GRUCell_1.iz.bias", "write_shape": [512], "chunk_shape": [512], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.ScannedRNN_0.GRUCell_1.iz.kernel", "write_shape": [512, 512], "chunk_shape": [512, 512], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.count", "write_shape": [], "chunk_shape": [], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.Dense_0.bias", "write_shape": [512], "chunk_shape": [512], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.Dense_0.kernel", "write_shape": [8268, 512], "chunk_shape": [8268, 512], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.Dense_1.bias", "write_shape": [512], "chunk_shape": [512], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.Dense_1.kernel", "write_shape": [512, 512], "chunk_shape": [512, 512], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.Dense_2.bias", "write_shape": [512], "chunk_shape": [512], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.Dense_2.kernel", "write_shape": [512, 512], "chunk_shape": [512, 512], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.Dense_3.bias", "write_shape": [43], "chunk_shape": [43], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.Dense_3.kernel", "write_shape": [512, 43], "chunk_shape": [512, 43], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.Dense_4.bias", "write_shape": [512], "chunk_shape": [512], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.Dense_4.kernel", "write_shape": [512, 512], "chunk_shape": [512, 512], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.Dense_5.bias", "write_shape": [512], "chunk_shape": [512], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.Dense_5.kernel", "write_shape": [512, 512], "chunk_shape": [512, 512], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.Dense_6.bias", "write_shape": [1], "chunk_shape": [1], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.Dense_6.kernel", "write_shape": [512, 1], "chunk_shape": [512, 1], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.ScannedRNN_0.GRUCell_1.hn.bias", "write_shape": [512], "chunk_shape": [512], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.ScannedRNN_0.GRUCell_1.hn.kernel", "write_shape": [512, 512], "chunk_shape": [512, 512], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.ScannedRNN_0.GRUCell_1.hr.kernel", "write_shape": [512, 512], "chunk_shape": [512, 512], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.ScannedRNN_0.GRUCell_1.hz.kernel", "write_shape": [512, 512], "chunk_shape": [512, 512], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.ScannedRNN_0.GRUCell_1.in.bias", "write_shape": [512], "chunk_shape": [512], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.ScannedRNN_0.GRUCell_1.in.kernel", "write_shape": [512, 512], "chunk_shape": [512, 512], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.ScannedRNN_0.GRUCell_1.ir.bias", "write_shape": [512], "chunk_shape": [512], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.ScannedRNN_0.GRUCell_1.ir.kernel", "write_shape": [512, 512], "chunk_shape": [512, 512], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.ScannedRNN_0.GRUCell_1.iz.bias", "write_shape": [512], "chunk_shape": [512], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.ScannedRNN_0.GRUCell_1.iz.kernel", "write_shape": [512, 512], "chunk_shape": [512, 512], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.Dense_0.bias", "write_shape": [512], "chunk_shape": [512], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.Dense_0.kernel", "write_shape": [8268, 512], "chunk_shape": [8268, 512], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.Dense_1.bias", "write_shape": [512], "chunk_shape": [512], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.Dense_1.kernel", "write_shape": [512, 512], "chunk_shape": [512, 512], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.Dense_2.bias", "write_shape": [512], "chunk_shape": [512], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.Dense_2.kernel", "write_shape": [512, 512], "chunk_shape": [512, 512], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.Dense_3.bias", "write_shape": [43], "chunk_shape": [43], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.Dense_3.kernel", "write_shape": [512, 43], "chunk_shape": [512, 43], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.Dense_4.bias", "write_shape": [512], "chunk_shape": [512], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.Dense_4.kernel", "write_shape": [512, 512], "chunk_shape": [512, 512], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.Dense_5.bias", "write_shape": [512], "chunk_shape": [512], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.Dense_5.kernel", "write_shape": [512, 512], "chunk_shape": [512, 512], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.Dense_6.bias", "write_shape": [1], "chunk_shape": [1], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.Dense_6.kernel", "write_shape": [512, 1], "chunk_shape": [512, 1], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.ScannedRNN_0.GRUCell_1.hn.bias", "write_shape": [512], "chunk_shape": [512], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.ScannedRNN_0.GRUCell_1.hn.kernel", "write_shape": [512, 512], "chunk_shape": [512, 512], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.ScannedRNN_0.GRUCell_1.hr.kernel", "write_shape": [512, 512], "chunk_shape": [512, 512], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.ScannedRNN_0.GRUCell_1.hz.kernel", "write_shape": [512, 512], "chunk_shape": [512, 512], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.ScannedRNN_0.GRUCell_1.in.bias", "write_shape": [512], "chunk_shape": [512], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.ScannedRNN_0.GRUCell_1.in.kernel", "write_shape": [512, 512], "chunk_shape": [512, 512], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.ScannedRNN_0.GRUCell_1.ir.bias", "write_shape": [512], "chunk_shape": [512], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.ScannedRNN_0.GRUCell_1.ir.kernel", "write_shape": [512, 512], "chunk_shape": [512, 512], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.ScannedRNN_0.GRUCell_1.iz.bias", "write_shape": [512], "chunk_shape": [512], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.ScannedRNN_0.GRUCell_1.iz.kernel", "write_shape": [512, 512], "chunk_shape": [512, 512], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.1.count", "write_shape": [], "chunk_shape": [], "ext_metadata": null}}]} \ No newline at end of file diff --git a/checkpoints/ppo_agents/Craftax-Symbolic-v1-PPO_RNN-1000M/1000000000/default/d/4c4db8700ec36cc1416c034bfbe0f71a b/checkpoints/ppo_agents/Craftax-Symbolic-v1-PPO_RNN-1000M/1000000000/default/d/4c4db8700ec36cc1416c034bfbe0f71a new file mode 100644 index 0000000000000000000000000000000000000000..fb6ebf9f33d1d8aa4170f17f6aee74ee7814bd7b Binary files /dev/null and b/checkpoints/ppo_agents/Craftax-Symbolic-v1-PPO_RNN-1000M/1000000000/default/d/4c4db8700ec36cc1416c034bfbe0f71a differ diff --git a/checkpoints/ppo_agents/Craftax-Symbolic-v1-PPO_RNN-1000M/1000000000/default/manifest.ocdbt b/checkpoints/ppo_agents/Craftax-Symbolic-v1-PPO_RNN-1000M/1000000000/default/manifest.ocdbt new file mode 100644 index 0000000000000000000000000000000000000000..02fec15038b36ed47875f7f6884eaee855ab03cd Binary files /dev/null and b/checkpoints/ppo_agents/Craftax-Symbolic-v1-PPO_RNN-1000M/1000000000/default/manifest.ocdbt differ diff --git a/checkpoints/ppo_agents/Craftax-Symbolic-v1-PPO_RNN-1000M/1000000000/default/ocdbt.process_0/d/0f5cd6bf71d63c0de8e8764ba0de8349 b/checkpoints/ppo_agents/Craftax-Symbolic-v1-PPO_RNN-1000M/1000000000/default/ocdbt.process_0/d/0f5cd6bf71d63c0de8e8764ba0de8349 new file mode 100644 index 0000000000000000000000000000000000000000..d34c873b8c69918f68938c5d2826dacc4a096b65 Binary files /dev/null and b/checkpoints/ppo_agents/Craftax-Symbolic-v1-PPO_RNN-1000M/1000000000/default/ocdbt.process_0/d/0f5cd6bf71d63c0de8e8764ba0de8349 differ diff --git a/checkpoints/ppo_agents/Craftax-Symbolic-v1-PPO_RNN-1000M/1000000000/default/ocdbt.process_0/d/3258293b820e7cc97e3af7b3d69255b6 b/checkpoints/ppo_agents/Craftax-Symbolic-v1-PPO_RNN-1000M/1000000000/default/ocdbt.process_0/d/3258293b820e7cc97e3af7b3d69255b6 new file mode 100644 index 0000000000000000000000000000000000000000..a00b64eed77a32cfa481a8bed9767b8ef3754688 Binary files /dev/null and b/checkpoints/ppo_agents/Craftax-Symbolic-v1-PPO_RNN-1000M/1000000000/default/ocdbt.process_0/d/3258293b820e7cc97e3af7b3d69255b6 differ diff --git a/checkpoints/ppo_agents/Craftax-Symbolic-v1-PPO_RNN-1000M/1000000000/default/ocdbt.process_0/d/3566d5252a486c26c607b2637742f32f b/checkpoints/ppo_agents/Craftax-Symbolic-v1-PPO_RNN-1000M/1000000000/default/ocdbt.process_0/d/3566d5252a486c26c607b2637742f32f new file mode 100644 index 0000000000000000000000000000000000000000..2df53ebd75e606fb7c7b6d9d9146fa05f586ee6b Binary files /dev/null and b/checkpoints/ppo_agents/Craftax-Symbolic-v1-PPO_RNN-1000M/1000000000/default/ocdbt.process_0/d/3566d5252a486c26c607b2637742f32f differ diff --git a/checkpoints/ppo_agents/Craftax-Symbolic-v1-PPO_RNN-1000M/1000000000/default/ocdbt.process_0/d/e5a4020f50167115120fe5dac41c20fb b/checkpoints/ppo_agents/Craftax-Symbolic-v1-PPO_RNN-1000M/1000000000/default/ocdbt.process_0/d/e5a4020f50167115120fe5dac41c20fb new file mode 100644 index 0000000000000000000000000000000000000000..083f9210c8836dbcfd9fe44f005a306d268a5b6c --- /dev/null +++ b/checkpoints/ppo_agents/Craftax-Symbolic-v1-PPO_RNN-1000M/1000000000/default/ocdbt.process_0/d/e5a4020f50167115120fe5dac41c20fb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:fba36a05ef8ebbfda38b7f622137b5e9c9bd56cf6d6a534c4acb34e1897082f1 +size 51884032 diff --git a/checkpoints/ppo_agents/Craftax-Symbolic-v1-PPO_RNN-1000M/1000000000/default/ocdbt.process_0/manifest.ocdbt b/checkpoints/ppo_agents/Craftax-Symbolic-v1-PPO_RNN-1000M/1000000000/default/ocdbt.process_0/manifest.ocdbt new file mode 100644 index 0000000000000000000000000000000000000000..ca6db89a1a9e9a6e8b93af317e0a23ec59be1d36 Binary files /dev/null and b/checkpoints/ppo_agents/Craftax-Symbolic-v1-PPO_RNN-1000M/1000000000/default/ocdbt.process_0/manifest.ocdbt differ diff --git a/configs/classic_exp_a_beta_fix.yaml b/configs/classic_exp_a_beta_fix.yaml new file mode 100644 index 0000000000000000000000000000000000000000..5fc8cc11e14e988cb387b642798cb2fc232e2cbf --- /dev/null +++ b/configs/classic_exp_a_beta_fix.yaml @@ -0,0 +1,90 @@ +# ============================================================================= +# Experiment A — Beta decay fix only (isolate data quality) +# Target hardware: NVIDIA RTX 3090 Ti (24 GB) +# Slows dagger_beta_decay so the expert stays present roughly 2x longer. +# Everything else identical to the baseline. +# ============================================================================= + +# ── Environment ────────────────────────────────────────────────────────────── +env_name: "Craftax-Classic-Symbolic-v1" +use_optimistic_resets: false +optimistic_reset_ratio: 16 + +# ── Transformer architecture ───────────────────────────────────────────────── +d_model: 256 +n_heads: 4 +n_layers: 4 +d_ff: 512 +obs_encoder_layers: 2 +obs_encoder_width: 512 +dropout_rate: 0.1 + +# ── Diffusion ──────────────────────────────────────────────────────────────── +plan_horizon: 32 +diffusion_schedule: "cosine" +train_sigma: 0.0 +label_smoothing: 0.0 + +# ── Reverse sampling ───────────────────────────────────────────────────────── +diffusion_steps: 15 +diffusion_steps_eval: 10 +remask_strategy: "rescale" +eta: 0.5 +use_loop: true +t_on: 0.7 +t_off: 0.3 +temperature: 0.5 +top_p: 0.95 + +# ── Optimisation ───────────────────────────────────────────────────────────── +lr: 3.0e-4 +max_grad_norm: 1.0 +lr_warmup_steps: 200 + +# ── Rollout ────────────────────────────────────────────────────────────────── +num_envs: 4096 +num_steps: 128 +num_minibatches: 16 +update_epochs: 8 +num_repeats: 1 + +# ── Offline training ───────────────────────────────────────────────────────── +offline_total_timesteps: 1.0e+8 +collect_temperature: 1.0 +return_weight_cap: 5.0 + +# ── Online DAgger training ─────────────────────────────────────────────────── +online_num_updates: 1500 # val plateaus by step 800 +dagger_beta_init: 1.0 +dagger_beta_decay: 0.9993 # beta > 0.5 at step 990, > 0.1 at step 3290 +dagger_buffer_max: 1000000 + +# ── Validation rollouts ────────────────────────────────────────────────────── +val_interval: 30 +val_diffusion_steps: 50 +val_replan_every: 4 +val_steps: 256 + +# ── Data collection ────────────────────────────────────────────────────────── +collect_num_steps: 10000000 +collect_num_envs: 128 +ppo_model_type: "ppo_rnn" +layer_size: 512 + +# ── Inference ──────────────────────────────────────────────────────────────── +eval_steps: 10000 +eval_num_envs: 32 + +# ── Checkpointing ──────────────────────────────────────────────────────────── +save_policy: true + +# ── Resume ─────────────────────────────────────────────────────────────────── +resume_checkpoint_path: null +resume_wandb_run_id: null +resume_step: null + +# ── Logging ────────────────────────────────────────────────────────────────── +seed: 42 +use_wandb: true +wandb_project: "remdm-craftax" +wandb_entity: "mathis-weil-university-college-london-ucl-" diff --git a/configs/classic_exp_b_beta_big_model.yaml b/configs/classic_exp_b_beta_big_model.yaml new file mode 100644 index 0000000000000000000000000000000000000000..cb6c0ea47cc28faf6478f4e869376855f3e77e64 --- /dev/null +++ b/configs/classic_exp_b_beta_big_model.yaml @@ -0,0 +1,90 @@ +# ============================================================================= +# Experiment B — Beta fix + bigger model (isolate model capacity) +# Target hardware: NVIDIA RTX 3090 Ti (24 GB) +# Cumulative with Experiment A: same beta decay fix, plus a 3.5x larger +# transformer (~18 GB peak VRAM). +# ============================================================================= + +# ── Environment ────────────────────────────────────────────────────────────── +env_name: "Craftax-Classic-Symbolic-v1" +use_optimistic_resets: false +optimistic_reset_ratio: 16 + +# ── Transformer architecture ───────────────────────────────────────────────── +d_model: 384 +n_heads: 8 +n_layers: 6 +d_ff: 768 +obs_encoder_layers: 2 +obs_encoder_width: 768 +dropout_rate: 0.1 + +# ── Diffusion ──────────────────────────────────────────────────────────────── +plan_horizon: 32 +diffusion_schedule: "cosine" +train_sigma: 0.0 +label_smoothing: 0.0 + +# ── Reverse sampling ───────────────────────────────────────────────────────── +diffusion_steps: 15 +diffusion_steps_eval: 10 +remask_strategy: "rescale" +eta: 0.5 +use_loop: true +t_on: 0.7 +t_off: 0.3 +temperature: 0.5 +top_p: 0.95 + +# ── Optimisation ───────────────────────────────────────────────────────────── +lr: 3.0e-4 +max_grad_norm: 1.0 +lr_warmup_steps: 200 + +# ── Rollout ────────────────────────────────────────────────────────────────── +num_envs: 4096 +num_steps: 128 +num_minibatches: 8 # minibatch 2048, better gradients than Exp A +update_epochs: 8 +num_repeats: 1 + +# ── Offline training ───────────────────────────────────────────────────────── +offline_total_timesteps: 1.0e+8 +collect_temperature: 1.0 +return_weight_cap: 5.0 + +# ── Online DAgger training ─────────────────────────────────────────────────── +online_num_updates: 1500 +dagger_beta_init: 1.0 +dagger_beta_decay: 0.9993 +dagger_buffer_max: 1000000 + +# ── Validation rollouts ────────────────────────────────────────────────────── +val_interval: 30 +val_diffusion_steps: 50 +val_replan_every: 4 +val_steps: 256 + +# ── Data collection ────────────────────────────────────────────────────────── +collect_num_steps: 10000000 +collect_num_envs: 128 +ppo_model_type: "ppo_rnn" +layer_size: 512 + +# ── Inference ──────────────────────────────────────────────────────────────── +eval_steps: 10000 +eval_num_envs: 32 + +# ── Checkpointing ──────────────────────────────────────────────────────────── +save_policy: true + +# ── Resume ─────────────────────────────────────────────────────────────────── +resume_checkpoint_path: null +resume_wandb_run_id: null +resume_step: null + +# ── Logging ────────────────────────────────────────────────────────────────── +seed: 42 +use_wandb: true +wandb_project: "remdm-craftax" +wandb_entity: "mathis-weil-university-college-london-ucl-" diff --git a/configs/classic_exp_c_full_recipe.yaml b/configs/classic_exp_c_full_recipe.yaml new file mode 100644 index 0000000000000000000000000000000000000000..6565bcf16ceccf225abc786ab6a0c2cb3adc7a35 --- /dev/null +++ b/configs/classic_exp_c_full_recipe.yaml @@ -0,0 +1,90 @@ +# ============================================================================= +# Experiment C — Full recipe (beta + big model + training dynamics) +# Target hardware: NVIDIA RTX 4090 (24 GB) +# Cumulative with Experiments A + B: also bumps diffusion_steps, lowers +# temperature, raises lr, and tightens validation to match training settings. +# ============================================================================= + +# ── Environment ────────────────────────────────────────────────────────────── +env_name: "Craftax-Classic-Symbolic-v1" +use_optimistic_resets: false +optimistic_reset_ratio: 16 + +# ── Transformer architecture ───────────────────────────────────────────────── +d_model: 384 +n_heads: 8 +n_layers: 6 +d_ff: 768 +obs_encoder_layers: 2 +obs_encoder_width: 768 +dropout_rate: 0.1 + +# ── Diffusion ──────────────────────────────────────────────────────────────── +plan_horizon: 32 +diffusion_schedule: "cosine" +train_sigma: 0.0 +label_smoothing: 0.0 + +# ── Reverse sampling ───────────────────────────────────────────────────────── +diffusion_steps: 25 # more denoising for cleaner plans +diffusion_steps_eval: 10 +remask_strategy: "rescale" +eta: 0.5 +use_loop: true +t_on: 0.7 +t_off: 0.3 +temperature: 0.3 # sharper sampling than baseline +top_p: 0.95 + +# ── Optimisation ───────────────────────────────────────────────────────────── +lr: 5.0e-4 # scale with 2x effective batch +max_grad_norm: 1.0 +lr_warmup_steps: 300 # stabilize higher LR + bigger model + +# ── Rollout ────────────────────────────────────────────────────────────────── +num_envs: 448 +num_steps: 128 +num_minibatches: 8 # minibatch 2048 +update_epochs: 8 +num_repeats: 1 + +# ── Offline training ───────────────────────────────────────────────────────── +offline_total_timesteps: 1.0e+8 +collect_temperature: 1.0 +return_weight_cap: 5.0 + +# ── Online DAgger training ─────────────────────────────────────────────────── +online_total_timesteps: 1.0e+8 +dagger_beta_init: 1.0 +dagger_beta_decay: 0.9993 +dagger_buffer_max: 1000000 + +# ── Validation rollouts ────────────────────────────────────────────────────── +val_interval: 30 +val_diffusion_steps: 25 # match diffusion_steps +val_replan_every: 2 +val_steps: 256 + +# ── Data collection ────────────────────────────────────────────────────────── +collect_num_steps: 10000000 +collect_num_envs: 128 +ppo_model_type: "ppo_rnn" +layer_size: 512 + +# ── Inference ──────────────────────────────────────────────────────────────── +eval_steps: 10000 +eval_num_envs: 32 + +# ── Checkpointing ──────────────────────────────────────────────────────────── +save_policy: true + +# ── Resume ─────────────────────────────────────────────────────────────────── +resume_checkpoint_path: null +resume_wandb_run_id: null +resume_step: null + +# ── Logging ────────────────────────────────────────────────────────────────── +seed: 42 +use_wandb: true +wandb_project: "remdm-craftax" +wandb_entity: "mathis-weil-university-college-london-ucl-" diff --git a/configs/classic_exp_d_100K_model.yaml b/configs/classic_exp_d_100K_model.yaml new file mode 100644 index 0000000000000000000000000000000000000000..7ee0c12c9b2814352ece10c5c3005262a1939837 --- /dev/null +++ b/configs/classic_exp_d_100K_model.yaml @@ -0,0 +1,84 @@ +# ============================================================================= +# Final Full Craftax DAgger — UCL RTX 4090 24 GB - EXP D - 0.1M model +# Big transformer + full sampling recipe + slow beta decay. Produces the Full +# Craftax DAgger checkpoint used by the RL fine-tuning ablation suite. +# ============================================================================= + +# ── Environment ────────────────────────────────────────────────────────────── +env_name: "Craftax-Classic-Symbolic-v1" +use_optimistic_resets: false +optimistic_reset_ratio: 16 + +# ── Transformer architecture ───────────────────────────────────────────────── +d_model: 32 +n_heads: 2 +n_layers: 1 +d_ff: 64 +obs_encoder_layers: 1 +obs_encoder_width: 64 +dropout_rate: 0.1 + +# ── Diffusion ──────────────────────────────────────────────────────────────── +plan_horizon: 32 +diffusion_schedule: "cosine" +train_sigma: 0.0 +label_smoothing: 0.0 + +# ── Reverse sampling ───────────────────────────────────────────────────────── +diffusion_steps: 25 +diffusion_steps_eval: 10 +remask_strategy: "rescale" +eta: 0.5 +use_loop: true +t_on: 0.7 +t_off: 0.3 +temperature: 0.3 +top_p: 0.95 + +# ── Optimisation ───────────────────────────────────────────────────────────── +lr: 5.0e-4 +max_grad_norm: 1.0 +lr_warmup_frames: 7.86432e7 # = 300 update steps at this hardware (300 * 2048 * 128) + +# ── Rollout ────────────────────────────────────────────────────────────────── +num_envs: 3072 +num_steps: 128 +num_minibatches: 8 +update_epochs: 8 +num_repeats: 1 + +# ── Offline training ───────────────────────────────────────────────────────── +offline_total_timesteps: 2.0e+8 +collect_temperature: 1.0 +return_weight_cap: 5.0 + +# ── Online DAgger training ─────────────────────────────────────────────────── +online_total_timesteps: 2.0e+8 +dagger_beta_init: 1.0 +dagger_beta_final: 0.385 # 0.9995^1907 at this hardware; auto-resolves per fpu +dagger_buffer_cycles: 0.76294 # ~200K samples on UCL = ~0.76 update cycles of history + +# ── Validation rollouts ────────────────────────────────────────────────────── +val_interval_frames: 1.0e+6 +val_diffusion_steps: 25 +val_replan_every: 2 +val_steps: 256 + +# ── Data collection ────────────────────────────────────────────────────────── +collect_num_steps: 10000000 +collect_num_envs: 128 +ppo_model_type: "ppo_rnn" +layer_size: 512 + +# ── Inference ──────────────────────────────────────────────────────────────── +eval_steps: 10000 +eval_num_envs: 32 + +# ── Checkpointing ──────────────────────────────────────────────────────────── +save_policy: true + +# ── Logging ────────────────────────────────────────────────────────────────── +seed: 42 +use_wandb: true +wandb_project: "remdm-craftax" +wandb_entity: "mathis-weil-university-college-london-ucl-" diff --git a/configs/classic_exp_d_250K_model.yaml b/configs/classic_exp_d_250K_model.yaml new file mode 100644 index 0000000000000000000000000000000000000000..3d96a18a41bc13bb472747769c9d2d109628ef0a --- /dev/null +++ b/configs/classic_exp_d_250K_model.yaml @@ -0,0 +1,84 @@ +# ============================================================================= +# Final Full Craftax DAgger — UCL RTX 4090 24 GB - EXP D - 0.25M model +# Big transformer + full sampling recipe + slow beta decay. Produces the Full +# Craftax DAgger checkpoint used by the RL fine-tuning ablation suite. +# ============================================================================= + +# ── Environment ────────────────────────────────────────────────────────────── +env_name: "Craftax-Classic-Symbolic-v1" +use_optimistic_resets: false +optimistic_reset_ratio: 16 + +# ── Transformer architecture ───────────────────────────────────────────────── +d_model: 64 +n_heads: 2 +n_layers: 2 +d_ff: 128 +obs_encoder_layers: 1 +obs_encoder_width: 128 +dropout_rate: 0.1 + +# ── Diffusion ──────────────────────────────────────────────────────────────── +plan_horizon: 32 +diffusion_schedule: "cosine" +train_sigma: 0.0 +label_smoothing: 0.0 + +# ── Reverse sampling ───────────────────────────────────────────────────────── +diffusion_steps: 25 +diffusion_steps_eval: 10 +remask_strategy: "rescale" +eta: 0.5 +use_loop: true +t_on: 0.7 +t_off: 0.3 +temperature: 0.3 +top_p: 0.95 + +# ── Optimisation ───────────────────────────────────────────────────────────── +lr: 5.0e-4 +max_grad_norm: 1.0 +lr_warmup_frames: 7.86432e7 # = 300 update steps at this hardware (300 * 2048 * 128) + +# ── Rollout ────────────────────────────────────────────────────────────────── +num_envs: 6144 +num_steps: 128 +num_minibatches: 8 +update_epochs: 8 +num_repeats: 1 + +# ── Offline training ───────────────────────────────────────────────────────── +offline_total_timesteps: 2.0e+8 +collect_temperature: 1.0 +return_weight_cap: 5.0 + +# ── Online DAgger training ─────────────────────────────────────────────────── +online_total_timesteps: 2.0e+8 +dagger_beta_init: 1.0 +dagger_beta_final: 0.385 # 0.9995^1907 at this hardware; auto-resolves per fpu +dagger_buffer_cycles: 0.76294 # ~200K samples on UCL = ~0.76 update cycles of history + +# ── Validation rollouts ────────────────────────────────────────────────────── +val_interval_frames: 1.0e+6 +val_diffusion_steps: 25 +val_replan_every: 2 +val_steps: 256 + +# ── Data collection ────────────────────────────────────────────────────────── +collect_num_steps: 10000000 +collect_num_envs: 128 +ppo_model_type: "ppo_rnn" +layer_size: 512 + +# ── Inference ──────────────────────────────────────────────────────────────── +eval_steps: 10000 +eval_num_envs: 32 + +# ── Checkpointing ──────────────────────────────────────────────────────────── +save_policy: true + +# ── Logging ────────────────────────────────────────────────────────────────── +seed: 42 +use_wandb: true +wandb_project: "remdm-craftax" +wandb_entity: "mathis-weil-university-college-london-ucl-" diff --git a/configs/classic_exp_d_3M_model.yaml b/configs/classic_exp_d_3M_model.yaml new file mode 100644 index 0000000000000000000000000000000000000000..feb140804795139458315344007014cc5566db65 --- /dev/null +++ b/configs/classic_exp_d_3M_model.yaml @@ -0,0 +1,84 @@ +# ============================================================================= +# Final Craftax Classic DAgger — UCL RTX 3090 Ti 24 GB - EXP D - 3.2M model +# Big transformer + slow beta decay. Produces the Classic DAgger checkpoint +# used by the RL fine-tuning ablation suite. +# ============================================================================= + +# ── Environment ────────────────────────────────────────────────────────────── +env_name: "Craftax-Classic-Symbolic-v1" +use_optimistic_resets: false +optimistic_reset_ratio: 16 + +# ── Transformer architecture ───────────────────────────────────────────────── +d_model: 256 +n_heads: 4 +n_layers: 4 +d_ff: 512 +obs_encoder_layers: 2 +obs_encoder_width: 512 +dropout_rate: 0.1 + +# ── Diffusion ──────────────────────────────────────────────────────────────── +plan_horizon: 32 +diffusion_schedule: "cosine" +train_sigma: 0.0 +label_smoothing: 0.0 + +# ── Reverse sampling ───────────────────────────────────────────────────────── +diffusion_steps: 15 +diffusion_steps_eval: 10 +remask_strategy: "rescale" +eta: 0.5 +use_loop: true +t_on: 0.7 +t_off: 0.3 +temperature: 0.5 +top_p: 0.95 + +# ── Optimisation ───────────────────────────────────────────────────────────── +lr: 3.0e-4 +max_grad_norm: 1.0 +lr_warmup_frames: 1.048576e8 # = 200 update steps at this hardware (200 * 4096 * 128) + +# ── Rollout ────────────────────────────────────────────────────────────────── +num_envs: 1120 +num_steps: 128 +num_minibatches: 8 +update_epochs: 8 +num_repeats: 1 + +# ── Offline training ───────────────────────────────────────────────────────── +offline_total_timesteps: 2.0e+8 +collect_temperature: 1.0 +return_weight_cap: 5.0 + +# ── Online DAgger training ─────────────────────────────────────────────────── +online_total_timesteps: 2.0e+8 +dagger_beta_init: 1.0 +dagger_beta_final: 0.344 # 0.9993^1525 at this hardware; auto-resolves per fpu +dagger_buffer_cycles: 1.90735 # ~1M samples on UCL = ~1.91 update cycles of history + +# ── Validation rollouts ────────────────────────────────────────────────────── +val_interval_frames: 1.0e+6 +val_diffusion_steps: 50 +val_replan_every: 4 +val_steps: 256 + +# ── Data collection ────────────────────────────────────────────────────────── +collect_num_steps: 10000000 +collect_num_envs: 128 +ppo_model_type: "ppo_rnn" +layer_size: 512 + +# ── Inference ──────────────────────────────────────────────────────────────── +eval_steps: 10000 +eval_num_envs: 32 + +# ── Checkpointing ──────────────────────────────────────────────────────────── +save_policy: true + +# ── Logging ────────────────────────────────────────────────────────────────── +seed: 42 +use_wandb: true +wandb_project: "remdm-craftax" +wandb_entity: "mathis-weil-university-college-london-ucl-" diff --git a/configs/classic_exp_d_850K_model.yaml b/configs/classic_exp_d_850K_model.yaml new file mode 100644 index 0000000000000000000000000000000000000000..9e1c2c49c3faccefb98985734a7c0a2e2a1459c9 --- /dev/null +++ b/configs/classic_exp_d_850K_model.yaml @@ -0,0 +1,84 @@ +# ============================================================================= +# Final Full Craftax DAgger — UCL RTX 4090 24 GB - EXP D - 0.85M model +# Big transformer + full sampling recipe + slow beta decay. Produces the Full +# Craftax DAgger checkpoint used by the RL fine-tuning ablation suite. +# ============================================================================= + +# ── Environment ────────────────────────────────────────────────────────────── +env_name: "Craftax-Classic-Symbolic-v1" +use_optimistic_resets: false +optimistic_reset_ratio: 16 + +# ── Transformer architecture ───────────────────────────────────────────────── +d_model: 128 +n_heads: 4 +n_layers: 3 +d_ff: 256 +obs_encoder_layers: 2 +obs_encoder_width: 256 +dropout_rate: 0.1 + +# ── Diffusion ──────────────────────────────────────────────────────────────── +plan_horizon: 32 +diffusion_schedule: "cosine" +train_sigma: 0.0 +label_smoothing: 0.0 + +# ── Reverse sampling ───────────────────────────────────────────────────────── +diffusion_steps: 25 +diffusion_steps_eval: 10 +remask_strategy: "rescale" +eta: 0.5 +use_loop: true +t_on: 0.7 +t_off: 0.3 +temperature: 0.3 +top_p: 0.95 + +# ── Optimisation ───────────────────────────────────────────────────────────── +lr: 5.0e-4 +max_grad_norm: 1.0 +lr_warmup_frames: 7.86432e7 # = 300 update steps at this hardware (300 * 2048 * 128) + +# ── Rollout ────────────────────────────────────────────────────────────────── +num_envs: 5120 +num_steps: 128 +num_minibatches: 8 +update_epochs: 8 +num_repeats: 1 + +# ── Offline training ───────────────────────────────────────────────────────── +offline_total_timesteps: 2.0e+8 +collect_temperature: 1.0 +return_weight_cap: 5.0 + +# ── Online DAgger training ─────────────────────────────────────────────────── +online_total_timesteps: 2.0e+8 +dagger_beta_init: 1.0 +dagger_beta_final: 0.385 # 0.9995^1907 at this hardware; auto-resolves per fpu +dagger_buffer_cycles: 0.76294 # ~200K samples on UCL = ~0.76 update cycles of history + +# ── Validation rollouts ────────────────────────────────────────────────────── +val_interval_frames: 1.0e+6 +val_diffusion_steps: 25 +val_replan_every: 2 +val_steps: 256 + +# ── Data collection ────────────────────────────────────────────────────────── +collect_num_steps: 10000000 +collect_num_envs: 128 +ppo_model_type: "ppo_rnn" +layer_size: 512 + +# ── Inference ──────────────────────────────────────────────────────────────── +eval_steps: 10000 +eval_num_envs: 32 + +# ── Checkpointing ──────────────────────────────────────────────────────────── +save_policy: true + +# ── Logging ────────────────────────────────────────────────────────────────── +seed: 42 +use_wandb: true +wandb_project: "remdm-craftax" +wandb_entity: "mathis-weil-university-college-london-ucl-" diff --git a/configs/craftax_exp_a_beta_fix.yaml b/configs/craftax_exp_a_beta_fix.yaml new file mode 100644 index 0000000000000000000000000000000000000000..20b2048e3c63d54e839dae1a564d3577da058fa4 --- /dev/null +++ b/configs/craftax_exp_a_beta_fix.yaml @@ -0,0 +1,90 @@ +# ============================================================================= +# Experiment A-Full — Beta fix only on Full Craftax +# Target hardware: NVIDIA RTX 4090 (24 GB) +# Same as exp_a_beta_fix.yaml but on Craftax-Symbolic-v1 (obs_dim 8268, 43 +# actions). num_envs and dagger_buffer_max are dropped to fit the larger obs. +# ============================================================================= + +# ── Environment ────────────────────────────────────────────────────────────── +env_name: "Craftax-Symbolic-v1" +use_optimistic_resets: false +optimistic_reset_ratio: 16 + +# ── Transformer architecture ───────────────────────────────────────────────── +d_model: 256 +n_heads: 4 +n_layers: 4 +d_ff: 512 +obs_encoder_layers: 2 +obs_encoder_width: 512 +dropout_rate: 0.1 + +# ── Diffusion ──────────────────────────────────────────────────────────────── +plan_horizon: 32 +diffusion_schedule: "cosine" +train_sigma: 0.0 +label_smoothing: 0.0 + +# ── Reverse sampling ───────────────────────────────────────────────────────── +diffusion_steps: 15 +diffusion_steps_eval: 10 +remask_strategy: "rescale" +eta: 0.5 +use_loop: true +t_on: 0.7 +t_off: 0.3 +temperature: 0.5 +top_p: 0.95 + +# ── Optimisation ───────────────────────────────────────────────────────────── +lr: 3.0e-4 +max_grad_norm: 1.0 +lr_warmup_steps: 200 + +# ── Rollout ────────────────────────────────────────────────────────────────── +num_envs: 2048 # halved from 4096: larger obs per env +num_steps: 128 +num_minibatches: 8 # minibatch = 2048 * 128 / 8 = 32768 +update_epochs: 8 +num_repeats: 1 + +# ── Offline training ───────────────────────────────────────────────────────── +offline_total_timesteps: 1.0e+8 +collect_temperature: 1.0 +return_weight_cap: 5.0 + +# ── Online DAgger training ─────────────────────────────────────────────────── +online_num_updates: 2000 # harder task, more room to learn +dagger_beta_init: 1.0 +dagger_beta_decay: 0.9993 +dagger_buffer_max: 200000 # 8268 obs -> ~6.6 GB at 200k + +# ── Validation rollouts ────────────────────────────────────────────────────── +val_interval: 30 +val_diffusion_steps: 50 +val_replan_every: 4 +val_steps: 256 + +# ── Data collection ────────────────────────────────────────────────────────── +collect_num_steps: 10000000 +collect_num_envs: 128 +ppo_model_type: "ppo_rnn" +layer_size: 512 + +# ── Inference ──────────────────────────────────────────────────────────────── +eval_steps: 10000 +eval_num_envs: 32 + +# ── Checkpointing ──────────────────────────────────────────────────────────── +save_policy: true + +# ── Resume ─────────────────────────────────────────────────────────────────── +resume_checkpoint_path: null +resume_wandb_run_id: null +resume_step: null + +# ── Logging ────────────────────────────────────────────────────────────────── +seed: 42 +use_wandb: true +wandb_project: "remdm-craftax" +wandb_entity: "mathis-weil-university-college-london-ucl-" diff --git a/configs/craftax_exp_b_beta_big_model.yaml b/configs/craftax_exp_b_beta_big_model.yaml new file mode 100644 index 0000000000000000000000000000000000000000..ccf979e9a4d98859034de4cd38ad1d992699cbd7 --- /dev/null +++ b/configs/craftax_exp_b_beta_big_model.yaml @@ -0,0 +1,90 @@ +# ============================================================================= +# Experiment C-Full — Full recipe on Full Craftax +# Target hardware: NVIDIA RTX 3090 Ti / 4090 (24 GB) +# Same as exp_c_full_recipe.yaml but on Craftax-Symbolic-v1 (obs_dim 8268, 43 +# actions). num_envs and dagger_buffer_max are dropped to fit the larger obs. +# ============================================================================= + +# ── Environment ────────────────────────────────────────────────────────────── +env_name: "Craftax-Symbolic-v1" +use_optimistic_resets: false +optimistic_reset_ratio: 16 + +# ── Transformer architecture ───────────────────────────────────────────────── +d_model: 384 +n_heads: 8 +n_layers: 6 +d_ff: 768 +obs_encoder_layers: 2 +obs_encoder_width: 768 +dropout_rate: 0.1 + +# ── Diffusion ──────────────────────────────────────────────────────────────── +plan_horizon: 32 +diffusion_schedule: "cosine" +train_sigma: 0.0 +label_smoothing: 0.0 + +# ── Reverse sampling ───────────────────────────────────────────────────────── +diffusion_steps: 25 +diffusion_steps_eval: 10 +remask_strategy: "rescale" +eta: 0.5 +use_loop: true +t_on: 0.7 +t_off: 0.3 +temperature: 0.3 +top_p: 0.95 + +# ── Optimisation ───────────────────────────────────────────────────────────── +lr: 5.0e-4 +max_grad_norm: 1.0 +lr_warmup_steps: 300 + +# ── Rollout ────────────────────────────────────────────────────────────────── +num_envs: 448 # halved from 4096: larger obs per env +num_steps: 128 +num_minibatches: 8 # minibatch = 2048 * 128 / 8 = 32768 +update_epochs: 8 +num_repeats: 1 + +# ── Offline training ───────────────────────────────────────────────────────── +offline_total_timesteps: 1.0e+8 +collect_temperature: 1.0 +return_weight_cap: 5.0 + +# ── Online DAgger training ─────────────────────────────────────────────────── +online_total_timesteps: 1.0e+8 # harder task, more room to learn +dagger_beta_init: 1.0 +dagger_beta_decay: 0.9993 +dagger_buffer_max: 200000 # 8268 obs -> ~6.6 GB at 200k + +# ── Validation rollouts ────────────────────────────────────────────────────── +val_interval_frames: 1.0e+6 +val_diffusion_steps: 25 # match diffusion_steps +val_replan_every: 2 +val_steps: 256 + +# ── Data collection ────────────────────────────────────────────────────────── +collect_num_steps: 10000000 +collect_num_envs: 128 +ppo_model_type: "ppo_rnn" +layer_size: 512 + +# ── Inference ──────────────────────────────────────────────────────────────── +eval_steps: 10000 +eval_num_envs: 32 + +# ── Checkpointing ──────────────────────────────────────────────────────────── +save_policy: true + +# ── Resume ─────────────────────────────────────────────────────────────────── +resume_checkpoint_path: null +resume_wandb_run_id: null +resume_step: null + +# ── Logging ────────────────────────────────────────────────────────────────── +seed: 42 +use_wandb: true +wandb_project: "remdm-craftax" +wandb_entity: "mathis-weil-university-college-london-ucl-" diff --git a/configs/craftax_exp_c_full_recipe.yaml b/configs/craftax_exp_c_full_recipe.yaml new file mode 100644 index 0000000000000000000000000000000000000000..26b97abf03f74f25e025a454958af4ce74b924d5 --- /dev/null +++ b/configs/craftax_exp_c_full_recipe.yaml @@ -0,0 +1,90 @@ +# ============================================================================= +# Experiment B-Full — Beta fix + big model on Full Craftax +# Target hardware: NVIDIA RTX 3090 Ti (24 GB) +# Same as exp_b_beta_big_model.yaml but on Craftax-Symbolic-v1 (obs_dim 8268, +# 43 actions). num_envs and dagger_buffer_max are dropped to fit the larger obs. +# ============================================================================= + +# ── Environment ────────────────────────────────────────────────────────────── +env_name: "Craftax-Symbolic-v1" +use_optimistic_resets: false +optimistic_reset_ratio: 16 + +# ── Transformer architecture ───────────────────────────────────────────────── +d_model: 384 +n_heads: 8 +n_layers: 6 +d_ff: 768 +obs_encoder_layers: 2 +obs_encoder_width: 768 +dropout_rate: 0.1 + +# ── Diffusion ──────────────────────────────────────────────────────────────── +plan_horizon: 32 +diffusion_schedule: "cosine" +train_sigma: 0.0 +label_smoothing: 0.0 + +# ── Reverse sampling ───────────────────────────────────────────────────────── +diffusion_steps: 15 +diffusion_steps_eval: 10 +remask_strategy: "rescale" +eta: 0.5 +use_loop: true +t_on: 0.7 +t_off: 0.3 +temperature: 0.5 +top_p: 0.95 + +# ── Optimisation ───────────────────────────────────────────────────────────── +lr: 3.0e-4 +max_grad_norm: 1.0 +lr_warmup_steps: 200 + +# ── Rollout ────────────────────────────────────────────────────────────────── +num_envs: 2048 # halved from 4096: larger obs per env +num_steps: 128 +num_minibatches: 8 # minibatch = 2048 * 128 / 8 = 32768 +update_epochs: 8 +num_repeats: 1 + +# ── Offline training ───────────────────────────────────────────────────────── +offline_total_timesteps: 1.0e+8 +collect_temperature: 1.0 +return_weight_cap: 5.0 + +# ── Online DAgger training ─────────────────────────────────────────────────── +online_num_updates: 2000 # harder task, more room to learn +dagger_beta_init: 1.0 +dagger_beta_decay: 0.9993 +dagger_buffer_max: 200000 # 8268 obs -> ~6.6 GB at 200k + +# ── Validation rollouts ────────────────────────────────────────────────────── +val_interval: 30 +val_diffusion_steps: 50 +val_replan_every: 4 +val_steps: 256 + +# ── Data collection ────────────────────────────────────────────────────────── +collect_num_steps: 10000000 +collect_num_envs: 128 +ppo_model_type: "ppo_rnn" +layer_size: 512 + +# ── Inference ──────────────────────────────────────────────────────────────── +eval_steps: 10000 +eval_num_envs: 32 + +# ── Checkpointing ──────────────────────────────────────────────────────────── +save_policy: true + +# ── Resume ─────────────────────────────────────────────────────────────────── +resume_checkpoint_path: null +resume_wandb_run_id: null +resume_step: null + +# ── Logging ────────────────────────────────────────────────────────────────── +seed: 42 +use_wandb: true +wandb_project: "remdm-craftax" +wandb_entity: "mathis-weil-university-college-london-ucl-" diff --git a/configs/craftax_exp_d_1M_model.yaml b/configs/craftax_exp_d_1M_model.yaml new file mode 100644 index 0000000000000000000000000000000000000000..034e858eca3d5552b098c39632aa1ef798439a83 --- /dev/null +++ b/configs/craftax_exp_d_1M_model.yaml @@ -0,0 +1,84 @@ +# ============================================================================= +# Final Full Craftax DAgger — UCL RTX 4090 24 GB - EXP D - 1.1M model +# Big transformer + full sampling recipe + slow beta decay. Produces the Full +# Craftax DAgger checkpoint used by the RL fine-tuning ablation suite. +# ============================================================================= + +# ── Environment ────────────────────────────────────────────────────────────── +env_name: "Craftax-Symbolic-v1" +use_optimistic_resets: false +optimistic_reset_ratio: 16 + +# ── Transformer architecture ───────────────────────────────────────────────── +d_model: 64 +n_heads: 2 +n_layers: 2 +d_ff: 128 +obs_encoder_layers: 1 +obs_encoder_width: 128 +dropout_rate: 0.1 + +# ── Diffusion ──────────────────────────────────────────────────────────────── +plan_horizon: 32 +diffusion_schedule: "cosine" +train_sigma: 0.0 +label_smoothing: 0.0 + +# ── Reverse sampling ───────────────────────────────────────────────────────── +diffusion_steps: 25 +diffusion_steps_eval: 10 +remask_strategy: "rescale" +eta: 0.5 +use_loop: true +t_on: 0.7 +t_off: 0.3 +temperature: 0.3 +top_p: 0.95 + +# ── Optimisation ───────────────────────────────────────────────────────────── +lr: 5.0e-4 +max_grad_norm: 1.0 +lr_warmup_frames: 7.86432e7 # = 300 update steps at this hardware (300 * 2048 * 128) + +# ── Rollout ────────────────────────────────────────────────────────────────── +num_envs: 1536 +num_steps: 128 +num_minibatches: 8 +update_epochs: 8 +num_repeats: 1 + +# ── Offline training ───────────────────────────────────────────────────────── +offline_total_timesteps: 2.0e+8 +collect_temperature: 1.0 +return_weight_cap: 5.0 + +# ── Online DAgger training ─────────────────────────────────────────────────── +online_total_timesteps: 2.0e+8 +dagger_beta_init: 1.0 +dagger_beta_final: 0.385 # 0.9995^1907 at this hardware; auto-resolves per fpu +dagger_buffer_cycles: 0.76294 # ~200K samples on UCL = ~0.76 update cycles of history + +# ── Validation rollouts ────────────────────────────────────────────────────── +val_interval_frames: 1.0e+6 +val_diffusion_steps: 25 +val_replan_every: 2 +val_steps: 256 + +# ── Data collection ────────────────────────────────────────────────────────── +collect_num_steps: 10000000 +collect_num_envs: 128 +ppo_model_type: "ppo_rnn" +layer_size: 512 + +# ── Inference ──────────────────────────────────────────────────────────────── +eval_steps: 10000 +eval_num_envs: 32 + +# ── Checkpointing ──────────────────────────────────────────────────────────── +save_policy: true + +# ── Logging ────────────────────────────────────────────────────────────────── +seed: 42 +use_wandb: true +wandb_project: "remdm-craftax" +wandb_entity: "mathis-weil-university-college-london-ucl-" diff --git a/configs/craftax_exp_d_3M_model.yaml b/configs/craftax_exp_d_3M_model.yaml new file mode 100644 index 0000000000000000000000000000000000000000..1192b9ac0578f64e101479b45926b5a329c73503 --- /dev/null +++ b/configs/craftax_exp_d_3M_model.yaml @@ -0,0 +1,84 @@ +# ============================================================================= +# Final Full Craftax DAgger — UCL RTX 4090 24 GB - EXP D - 2.6M model +# Big transformer + full sampling recipe + slow beta decay. Produces the Full +# Craftax DAgger checkpoint used by the RL fine-tuning ablation suite. +# ============================================================================= + +# ── Environment ────────────────────────────────────────────────────────────── +env_name: "Craftax-Symbolic-v1" +use_optimistic_resets: false +optimistic_reset_ratio: 16 + +# ── Transformer architecture ───────────────────────────────────────────────── +d_model: 128 +n_heads: 4 +n_layers: 3 +d_ff: 256 +obs_encoder_layers: 2 +obs_encoder_width: 256 +dropout_rate: 0.1 + +# ── Diffusion ──────────────────────────────────────────────────────────────── +plan_horizon: 32 +diffusion_schedule: "cosine" +train_sigma: 0.0 +label_smoothing: 0.0 + +# ── Reverse sampling ───────────────────────────────────────────────────────── +diffusion_steps: 25 +diffusion_steps_eval: 10 +remask_strategy: "rescale" +eta: 0.5 +use_loop: true +t_on: 0.7 +t_off: 0.3 +temperature: 0.3 +top_p: 0.95 + +# ── Optimisation ───────────────────────────────────────────────────────────── +lr: 5.0e-4 +max_grad_norm: 1.0 +lr_warmup_frames: 7.86432e7 # = 300 update steps at this hardware (300 * 2048 * 128) + +# ── Rollout ────────────────────────────────────────────────────────────────── +num_envs: 1240 +num_steps: 128 +num_minibatches: 8 +update_epochs: 8 +num_repeats: 1 + +# ── Offline training ───────────────────────────────────────────────────────── +offline_total_timesteps: 2.0e+8 +collect_temperature: 1.0 +return_weight_cap: 5.0 + +# ── Online DAgger training ─────────────────────────────────────────────────── +online_total_timesteps: 2.0e+8 +dagger_beta_init: 1.0 +dagger_beta_final: 0.385 # 0.9995^1907 at this hardware; auto-resolves per fpu +dagger_buffer_cycles: 0.76294 # ~200K samples on UCL = ~0.76 update cycles of history + +# ── Validation rollouts ────────────────────────────────────────────────────── +val_interval_frames: 1.0e+6 +val_diffusion_steps: 25 +val_replan_every: 2 +val_steps: 256 + +# ── Data collection ────────────────────────────────────────────────────────── +collect_num_steps: 10000000 +collect_num_envs: 128 +ppo_model_type: "ppo_rnn" +layer_size: 512 + +# ── Inference ──────────────────────────────────────────────────────────────── +eval_steps: 10000 +eval_num_envs: 32 + +# ── Checkpointing ──────────────────────────────────────────────────────────── +save_policy: true + +# ── Logging ────────────────────────────────────────────────────────────────── +seed: 42 +use_wandb: true +wandb_project: "remdm-craftax" +wandb_entity: "mathis-weil-university-college-london-ucl-" diff --git a/configs/craftax_exp_d_500K_model.yaml b/configs/craftax_exp_d_500K_model.yaml new file mode 100644 index 0000000000000000000000000000000000000000..ae5cf7146cadcd6d61d8547ec81591f561a415d2 --- /dev/null +++ b/configs/craftax_exp_d_500K_model.yaml @@ -0,0 +1,84 @@ +# ============================================================================= +# Final Full Craftax DAgger — UCL RTX 4090 24 GB - EXP D - 0.5M model +# Big transformer + full sampling recipe + slow beta decay. Produces the Full +# Craftax DAgger checkpoint used by the RL fine-tuning ablation suite. +# ============================================================================= + +# ── Environment ────────────────────────────────────────────────────────────── +env_name: "Craftax-Symbolic-v1" +use_optimistic_resets: false +optimistic_reset_ratio: 16 + +# ── Transformer architecture ───────────────────────────────────────────────── +d_model: 32 +n_heads: 2 +n_layers: 1 +d_ff: 64 +obs_encoder_layers: 1 +obs_encoder_width: 64 +dropout_rate: 0.1 + +# ── Diffusion ──────────────────────────────────────────────────────────────── +plan_horizon: 32 +diffusion_schedule: "cosine" +train_sigma: 0.0 +label_smoothing: 0.0 + +# ── Reverse sampling ───────────────────────────────────────────────────────── +diffusion_steps: 25 +diffusion_steps_eval: 10 +remask_strategy: "rescale" +eta: 0.5 +use_loop: true +t_on: 0.7 +t_off: 0.3 +temperature: 0.3 +top_p: 0.95 + +# ── Optimisation ───────────────────────────────────────────────────────────── +lr: 5.0e-4 +max_grad_norm: 1.0 +lr_warmup_frames: 7.86432e7 # = 300 update steps at this hardware (300 * 2048 * 128) + +# ── Rollout ────────────────────────────────────────────────────────────────── +num_envs: 2976 +num_steps: 128 +num_minibatches: 8 +update_epochs: 8 +num_repeats: 1 + +# ── Offline training ───────────────────────────────────────────────────────── +offline_total_timesteps: 2.0e+8 +collect_temperature: 1.0 +return_weight_cap: 5.0 + +# ── Online DAgger training ─────────────────────────────────────────────────── +online_total_timesteps: 2.0e+8 +dagger_beta_init: 1.0 +dagger_beta_final: 0.385 # 0.9995^1907 at this hardware; auto-resolves per fpu +dagger_buffer_cycles: 0.76294 # ~200K samples on UCL = ~0.76 update cycles of history + +# ── Validation rollouts ────────────────────────────────────────────────────── +val_interval_frames: 1.0e+6 +val_diffusion_steps: 25 +val_replan_every: 2 +val_steps: 256 + +# ── Data collection ────────────────────────────────────────────────────────── +collect_num_steps: 10000000 +collect_num_envs: 128 +ppo_model_type: "ppo_rnn" +layer_size: 512 + +# ── Inference ──────────────────────────────────────────────────────────────── +eval_steps: 10000 +eval_num_envs: 32 + +# ── Checkpointing ──────────────────────────────────────────────────────────── +save_policy: true + +# ── Logging ────────────────────────────────────────────────────────────────── +seed: 42 +use_wandb: true +wandb_project: "remdm-craftax" +wandb_entity: "mathis-weil-university-college-london-ucl-" diff --git a/configs/craftax_exp_d_7M_model.yaml b/configs/craftax_exp_d_7M_model.yaml new file mode 100644 index 0000000000000000000000000000000000000000..922ac05b772428036378e97996eddb1d1103b50a --- /dev/null +++ b/configs/craftax_exp_d_7M_model.yaml @@ -0,0 +1,84 @@ +# ============================================================================= +# Final Full Craftax DAgger — UCL RTX 4090 24 GB - EXP D - 6.8M model +# Big transformer + full sampling recipe + slow beta decay. Produces the Full +# Craftax DAgger checkpoint used by the RL fine-tuning ablation suite. +# ============================================================================= + +# ── Environment ────────────────────────────────────────────────────────────── +env_name: "Craftax-Symbolic-v1" +use_optimistic_resets: false +optimistic_reset_ratio: 16 + +# ── Transformer architecture ───────────────────────────────────────────────── +d_model: 256 +n_heads: 4 +n_layers: 4 +d_ff: 512 +obs_encoder_layers: 2 +obs_encoder_width: 512 +dropout_rate: 0.1 + +# ── Diffusion ──────────────────────────────────────────────────────────────── +plan_horizon: 32 +diffusion_schedule: "cosine" +train_sigma: 0.0 +label_smoothing: 0.0 + +# ── Reverse sampling ───────────────────────────────────────────────────────── +diffusion_steps: 25 +diffusion_steps_eval: 10 +remask_strategy: "rescale" +eta: 0.5 +use_loop: true +t_on: 0.7 +t_off: 0.3 +temperature: 0.3 +top_p: 0.95 + +# ── Optimisation ───────────────────────────────────────────────────────────── +lr: 5.0e-4 +max_grad_norm: 1.0 +lr_warmup_frames: 7.86432e7 # = 300 update steps at this hardware (300 * 2048 * 128) + +# ── Rollout ────────────────────────────────────────────────────────────────── +num_envs: 832 +num_steps: 128 +num_minibatches: 8 +update_epochs: 8 +num_repeats: 1 + +# ── Offline training ───────────────────────────────────────────────────────── +offline_total_timesteps: 2.0e+8 +collect_temperature: 1.0 +return_weight_cap: 5.0 + +# ── Online DAgger training ─────────────────────────────────────────────────── +online_total_timesteps: 2.0e+8 +dagger_beta_init: 1.0 +dagger_beta_final: 0.385 # 0.9995^1907 at this hardware; auto-resolves per fpu +dagger_buffer_cycles: 0.76294 # ~200K samples on UCL = ~0.76 update cycles of history + +# ── Validation rollouts ────────────────────────────────────────────────────── +val_interval_frames: 1.0e+6 +val_diffusion_steps: 25 +val_replan_every: 2 +val_steps: 256 + +# ── Data collection ────────────────────────────────────────────────────────── +collect_num_steps: 10000000 +collect_num_envs: 128 +ppo_model_type: "ppo_rnn" +layer_size: 512 + +# ── Inference ──────────────────────────────────────────────────────────────── +eval_steps: 10000 +eval_num_envs: 32 + +# ── Checkpointing ──────────────────────────────────────────────────────────── +save_policy: true + +# ── Logging ────────────────────────────────────────────────────────────────── +seed: 42 +use_wandb: true +wandb_project: "remdm-craftax" +wandb_entity: "mathis-weil-university-college-london-ucl-" diff --git a/configs/defaults.yaml b/configs/defaults.yaml new file mode 100644 index 0000000000000000000000000000000000000000..c08682df1800528c53199c9e59c73b22a875b24a --- /dev/null +++ b/configs/defaults.yaml @@ -0,0 +1,120 @@ +# ============================================================================= +# ReMDM default configuration +# Override any value via CLI: python main.py --mode offline --lr 1e-4 +# ============================================================================= + +# ── Environment ────────────────────────────────────────────────────────────── +env_name: "Craftax-Classic-Symbolic-v1" +use_optimistic_resets: false +optimistic_reset_ratio: 16 + +# ── Transformer architecture ───────────────────────────────────────────────── +d_model: 256 +n_heads: 4 +n_layers: 4 +d_ff: 512 +obs_encoder_layers: 2 +obs_encoder_width: 512 +dropout_rate: 0.1 + +# ── Diffusion ──────────────────────────────────────────────────────────────── +plan_horizon: 32 +diffusion_schedule: "cosine" # cosine | linear +train_sigma: 0.0 # remasking correction during training (0 = standard MDLM) +label_smoothing: 0.0 # cross-entropy label smoothing (0 = exact ELBO) + +# ── Reverse sampling (training, validation, inference) ────────────────────── +diffusion_steps: 15 # denoising steps T during training +diffusion_steps_eval: 10 # denoising steps T at inference (--mode inference) +remask_strategy: "rescale" # rescale | cap | conf +eta: 0.5 # remasking strength +use_loop: true # three-phase loop (Algorithm 3) +t_on: 0.7 +t_off: 0.3 +temperature: 0.5 +top_p: 0.95 + +# ── Optimisation ───────────────────────────────────────────────────────────── +lr: 3.0e-4 +max_grad_norm: 1.0 +# lr_warmup_frames is the env-frame-denominated source of truth (invariant +# under num_envs changes). When set, it overrides lr_warmup_steps via: +# lr_warmup_steps = lr_warmup_frames // (num_envs * num_steps) +lr_warmup_frames: null # PRIMARY: env-frame warmup budget +lr_warmup_steps: 0 # LEGACY: linear warm-up before cosine LR decay (0 = disabled) + +# ── Rollout (shared by offline + online) ───────────────────────────────────── +num_envs: 1024 +num_steps: 64 +num_minibatches: 8 +update_epochs: 4 +num_repeats: 1 # independent training seeds (vmapped) + +# ── Offline training (--mode offline) ──────────────────────────────────────── +# num_updates is derived as offline_total_timesteps // (num_envs * num_steps), +# so the env-frame budget is hardware-portable across num_envs changes. +offline_total_timesteps: 1.0e+8 # PRIMARY: env-frame budget +offline_num_updates: null # LEGACY: used only when offline_total_timesteps is unset +collect_temperature: 1.0 # softmax temperature on PPO logits during data collection +return_weight_cap: 5.0 # clip ceiling for per-window return weights + +# ── Online DAgger training (--mode online) ─────────────────────────────────── +# num_updates is derived as online_total_timesteps // (num_envs * num_steps). +online_total_timesteps: null # PRIMARY: env-frame budget +online_num_updates: 1000 # LEGACY: used only when online_total_timesteps is unset +dagger_beta_init: 1.0 # initial expert mixing probability beta_1 +# dagger_beta_final is the env-frame-invariant source of truth: the target +# mixing ratio at the end of training. When set, it overrides +# dagger_beta_decay via: decay = (beta_final / beta_init) ** (1 / num_updates) +dagger_beta_final: null # PRIMARY: target final beta +dagger_beta_decay: 0.95 # LEGACY: per-update decay beta_i = beta_init * decay^i +# dagger_buffer_cycles measures buffer capacity in *update cycles of history* +# (1 cycle = num_envs * num_steps frames), invariant under num_envs. When +# set, it overrides dagger_buffer_max via: buffer_max = cycles * fpu +dagger_buffer_cycles: null # PRIMARY: buffer capacity in update cycles +dagger_buffer_max: 100000 # LEGACY: max samples in DAgger replay buffer +# B1: passes per update over the aggregated buffer. Each pass redraws a +# fresh sample of size samples_per_update from D, so total samples seen +# per update ≈ n_passes * samples_per_update. null = 1 pass, which +# matches offline BC's per-update gradient work exactly (fair compute +# comparison: same num_updates * update_epochs * num_minibatches grad +# steps). Raise to >1 to trade BC fairness for higher per-update D +# coverage (e.g. ~⌊|D|/samples_per_update⌋ for full-buffer coverage). +dagger_train_passes: null +# B2: deterministic argmax expert (true) vs categorical sampling (false). +# True keeps the expert mapping s -> a* fixed, removing label noise from D. +dagger_expert_deterministic: true + +# ── Validation rollouts ────────────────────────────────────────────────────── +# val_interval_frames is the env-frame-invariant source of truth. When set, +# it overrides val_interval via: val_interval = val_interval_frames // fpu +val_interval_frames: null # PRIMARY: env-frames between validations +val_interval: 50 # LEGACY: update steps between validation rollouts +val_diffusion_steps: 50 # denoising steps during validation +val_replan_every: 4 # env steps executed per diffusion plan +val_steps: 128 # total env steps per validation rollout + +# ── Data collection (--mode collect) ───────────────────────────────────────── +collect_num_steps: 10000000 +collect_num_envs: 128 +ppo_model_type: "ppo_rnn" # ppo | ppo_rnn | ppo_rnd +layer_size: 512 + +# ── Inference (--mode inference) ───────────────────────────────────────────── +eval_steps: 10000 +eval_num_envs: 32 + +# ── Checkpointing ──────────────────────────────────────────────────────────── +save_policy: true + +# ── Resume ─────────────────────────────────────────────────────────────────── +resume_checkpoint_path: null # accepts wandb: refs +resume_wandb_run_id: null # auto-read from checkpoint metadata if null +resume_step: null # auto-read from checkpoint metadata if null + +# ── Logging ────────────────────────────────────────────────────────────────── +seed: null # random if not set +use_wandb: true +wandb_project: "remdm-craftax" +wandb_entity: "mathis-weil-university-college-london-ucl-" +wandb_download_dir: null # null = wandb default ./artifacts/ diff --git a/configs/final_classic_qmul.yaml b/configs/final_classic_qmul.yaml new file mode 100644 index 0000000000000000000000000000000000000000..4d93464d885797ee0b1f1e2404a83d9328cb0df2 --- /dev/null +++ b/configs/final_classic_qmul.yaml @@ -0,0 +1,93 @@ +# ============================================================================= +# Final Craftax Classic DAgger — QMUL H200 8 GB partition (seed 43) - EXP B - 9M model +# Env-frame-matched second seed of final_classic_ucl.yaml. +# ----------------------------------------------------------------------------- +# Identical to final_classic_ucl.yaml except for num_envs (96 vs 4096) and +# seed (43 vs 42). All fairness-critical hyperparameters are now denominated +# in env frames or update cycles, so they are AUTOMATICALLY scaled to this +# hardware tier by resolve_scaled_hyperparams() — no manual derivation. +# ============================================================================= + +# ── Environment ────────────────────────────────────────────────────────────── +env_name: "Craftax-Classic-Symbolic-v1" +use_optimistic_resets: false +optimistic_reset_ratio: 16 + +# ── Transformer architecture ───────────────────────────────────────────────── +d_model: 384 +n_heads: 8 +n_layers: 6 +d_ff: 768 +obs_encoder_layers: 2 +obs_encoder_width: 768 +dropout_rate: 0.1 + +# ── Diffusion ──────────────────────────────────────────────────────────────── +plan_horizon: 32 +diffusion_schedule: "cosine" +train_sigma: 0.0 +label_smoothing: 0.0 + +# ── Reverse sampling ───────────────────────────────────────────────────────── +diffusion_steps: 15 +diffusion_steps_eval: 10 +remask_strategy: "rescale" +eta: 0.5 +use_loop: true +t_on: 0.7 +t_off: 0.3 +temperature: 0.5 +top_p: 0.95 + +# ── Optimisation ───────────────────────────────────────────────────────────── +lr: 3.0e-4 +max_grad_norm: 1.0 +lr_warmup_frames: 1.048576e8 # IDENTICAL to UCL; resolves to 8533 update steps on 96 envs + +# ── Rollout ────────────────────────────────────────────────────────────────── +num_envs: 96 # VRAM-tuned to ~6.5–7.5 GB; minibatch = 96*128/8 = 1536 +num_steps: 128 +num_minibatches: 8 +update_epochs: 8 +num_repeats: 1 + +# ── Offline training ───────────────────────────────────────────────────────── +offline_total_timesteps: 1.0e+8 +collect_temperature: 1.0 +return_weight_cap: 5.0 + +# ── Online DAgger training ─────────────────────────────────────────────────── +online_total_timesteps: 1.0e+8 +dagger_beta_init: 1.0 +dagger_beta_final: 0.344 +dagger_buffer_cycles: 1.90735 + +# ── Validation rollouts ────────────────────────────────────────────────────── +val_interval_frames: 1.0e+6 +val_diffusion_steps: 50 +val_replan_every: 4 +val_steps: 256 + +# ── Data collection ────────────────────────────────────────────────────────── +collect_num_steps: 10000000 +collect_num_envs: 128 +ppo_model_type: "ppo_rnn" +layer_size: 512 + +# ── Inference ──────────────────────────────────────────────────────────────── +eval_steps: 10000 +eval_num_envs: 32 + +# ── Checkpointing ──────────────────────────────────────────────────────────── +save_policy: true + +# ── Resume ─────────────────────────────────────────────────────────────────── +resume_checkpoint_path: null +resume_wandb_run_id: null +resume_step: null + +# ── Logging ────────────────────────────────────────────────────────────────── +seed: 43 +use_wandb: true +wandb_project: "remdm-craftax" +wandb_entity: "mathis-weil-university-college-london-ucl-" diff --git a/configs/final_classic_ucl.yaml b/configs/final_classic_ucl.yaml new file mode 100644 index 0000000000000000000000000000000000000000..9619cda71db09d0a11c5e1ada69b28c605afcf5f --- /dev/null +++ b/configs/final_classic_ucl.yaml @@ -0,0 +1,84 @@ +# ============================================================================= +# Final Craftax Classic DAgger — UCL RTX 3090 Ti 24 GB - EXP B - 9M model +# Big transformer + slow beta decay. Produces the Classic DAgger checkpoint +# used by the RL fine-tuning ablation suite. +# ============================================================================= + +# ── Environment ────────────────────────────────────────────────────────────── +env_name: "Craftax-Classic-Symbolic-v1" +use_optimistic_resets: false +optimistic_reset_ratio: 16 + +# ── Transformer architecture ───────────────────────────────────────────────── +d_model: 384 +n_heads: 8 +n_layers: 6 +d_ff: 768 +obs_encoder_layers: 2 +obs_encoder_width: 768 +dropout_rate: 0.1 + +# ── Diffusion ──────────────────────────────────────────────────────────────── +plan_horizon: 32 +diffusion_schedule: "cosine" +train_sigma: 0.0 +label_smoothing: 0.0 + +# ── Reverse sampling ───────────────────────────────────────────────────────── +diffusion_steps: 15 +diffusion_steps_eval: 10 +remask_strategy: "rescale" +eta: 0.5 +use_loop: true +t_on: 0.7 +t_off: 0.3 +temperature: 0.5 +top_p: 0.95 + +# ── Optimisation ───────────────────────────────────────────────────────────── +lr: 3.0e-4 +max_grad_norm: 1.0 +lr_warmup_frames: 1.048576e8 # = 200 update steps at this hardware (200 * 4096 * 128) + +# ── Rollout ────────────────────────────────────────────────────────────────── +num_envs: 512 +num_steps: 128 +num_minibatches: 8 +update_epochs: 8 +num_repeats: 1 + +# ── Offline training ───────────────────────────────────────────────────────── +offline_total_timesteps: 1.0e+8 +collect_temperature: 1.0 +return_weight_cap: 5.0 + +# ── Online DAgger training ─────────────────────────────────────────────────── +online_total_timesteps: 1.0e+8 +dagger_beta_init: 1.0 +dagger_beta_final: 0.344 # 0.9993^1525 at this hardware; auto-resolves per fpu +dagger_buffer_cycles: 1.90735 # ~1M samples on UCL = ~1.91 update cycles of history + +# ── Validation rollouts ────────────────────────────────────────────────────── +val_interval_frames: 1.0e+6 +val_diffusion_steps: 50 +val_replan_every: 4 +val_steps: 256 + +# ── Data collection ────────────────────────────────────────────────────────── +collect_num_steps: 10000000 +collect_num_envs: 128 +ppo_model_type: "ppo_rnn" +layer_size: 512 + +# ── Inference ──────────────────────────────────────────────────────────────── +eval_steps: 10000 +eval_num_envs: 32 + +# ── Checkpointing ──────────────────────────────────────────────────────────── +save_policy: true + +# ── Logging ────────────────────────────────────────────────────────────────── +seed: 42 +use_wandb: true +wandb_project: "remdm-craftax" +wandb_entity: "mathis-weil-university-college-london-ucl-" diff --git a/configs/final_craftax_qmul.yaml b/configs/final_craftax_qmul.yaml new file mode 100644 index 0000000000000000000000000000000000000000..9b0ed8951ff8b4db08c41f7b5c78dcf8813cdc52 --- /dev/null +++ b/configs/final_craftax_qmul.yaml @@ -0,0 +1,94 @@ +# ============================================================================= +# Final Full Craftax DAgger — QMUL H200 8 GB partition - EXP C - 15M model +# Env-frame-matched second seed of final_craftax_full_ucl.yaml. +# ----------------------------------------------------------------------------- +# Identical to final_craftax_full_ucl.yaml except for num_envs (384 vs 2048) +# and seed (43 vs 42). All fairness-critical hyperparameters are now +# denominated in env frames or update cycles, so they are AUTOMATICALLY +# scaled to this hardware tier by resolve_scaled_hyperparams() — no manual +# derivation. +# ============================================================================= + +# ── Environment ────────────────────────────────────────────────────────────── +env_name: "Craftax-Symbolic-v1" +use_optimistic_resets: false +optimistic_reset_ratio: 16 + +# ── Transformer architecture ───────────────────────────────────────────────── +d_model: 384 +n_heads: 8 +n_layers: 6 +d_ff: 768 +obs_encoder_layers: 2 +obs_encoder_width: 768 +dropout_rate: 0.1 + +# ── Diffusion ──────────────────────────────────────────────────────────────── +plan_horizon: 32 +diffusion_schedule: "cosine" +train_sigma: 0.0 +label_smoothing: 0.0 + +# ── Reverse sampling ───────────────────────────────────────────────────────── +diffusion_steps: 25 +diffusion_steps_eval: 10 +remask_strategy: "rescale" +eta: 0.5 +use_loop: true +t_on: 0.7 +t_off: 0.3 +temperature: 0.3 +top_p: 0.95 + +# ── Optimisation ───────────────────────────────────────────────────────────── +lr: 5.0e-4 +max_grad_norm: 1.0 +lr_warmup_frames: 7.86432e7 # IDENTICAL to UCL; resolves to 1600 update steps on 384 envs + +# ── Rollout ────────────────────────────────────────────────────────────────── +num_envs: 64 +num_steps: 128 +num_minibatches: 8 +update_epochs: 8 +num_repeats: 1 + +# ── Offline training ───────────────────────────────────────────────────────── +offline_total_timesteps: 2.0e+8 # IDENTICAL to UCL +collect_temperature: 1.0 +return_weight_cap: 5.0 + +# ── Online DAgger training ─────────────────────────────────────────────────── +online_total_timesteps: 2.0e+8 # IDENTICAL to UCL +dagger_beta_init: 1.0 +dagger_beta_final: 0.385 # IDENTICAL to UCL; resolves to 0.99990623 decay on 384 envs +dagger_buffer_cycles: 0.76294 # IDENTICAL to UCL; resolves to ~37500 samples on 384 envs + +# ── Validation rollouts ────────────────────────────────────────────────────── +val_interval_frames: 1.0e+6 +val_diffusion_steps: 25 +val_replan_every: 2 +val_steps: 256 + +# ── Data collection ────────────────────────────────────────────────────────── +collect_num_steps: 10000000 +collect_num_envs: 128 +ppo_model_type: "ppo_rnn" +layer_size: 512 + +# ── Inference ──────────────────────────────────────────────────────────────── +eval_steps: 10000 +eval_num_envs: 32 + +# ── Checkpointing ──────────────────────────────────────────────────────────── +save_policy: true + +# ── Resume ─────────────────────────────────────────────────────────────────── +resume_checkpoint_path: null +resume_wandb_run_id: null +resume_step: null + +# ── Logging ────────────────────────────────────────────────────────────────── +seed: 43 +use_wandb: true +wandb_project: "remdm-craftax" +wandb_entity: "mathis-weil-university-college-london-ucl-" diff --git a/configs/final_craftax_ucl.yaml b/configs/final_craftax_ucl.yaml new file mode 100644 index 0000000000000000000000000000000000000000..5a26a872b45f1843b095e3d199ff550996f071e6 --- /dev/null +++ b/configs/final_craftax_ucl.yaml @@ -0,0 +1,84 @@ +# ============================================================================= +# Final Full Craftax DAgger — UCL RTX 4090 24 GB - EXP C - 15M model +# Big transformer + full sampling recipe + slow beta decay. Produces the Full +# Craftax DAgger checkpoint used by the RL fine-tuning ablation suite. +# ============================================================================= + +# ── Environment ────────────────────────────────────────────────────────────── +env_name: "Craftax-Symbolic-v1" +use_optimistic_resets: false +optimistic_reset_ratio: 16 + +# ── Transformer architecture ───────────────────────────────────────────────── +d_model: 384 +n_heads: 8 +n_layers: 6 +d_ff: 768 +obs_encoder_layers: 2 +obs_encoder_width: 768 +dropout_rate: 0.1 + +# ── Diffusion ──────────────────────────────────────────────────────────────── +plan_horizon: 32 +diffusion_schedule: "cosine" +train_sigma: 0.0 +label_smoothing: 0.0 + +# ── Reverse sampling ───────────────────────────────────────────────────────── +diffusion_steps: 25 +diffusion_steps_eval: 10 +remask_strategy: "rescale" +eta: 0.5 +use_loop: true +t_on: 0.7 +t_off: 0.3 +temperature: 0.3 +top_p: 0.95 + +# ── Optimisation ───────────────────────────────────────────────────────────── +lr: 5.0e-4 +max_grad_norm: 1.0 +lr_warmup_frames: 7.86432e7 # = 300 update steps at this hardware (300 * 2048 * 128) + +# ── Rollout ────────────────────────────────────────────────────────────────── +num_envs: 448 +num_steps: 128 +num_minibatches: 8 +update_epochs: 8 +num_repeats: 1 + +# ── Offline training ───────────────────────────────────────────────────────── +offline_total_timesteps: 2.0e+8 +collect_temperature: 1.0 +return_weight_cap: 5.0 + +# ── Online DAgger training ─────────────────────────────────────────────────── +online_total_timesteps: 2.0e+8 +dagger_beta_init: 1.0 +dagger_beta_final: 0.385 # 0.9995^1907 at this hardware; auto-resolves per fpu +dagger_buffer_cycles: 0.76294 # ~200K samples on UCL = ~0.76 update cycles of history + +# ── Validation rollouts ────────────────────────────────────────────────────── +val_interval_frames: 1.0e+6 +val_diffusion_steps: 25 +val_replan_every: 2 +val_steps: 256 + +# ── Data collection ────────────────────────────────────────────────────────── +collect_num_steps: 10000000 +collect_num_envs: 128 +ppo_model_type: "ppo_rnn" +layer_size: 512 + +# ── Inference ──────────────────────────────────────────────────────────────── +eval_steps: 10000 +eval_num_envs: 32 + +# ── Checkpointing ──────────────────────────────────────────────────────────── +save_policy: true + +# ── Logging ────────────────────────────────────────────────────────────────── +seed: 42 +use_wandb: true +wandb_project: "remdm-craftax" +wandb_entity: "mathis-weil-university-college-london-ucl-" diff --git a/demo_craftax.ipynb b/demo_craftax.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..e0f5fdb18e02260e540480dcec1bc896b47a5622 --- /dev/null +++ b/demo_craftax.ipynb @@ -0,0 +1,803 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "02c2db887e2d", + "metadata": {}, + "source": [ + "# COMP0258 \u2014 ReMDM Discrete Diffusion Planning in Craftax\n", + "\n", + "> **Self-contained Colab notebook.** Loads pre-trained checkpoints from a public\n", + "> HuggingFace repo and demonstrates live inference, agent visualisation and the\n", + "> ReMDM denoising loop. No training is performed inside the notebook.\n", + "\n", + "**How to use this notebook**\n", + "\n", + "1. Open in Google Colab (preferably with a GPU runtime).\n", + "2. Run all cells top-to-bottom.\n", + "3. To test on **unseen inputs**, edit the constants in the *Configuration*\n", + " cell (`SEED`, `ENV_NAME`, `EVAL_STEPS`, `EVAL_NUM_ENVS`, `DIFFUSION_STEPS_EVAL`)\n", + " and re-run from there. Every Craftax seed produces a procedurally generated\n", + " world the agent has never seen.\n", + "\n", + "**Submission compliance**\n", + "\n", + "- Cell 1 downloads everything from a single public HuggingFace repo\n", + " (`HF_REPO_ID`) \u2014 no authentication required.\n", + "- The pre-trained checkpoint is loaded; no training happens here.\n", + "- Live inference (Cells 5\u20137) demonstrates that reported numbers reproduce.\n", + "- Pre-computed ablation figures (Cells 9\u201310) demonstrate the research finding.\n" + ] + }, + { + "cell_type": "code", + "id": "03d921f981da", + "metadata": {}, + "execution_count": null, + "outputs": [], + "source": [ + "# =============================================================================\n", + "# CONFIGURATION \u2014 marker can edit any of these and re-run from here\n", + "# =============================================================================\n", + "\n", + "# Public HuggingFace repo containing src/, Craftax_Baselines/, configs/,\n", + "# checkpoints/ and pre-computed ablation outputs.\n", + "HF_REPO_ID = \"TODO_HF_REPO_ID\"\n", + "LOCAL_DIR = \"remdm-craftax\"\n", + "\n", + "# --- Reproducibility & evaluation knobs ---\n", + "SEED = 42\n", + "\n", + "# \"Craftax-Classic-Symbolic-v1\": 22 achievements, 17 actions (DAgger checkpoint)\n", + "# \"Craftax-Symbolic-v1\": 65 achievements, 43 actions (PPO expert only)\n", + "ENV_NAME = \"Craftax-Classic-Symbolic-v1\"\n", + "\n", + "EVAL_STEPS = 2000 # Per-env step budget for live diffusion eval\n", + "EVAL_NUM_ENVS = 16 # Parallel environments (lower if Colab OOMs)\n", + "DIFFUSION_STEPS_EVAL = 10 # Reverse denoising steps at inference time\n" + ] + }, + { + "cell_type": "markdown", + "id": "81c4103fc614", + "metadata": {}, + "source": [ + "## 1. Project overview\n", + "\n", + "**Problem.** Plan action sequences in *Craftax*, a JAX-accelerated procedurally\n", + "generated open-world survival game (a Crafter reimplementation extended with\n", + "NetHack-like mechanics).\n", + "\n", + "**Approach.** A bidirectional **denoising transformer** generates a\n", + "`plan_horizon = 32` action plan by iteratively denoising masked discrete tokens\n", + "(MDLM / **ReMDM** \u2014 Wang et al.) conditioned on the current symbolic\n", + "observation. At inference, MPC executes one action per plan and re-plans every\n", + "step with **historical inpainting**: positions `0 .. hist_len - 1` are locked to\n", + "the actions actually taken.\n", + "\n", + "**Pipeline (offline, before this notebook).**\n", + "```\n", + "[1] PPO-RNN expert (Craftax_Baselines/ppo_rnn.py)\n", + "[2] Offline behaviour cloning (main.py --mode offline)\n", + "[3] Online DAgger fine-tuning (main.py --mode online)\n", + "[4] RL fine-tuning ablation suite \u2014 25 ablations\n", + "```\n", + "\n", + "**Research question.** Can RL fine-tuning improve a DAgger-pretrained masked\n", + "diffusion planner?\n", + "\n", + "**Headline finding.** DAgger faithfully imitates the PPO expert, but **no RL\n", + "ablation** (25 tested across regularisation, optimisation, data and capacity)\n", + "meaningfully improves on the DAgger checkpoint. This matches the parallel\n", + "finding from our MiniHack codebase, indicating the obstruction is\n", + "framework-independent.\n", + "\n", + "The notebook demonstrates the three things the marker needs to verify:\n", + "\n", + "| Cell | Demonstrates |\n", + "|---|---|\n", + "| 5 | Live inference numbers reproduce reported results |\n", + "| 6 | The agent meaningfully interacts with Craftax (rewards, achievements) |\n", + "| 7 | The ReMDM iterative-unmasking sampler in action |\n", + "| 8 | DAgger \u2248 PPO expert (live PPO eval) |\n", + "| 9\u201310 | Pre-computed ablation results \u2014 RL doesn't help |\n" + ] + }, + { + "cell_type": "code", + "id": "449dfd350169", + "metadata": {}, + "execution_count": null, + "outputs": [], + "source": [ + "# =============================================================================\n", + "# Install pinned dependencies, download the HF repo, set up sys.path\n", + "# =============================================================================\n", + "\n", + "import importlib\n", + "import os\n", + "import subprocess\n", + "import sys\n", + "\n", + "\n", + "def _pip(*pkgs: str) -> None:\n", + " subprocess.check_call(\n", + " [sys.executable, \"-m\", \"pip\", \"install\", \"-q\", *pkgs],\n", + " stdout=subprocess.DEVNULL,\n", + " )\n", + "\n", + "\n", + "# Core dependencies (versions match pyproject.toml).\n", + "_pip(\n", + " \"huggingface_hub>=1.9.1\",\n", + " \"craftax>=1.5.0\",\n", + " \"flax>=0.12.6\",\n", + " \"optax>=0.2.8\",\n", + " \"orbax-checkpoint>=0.5\",\n", + " \"distrax>=0.1.7\",\n", + " \"chex>=0.1.91\",\n", + " \"polars>=1.39.3\",\n", + " \"orjson>=3.11.8\",\n", + " \"pyyaml>=6.0\",\n", + ")\n", + "\n", + "# JAX: keep Colab's preinstalled GPU build if present, otherwise install CPU.\n", + "try:\n", + " import jax # noqa: F401\n", + " if jax.default_backend() == \"gpu\":\n", + " pass\n", + " else:\n", + " raise RuntimeError(\"re-install needed\")\n", + "except Exception:\n", + " _pip(\"jax>=0.9.2\")\n", + " importlib.invalidate_caches()\n", + " import jax # noqa: F401\n", + "\n", + "import craftax # noqa: F401\n", + "\n", + "backend = jax.default_backend()\n", + "device = jax.devices()[0]\n", + "print(f\"JAX {jax.__version__} | backend={backend} | device={device}\")\n", + "print(f\"Craftax {craftax.__version__}\")\n", + "if backend != \"gpu\":\n", + " print(\n", + " \"WARNING: JAX is running on CPU. Inference will be ~10x slower than GPU.\"\n", + " \"\\n In Colab, switch to a GPU runtime via Runtime -> Change runtime type.\"\n", + " )\n", + "\n", + "# ----------------------------------------------------------------------------\n", + "# Pull the project tree (code, configs, checkpoints, pre-computed results)\n", + "# from a single public HuggingFace repo. No authentication required.\n", + "# ----------------------------------------------------------------------------\n", + "from huggingface_hub import snapshot_download\n", + "\n", + "snapshot_path = snapshot_download(repo_id=HF_REPO_ID, local_dir=LOCAL_DIR)\n", + "print(f\"Snapshot downloaded to: {snapshot_path}\")\n", + "\n", + "# Both the parent (for `import src.planners...`) and the Craftax_Baselines/\n", + "# subdir (for `from wrappers import ...` inside Craftax_Baselines/ppo_rnn.py)\n", + "# must be on sys.path. main.py does the same thing.\n", + "for p in (snapshot_path, os.path.join(snapshot_path, \"Craftax_Baselines\")):\n", + " if p not in sys.path:\n", + " sys.path.insert(0, p)\n", + "\n", + "os.chdir(snapshot_path)\n", + "print(f\"cwd: {os.getcwd()}\")\n" + ] + }, + { + "cell_type": "code", + "id": "e71fb8bee8cc", + "metadata": {}, + "execution_count": null, + "outputs": [], + "source": [ + "# =============================================================================\n", + "# Load the pre-trained DAgger diffusion checkpoint\n", + "# =============================================================================\n", + "\n", + "import json\n", + "\n", + "import jax\n", + "import jax.numpy as jnp\n", + "from craftax.craftax_env import make_craftax_env_from_name\n", + "\n", + "from src.planners.model import build_model, load_checkpoint, make_apply_fns\n", + "\n", + "# Checkpoint inventory bundled with the HF repo.\n", + "DIFFUSION_OFFLINE_CKPT = (\n", + " \"checkpoints/offline/Craftax-Classic-Symbolic-v1-OfflineDiffusion-BC-100M\"\n", + ")\n", + "DIFFUSION_ONLINE_CKPT = (\n", + " \"checkpoints/online/Craftax-Classic-Symbolic-v1-OnlineDiffusion-DAgger-50M\"\n", + ")\n", + "PPO_CKPT = {\n", + " \"Craftax-Classic-Symbolic-v1\": (\n", + " \"checkpoints/ppo_agents/Craftax-Classic-Symbolic-v1-PPO_RNN-1000M\"\n", + " ),\n", + " \"Craftax-Symbolic-v1\": (\n", + " \"checkpoints/ppo_agents/Craftax-Symbolic-v1-PPO_RNN-1000M\"\n", + " ),\n", + "}\n", + "\n", + "# We pin to the DAgger (online) checkpoint \u2014 that is the headline result.\n", + "# The architecture hyperparameters are stored alongside the offline checkpoint\n", + "# in `resume_metadata.json` (online and offline share the same architecture).\n", + "with open(os.path.join(DIFFUSION_OFFLINE_CKPT, \"resume_metadata.json\")) as f:\n", + " META = json.load(f)\n", + "ARCH_CFG = META[\"config_snapshot\"]\n", + "\n", + "# The diffusion checkpoints are trained on Classic Craftax. The denoiser's\n", + "# action head dimension is fixed at training time (num_actions = 17 for Classic),\n", + "# so we cannot transfer it to Full Craftax (43 actions). PPO covers both.\n", + "DIFFUSION_ENV_NAME = \"Craftax-Classic-Symbolic-v1\"\n", + "\n", + "env_init = make_craftax_env_from_name(DIFFUSION_ENV_NAME, auto_reset=True)\n", + "env_params_init = env_init.default_params\n", + "NUM_ACTIONS = int(env_init.action_space(env_params_init).n)\n", + "OBS_DIM = int(env_init.observation_space(env_params_init).shape[0])\n", + "PLAN_HORIZON = int(ARCH_CFG[\"PLAN_HORIZON\"])\n", + "\n", + "print(f\"DiffusionEnv : {DIFFUSION_ENV_NAME}\")\n", + "print(f\" obs_dim : {OBS_DIM}\")\n", + "print(f\" num_actions: {NUM_ACTIONS}\")\n", + "print(f\"Architecture : d_model={ARCH_CFG['D_MODEL']} n_layers={ARCH_CFG['N_LAYERS']} \"\n", + " f\"n_heads={ARCH_CFG['N_HEADS']} plan_horizon={PLAN_HORIZON}\")\n", + "\n", + "model = build_model(ARCH_CFG, NUM_ACTIONS)\n", + "apply_eval, _ = make_apply_fns(model)\n", + "diffusion_params = load_checkpoint(\n", + " model,\n", + " jax.random.PRNGKey(SEED),\n", + " OBS_DIM,\n", + " PLAN_HORIZON,\n", + " DIFFUSION_ONLINE_CKPT,\n", + ")\n", + "\n", + "n_params = sum(int(p.size) for p in jax.tree.leaves(diffusion_params))\n", + "print(f\"Loaded DAgger checkpoint: {n_params / 1e6:.2f}M parameters\")\n" + ] + }, + { + "cell_type": "code", + "id": "e27e94345edf", + "metadata": {}, + "execution_count": null, + "outputs": [], + "source": [ + "# =============================================================================\n", + "# CELL 5 (PRIORITY) \u2014 live inference on `EVAL_NUM_ENVS` parallel envs\n", + "# =============================================================================\n", + "#\n", + "# This calls the *exact* production inference code path used by\n", + "# `python main.py --mode inference`. The marker can change ENV_NAME / SEED /\n", + "# EVAL_STEPS in Cell 1 to test on entirely fresh procedurally generated worlds.\n", + "#\n", + "# Reported numbers (DAgger online checkpoint, EVAL_STEPS=10 000, 32 envs,\n", + "# Craftax-Classic-Symbolic-v1):\n", + "# mean episode return \u2248 10.4 (compare with PPO expert \u2248 10.4)\n", + "# score ceiling: 22 (one point per achievement)\n", + "# -----------------------------------------------------------------------------\n", + "\n", + "from src.planners.inference import run_inference\n", + "\n", + "if \"Classic\" not in ENV_NAME:\n", + " print(\n", + " f\"ENV_NAME={ENV_NAME!r}: no diffusion checkpoint exists for Full Craftax\\n\"\n", + " f\"(action vocabularies differ: 17 vs 43). Skipping diffusion eval \u2014 see\\n\"\n", + " f\"Cell 8 for the matching PPO expert evaluation on this environment.\"\n", + " )\n", + "else:\n", + " inference_cfg = {\n", + " # Architecture (read from checkpoint metadata)\n", + " **{k: v for k, v in ARCH_CFG.items() if k.isupper()},\n", + " # Marker-controlled\n", + " \"ENV_NAME\": ENV_NAME,\n", + " \"SEED\": SEED,\n", + " \"EVAL_STEPS\": EVAL_STEPS,\n", + " \"EVAL_NUM_ENVS\": EVAL_NUM_ENVS,\n", + " \"DIFFUSION_STEPS_EVAL\": DIFFUSION_STEPS_EVAL,\n", + " # Always disable W&B inside the notebook\n", + " \"USE_WANDB\": False,\n", + " \"CHECKPOINT_PATH\": DIFFUSION_ONLINE_CKPT,\n", + " }\n", + " run_inference(inference_cfg)\n" + ] + }, + { + "cell_type": "code", + "id": "fa0d0926061a", + "metadata": {}, + "execution_count": null, + "outputs": [], + "source": [ + "# =============================================================================\n", + "# CELL 6 (PRIORITY) \u2014 visualise the diffusion planner acting in Craftax\n", + "# =============================================================================\n", + "#\n", + "# Re-runs the MPC + historical-inpainting loop on a small batch of envs and\n", + "# captures per-step rewards / achievements / actions for plotting.\n", + "# -----------------------------------------------------------------------------\n", + "\n", + "import numpy as np\n", + "import jax\n", + "import jax.numpy as jnp\n", + "import matplotlib.pyplot as plt\n", + "\n", + "from src.diffusion.sampling import sample_plan_inpainting\n", + "\n", + "VIZ_NUM_ENVS = 4\n", + "VIZ_STEPS = 600\n", + "\n", + "if \"Classic\" not in ENV_NAME:\n", + " print(f\"Skipping behaviour viz: no diffusion checkpoint for {ENV_NAME!r}.\")\n", + "else:\n", + " viz_env = make_craftax_env_from_name(DIFFUSION_ENV_NAME, auto_reset=True)\n", + " viz_params = viz_env.default_params\n", + "\n", + " rng = jax.random.PRNGKey(SEED + 1)\n", + " rng, env_rng = jax.random.split(rng)\n", + " obs0, state0 = jax.vmap(viz_env.reset, in_axes=(0, None))(\n", + " jax.random.split(env_rng, VIZ_NUM_ENVS), viz_params,\n", + " )\n", + " history0 = jnp.full((VIZ_NUM_ENVS, PLAN_HORIZON), NUM_ACTIONS, dtype=jnp.int32)\n", + " hist_len0 = jnp.zeros(VIZ_NUM_ENVS, dtype=jnp.int32)\n", + " env_indices = jnp.arange(VIZ_NUM_ENVS)\n", + "\n", + " @jax.jit\n", + " def viz_step(carry, _):\n", + " obs, state, rng, history, hist_len = carry\n", + " rng, plan_rng, env_rng = jax.random.split(rng, 3)\n", + "\n", + " # Reset history when plan window is exhausted (matches inference.py).\n", + " seq_full = hist_len >= PLAN_HORIZON\n", + " hist_len = jnp.where(seq_full, 0, hist_len)\n", + " history = jnp.where(seq_full[:, None], NUM_ACTIONS, history)\n", + "\n", + " plan = sample_plan_inpainting(\n", + " apply_eval, diffusion_params, plan_rng, obs,\n", + " history, hist_len, NUM_ACTIONS, PLAN_HORIZON,\n", + " DIFFUSION_STEPS_EVAL,\n", + " ARCH_CFG[\"TEMPERATURE\"], ARCH_CFG[\"TOP_P\"],\n", + " )\n", + " action = jnp.take_along_axis(plan, hist_len[:, None], axis=-1).squeeze(-1)\n", + " history = history.at[env_indices, hist_len].set(action)\n", + " hist_len = hist_len + 1\n", + "\n", + " obs_next, state_next, reward, done, _ = jax.vmap(\n", + " viz_env.step, in_axes=(0, 0, 0, None),\n", + " )(jax.random.split(env_rng, VIZ_NUM_ENVS), state, action, viz_params)\n", + "\n", + " hist_len = jnp.where(done, 0, hist_len)\n", + " history = jnp.where(done[:, None], NUM_ACTIONS, history)\n", + " return (\n", + " (obs_next, state_next, rng, history, hist_len),\n", + " (action, reward, done, state_next.achievements),\n", + " )\n", + "\n", + " print(f\"Rolling out {VIZ_NUM_ENVS} agents for {VIZ_STEPS} steps...\")\n", + " _, (acts, rews, dones, achs) = jax.lax.scan(\n", + " viz_step, (obs0, state0, rng, history0, hist_len0), jnp.arange(VIZ_STEPS),\n", + " )\n", + " acts_np = np.array(acts) # [T, E]\n", + " rews_np = np.array(rews) # [T, E]\n", + " achs_np = np.array(achs) # [T, E, num_ach]\n", + " dones_np = np.array(dones) # [T, E]\n", + "\n", + " # First-life cumulative reward + unlock count per agent\n", + " fig, axes = plt.subplots(1, 2, figsize=(12, 4))\n", + " cum_reward = np.cumsum(rews_np, axis=0)\n", + " unlock_count = achs_np.sum(axis=-1)\n", + " for e in range(VIZ_NUM_ENVS):\n", + " end = np.where(dones_np[:, e])[0]\n", + " cutoff = int(end[0]) + 1 if len(end) > 0 else VIZ_STEPS\n", + " axes[0].plot(np.arange(cutoff), cum_reward[:cutoff, e], label=f\"agent {e}\")\n", + " axes[1].plot(np.arange(cutoff), unlock_count[:cutoff, e], label=f\"agent {e}\")\n", + " axes[0].set_title(\"Cumulative reward (first life)\")\n", + " axes[0].set_xlabel(\"env step\"); axes[0].set_ylabel(\"cumulative reward\")\n", + " axes[1].set_title(\"Achievements unlocked\")\n", + " axes[1].set_xlabel(\"env step\"); axes[1].set_ylabel(\"# unique achievements\")\n", + " for ax in axes:\n", + " ax.legend(loc=\"best\", fontsize=8); ax.grid(alpha=0.3)\n", + " plt.tight_layout(); plt.show()\n", + "\n", + " # Action histogram (which actions does the planner actually use?)\n", + " from craftax.craftax_classic.constants import Action as ClassicAction\n", + " action_names = [a.name for a in ClassicAction]\n", + " counts = np.bincount(acts_np.flatten(), minlength=NUM_ACTIONS)\n", + " order = np.argsort(-counts)\n", + " fig, ax = plt.subplots(figsize=(11, 3.5))\n", + " ax.bar(range(NUM_ACTIONS), counts[order])\n", + " ax.set_xticks(range(NUM_ACTIONS))\n", + " ax.set_xticklabels([action_names[i] for i in order], rotation=70, ha=\"right\", fontsize=8)\n", + " ax.set_title(f\"Action usage over {VIZ_STEPS * VIZ_NUM_ENVS} env steps\")\n", + " ax.set_ylabel(\"count\"); ax.grid(alpha=0.3, axis=\"y\")\n", + " plt.tight_layout(); plt.show()\n" + ] + }, + { + "cell_type": "code", + "id": "522c414d6a5a", + "metadata": {}, + "execution_count": null, + "outputs": [], + "source": [ + "# =============================================================================\n", + "# CELL 7 (PRIORITY) \u2014 visualise the ReMDM iterative unmasking process\n", + "# =============================================================================\n", + "#\n", + "# `sample_plan_inpainting` is JIT-compiled via lax.scan, so to capture the\n", + "# intermediate token sequences we re-implement its body as a Python loop. This\n", + "# is faithful to the production sampler \u2014 only the loop construct differs.\n", + "# -----------------------------------------------------------------------------\n", + "\n", + "import numpy as np\n", + "import jax\n", + "import jax.numpy as jnp\n", + "import matplotlib.pyplot as plt\n", + "from matplotlib.colors import ListedColormap\n", + "\n", + "if \"Classic\" not in ENV_NAME:\n", + " print(f\"Skipping denoising viz: no diffusion checkpoint for {ENV_NAME!r}.\")\n", + "else:\n", + " DENOISE_STEPS = 12 # show enough steps to see iterative unmasking\n", + " MASK_ID = NUM_ACTIONS\n", + "\n", + " # Sample one observation from the env to condition on.\n", + " viz_env = make_craftax_env_from_name(DIFFUSION_ENV_NAME, auto_reset=True)\n", + " viz_params = viz_env.default_params\n", + " one_rng = jax.random.PRNGKey(SEED + 2)\n", + " obs1, _state1 = viz_env.reset(one_rng, viz_params)\n", + " obs_b = obs1[None, :] # [1, obs_dim]\n", + "\n", + " seq = jnp.full((1, PLAN_HORIZON), MASK_ID, dtype=jnp.int32)\n", + " history = jnp.full((1, PLAN_HORIZON), MASK_ID, dtype=jnp.int32)\n", + " hist_len = jnp.zeros((1,), dtype=jnp.int32)\n", + " rng = jax.random.PRNGKey(SEED + 3)\n", + " temperature = float(ARCH_CFG[\"TEMPERATURE\"])\n", + " top_p = float(ARCH_CFG[\"TOP_P\"])\n", + "\n", + " trace = [np.array(seq[0])] # row 0 = fully masked\n", + "\n", + " # Mirror the body of src.diffusion.sampling.sample_plan_inpainting._step\n", + " for step in range(1, DENOISE_STEPS + 1):\n", + " rng, model_rng, sample_rng, remask_rng = jax.random.split(rng, 4)\n", + " ratio = step / DENOISE_STEPS\n", + " t_tensor = jnp.full((1,), 1.0 - ratio)\n", + " logits = apply_eval(diffusion_params, obs_b, seq, t_tensor, model_rng) / max(temperature, 1e-8)\n", + "\n", + " # Nucleus filtering\n", + " probs = jax.nn.softmax(logits, axis=-1)\n", + " sorted_idx = jnp.argsort(-probs, axis=-1)\n", + " sorted_p = jnp.take_along_axis(probs, sorted_idx, axis=-1)\n", + " cutoff = jnp.cumsum(sorted_p, axis=-1) - sorted_p\n", + " inv_idx = jnp.argsort(sorted_idx, axis=-1)\n", + " nucleus_mask = jnp.take_along_axis(cutoff >= top_p, inv_idx, axis=-1)\n", + " logits = jnp.where(nucleus_mask, -jnp.inf, logits)\n", + "\n", + " preds = jax.random.categorical(sample_rng, logits, axis=-1)\n", + " conf = jnp.take_along_axis(\n", + " jax.nn.softmax(logits, axis=-1), preds[..., None], axis=-1,\n", + " ).squeeze(-1)\n", + " num_unmask = max(1, int(PLAN_HORIZON * ratio))\n", + " sorted_conf = jnp.sort(conf, axis=-1)[..., ::-1]\n", + " thresh = sorted_conf[0, num_unmask - 1]\n", + " seq_new = jnp.where(conf < thresh, MASK_ID, preds)\n", + "\n", + " # ReMDM-style remasking (matches sample_plan_inpainting)\n", + " remask_prob = 0.15 * (1.0 - ratio)\n", + " do_remask = (\n", + " (jax.random.uniform(remask_rng, seq_new.shape) < remask_prob)\n", + " & (seq_new != MASK_ID)\n", + " )\n", + " seq_new = jnp.where(do_remask, MASK_ID, seq_new)\n", + "\n", + " # Lock historical prefix (here: empty, so no-op)\n", + " pos = jnp.broadcast_to(jnp.arange(PLAN_HORIZON)[None, :], (1, PLAN_HORIZON))\n", + " seq_new = jnp.where(pos < hist_len[:, None], history, seq_new)\n", + "\n", + " seq = seq_new\n", + " trace.append(np.array(seq[0]))\n", + "\n", + " trace = np.stack(trace) # [steps+1, plan_horizon]\n", + "\n", + " # Heatmap: rows = denoising step, columns = action position\n", + " # Mask cells coloured grey, action cells coloured by token id (viridis).\n", + " masked = trace == MASK_ID\n", + " fig, ax = plt.subplots(figsize=(12, 5))\n", + " cmap = plt.cm.viridis.copy()\n", + " display = np.where(masked, np.nan, trace.astype(float))\n", + " im = ax.imshow(\n", + " display, aspect=\"auto\", cmap=cmap, vmin=0, vmax=NUM_ACTIONS - 1,\n", + " interpolation=\"nearest\",\n", + " )\n", + " # Overlay grey for masked cells\n", + " ax.imshow(\n", + " np.where(masked, 1.0, np.nan), aspect=\"auto\",\n", + " cmap=ListedColormap([\"#dddddd\"]), vmin=0, vmax=1, interpolation=\"nearest\",\n", + " )\n", + " ax.set_xlabel(\"plan position (action token)\")\n", + " ax.set_ylabel(\"denoising step\")\n", + " ax.set_title(\n", + " f\"ReMDM iterative unmasking ({DENOISE_STEPS} steps, plan_horizon={PLAN_HORIZON})\"\n", + " \"\\nGrey = MASK token, colour = sampled action ID\"\n", + " )\n", + " cbar = plt.colorbar(im, ax=ax, fraction=0.04, pad=0.02)\n", + " cbar.set_label(\"action ID\")\n", + " plt.tight_layout(); plt.show()\n", + "\n", + " n_masked = masked.sum(axis=1)\n", + " print(f\"Masked tokens per step: {list(n_masked)}\")\n", + " print(f\" step 0 (start): {int(n_masked[0])}/{PLAN_HORIZON} masked\")\n", + " print(f\" step {DENOISE_STEPS} (final): {int(n_masked[-1])}/{PLAN_HORIZON} masked\")\n" + ] + }, + { + "cell_type": "code", + "id": "8bc67b9b704a", + "metadata": {}, + "execution_count": null, + "outputs": [], + "source": [ + "# =============================================================================\n", + "# CELL 8 \u2014 PPO expert baseline (loaded live, evaluated on `ENV_NAME`)\n", + "# =============================================================================\n", + "#\n", + "# DAgger trains the diffusion planner to imitate this expert. We expect the\n", + "# DAgger return (Cell 5) to be close to the PPO expert return on Classic.\n", + "# For Full Craftax (where no diffusion checkpoint exists), this is the only\n", + "# evaluation that can run.\n", + "# -----------------------------------------------------------------------------\n", + "\n", + "import os\n", + "import numpy as np\n", + "import jax\n", + "import jax.numpy as jnp\n", + "import orbax.checkpoint as ocp\n", + "from orbax.checkpoint import checkpoint_utils\n", + "from craftax.craftax_env import make_craftax_env_from_name\n", + "\n", + "from src.planners.ppo import build_ppo_network, PPOAgent\n", + "\n", + "PPO_NUM_ENVS = EVAL_NUM_ENVS\n", + "PPO_EVAL_STEPS = min(EVAL_STEPS, 1500) # PPO eval is fast \u2014 bound it for snappy demo\n", + "\n", + "ppo_path = os.path.abspath(PPO_CKPT[ENV_NAME])\n", + "ppo_env = make_craftax_env_from_name(ENV_NAME, auto_reset=True)\n", + "ppo_params_env = ppo_env.default_params\n", + "ppo_num_actions = int(ppo_env.action_space(ppo_params_env).n)\n", + "ppo_obs_dim = int(ppo_env.observation_space(ppo_params_env).shape[0])\n", + "\n", + "# Build the PPO-RNN network and an abstract param pytree from a dummy init.\n", + "PPO_LAYER_SIZE = 512\n", + "ppo_net = build_ppo_network(\"ppo_rnn\", ppo_num_actions, PPO_LAYER_SIZE,\n", + " {\"LAYER_SIZE\": PPO_LAYER_SIZE})\n", + "_dummy_x = (jnp.zeros((1, PPO_NUM_ENVS, ppo_obs_dim)),\n", + " jnp.zeros((1, PPO_NUM_ENVS)))\n", + "_abstract_params = ppo_net.init(\n", + " jax.random.PRNGKey(0),\n", + " jnp.zeros((PPO_NUM_ENVS, PPO_LAYER_SIZE)),\n", + " _dummy_x,\n", + ")\n", + "\n", + "# Restore params only (the on-disk checkpoint also contains opt_state which we\n", + "# don't need for inference). `partial_restore=True` lets us read just `params`,\n", + "# and `construct_restore_args` is required on orbax >= 0.11 to provide sharding.\n", + "_restore_args = checkpoint_utils.construct_restore_args({\"params\": _abstract_params})\n", + "with ocp.CheckpointManager(ppo_path) as _mgr:\n", + " _step = _mgr.latest_step()\n", + " _restored = _mgr.restore(\n", + " _step,\n", + " args=ocp.args.PyTreeRestore(\n", + " item={\"params\": _abstract_params},\n", + " restore_args=_restore_args,\n", + " partial_restore=True,\n", + " ),\n", + " )\n", + "print(f\"Loaded PPO_RNN checkpoint from '{ppo_path}' (step {_step})\")\n", + "\n", + "ppo_agent = PPOAgent(\n", + " network=ppo_net,\n", + " params=_restored[\"params\"],\n", + " model_type=\"ppo_rnn\",\n", + " layer_size=PPO_LAYER_SIZE,\n", + ")\n", + "\n", + "rng = jax.random.PRNGKey(SEED + 100)\n", + "rng, env_rng = jax.random.split(rng)\n", + "obs, state = jax.vmap(ppo_env.reset, in_axes=(0, None))(\n", + " jax.random.split(env_rng, PPO_NUM_ENVS), ppo_params_env,\n", + ")\n", + "hidden0 = ppo_agent.init_hidden(PPO_NUM_ENVS)\n", + "done0 = jnp.zeros(PPO_NUM_ENVS, dtype=bool)\n", + "\n", + "\n", + "@jax.jit\n", + "def ppo_step(carry, _):\n", + " obs, state, hidden, done, rng = carry\n", + " rng, act_rng, env_rng = jax.random.split(rng, 3)\n", + " action, hidden = ppo_agent.act(obs, done, hidden, act_rng, temperature=1.0)\n", + " obs_next, state_next, reward, done_next, _ = jax.vmap(\n", + " ppo_env.step, in_axes=(0, 0, 0, None),\n", + " )(jax.random.split(env_rng, PPO_NUM_ENVS), state, action, ppo_params_env)\n", + " return (obs_next, state_next, hidden, done_next, rng), (reward, done_next, state_next.achievements)\n", + "\n", + "\n", + "print(f\"Running PPO expert: {PPO_NUM_ENVS} envs x {PPO_EVAL_STEPS} steps on {ENV_NAME}...\")\n", + "_, (rewards, dones, achievements) = jax.lax.scan(\n", + " ppo_step, (obs, state, hidden0, done0, rng), jnp.arange(PPO_EVAL_STEPS),\n", + ")\n", + "\n", + "rewards_np = np.array(rewards)\n", + "dones_np = np.array(dones)\n", + "ach_np = np.array(achievements)\n", + "\n", + "# First-life evaluation (matches src/planners/inference.py convention)\n", + "ep_returns = np.zeros(PPO_NUM_ENVS)\n", + "ep_unlocks = np.zeros(PPO_NUM_ENVS, dtype=int)\n", + "for i in range(PPO_NUM_ENVS):\n", + " deaths = np.where(dones_np[:, i])[0]\n", + " end = int(deaths[0]) if len(deaths) > 0 else PPO_EVAL_STEPS - 1\n", + " ep_returns[i] = rewards_np[: end + 1, i].sum()\n", + " ep_unlocks[i] = int(ach_np[: end + 1, i].max(axis=0).sum())\n", + "\n", + "print()\n", + "print(f\"PPO expert mean return : {ep_returns.mean():.2f} (best={ep_returns.max():.2f})\")\n", + "print(f\"PPO expert mean unlocks: {ep_unlocks.mean():.2f} achievements\")\n", + "if \"Classic\" in ENV_NAME:\n", + " print()\n", + " print(\n", + " \"Compare with the DAgger diffusion result printed in Cell 5. On the\\n\"\n", + " \"DAgger checkpoint with EVAL_STEPS=10000 / 32 envs, both methods score\\n\"\n", + " \"\u2248 10.4 / 22 in the paper \u2014 i.e. DAgger has fully imitated the expert.\"\n", + " )\n" + ] + }, + { + "cell_type": "markdown", + "id": "059f96f8e951", + "metadata": {}, + "source": [ + "## 2. RL fine-tuning ablation suite (pre-computed)\n", + "\n", + "Once DAgger had reached the expert ceiling, we ran a **25-ablation suite** of\n", + "RL fine-tuning interventions trying to push past it. Every figure and table\n", + "below was generated offline by `experiments/rl_finetuning/run_ablations.py`\n", + "and shipped in the HF repo at\n", + "`experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/`.\n", + "\n", + "The four groups tested in the ablation suite:\n", + "\n", + "| Group | Hypothesis tested |\n", + "|---|---|\n", + "| **A** Regularisation | KL penalty, EWC, LLRD, LoRA, mixed replay, hard trust region |\n", + "| **B** Optimisation | t-curriculum, entropy bonus, PCGrad, advantage clip, normalised adv, BC-on-wins, low-t |\n", + "| **C** Capacity | head-only / FFN-only / attention-only / frozen backbone / top-k layer ablations |\n", + "| **D** Data | reward filtering, action diversity, running stats, learned reward model |\n", + "\n", + "**Headline finding.** No ablation meaningfully improves on the DAgger\n", + "checkpoint (best \u0394-score over baseline RL is +0.18 \u2014 within noise). Multiple\n", + "ablations *collapse* (\u0394 \u2248 \u221210) \u2014 see the figures below.\n" + ] + }, + { + "cell_type": "code", + "id": "522024988f95", + "metadata": {}, + "execution_count": null, + "outputs": [], + "source": [ + "# =============================================================================\n", + "# CELL 10 \u2014 Pre-computed ablation figures\n", + "# =============================================================================\n", + "\n", + "import os\n", + "from IPython.display import Image, display, Markdown\n", + "\n", + "ABLATION_DIR = \"experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures\"\n", + "\n", + "KEY_FIGURES = [\n", + " (\n", + " \"group_comparison.png\",\n", + " \"**Group comparison** \u2014 final score by ablation group. Group A \"\n", + " \"(regularisation) clusters tightly around the baseline; Group C \"\n", + " \"(capacity ablation) collapses.\",\n", + " ),\n", + " (\n", + " \"score_delta_over_baseline_rl.png\",\n", + " \"**\u0394-score over baseline RL** \u2014 every ablation, sorted. The best \"\n", + " \"improvement is < +0.2; many collapse to \u221210.\",\n", + " ),\n", + " (\n", + " \"gradient_alignment.png\",\n", + " \"**Gradient alignment** \u2014 cosine similarity between the BC and RL \"\n", + " \"loss gradients during fine-tuning. Frequently negative, hinting at \"\n", + " \"the underlying conflict.\",\n", + " ),\n", + " (\n", + " \"gradient_conflict_map.png\",\n", + " \"**Per-layer gradient conflict** \u2014 where in the network the BC and \"\n", + " \"RL losses pull in opposite directions.\",\n", + " ),\n", + "]\n", + "\n", + "for fname, caption in KEY_FIGURES:\n", + " path = os.path.join(ABLATION_DIR, fname)\n", + " if os.path.exists(path):\n", + " display(Markdown(caption))\n", + " display(Image(filename=path))\n", + " else:\n", + " display(Markdown(f\"_Missing figure: `{fname}`_\"))\n" + ] + }, + { + "cell_type": "code", + "id": "a2a77a2dd8b1", + "metadata": {}, + "execution_count": null, + "outputs": [], + "source": [ + "# =============================================================================\n", + "# CELL 11 \u2014 Ablation results tables\n", + "# =============================================================================\n", + "\n", + "import os\n", + "import polars as pl\n", + "\n", + "TABLES_DIR = \"experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/tables\"\n", + "\n", + "main_df = pl.read_csv(os.path.join(TABLES_DIR, \"main_results.csv\")).sort(\n", + " \"Final_Score\", descending=True,\n", + ")\n", + "print(\"Main ablation results (sorted by Final_Score, ceiling = 22):\")\n", + "print(main_df)\n", + "\n", + "verdict_df = pl.read_csv(os.path.join(TABLES_DIR, \"hypothesis_verdict.csv\"))\n", + "print()\n", + "print(\"Hypothesis verdicts:\")\n", + "print(verdict_df.select([\"Ablation\", \"Group\", \"Result\", \"Conclusion\"]))\n" + ] + }, + { + "cell_type": "markdown", + "id": "d2dbe7e3e2ab", + "metadata": {}, + "source": [ + "## 3. Conclusions\n", + "\n", + "1. **DAgger works.** The online DAgger checkpoint imitates the PPO expert on\n", + " Craftax-Classic to within noise (Cells 5 vs 8).\n", + "2. **RL fine-tuning does not help.** Across 25 ablations spanning\n", + " regularisation, optimisation, capacity and data interventions, no method\n", + " meaningfully improves on the DAgger checkpoint. The best \u0394-score over\n", + " baseline RL is +0.18, well within seed-to-seed variance.\n", + "3. **Several ablations actively collapse** (Group C: capacity restrictions,\n", + " plus a handful of regularisers). The collapse pattern correlates with deep\n", + " gradient flow into the backbone \u2014 see `gradient_conflict_map.png`.\n", + "4. **Framework independence.** The same finding holds in our parallel MiniHack\n", + " PyTorch codebase, suggesting the obstruction is fundamental to RL\n", + " fine-tuning of masked discrete diffusion planners rather than a Craftax-\n", + " or JAX-specific quirk.\n", + "\n", + "The full project documentation lives in the HF repo's `README.md`.\n" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "name": "python", + "pygments_lexer": "ipython3" + }, + "colab": { + "provenance": [], + "toc_visible": true + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} \ No newline at end of file diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/achievement_breakdown.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/achievement_breakdown.png new file mode 100644 index 0000000000000000000000000000000000000000..ae25bc6f337351b600ecda1e1a5b14ae84136294 --- /dev/null +++ b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/achievement_breakdown.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:8fd0272fca80a41f963407732ae253f29dfc5854ff9b5b523c211b8a8ac66331 +size 237515 diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/achievement_collapse_action_diversity.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/achievement_collapse_action_diversity.png new file mode 100644 index 0000000000000000000000000000000000000000..d85fc59485d8d3203b39382031135bfef0107aef --- /dev/null +++ b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/achievement_collapse_action_diversity.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b6eb2e3b57cef0f52ca8fa9d11dc5e519ecde4384b90d905b4718e1d05eadd45 +size 135933 diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/achievement_collapse_advantage_clip.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/achievement_collapse_advantage_clip.png new file mode 100644 index 0000000000000000000000000000000000000000..5eecfdbbc611b0e0a968c706cbf343962d47940f --- /dev/null +++ b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/achievement_collapse_advantage_clip.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:0553ad5e2ed113a0600cf73553275fdf01638c2dd0b6cd665c57a1f2581e601c +size 136022 diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/achievement_collapse_attention_only.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/achievement_collapse_attention_only.png new file mode 100644 index 0000000000000000000000000000000000000000..ef13baa8ae189aa3695de18d3fc24f53b4237ba9 --- /dev/null +++ b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/achievement_collapse_attention_only.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ee83c2fe557294320c78ee8bd8d92933122f995dfce1f46d34f498475ace334a +size 133347 diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/achievement_collapse_baseline_rl.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/achievement_collapse_baseline_rl.png new file mode 100644 index 0000000000000000000000000000000000000000..38d08b7f0af5b65d493be0b07e04f2ae097a307f --- /dev/null +++ b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/achievement_collapse_baseline_rl.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:010af9f97e110e07ed29784d80969f2597a9f6081db620f98184417c747fcd85 +size 134511 diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/achievement_collapse_bc_wins.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/achievement_collapse_bc_wins.png new file mode 100644 index 0000000000000000000000000000000000000000..54ce17a38510fe2603fcc78dfc2be1efb6b0e832 --- /dev/null +++ b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/achievement_collapse_bc_wins.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:3cd85f3e810b8cdec773f371386c4397fbe634b44bb63f8ad16443cad14b3f0d +size 134632 diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/achievement_collapse_entropy_bonus.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/achievement_collapse_entropy_bonus.png new file mode 100644 index 0000000000000000000000000000000000000000..d188974a2c9c3f0e2ec14514d6bf0157ab1d4bbc --- /dev/null +++ b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/achievement_collapse_entropy_bonus.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ba4b34ba6d7ccb2c2583e6f35ead7e466e62e4bffe9e91dc8867111fc67dfc9e +size 135210 diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/achievement_collapse_ewc.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/achievement_collapse_ewc.png new file mode 100644 index 0000000000000000000000000000000000000000..1dea7616447eb3e181068bcc926ae0f5be56b5e7 --- /dev/null +++ b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/achievement_collapse_ewc.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:045be833a526d9877e900bc3de3485d00fc61584be64714c4b622a068c2ad98a +size 133720 diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/achievement_collapse_ffn_only.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/achievement_collapse_ffn_only.png new file mode 100644 index 0000000000000000000000000000000000000000..0e8c71e560be90d75b5ef267da6079e2bde0c26c --- /dev/null +++ b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/achievement_collapse_ffn_only.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:5b46b767bf0d0bd138521dc1ee3ee06dfacab3ed2598249cecd02b400fc794a2 +size 134017 diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/achievement_collapse_frozen_backbone.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/achievement_collapse_frozen_backbone.png new file mode 100644 index 0000000000000000000000000000000000000000..f4b9e1aa61e878b91f6b8549ec38d08cb136bf20 --- /dev/null +++ b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/achievement_collapse_frozen_backbone.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:0037a7256ba08b31ca81611a36e03f7db615f407268a5c9bda66206d0790fc4c +size 133997 diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/achievement_collapse_gradient_surgery.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/achievement_collapse_gradient_surgery.png new file mode 100644 index 0000000000000000000000000000000000000000..f01365700450f3bcc1506d5eb398be8ba836840e --- /dev/null +++ b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/achievement_collapse_gradient_surgery.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:95226a3364a619c349d37befd6d610bac5cfd6c98a921b3c6a18c254f4062c62 +size 136173 diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/achievement_collapse_head_only.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/achievement_collapse_head_only.png new file mode 100644 index 0000000000000000000000000000000000000000..6e4d81bae783d01720d6c7b612215b4f878db6ac --- /dev/null +++ b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/achievement_collapse_head_only.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:21166bb529dcc724b805f446c624b46cb70791e3431e4df3a04dd34ff465e93f +size 132643 diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/achievement_collapse_kl_penalty.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/achievement_collapse_kl_penalty.png new file mode 100644 index 0000000000000000000000000000000000000000..177ab427c9e63a8f1596ba9607134a40cb15f4af --- /dev/null +++ b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/achievement_collapse_kl_penalty.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4217cf6a98602c5ba3fb7f9f468d59a931f3943f861796569cfa8561e0e59f0e +size 134951 diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/achievement_collapse_layer_ablation_top1.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/achievement_collapse_layer_ablation_top1.png new file mode 100644 index 0000000000000000000000000000000000000000..14e1eea3b3a6488ac6f243f517b45012094583c3 --- /dev/null +++ b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/achievement_collapse_layer_ablation_top1.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:65bf5071ea5c277181cfc6f9c6494ceda57ed254a1f5745ccded973efe560a40 +size 136013 diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/achievement_collapse_layer_ablation_top2.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/achievement_collapse_layer_ablation_top2.png new file mode 100644 index 0000000000000000000000000000000000000000..af5d3ecb91d23766c51ab17c067f10cdda253f66 --- /dev/null +++ b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/achievement_collapse_layer_ablation_top2.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:228cbd1d905792fbbd2f44f9c8d02a7f055a3f8e87fee5b98663e9e2a4c7064c +size 136454 diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/achievement_collapse_layer_ablation_top3.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/achievement_collapse_layer_ablation_top3.png new file mode 100644 index 0000000000000000000000000000000000000000..fc7c5c586be53a47dca999c646cee7419a292bcb --- /dev/null +++ b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/achievement_collapse_layer_ablation_top3.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:34e2cf8ba9dcddb86bc358846d268a4ae32918c67afb9b377f5b8dd04a6a97c9 +size 136433 diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/achievement_collapse_llrd.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/achievement_collapse_llrd.png new file mode 100644 index 0000000000000000000000000000000000000000..f599187d2fe3a4cdf58744abdf4c884b17cce5be --- /dev/null +++ b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/achievement_collapse_llrd.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:607afe4b1edd9232abb71f5ac2b13ccf755dfda876fa079bcf38a416e22a700d +size 133481 diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/achievement_collapse_lora.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/achievement_collapse_lora.png new file mode 100644 index 0000000000000000000000000000000000000000..1942fc5c35db229d2ff22780b0abca18048ac957 --- /dev/null +++ b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/achievement_collapse_lora.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f393911fcbebdf0975cf856f5f89bfe29a910aa8810b5b0b05dee6bd36b61402 +size 131535 diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/achievement_collapse_low_t.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/achievement_collapse_low_t.png new file mode 100644 index 0000000000000000000000000000000000000000..f9568e099c123d55278f1e5314fe2ef24b7f1061 --- /dev/null +++ b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/achievement_collapse_low_t.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b12fecc1bb6a52be0eead06d376d03181a2d28d0948bf4e0e4d13d4b2e9b29f6 +size 134018 diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/achievement_collapse_mixed_replay.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/achievement_collapse_mixed_replay.png new file mode 100644 index 0000000000000000000000000000000000000000..f6235b007ada5c9f9069eea0e00181b2755fbf8f --- /dev/null +++ b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/achievement_collapse_mixed_replay.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:6ea22bc4ba454bf3f4afc12bd2ae2ff24a1f6c282cbdfe7ff101e2709cd04d5a +size 135081 diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/achievement_collapse_normalized_adv.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/achievement_collapse_normalized_adv.png new file mode 100644 index 0000000000000000000000000000000000000000..295022834e613ef957c6dc99bf5e21e4728b475d --- /dev/null +++ b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/achievement_collapse_normalized_adv.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:23ab90a189e70f55f1326cc7063994fbe79a32bee0be5b5639cf53df529848eb +size 135355 diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/achievement_collapse_reward_filtering.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/achievement_collapse_reward_filtering.png new file mode 100644 index 0000000000000000000000000000000000000000..f058eb12d16734126f207a8567d1b1770e6a0ac6 --- /dev/null +++ b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/achievement_collapse_reward_filtering.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:40c1b9b3e09a33b9ed9b263c4c41618dca9eebe30ec932da9e4ec7b1ecca130e +size 135850 diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/achievement_collapse_reward_model.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/achievement_collapse_reward_model.png new file mode 100644 index 0000000000000000000000000000000000000000..0127bb0678d56c85c721d01808c79bbe6accd0dc --- /dev/null +++ b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/achievement_collapse_reward_model.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:95a05f55fca14efed58d57a100adb007668040ad953763d13907116759e40ee5 +size 135375 diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/achievement_collapse_running_stats.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/achievement_collapse_running_stats.png new file mode 100644 index 0000000000000000000000000000000000000000..e5a9b983bb43a723e6e69aae5a08e63ed3095636 --- /dev/null +++ b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/achievement_collapse_running_stats.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:08fc0df1830dfe974d5bbc8f57e059a3cc4a84098f753485a187ffa1ff35a88b +size 135119 diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/achievement_collapse_t_curriculum.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/achievement_collapse_t_curriculum.png new file mode 100644 index 0000000000000000000000000000000000000000..d209d50a2d598096f8c81f6c62a960f0c2f658a8 --- /dev/null +++ b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/achievement_collapse_t_curriculum.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d8bb7bf6adca273d06d8588f1100ba8f84d3e5a2248593491a9ee1a301458495 +size 134524 diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/achievement_collapse_trust_region_kl.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/achievement_collapse_trust_region_kl.png new file mode 100644 index 0000000000000000000000000000000000000000..0954e1490afcf90d22e2055057e8a84dcba8e03b --- /dev/null +++ b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/achievement_collapse_trust_region_kl.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:626f6de40a928e34b3ad2f0cd08ae5b77ba21b1197407e585a0f5c324a245930 +size 134955 diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_freq_action_diversity.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_freq_action_diversity.png new file mode 100644 index 0000000000000000000000000000000000000000..ee708ca6476ff21de43ddab4abc77785f0310777 Binary files /dev/null and b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_freq_action_diversity.png differ diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_freq_advantage_clip.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_freq_advantage_clip.png new file mode 100644 index 0000000000000000000000000000000000000000..129ff30f5cf3b46db5cca159ce52d84f8e9800e5 Binary files /dev/null and b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_freq_advantage_clip.png differ diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_freq_attention_only.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_freq_attention_only.png new file mode 100644 index 0000000000000000000000000000000000000000..36068856843e3857eb4975dfd17fd06a5dfe1a6c Binary files /dev/null and b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_freq_attention_only.png differ diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_freq_baseline_rl.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_freq_baseline_rl.png new file mode 100644 index 0000000000000000000000000000000000000000..19a40d01dbf279ed94893f1e626a49fd460c9ca7 Binary files /dev/null and b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_freq_baseline_rl.png differ diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_freq_bc_wins.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_freq_bc_wins.png new file mode 100644 index 0000000000000000000000000000000000000000..88de1b2c885f882babf55bd87397f8e78cf45c89 Binary files /dev/null and b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_freq_bc_wins.png differ diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_freq_entropy_bonus.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_freq_entropy_bonus.png new file mode 100644 index 0000000000000000000000000000000000000000..bd4b2b7a1123b9287bf7238329980bd2775c6f7f Binary files /dev/null and b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_freq_entropy_bonus.png differ diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_freq_ewc.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_freq_ewc.png new file mode 100644 index 0000000000000000000000000000000000000000..7efa39242bfc08a395e1831c00568f382d0ca3fe Binary files /dev/null and b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_freq_ewc.png differ diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_freq_ffn_only.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_freq_ffn_only.png new file mode 100644 index 0000000000000000000000000000000000000000..acba80473c0b6d4881c1391a75f2f550c2d34c9e Binary files /dev/null and b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_freq_ffn_only.png differ diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_freq_frozen_backbone.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_freq_frozen_backbone.png new file mode 100644 index 0000000000000000000000000000000000000000..c51249a3d2c7c916850fbe056e6b365c38a73387 Binary files /dev/null and b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_freq_frozen_backbone.png differ diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_freq_gradient_surgery.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_freq_gradient_surgery.png new file mode 100644 index 0000000000000000000000000000000000000000..1b4cb141d1636ff4fc02f9138546950d8c6332a2 Binary files /dev/null and b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_freq_gradient_surgery.png differ diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_freq_head_only.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_freq_head_only.png new file mode 100644 index 0000000000000000000000000000000000000000..4d2adaa28336334cba06b8f4365c14d28cea529a Binary files /dev/null and b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_freq_head_only.png differ diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_freq_kl_penalty.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_freq_kl_penalty.png new file mode 100644 index 0000000000000000000000000000000000000000..fc4769ee998d9edac22f40227cc61457ab3cbbfb Binary files /dev/null and b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_freq_kl_penalty.png differ diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_freq_layer_ablation_top1.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_freq_layer_ablation_top1.png new file mode 100644 index 0000000000000000000000000000000000000000..d22be76bdd684170decbd1f3dafe3c13502f1681 Binary files /dev/null and b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_freq_layer_ablation_top1.png differ diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_freq_layer_ablation_top2.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_freq_layer_ablation_top2.png new file mode 100644 index 0000000000000000000000000000000000000000..181da76c1e9ecb4c9e0c03f9541e4d1a88525541 Binary files /dev/null and b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_freq_layer_ablation_top2.png differ diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_freq_layer_ablation_top3.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_freq_layer_ablation_top3.png new file mode 100644 index 0000000000000000000000000000000000000000..110b4a7725dd336559050d6f4cc87fa5bb637181 Binary files /dev/null and b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_freq_layer_ablation_top3.png differ diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_freq_llrd.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_freq_llrd.png new file mode 100644 index 0000000000000000000000000000000000000000..9caabed2ea78dd42e31f3df05004355f487753c7 Binary files /dev/null and b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_freq_llrd.png differ diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_freq_lora.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_freq_lora.png new file mode 100644 index 0000000000000000000000000000000000000000..5f999818a0cb9be3c9474e49f0f318ff815fe53d Binary files /dev/null and b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_freq_lora.png differ diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_freq_low_t.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_freq_low_t.png new file mode 100644 index 0000000000000000000000000000000000000000..c016e18c2e344eb2bf7776cd751ca529510d7639 Binary files /dev/null and b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_freq_low_t.png differ diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_freq_mixed_replay.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_freq_mixed_replay.png new file mode 100644 index 0000000000000000000000000000000000000000..71057088f8d96954a4fc812f97e02d567a797541 Binary files /dev/null and b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_freq_mixed_replay.png differ diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_freq_normalized_adv.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_freq_normalized_adv.png new file mode 100644 index 0000000000000000000000000000000000000000..ea5e0361f672dac5663ea03df1e48edddb6b0c2f Binary files /dev/null and b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_freq_normalized_adv.png differ diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_freq_reward_filtering.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_freq_reward_filtering.png new file mode 100644 index 0000000000000000000000000000000000000000..1809c8b70f0fd8e26d4a1250cf317e0d2d81a2cf Binary files /dev/null and b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_freq_reward_filtering.png differ diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_freq_reward_model.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_freq_reward_model.png new file mode 100644 index 0000000000000000000000000000000000000000..bdd0ee3503ebd9e6c59395c419f5e84d0c339837 Binary files /dev/null and b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_freq_reward_model.png differ diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_freq_running_stats.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_freq_running_stats.png new file mode 100644 index 0000000000000000000000000000000000000000..b4fbcbeb8ca00342f2a3ab6f246ece8551da36ad Binary files /dev/null and b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_freq_running_stats.png differ diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_freq_t_curriculum.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_freq_t_curriculum.png new file mode 100644 index 0000000000000000000000000000000000000000..4c57d92d70c79cf53e83e9629440de601bd75f49 Binary files /dev/null and b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_freq_t_curriculum.png differ diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_freq_trust_region_kl.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_freq_trust_region_kl.png new file mode 100644 index 0000000000000000000000000000000000000000..ff0393dae46f4cc7e29597d51a45eaf533ea4cbe Binary files /dev/null and b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_freq_trust_region_kl.png differ diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_metrics_action_diversity.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_metrics_action_diversity.png new file mode 100644 index 0000000000000000000000000000000000000000..6d8696b9e44f1db2b977a26cbb5bfe38ea5c81e3 Binary files /dev/null and b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_metrics_action_diversity.png differ diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_metrics_advantage_clip.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_metrics_advantage_clip.png new file mode 100644 index 0000000000000000000000000000000000000000..39d10ea2e59878ee247530135185ca6de6220c4a Binary files /dev/null and b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_metrics_advantage_clip.png differ diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_metrics_attention_only.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_metrics_attention_only.png new file mode 100644 index 0000000000000000000000000000000000000000..b57a609a5459e7a886505f2b654dde658ebc13ae Binary files /dev/null and b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_metrics_attention_only.png differ diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_metrics_baseline_rl.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_metrics_baseline_rl.png new file mode 100644 index 0000000000000000000000000000000000000000..ca35ed0b03d7884c41359cd4a8d9906c99bde7f1 Binary files /dev/null and b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_metrics_baseline_rl.png differ diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_metrics_bc_wins.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_metrics_bc_wins.png new file mode 100644 index 0000000000000000000000000000000000000000..c807f2ec06c7711025cbfb2288d4f48a3b61b70f Binary files /dev/null and b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_metrics_bc_wins.png differ diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_metrics_entropy_bonus.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_metrics_entropy_bonus.png new file mode 100644 index 0000000000000000000000000000000000000000..ef44ae7d1ef7b1417bd8311e97b7ccd03a649609 Binary files /dev/null and b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_metrics_entropy_bonus.png differ diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_metrics_ewc.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_metrics_ewc.png new file mode 100644 index 0000000000000000000000000000000000000000..2c0b581e8e7db09c02f307e8d8cabd4c7d776dc1 Binary files /dev/null and b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_metrics_ewc.png differ diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_metrics_ffn_only.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_metrics_ffn_only.png new file mode 100644 index 0000000000000000000000000000000000000000..287e7381fe8b29a2e3b635a7702ff4cf7f686b30 Binary files /dev/null and b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_metrics_ffn_only.png differ diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_metrics_frozen_backbone.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_metrics_frozen_backbone.png new file mode 100644 index 0000000000000000000000000000000000000000..7b277ce28e77355f0de5d4b5bd2ed9161eb526bf Binary files /dev/null and b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_metrics_frozen_backbone.png differ diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_metrics_gradient_surgery.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_metrics_gradient_surgery.png new file mode 100644 index 0000000000000000000000000000000000000000..e450b59c96a6e4365aee4e92f265b7fd23e5d87e Binary files /dev/null and b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_metrics_gradient_surgery.png differ diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_metrics_head_only.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_metrics_head_only.png new file mode 100644 index 0000000000000000000000000000000000000000..24fd8a10dd306101f58efce8dee764fd3d1c56f5 Binary files /dev/null and b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_metrics_head_only.png differ diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_metrics_kl_penalty.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_metrics_kl_penalty.png new file mode 100644 index 0000000000000000000000000000000000000000..6b485a1452785663af7b00f62b525288dd6cf15c Binary files /dev/null and b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_metrics_kl_penalty.png differ diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_metrics_layer_ablation_top1.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_metrics_layer_ablation_top1.png new file mode 100644 index 0000000000000000000000000000000000000000..5a16c5c3f2eb18257562530e115b58986cbcffd4 Binary files /dev/null and b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_metrics_layer_ablation_top1.png differ diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_metrics_layer_ablation_top2.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_metrics_layer_ablation_top2.png new file mode 100644 index 0000000000000000000000000000000000000000..2df327df133d32935fd907a7d02b55e337524834 Binary files /dev/null and b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_metrics_layer_ablation_top2.png differ diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_metrics_layer_ablation_top3.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_metrics_layer_ablation_top3.png new file mode 100644 index 0000000000000000000000000000000000000000..a927f431403e939da2e30a8c8c800db6d76ab118 Binary files /dev/null and b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_metrics_layer_ablation_top3.png differ diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_metrics_llrd.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_metrics_llrd.png new file mode 100644 index 0000000000000000000000000000000000000000..55d9b11216a8679683633aa3dd313931f2bd7aa7 Binary files /dev/null and b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_metrics_llrd.png differ diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_metrics_lora.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_metrics_lora.png new file mode 100644 index 0000000000000000000000000000000000000000..1174438b7bef57230fd9b7cf57f8275c2f6b6d61 Binary files /dev/null and b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_metrics_lora.png differ diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_metrics_low_t.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_metrics_low_t.png new file mode 100644 index 0000000000000000000000000000000000000000..349e754da6aedbde1fc2e53171c87ff88482005c Binary files /dev/null and b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_metrics_low_t.png differ diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_metrics_mixed_replay.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_metrics_mixed_replay.png new file mode 100644 index 0000000000000000000000000000000000000000..74745cf99ad778ec4cb70f4bc570cc563001e1dd Binary files /dev/null and b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_metrics_mixed_replay.png differ diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_metrics_normalized_adv.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_metrics_normalized_adv.png new file mode 100644 index 0000000000000000000000000000000000000000..530add18e2e869cb2ccf971fc551eb5882076dea Binary files /dev/null and b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_metrics_normalized_adv.png differ diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_metrics_reward_filtering.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_metrics_reward_filtering.png new file mode 100644 index 0000000000000000000000000000000000000000..21761b876fe67bdaa19b75bf1564270d0050b0db Binary files /dev/null and b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_metrics_reward_filtering.png differ diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_metrics_reward_model.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_metrics_reward_model.png new file mode 100644 index 0000000000000000000000000000000000000000..9c66a14c73934017ab40ffd470a6fd227fe50e95 Binary files /dev/null and b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_metrics_reward_model.png differ diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_metrics_running_stats.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_metrics_running_stats.png new file mode 100644 index 0000000000000000000000000000000000000000..ad874a6efce1e124fcfd0b9d0114d6c2c0f263be Binary files /dev/null and b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_metrics_running_stats.png differ diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_metrics_t_curriculum.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_metrics_t_curriculum.png new file mode 100644 index 0000000000000000000000000000000000000000..be37abd7469a25d2c2f5fe479453f65c34f77801 Binary files /dev/null and b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_metrics_t_curriculum.png differ diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_metrics_trust_region_kl.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_metrics_trust_region_kl.png new file mode 100644 index 0000000000000000000000000000000000000000..22e33a613e0cc9395e5ecee7d5e790a803cd543f Binary files /dev/null and b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/action_metrics_trust_region_kl.png differ diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/js_divergence_comparison.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/js_divergence_comparison.png new file mode 100644 index 0000000000000000000000000000000000000000..34184d1b5fb7bbc48f3d74a4dc2ef03ee2870aa4 --- /dev/null +++ b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/js_divergence_comparison.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:51c1fddc8f4a7784df2e35debaa257c9d55ac0a90c4a6c5d71f64ff7651b863f +size 132243 diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/transition_matrix_action_diversity.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/transition_matrix_action_diversity.png new file mode 100644 index 0000000000000000000000000000000000000000..6bd7baf5d304e9c1a5b0807bebcee5cb5cb1e6be Binary files /dev/null and b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/transition_matrix_action_diversity.png differ diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/transition_matrix_advantage_clip.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/transition_matrix_advantage_clip.png new file mode 100644 index 0000000000000000000000000000000000000000..8bf58cfd3883ba9b7b97de40ea20199a2cac2e33 Binary files /dev/null and b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/transition_matrix_advantage_clip.png differ diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/transition_matrix_attention_only.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/transition_matrix_attention_only.png new file mode 100644 index 0000000000000000000000000000000000000000..40c4120101996395b502e3e1cf096c083e1729d0 Binary files /dev/null and b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/transition_matrix_attention_only.png differ diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/transition_matrix_baseline_rl.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/transition_matrix_baseline_rl.png new file mode 100644 index 0000000000000000000000000000000000000000..54781c1097b8436e9eff006bac97f9ca4644a2f8 Binary files /dev/null and b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/transition_matrix_baseline_rl.png differ diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/transition_matrix_bc_wins.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/transition_matrix_bc_wins.png new file mode 100644 index 0000000000000000000000000000000000000000..bf6f528d3d3a02453f4095fcbebbde3086e63cb5 Binary files /dev/null and b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/transition_matrix_bc_wins.png differ diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/transition_matrix_entropy_bonus.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/transition_matrix_entropy_bonus.png new file mode 100644 index 0000000000000000000000000000000000000000..e21cc53573af9aceb3c7a96beeeb88f005108f59 Binary files /dev/null and b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/transition_matrix_entropy_bonus.png differ diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/transition_matrix_ewc.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/transition_matrix_ewc.png new file mode 100644 index 0000000000000000000000000000000000000000..ba72945ed3281fbb9d7829fd506de35c62be6209 Binary files /dev/null and b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/transition_matrix_ewc.png differ diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/transition_matrix_ffn_only.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/transition_matrix_ffn_only.png new file mode 100644 index 0000000000000000000000000000000000000000..98d546858d2bfe866846c1844ed0353a41ddd3a9 Binary files /dev/null and b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/transition_matrix_ffn_only.png differ diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/transition_matrix_frozen_backbone.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/transition_matrix_frozen_backbone.png new file mode 100644 index 0000000000000000000000000000000000000000..edde7351c85dc57c9b8b52a263a0223000bf97f6 Binary files /dev/null and b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/transition_matrix_frozen_backbone.png differ diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/transition_matrix_gradient_surgery.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/transition_matrix_gradient_surgery.png new file mode 100644 index 0000000000000000000000000000000000000000..e9c1ef9a286401d66076b0194bd13e93fa2c0941 Binary files /dev/null and b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/transition_matrix_gradient_surgery.png differ diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/transition_matrix_head_only.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/transition_matrix_head_only.png new file mode 100644 index 0000000000000000000000000000000000000000..ab70b7e7298233ba645e7e52486539cad8b61731 Binary files /dev/null and b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/transition_matrix_head_only.png differ diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/transition_matrix_kl_penalty.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/transition_matrix_kl_penalty.png new file mode 100644 index 0000000000000000000000000000000000000000..7dbbbf441d9c11ebbbf93a7e11431f20051e53d0 Binary files /dev/null and b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/transition_matrix_kl_penalty.png differ diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/transition_matrix_layer_ablation_top1.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/transition_matrix_layer_ablation_top1.png new file mode 100644 index 0000000000000000000000000000000000000000..ce0f8bd73a3f4f5072dbcfb74c01db3706be91c9 Binary files /dev/null and b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/transition_matrix_layer_ablation_top1.png differ diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/transition_matrix_layer_ablation_top2.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/transition_matrix_layer_ablation_top2.png new file mode 100644 index 0000000000000000000000000000000000000000..c4c599671379907d4b2e45ff605ca84a7310f826 Binary files /dev/null and b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/transition_matrix_layer_ablation_top2.png differ diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/transition_matrix_layer_ablation_top3.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/transition_matrix_layer_ablation_top3.png new file mode 100644 index 0000000000000000000000000000000000000000..a2c3cb311f2903aff489c407b0f50ca8f286b845 Binary files /dev/null and b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/transition_matrix_layer_ablation_top3.png differ diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/transition_matrix_llrd.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/transition_matrix_llrd.png new file mode 100644 index 0000000000000000000000000000000000000000..57b26ca326283b92e16cf92badfbc963eb014e46 Binary files /dev/null and b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/transition_matrix_llrd.png differ diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/transition_matrix_lora.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/transition_matrix_lora.png new file mode 100644 index 0000000000000000000000000000000000000000..3d2d95a97d6ddb73c44d4349a7bc75e38aeea0e3 Binary files /dev/null and b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/transition_matrix_lora.png differ diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/transition_matrix_low_t.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/transition_matrix_low_t.png new file mode 100644 index 0000000000000000000000000000000000000000..339b001947a9f72934b56a76802764f991c00600 Binary files /dev/null and b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/transition_matrix_low_t.png differ diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/transition_matrix_mixed_replay.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/transition_matrix_mixed_replay.png new file mode 100644 index 0000000000000000000000000000000000000000..a24c85f65bb593824ff3055ca3af1920559c72e4 Binary files /dev/null and b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/transition_matrix_mixed_replay.png differ diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/transition_matrix_normalized_adv.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/transition_matrix_normalized_adv.png new file mode 100644 index 0000000000000000000000000000000000000000..f70f3dfd0fca60b6ee3fc3a149952d94a04cc439 Binary files /dev/null and b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/transition_matrix_normalized_adv.png differ diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/transition_matrix_reward_filtering.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/transition_matrix_reward_filtering.png new file mode 100644 index 0000000000000000000000000000000000000000..9b9aa5ef1b944a8fce413073533b661943015025 Binary files /dev/null and b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/transition_matrix_reward_filtering.png differ diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/transition_matrix_reward_model.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/transition_matrix_reward_model.png new file mode 100644 index 0000000000000000000000000000000000000000..45938e25c038d8e9b602348a48e250555de0720c Binary files /dev/null and b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/transition_matrix_reward_model.png differ diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/transition_matrix_running_stats.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/transition_matrix_running_stats.png new file mode 100644 index 0000000000000000000000000000000000000000..4af89311f43c6ed5cdf4407e014609fcb5f5dece Binary files /dev/null and b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/transition_matrix_running_stats.png differ diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/transition_matrix_t_curriculum.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/transition_matrix_t_curriculum.png new file mode 100644 index 0000000000000000000000000000000000000000..33a5a938dc96392309c40ee7b99bab4534675373 Binary files /dev/null and b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/transition_matrix_t_curriculum.png differ diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/transition_matrix_trust_region_kl.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/transition_matrix_trust_region_kl.png new file mode 100644 index 0000000000000000000000000000000000000000..e5ef83d76ef24422728311a0fca94e447d24bf9d Binary files /dev/null and b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/transition_matrix_trust_region_kl.png differ diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/cka_similarity.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/cka_similarity.png new file mode 100644 index 0000000000000000000000000000000000000000..c310bc6a1f385c626d8ce2be6fc205a0b5d39720 --- /dev/null +++ b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/cka_similarity.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:44593d43a80c26b41e478d39e5c8236623d7dc299ac5c5e283ce40b2b812c5c8 +size 290679 diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/curves_action_diversity.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/curves_action_diversity.png new file mode 100644 index 0000000000000000000000000000000000000000..1a1c5a6a594192dfb3a0aa60e3d7fa6053a6c16d --- /dev/null +++ b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/curves_action_diversity.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:692f105d51077ce63ff6ea982247e27efdb98ce427da2228f19b8f0b577ffec7 +size 355986 diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/curves_advantage_clip.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/curves_advantage_clip.png new file mode 100644 index 0000000000000000000000000000000000000000..b5229ea5c5aed3d86ae71cad29a24caa247fe302 --- /dev/null +++ b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/curves_advantage_clip.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:41ddfc2337c144d5dec6162cbf570ecfb817c54793fbadf746ef0291f644cc80 +size 343397 diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/curves_attention_only.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/curves_attention_only.png new file mode 100644 index 0000000000000000000000000000000000000000..8ec172260b948d454e965867c0efeabedab915a6 --- /dev/null +++ b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/curves_attention_only.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:0bf083bb7799ea94d910229a66c05088bdfb7ef77d7dd1e83c5cdb9afebc597e +size 262074 diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/curves_baseline_rl.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/curves_baseline_rl.png new file mode 100644 index 0000000000000000000000000000000000000000..6f29c6913b605cfe0a1983382a437bcf9f9476a8 --- /dev/null +++ b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/curves_baseline_rl.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:68a6a41afcb4bf2c0ae0eb887359949bba8c68437fd46944c580864e8e28557f +size 331413 diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/curves_bc_wins.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/curves_bc_wins.png new file mode 100644 index 0000000000000000000000000000000000000000..b1425c9617b77c55160a0efc3a31cd9a38f13d28 --- /dev/null +++ b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/curves_bc_wins.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:7c0061a7fcc6c7fb1cda0200c5c24395683010bc46bcb9dae8ac53a29d8a9549 +size 341009 diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/curves_entropy_bonus.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/curves_entropy_bonus.png new file mode 100644 index 0000000000000000000000000000000000000000..905e694ab7c2369e0400d6ae1240bbf630535108 --- /dev/null +++ b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/curves_entropy_bonus.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c5728b730211256bf431ff6a72fcf75379b1903d67b73e34644ef291a4de0c58 +size 330705 diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/curves_ewc.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/curves_ewc.png new file mode 100644 index 0000000000000000000000000000000000000000..32bde7e5357f6366bea382a99b90c1419cd2490c --- /dev/null +++ b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/curves_ewc.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:cefbf2bce56057bf07b2fb5d6cf206f317b352fc0fc1db9d157c208799c11046 +size 321337 diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/curves_ffn_only.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/curves_ffn_only.png new file mode 100644 index 0000000000000000000000000000000000000000..89d9d84b2885991a2556055894330fa4f3d821da --- /dev/null +++ b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/curves_ffn_only.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:46eeab5469caa72bf4ba8ca960fc6c9ca9da98d1423a554bfe869e25521980f2 +size 305749 diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/curves_frozen_backbone.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/curves_frozen_backbone.png new file mode 100644 index 0000000000000000000000000000000000000000..63527eed5b03bb9aa7a52ae7a97e917f537a57b4 --- /dev/null +++ b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/curves_frozen_backbone.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:9760f15a8db02091cbaad275466a743788a3adbf61eaf0ec417924e9c456ec41 +size 257882 diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/curves_gradient_surgery.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/curves_gradient_surgery.png new file mode 100644 index 0000000000000000000000000000000000000000..b070484d580a3fb8b253fb73be50ca35a47aa5ff --- /dev/null +++ b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/curves_gradient_surgery.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:44bf4ae5051aef0ab769ea7b5a072be9cb28b47303351514dd9015c8392a966c +size 345006 diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/curves_head_only.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/curves_head_only.png new file mode 100644 index 0000000000000000000000000000000000000000..aebe0a3c7acddf4bafd93bd63698619644f4fa03 --- /dev/null +++ b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/curves_head_only.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:84eefc42a64d0c8a513b2db52f0bd5f662d2ae4f77cae5a0d29cc7bca7013f86 +size 257202 diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/curves_kl_penalty.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/curves_kl_penalty.png new file mode 100644 index 0000000000000000000000000000000000000000..c39046c1e5d4398b5ba125ecee85c56b17aee799 --- /dev/null +++ b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/curves_kl_penalty.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4a034184498328b065391d527931e23cab380ea96ad1fe011fd9be5a67c0c98c +size 333918 diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/curves_layer_ablation_top1.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/curves_layer_ablation_top1.png new file mode 100644 index 0000000000000000000000000000000000000000..3e68662069c4c25c534b53e04fa0d99a4ccee073 --- /dev/null +++ b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/curves_layer_ablation_top1.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:99269fe8832261c43e0c4271ebf670f4b48f0aff72b76f9c6496bc45b3f14f16 +size 286476 diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/curves_layer_ablation_top2.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/curves_layer_ablation_top2.png new file mode 100644 index 0000000000000000000000000000000000000000..192ca2a6df38b875f6939b94ec42fad247799e42 --- /dev/null +++ b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/curves_layer_ablation_top2.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a18329b9f876bc06de9193263269bef3a2eeed8cea78c171cf2b66a505145f0c +size 291612 diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/curves_layer_ablation_top3.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/curves_layer_ablation_top3.png new file mode 100644 index 0000000000000000000000000000000000000000..e6a516ce85db660480522aefa6420a28e42b3150 --- /dev/null +++ b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/curves_layer_ablation_top3.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:706c7c8edc372b72ea63584d2f7d96aa82d2d996fe294baccb037af9c1fc154c +size 306494 diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/curves_llrd.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/curves_llrd.png new file mode 100644 index 0000000000000000000000000000000000000000..0678dbe5a6b3e5ee6cd734d43da0761fa8085211 --- /dev/null +++ b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/curves_llrd.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4c6ba8854e904f9d5662453c2ba45d1ee3efc9f4b3e9b83b1c34ba78f6c06c80 +size 340141 diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/curves_lora.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/curves_lora.png new file mode 100644 index 0000000000000000000000000000000000000000..84afff3a206a6803eef3497c2c7a599f8c2420df --- /dev/null +++ b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/curves_lora.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:369bd3153e2e3cfe5c65c20d2b7ff4c72dcaa625d653cb53d33b6ee0e6b4a7fc +size 236798 diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/curves_low_t.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/curves_low_t.png new file mode 100644 index 0000000000000000000000000000000000000000..a3b5a967b24d0b0e0fb5adf1a1c8943b44b7fefb --- /dev/null +++ b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/curves_low_t.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:df3fe897122d26fe2595a737d3b3107142ee29a5a8483cb7c93e4b26f834bb6d +size 327422 diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/curves_mixed_replay.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/curves_mixed_replay.png new file mode 100644 index 0000000000000000000000000000000000000000..c47a2172737ec58b169525605d0bbe87de150ddd --- /dev/null +++ b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/curves_mixed_replay.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c9d2653bb727995f48273b9a349ad5a05503328ba4ca0ddbaf64a436e57b859d +size 344053 diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/curves_normalized_adv.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/curves_normalized_adv.png new file mode 100644 index 0000000000000000000000000000000000000000..f45ddba21ee52a51e240d2f8e9948b15ff28684c --- /dev/null +++ b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/curves_normalized_adv.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:08f0d4bb3e5caa39badf401cafc58cd72589a0087bd6965a852b6e92e20bb0f5 +size 252437 diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/curves_reward_filtering.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/curves_reward_filtering.png new file mode 100644 index 0000000000000000000000000000000000000000..74622513ef82da68cf053695ebb5fc6c2958bcc6 --- /dev/null +++ b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/curves_reward_filtering.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b25a6398d0202ea5f1ed44231aef0682651e975a4e4fdc5eaf5c7e81d19eb502 +size 327555 diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/curves_reward_model.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/curves_reward_model.png new file mode 100644 index 0000000000000000000000000000000000000000..94abe0730dd5cbdef3ebf19807c014ca22954067 --- /dev/null +++ b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/curves_reward_model.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:0bdfe4260d665dd71af6abdeae55f7f7992e5ac4ec98be8ea2c3f425a3e27b11 +size 337185 diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/curves_running_stats.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/curves_running_stats.png new file mode 100644 index 0000000000000000000000000000000000000000..4d87cb3cedf53c0e09b9b56f76423e98b3a866da --- /dev/null +++ b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/curves_running_stats.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:132fa87c7abd362e297b623817465c36ab85c48ae5e965fd58958c78c586c5f7 +size 329614 diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/curves_t_curriculum.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/curves_t_curriculum.png new file mode 100644 index 0000000000000000000000000000000000000000..9df625eaf9d00624d14874833d0c4d0e17e11c19 --- /dev/null +++ b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/curves_t_curriculum.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:cf9f3044a96384420ada27a04b487f6d7c3b58be37f3227f8e38495aac3fb320 +size 340798 diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/curves_trust_region_kl.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/curves_trust_region_kl.png new file mode 100644 index 0000000000000000000000000000000000000000..8f533edfee5119bb4ac87fb8d4abe04b820f5fc1 --- /dev/null +++ b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/curves_trust_region_kl.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d993bc2e7a32b696c365cc2f8eda086809a863af0ee97b5b74f657658787cca3 +size 334205 diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/diagnosis_decision_tree.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/diagnosis_decision_tree.png new file mode 100644 index 0000000000000000000000000000000000000000..50d1f59f3260eeaae9f65cc2690e4a498469bd01 --- /dev/null +++ b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/diagnosis_decision_tree.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:919183e2b4d733aaeedade61df932f0b7b3d7ee808d3af20b6b3f81320ac6fcb +size 104578 diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/eval_scores_over_training.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/eval_scores_over_training.png new file mode 100644 index 0000000000000000000000000000000000000000..e0eb1e13fec7b25e939ba8f8daf20fdbe0f3848d --- /dev/null +++ b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/eval_scores_over_training.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e0122400add272c44b2aa9b3659edf931d63192ccb1cb9c338171b5fba388ac3 +size 347896 diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/final_score_comparison.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/final_score_comparison.png new file mode 100644 index 0000000000000000000000000000000000000000..6e9a675069881584e8098abed67eaa51068dbd8b --- /dev/null +++ b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/final_score_comparison.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ec17c47b7cb8a29e846b7d403c0894686f6a1d3c4ebd308c69df453b9bd21eab +size 150907 diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/gradient_alignment.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/gradient_alignment.png new file mode 100644 index 0000000000000000000000000000000000000000..8e3f489294cffc5704593f29d745ce8905806ccc --- /dev/null +++ b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/gradient_alignment.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:dfee54a47ebe64124c994e6a10db054b9e6d61a2c9f726c454470b3971c3dea6 +size 248912 diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/gradient_conflict_map.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/gradient_conflict_map.png new file mode 100644 index 0000000000000000000000000000000000000000..0896e519f1badfa414f4faa2556ce8e51b08cd84 --- /dev/null +++ b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/gradient_conflict_map.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:cbdf5c36f4710f63813faedfde96e5ce001685d3dddcf1ab11b8fd0be1abce5d +size 146081 diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/group_comparison.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/group_comparison.png new file mode 100644 index 0000000000000000000000000000000000000000..7423487b00d8b996f7da2a901a58250c7be58b2d Binary files /dev/null and b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/group_comparison.png differ diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/per_layer_grad_heatmap_action_diversity.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/per_layer_grad_heatmap_action_diversity.png new file mode 100644 index 0000000000000000000000000000000000000000..7e1f3907020d6f8f67a89127449ffcf5e6512ad1 --- /dev/null +++ b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/per_layer_grad_heatmap_action_diversity.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:9117294129e04b011b00ebdab11715c4afff988968c24d29df87389e8ed124b3 +size 635770 diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/per_layer_grad_heatmap_advantage_clip.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/per_layer_grad_heatmap_advantage_clip.png new file mode 100644 index 0000000000000000000000000000000000000000..63bb9cee32becbf2c376c38fdad85931d306ef13 --- /dev/null +++ b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/per_layer_grad_heatmap_advantage_clip.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:fb3e78245c5051ff9596ca5a8ee020a82cf21c75df8fd42df4f90f0bd387d443 +size 648859 diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/per_layer_grad_heatmap_attention_only.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/per_layer_grad_heatmap_attention_only.png new file mode 100644 index 0000000000000000000000000000000000000000..e1c982a658c8106fb0adb338e8dbd394f18c8aa0 --- /dev/null +++ b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/per_layer_grad_heatmap_attention_only.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:965213853ede94e3ffd82c132fcf28ea485d55e956cec0560e378b2bcb8943fb +size 523522 diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/per_layer_grad_heatmap_baseline_rl.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/per_layer_grad_heatmap_baseline_rl.png new file mode 100644 index 0000000000000000000000000000000000000000..82ddfc80320b77265fa688ad37c57eb1b36f7a35 --- /dev/null +++ b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/per_layer_grad_heatmap_baseline_rl.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e250b22110869a3f2437ce0da897419ce06eeeb4e3fac43137822562dc8f7532 +size 633145 diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/per_layer_grad_heatmap_bc_wins.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/per_layer_grad_heatmap_bc_wins.png new file mode 100644 index 0000000000000000000000000000000000000000..2f6107cefbae7c8d4e513090ea5d62f960752db8 --- /dev/null +++ b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/per_layer_grad_heatmap_bc_wins.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:8fe0cc2ba771668987f1c894f0b6cb97f5a80d6085bfcc1bf4445f4cda296e2a +size 657126 diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/per_layer_grad_heatmap_entropy_bonus.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/per_layer_grad_heatmap_entropy_bonus.png new file mode 100644 index 0000000000000000000000000000000000000000..bba2a91b4082a3df508c692d262f73323db272dc --- /dev/null +++ b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/per_layer_grad_heatmap_entropy_bonus.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b90855ca01ab188a82497a2300f645a27b8afce196e6e9cde8a231af31c8c98e +size 650095 diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/per_layer_grad_heatmap_ewc.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/per_layer_grad_heatmap_ewc.png new file mode 100644 index 0000000000000000000000000000000000000000..f4a03ad1fe9e41c93a8c4808028f245e28a37fe3 --- /dev/null +++ b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/per_layer_grad_heatmap_ewc.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:28972e9618898cf93f2faf9056c0e59007ddd024a3c9bc7b46f303f064884b97 +size 620015 diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/per_layer_grad_heatmap_ffn_only.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/per_layer_grad_heatmap_ffn_only.png new file mode 100644 index 0000000000000000000000000000000000000000..ffe05afa37d06dbe41b1cbe23754c2378688ca99 --- /dev/null +++ b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/per_layer_grad_heatmap_ffn_only.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f71e241620c23f0d0c11d31a688600a9887d6345320e46f8011e1070302f2d90 +size 579344 diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/per_layer_grad_heatmap_frozen_backbone.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/per_layer_grad_heatmap_frozen_backbone.png new file mode 100644 index 0000000000000000000000000000000000000000..16cf861138b847d7f61d8b21e54f723ac5ca9d0a --- /dev/null +++ b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/per_layer_grad_heatmap_frozen_backbone.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:90adc39ac2eab073cb04d3901a3f1e82d6d4ab0ea1a28b4cc2ff2af19287b0bd +size 470394 diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/per_layer_grad_heatmap_gradient_surgery.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/per_layer_grad_heatmap_gradient_surgery.png new file mode 100644 index 0000000000000000000000000000000000000000..0eb919762fafc34fec20eedb7678942df7fbf7e1 --- /dev/null +++ b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/per_layer_grad_heatmap_gradient_surgery.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:547c6bca9c6059378318db5a3772a9a59b79a2235da093acfa3dc01bba5f10ac +size 634449 diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/per_layer_grad_heatmap_head_only.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/per_layer_grad_heatmap_head_only.png new file mode 100644 index 0000000000000000000000000000000000000000..6e70c4ffc50dd44c25a3bb3ef60afcf46440d6c4 --- /dev/null +++ b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/per_layer_grad_heatmap_head_only.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4a135f63a1e1e6c6fa527802fdb7dc072dda40cb1c681464dfc6bc02259d7820 +size 471173 diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/per_layer_grad_heatmap_kl_penalty.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/per_layer_grad_heatmap_kl_penalty.png new file mode 100644 index 0000000000000000000000000000000000000000..7b1b5d98f8dbd1dda63ffe9ab2c894e4f7e05468 --- /dev/null +++ b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/per_layer_grad_heatmap_kl_penalty.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:9e1788e4720fbeac598ba06d70e5bc6239cd5e206350dcac1805d4faf8180675 +size 648697 diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/per_layer_grad_heatmap_layer_ablation_top1.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/per_layer_grad_heatmap_layer_ablation_top1.png new file mode 100644 index 0000000000000000000000000000000000000000..6090145bf08ee37e0afaf15bec06ea382db210aa --- /dev/null +++ b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/per_layer_grad_heatmap_layer_ablation_top1.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:85a743053213740f4288e2cf715165d5e1c36d30c35ff540bc3a47c1ac1f7c6a +size 571169 diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/per_layer_grad_heatmap_layer_ablation_top2.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/per_layer_grad_heatmap_layer_ablation_top2.png new file mode 100644 index 0000000000000000000000000000000000000000..76b011b88b9d376792c7adc3f195b3833c8d2cca --- /dev/null +++ b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/per_layer_grad_heatmap_layer_ablation_top2.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b9833dc299acbff783db8ba4ccdd6328c9e5cad6c236b9d53e68933244fff619 +size 595591 diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/per_layer_grad_heatmap_layer_ablation_top3.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/per_layer_grad_heatmap_layer_ablation_top3.png new file mode 100644 index 0000000000000000000000000000000000000000..5d3697e719cc2d3c0aaa79b52a97aa6bbf7656a3 --- /dev/null +++ b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/per_layer_grad_heatmap_layer_ablation_top3.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ab72de8a763a49b97e3ef0c3f5ac87e4c37559275f6fd716bc125fc37786913e +size 634144 diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/per_layer_grad_heatmap_llrd.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/per_layer_grad_heatmap_llrd.png new file mode 100644 index 0000000000000000000000000000000000000000..3e36b27446b60f78ec4ea77212c85f875888b481 --- /dev/null +++ b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/per_layer_grad_heatmap_llrd.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e5adfbb0c7e74d4609063ef1f842345eae97d59f435b4cb37dc179fdd9a4701d +size 632657 diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/per_layer_grad_heatmap_lora.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/per_layer_grad_heatmap_lora.png new file mode 100644 index 0000000000000000000000000000000000000000..925c840f0e8f6389aa8d5e6c50465166375007b2 --- /dev/null +++ b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/per_layer_grad_heatmap_lora.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:37b2c0d771f4f849a647a4083f774de47480a388d8c830bfb34db521f2049c1e +size 562059 diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/per_layer_grad_heatmap_low_t.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/per_layer_grad_heatmap_low_t.png new file mode 100644 index 0000000000000000000000000000000000000000..57a1cee5da24e9344af5fc142f4c1409ce42290a --- /dev/null +++ b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/per_layer_grad_heatmap_low_t.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:efb3cd405fd950a3f19991b173178082a5cb75c0af0c32c597af3b92df8b821f +size 655970 diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/per_layer_grad_heatmap_mixed_replay.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/per_layer_grad_heatmap_mixed_replay.png new file mode 100644 index 0000000000000000000000000000000000000000..04a0bd9dc44547098146d66d7c2007518b9a0d94 --- /dev/null +++ b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/per_layer_grad_heatmap_mixed_replay.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:869638d9115e97dfd808143d7dd406ee4ed43adb002f2e0a56fd950f9c5fde46 +size 619249 diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/per_layer_grad_heatmap_normalized_adv.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/per_layer_grad_heatmap_normalized_adv.png new file mode 100644 index 0000000000000000000000000000000000000000..289e87f22f6d071a9723592b7c9ec9e5ee289639 --- /dev/null +++ b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/per_layer_grad_heatmap_normalized_adv.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ee3e4be59cfe42e837d62fa50d507f919afeefd3cba62a4f2650706d1927dbd5 +size 604201 diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/per_layer_grad_heatmap_reward_filtering.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/per_layer_grad_heatmap_reward_filtering.png new file mode 100644 index 0000000000000000000000000000000000000000..4c2321c56378015ec65b8f58bdb89f9370a56cbd --- /dev/null +++ b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/per_layer_grad_heatmap_reward_filtering.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:de6e1001eae7d7664c90d9f429f0fe292b78186e8415f677f545f4b3ca544db1 +size 647356 diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/per_layer_grad_heatmap_reward_model.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/per_layer_grad_heatmap_reward_model.png new file mode 100644 index 0000000000000000000000000000000000000000..fa13e6cbcdcf0707cca32f8a0ff79c40d880aa02 --- /dev/null +++ b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/per_layer_grad_heatmap_reward_model.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:9daa06890a2dab01bfe729b493b5bbee6ef393ca02138cfe6acc883039c619ce +size 630140 diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/per_layer_grad_heatmap_running_stats.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/per_layer_grad_heatmap_running_stats.png new file mode 100644 index 0000000000000000000000000000000000000000..2c8b46e1fcd2762a88e83d97b2bf9ee1014a9ebb --- /dev/null +++ b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/per_layer_grad_heatmap_running_stats.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:fcb8698ed91abb5c11c2ebac01fa8f02ffd371e400ee7b020bcbf551e406ba9f +size 663296 diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/per_layer_grad_heatmap_t_curriculum.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/per_layer_grad_heatmap_t_curriculum.png new file mode 100644 index 0000000000000000000000000000000000000000..66a29584c966ef833178c0a927f99685a871eee7 --- /dev/null +++ b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/per_layer_grad_heatmap_t_curriculum.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:8239835c15562278c31f7adce44a595cdbd73493b9892ff191d36d41228cbe4b +size 664034 diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/per_layer_grad_heatmap_trust_region_kl.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/per_layer_grad_heatmap_trust_region_kl.png new file mode 100644 index 0000000000000000000000000000000000000000..47fc6a6d10f2b24399e28c63c988eb62f25da9f0 --- /dev/null +++ b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/per_layer_grad_heatmap_trust_region_kl.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:8723ec21354878e4104db606715319b8f2f086f75920ee56a8f37fb7b22b79be +size 608690 diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/representation_drift.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/representation_drift.png new file mode 100644 index 0000000000000000000000000000000000000000..bb3f1feb52b2c2c4840aa8aaf1d6c6c0fb28cd79 --- /dev/null +++ b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/representation_drift.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:6547c1be5f274793b131c9df4ffab650237e78591f480dc53b79f87bb50e6370 +size 209309 diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/score_delta_over_baseline_rl.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/score_delta_over_baseline_rl.png new file mode 100644 index 0000000000000000000000000000000000000000..9b2d2653aa18516117c2967b49387c12d790a05f --- /dev/null +++ b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/score_delta_over_baseline_rl.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4a71e3d5e3a6a7d24fe06e763676821a23716d070bf5f6ad791d98f2e8eafb26 +size 134007 diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/t_bin_grad_norms_action_diversity.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/t_bin_grad_norms_action_diversity.png new file mode 100644 index 0000000000000000000000000000000000000000..d0f5ae6df6ceaac228fd35a3ef30c10bf47043cf --- /dev/null +++ b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/t_bin_grad_norms_action_diversity.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:007ace3dad75964a45c593569f5f8f4ee58f49093358f701c5075f2828cb90ec +size 187940 diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/t_bin_grad_norms_advantage_clip.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/t_bin_grad_norms_advantage_clip.png new file mode 100644 index 0000000000000000000000000000000000000000..5727e6c1575d9f327faca7af66a545b70fd41c48 --- /dev/null +++ b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/t_bin_grad_norms_advantage_clip.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f055d32819cd7592d2c168bc93c121fb0b4f262bfe303d2902570cb2db905f0a +size 192830 diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/t_bin_grad_norms_attention_only.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/t_bin_grad_norms_attention_only.png new file mode 100644 index 0000000000000000000000000000000000000000..9847688ad3e88ae123cdc4df9f2bd6742c468ade --- /dev/null +++ b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/t_bin_grad_norms_attention_only.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:53f167b6fb25034bc2e6ee9848e7a0bef6a22c24d9efa19d3f735c210b9fa24e +size 195455 diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/t_bin_grad_norms_baseline_rl.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/t_bin_grad_norms_baseline_rl.png new file mode 100644 index 0000000000000000000000000000000000000000..fa467a45ded4481c919bc3b3817590ea69e623b2 --- /dev/null +++ b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/t_bin_grad_norms_baseline_rl.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:532e3f8761399ea366804661f148b65e1f415c272d909c5365bc2694e5670275 +size 186133 diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/t_bin_grad_norms_bc_wins.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/t_bin_grad_norms_bc_wins.png new file mode 100644 index 0000000000000000000000000000000000000000..0ed4fad94c30b3f74b66408975ae423a1bf1e1e6 --- /dev/null +++ b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/t_bin_grad_norms_bc_wins.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:0616e3abc75b57a5bd75d6c403849f4a8f16fc24e6cea08d85a499496a883bf5 +size 153627 diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/t_bin_grad_norms_entropy_bonus.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/t_bin_grad_norms_entropy_bonus.png new file mode 100644 index 0000000000000000000000000000000000000000..683df4c4b281171b0fc77e79528663cd5944eba0 --- /dev/null +++ b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/t_bin_grad_norms_entropy_bonus.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:803f74d13538b199b36a734774543922fef2ddf88699474e98a093c397e60118 +size 183155 diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/t_bin_grad_norms_ewc.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/t_bin_grad_norms_ewc.png new file mode 100644 index 0000000000000000000000000000000000000000..d1a2f3cd3b1f4ad5467f7bbb25618e2f9578006e --- /dev/null +++ b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/t_bin_grad_norms_ewc.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:05ca0467d3717181c746f3e2ed1aa3f71bdf347d7162d44ddaacce4527d8bb92 +size 175387 diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/t_bin_grad_norms_ffn_only.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/t_bin_grad_norms_ffn_only.png new file mode 100644 index 0000000000000000000000000000000000000000..9f430954dae038ff0f5ad79d4563fcecd6161869 --- /dev/null +++ b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/t_bin_grad_norms_ffn_only.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:13e519ebf3e30e8e9b55db8d30280af3a5a4e783d11c438825941c133b189f28 +size 215131 diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/t_bin_grad_norms_frozen_backbone.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/t_bin_grad_norms_frozen_backbone.png new file mode 100644 index 0000000000000000000000000000000000000000..edd37da56856ab400728b58655b653f916b0b9c3 --- /dev/null +++ b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/t_bin_grad_norms_frozen_backbone.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:785fd8bf211f93e1f07289b60d29d5f4fd253063eda5338bfb673da3823807eb +size 187307 diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/t_bin_grad_norms_gradient_surgery.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/t_bin_grad_norms_gradient_surgery.png new file mode 100644 index 0000000000000000000000000000000000000000..ff21f64911b019a38c5c0f8dcfa87363a3787ef7 --- /dev/null +++ b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/t_bin_grad_norms_gradient_surgery.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:7fc39fddd3b6cd3e39768f7241f1ebc22252dd17d96f28e5aeb87f37c1dd68f6 +size 187654 diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/t_bin_grad_norms_head_only.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/t_bin_grad_norms_head_only.png new file mode 100644 index 0000000000000000000000000000000000000000..6ebc5fe6706a3818a0b2b06451d58e8ed9f14bda --- /dev/null +++ b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/t_bin_grad_norms_head_only.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f988502a78935585b3b04c9604b65ee0727ccf7f9bf3d4819e8e2612972d16dd +size 187162 diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/t_bin_grad_norms_kl_penalty.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/t_bin_grad_norms_kl_penalty.png new file mode 100644 index 0000000000000000000000000000000000000000..23c979b88cbd4fe6ae80c5fd3e7f41554fdabb52 --- /dev/null +++ b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/t_bin_grad_norms_kl_penalty.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:8dc72f23b90d98ce82a4411dba2c4affa3c0f06046d127c3e3098dbccb60ca69 +size 184990 diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/t_bin_grad_norms_layer_ablation_top1.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/t_bin_grad_norms_layer_ablation_top1.png new file mode 100644 index 0000000000000000000000000000000000000000..7f258e7d84f162effbc1738fc062f68036b2d534 --- /dev/null +++ b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/t_bin_grad_norms_layer_ablation_top1.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ead65ae8815f6025348d58f8f70f94ebca866cad91fd11a3be0f24646037d83e +size 218911 diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/t_bin_grad_norms_layer_ablation_top2.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/t_bin_grad_norms_layer_ablation_top2.png new file mode 100644 index 0000000000000000000000000000000000000000..b1b3d3c1643d969bd97bbb51a16e2966c17470fc --- /dev/null +++ b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/t_bin_grad_norms_layer_ablation_top2.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:17ce988ee085b339c0960df2818428e15ad81abbc2216b118fcc27700ad6c738 +size 227174 diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/t_bin_grad_norms_layer_ablation_top3.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/t_bin_grad_norms_layer_ablation_top3.png new file mode 100644 index 0000000000000000000000000000000000000000..f0d33d31afcfeb8b0a31dc4ab180fe1e55a2d63a --- /dev/null +++ b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/t_bin_grad_norms_layer_ablation_top3.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:24aafd1030168a62bda314878cb45e5b71fa10b7b1b52af789c34f0d7990b9bb +size 219898 diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/t_bin_grad_norms_llrd.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/t_bin_grad_norms_llrd.png new file mode 100644 index 0000000000000000000000000000000000000000..c756680f5978d6888cbff9393c13ee2257a9d3ab --- /dev/null +++ b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/t_bin_grad_norms_llrd.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:2e632294988ca1b860a032a109b88dabe01fd355a16731166f911304e6c6f077 +size 186632 diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/t_bin_grad_norms_lora.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/t_bin_grad_norms_lora.png new file mode 100644 index 0000000000000000000000000000000000000000..268634acbbd05cec6a7520fd5891ed2cc682247c Binary files /dev/null and b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/t_bin_grad_norms_lora.png differ diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/t_bin_grad_norms_low_t.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/t_bin_grad_norms_low_t.png new file mode 100644 index 0000000000000000000000000000000000000000..4ab0c363518ecff97a596985a4fd9458b0fe3e10 --- /dev/null +++ b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/t_bin_grad_norms_low_t.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c000b09dd84f6cdd71d828d251ddb96317da7efda6f180fb5249275eaffd5cde +size 178303 diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/t_bin_grad_norms_mixed_replay.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/t_bin_grad_norms_mixed_replay.png new file mode 100644 index 0000000000000000000000000000000000000000..d95cf109345620af88462185399e0877049f8771 --- /dev/null +++ b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/t_bin_grad_norms_mixed_replay.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:85d8d0c6bcee0597c5e1c559f1be45c8170759e48ef34833c5b3bcad4f8458fe +size 176911 diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/t_bin_grad_norms_normalized_adv.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/t_bin_grad_norms_normalized_adv.png new file mode 100644 index 0000000000000000000000000000000000000000..ea168c31e9cda0a51270ea1237f7ff014dbc9693 --- /dev/null +++ b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/t_bin_grad_norms_normalized_adv.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:fa9cf827f4bb35145bb6bcbfc0f0eca38e157db3895df881a4b0f6a1876f21fe +size 260526 diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/t_bin_grad_norms_reward_filtering.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/t_bin_grad_norms_reward_filtering.png new file mode 100644 index 0000000000000000000000000000000000000000..8c7c68a7d33bdbd9100a6e3ef0cb0a7e70750386 --- /dev/null +++ b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/t_bin_grad_norms_reward_filtering.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ea40744f368d168d2127c62d695714e2621866b759787b5a91957030387bf477 +size 183886 diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/t_bin_grad_norms_reward_model.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/t_bin_grad_norms_reward_model.png new file mode 100644 index 0000000000000000000000000000000000000000..5072c9556f449461a823792b1cdf0e8c2fe1786c --- /dev/null +++ b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/t_bin_grad_norms_reward_model.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:09b8a6bcadac61b4a27ae93ecadfb9c35433bbffcf4dc549f98c331cf5613850 +size 165716 diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/t_bin_grad_norms_running_stats.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/t_bin_grad_norms_running_stats.png new file mode 100644 index 0000000000000000000000000000000000000000..2144b4779e22535956ffe16c3492d6f3f288eb80 --- /dev/null +++ b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/t_bin_grad_norms_running_stats.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:92c36eb53fdde01f3f59e39eeb265b97d4d86a4d7b7a64f5c6157a2891c7e4d6 +size 173606 diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/t_bin_grad_norms_t_curriculum.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/t_bin_grad_norms_t_curriculum.png new file mode 100644 index 0000000000000000000000000000000000000000..60b8af4bd413b0c7bbfa937904b3204c0803cd01 --- /dev/null +++ b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/t_bin_grad_norms_t_curriculum.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:515261a5f13dcf5753e8fa8e37bd50f19adae673c949ea9f55eb17cd8362b715 +size 173067 diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/t_bin_grad_norms_trust_region_kl.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/t_bin_grad_norms_trust_region_kl.png new file mode 100644 index 0000000000000000000000000000000000000000..05e623cce9702f9e7a1b2d7b20d52de1039572e1 --- /dev/null +++ b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/t_bin_grad_norms_trust_region_kl.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a023d1d969bf7a88930905249dd82839069fbfe943161984401f857f79bfb746 +size 182990 diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/t_bin_norms_heatmap.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/t_bin_norms_heatmap.png new file mode 100644 index 0000000000000000000000000000000000000000..32674032b759041a6df618043b79c21613f8d829 --- /dev/null +++ b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/t_bin_norms_heatmap.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e9645f05d86a8a02d063f550eb79c3bbd153b31d615e579b120f824fcac8fb20 +size 136164 diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/t_distribution_analysis.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/t_distribution_analysis.png new file mode 100644 index 0000000000000000000000000000000000000000..8c76f8bf3993df88823eea3902c94def755b8166 --- /dev/null +++ b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/t_distribution_analysis.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:374b51477ed71e52770a6c9b991033808cae04960e3a92b5168e085eed31499a +size 347499 diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/win_rate_and_effective_batch_size.png b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/win_rate_and_effective_batch_size.png new file mode 100644 index 0000000000000000000000000000000000000000..abee4a1459e12ef75bff95b0d1032cc3d6c5e744 --- /dev/null +++ b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/win_rate_and_effective_batch_size.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:403950780d101e572ae75a05a5a71abc25070078ab3d36b565486b8649e42642 +size 527666 diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/tables/achievement_summary.csv b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/tables/achievement_summary.csv new file mode 100644 index 0000000000000000000000000000000000000000..d86c9ef11d5974d3a7548d4bfd6dfc144bb60ecf --- /dev/null +++ b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/tables/achievement_summary.csv @@ -0,0 +1,23 @@ +Achievement,Pretrained,action_diversity,delta_action_diversity,advantage_clip,delta_advantage_clip,attention_only,delta_attention_only,baseline_rl,delta_baseline_rl,bc_wins,delta_bc_wins,entropy_bonus,delta_entropy_bonus,ewc,delta_ewc,ffn_only,delta_ffn_only,frozen_backbone,delta_frozen_backbone,gradient_surgery,delta_gradient_surgery,head_only,delta_head_only,kl_penalty,delta_kl_penalty,layer_ablation_top1,delta_layer_ablation_top1,layer_ablation_top2,delta_layer_ablation_top2,layer_ablation_top3,delta_layer_ablation_top3,llrd,delta_llrd,lora,delta_lora,low_t,delta_low_t,mixed_replay,delta_mixed_replay,normalized_adv,delta_normalized_adv,reward_filtering,delta_reward_filtering,reward_model,delta_reward_model,running_stats,delta_running_stats,t_curriculum,delta_t_curriculum,trust_region_kl,delta_trust_region_kl +Achievements/collect_coal,0.4857,0.5079,0.0222,0.4338,-0.0519,0.0,-0.4857,0.4429,-0.0429,0.4643,-0.0214,0.4255,-0.0602,0.5407,0.055,0.0303,-0.4554,0.0,-0.4857,0.4118,-0.0739,0.0,-0.4857,0.4815,-0.0042,0.1702,-0.3155,0.2806,-0.2051,0.1812,-0.3046,0.4468,-0.0389,0.0,-0.4857,0.4085,-0.0773,0.5391,0.0533,0.084,-0.4017,0.4317,-0.0541,0.4615,-0.0242,0.5074,0.0216,0.4397,-0.046,0.4074,-0.0783 +Achievements/collect_diamond,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0074,0.0074,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0 +Achievements/collect_drink,0.3857,0.4286,0.0429,0.4044,0.0187,0.0,-0.3857,0.3857,0.0,0.3714,-0.0143,0.3333,-0.0524,0.437,0.0513,0.2848,-0.1009,0.0,-0.3857,0.4191,0.0334,0.0,-0.3857,0.3259,-0.0598,0.3333,-0.0524,0.4245,0.0387,0.4203,0.0346,0.4043,0.0185,0.0,-0.3857,0.3944,0.0087,0.4141,0.0283,0.1849,-0.2008,0.3453,-0.0404,0.3427,-0.0431,0.3676,-0.0181,0.3972,0.0114,0.4,0.0143 +Achievements/collect_iron,0.3357,0.3095,-0.0262,0.3309,-0.0048,0.0,-0.3357,0.3786,0.0429,0.3571,0.0214,0.2979,-0.0378,0.3556,0.0198,0.0,-0.3357,0.0,-0.3357,0.3088,-0.0269,0.0,-0.3357,0.3704,0.0347,0.0213,-0.3144,0.1151,-0.2206,0.0652,-0.2705,0.2979,-0.0378,0.0,-0.3357,0.2676,-0.0681,0.4062,0.0705,0.0084,-0.3273,0.4029,0.0672,0.3427,0.0069,0.4118,0.0761,0.3546,0.0189,0.2815,-0.0542 +Achievements/collect_sapling,0.9071,0.9127,0.0056,0.9265,0.0193,0.0,-0.9071,0.8643,-0.0429,0.8571,-0.05,0.9362,0.029,0.9037,-0.0034,0.5576,-0.3496,0.0,-0.9071,0.9191,0.012,0.0,-0.9071,0.9333,0.0262,0.7801,-0.127,0.8993,-0.0079,0.8623,-0.0448,0.8936,-0.0135,0.0,-0.9071,0.9296,0.0224,0.9141,0.0069,1.0,0.0929,0.9065,-0.0007,0.9021,-0.005,0.8897,-0.0174,0.9149,0.0078,0.8889,-0.0183 +Achievements/collect_stone,0.9,0.8651,-0.0349,0.8824,-0.0176,0.0,-0.9,0.8929,-0.0071,0.8929,-0.0071,0.8936,-0.0064,0.9037,0.0037,0.1576,-0.7424,0.0,-0.9,0.8603,-0.0397,0.0,-0.9,0.8815,-0.0185,0.5106,-0.3894,0.5899,-0.3101,0.6522,-0.2478,0.8652,-0.0348,0.0,-0.9,0.8873,-0.0127,0.9141,0.0141,0.3025,-0.5975,0.9065,0.0065,0.8951,-0.0049,0.9044,0.0044,0.8865,-0.0135,0.9333,0.0333 +Achievements/collect_wood,1.0,0.9921,-0.0079,0.9926,-0.0074,0.0,-1.0,0.9929,-0.0071,0.9929,-0.0071,0.9858,-0.0142,1.0,0.0,0.9212,-0.0788,0.0,-1.0,1.0,0.0,0.0,-1.0,1.0,0.0,1.0,0.0,1.0,0.0,0.9928,-0.0072,0.9929,-0.0071,0.0,-1.0,0.9859,-0.0141,1.0,0.0,0.958,-0.042,0.9928,-0.0072,1.0,0.0,0.9926,-0.0074,0.9929,-0.0071,1.0,0.0 +Achievements/defeat_skeleton,0.1571,0.254,0.0968,0.2279,0.0708,0.0,-0.1571,0.2143,0.0571,0.2071,0.05,0.1844,0.0273,0.2,0.0429,0.0242,-0.1329,0.0,-0.1571,0.2132,0.0561,0.0,-0.1571,0.2296,0.0725,0.078,-0.0791,0.0791,-0.078,0.1087,-0.0484,0.0993,-0.0579,0.0,-0.1571,0.1972,0.04,0.2266,0.0694,0.0924,-0.0647,0.2446,0.0875,0.2028,0.0457,0.1838,0.0267,0.2553,0.0982,0.2296,0.0725 +Achievements/defeat_zombie,0.6214,0.6508,0.0294,0.5956,-0.0258,0.0,-0.6214,0.7214,0.1,0.5714,-0.05,0.695,0.0736,0.6815,0.0601,0.2606,-0.3608,0.0,-0.6214,0.6912,0.0697,0.0,-0.6214,0.6519,0.0304,0.5887,-0.0328,0.5971,-0.0243,0.6884,0.067,0.6879,0.0665,0.0,-0.6214,0.6761,0.0546,0.6484,0.027,0.6134,-0.008,0.6906,0.0692,0.6643,0.0429,0.6397,0.0183,0.6454,0.024,0.6815,0.0601 +Achievements/eat_cow,0.3357,0.4603,0.1246,0.3603,0.0246,0.0,-0.3357,0.3429,0.0071,0.3786,0.0429,0.383,0.0473,0.3333,-0.0024,0.1152,-0.2206,0.0,-0.3357,0.375,0.0393,0.0,-0.3357,0.3926,0.0569,0.2057,-0.13,0.2446,-0.0911,0.2971,-0.0386,0.2908,-0.0449,0.0,-0.3357,0.3732,0.0375,0.4219,0.0862,0.2437,-0.092,0.3022,-0.0336,0.3427,0.0069,0.375,0.0393,0.305,-0.0307,0.363,0.0272 +Achievements/eat_plant,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0 +Achievements/make_iron_pickaxe,0.0214,0.0238,0.0024,0.0221,0.0006,0.0,-0.0214,0.0214,0.0,0.0,-0.0214,0.0284,0.0069,0.037,0.0156,0.0,-0.0214,0.0,-0.0214,0.0074,-0.0141,0.0,-0.0214,0.0222,0.0008,0.0,-0.0214,0.0144,-0.007,0.0072,-0.0142,0.0213,-0.0002,0.0,-0.0214,0.0141,-0.0073,0.0078,-0.0136,0.0,-0.0214,0.0504,0.0289,0.014,-0.0074,0.0368,0.0153,0.0213,-0.0002,0.0,-0.0214 +Achievements/make_iron_sword,0.1214,0.1349,0.0135,0.125,0.0036,0.0,-0.1214,0.1571,0.0357,0.1571,0.0357,0.1206,-0.0009,0.1259,0.0045,0.0,-0.1214,0.0,-0.1214,0.0956,-0.0258,0.0,-0.1214,0.1778,0.0563,0.0,-0.1214,0.0288,-0.0927,0.0,-0.1214,0.0922,-0.0292,0.0,-0.1214,0.0845,-0.0369,0.1875,0.0661,0.0,-0.1214,0.1367,0.0153,0.1538,0.0324,0.1471,0.0256,0.1277,0.0062,0.0815,-0.0399 +Achievements/make_stone_pickaxe,0.65,0.6746,0.0246,0.6029,-0.0471,0.0,-0.65,0.6714,0.0214,0.6929,0.0429,0.6525,0.0025,0.7407,0.0907,0.0242,-0.6258,0.0,-0.65,0.6838,0.0338,0.0,-0.65,0.6815,0.0315,0.1418,-0.5082,0.3525,-0.2975,0.2971,-0.3529,0.6525,0.0025,0.0,-0.65,0.6408,-0.0092,0.6875,0.0375,0.0252,-0.6248,0.7122,0.0622,0.6993,0.0493,0.6838,0.0338,0.6454,-0.0046,0.7111,0.0611 +Achievements/make_stone_sword,0.7286,0.7302,0.0016,0.7353,0.0067,0.0,-0.7286,0.7214,-0.0071,0.75,0.0214,0.766,0.0374,0.7259,-0.0026,0.0545,-0.674,0.0,-0.7286,0.6765,-0.0521,0.0,-0.7286,0.7185,-0.0101,0.2482,-0.4803,0.3669,-0.3617,0.3696,-0.359,0.7447,0.0161,0.0,-0.7286,0.7465,0.0179,0.7656,0.0371,0.0672,-0.6613,0.7986,0.07,0.6923,-0.0363,0.7279,-0.0006,0.7021,-0.0264,0.763,0.0344 +Achievements/make_wood_pickaxe,0.9357,0.9444,0.0087,0.9265,-0.0092,0.0,-0.9357,0.9286,-0.0071,0.9571,0.0214,0.9362,0.0005,0.9556,0.0198,0.3636,-0.5721,0.0,-0.9357,0.9412,0.0055,0.0,-0.9357,0.9481,0.0124,0.695,-0.2407,0.7554,-0.1803,0.7971,-0.1386,0.922,-0.0137,0.0,-0.9357,0.9155,-0.0202,0.9297,-0.006,0.3445,-0.5912,0.9353,-0.0005,0.9301,-0.0056,0.9412,0.0055,0.922,-0.0137,0.9481,0.0124 +Achievements/make_wood_sword,0.6714,0.6587,-0.0127,0.6765,0.005,0.0,-0.6714,0.6786,0.0071,0.7357,0.0643,0.6879,0.0165,0.6667,-0.0048,0.3091,-0.3623,0.0,-0.6714,0.7132,0.0418,0.0,-0.6714,0.6889,0.0175,0.6596,-0.0119,0.7554,0.084,0.7971,0.1257,0.6738,0.0023,0.0,-0.6714,0.7254,0.0539,0.6875,0.0161,0.1176,-0.5538,0.6978,0.0264,0.6154,-0.056,0.75,0.0786,0.6738,0.0023,0.7481,0.0767 +Achievements/place_furnace,0.8,0.7778,-0.0222,0.8309,0.0309,0.0,-0.8,0.8214,0.0214,0.8214,0.0214,0.8156,0.0156,0.8148,0.0148,0.0303,-0.7697,0.0,-0.8,0.7574,-0.0426,0.0,-0.8,0.8,-0.0,0.1986,-0.6014,0.4388,-0.3612,0.4565,-0.3435,0.7872,-0.0128,0.0,-0.8,0.7958,-0.0042,0.7969,-0.0031,0.2017,-0.5983,0.8417,0.0417,0.7902,-0.0098,0.8088,0.0088,0.8298,0.0298,0.8,-0.0 +Achievements/place_plant,0.5929,0.6508,0.0579,0.7059,0.113,0.0,-0.5929,0.5786,-0.0143,0.6286,0.0357,0.695,0.1022,0.6519,0.059,0.0242,-0.5686,0.0,-0.5929,0.6324,0.0395,0.0,-0.5929,0.6593,0.0664,0.0426,-0.5503,0.0791,-0.5137,0.1014,-0.4914,0.6879,0.0951,0.0,-0.5929,0.6901,0.0973,0.6719,0.079,0.9412,0.3483,0.6906,0.0978,0.7063,0.1134,0.625,0.0321,0.6667,0.0738,0.6519,0.059 +Achievements/place_stone,0.6429,0.6429,0.0,0.7279,0.0851,0.0,-0.6429,0.7143,0.0714,0.7071,0.0643,0.695,0.0522,0.7185,0.0757,0.0485,-0.5944,0.0,-0.6429,0.6397,-0.0032,0.0,-0.6429,0.6444,0.0016,0.227,-0.4159,0.2806,-0.3623,0.2681,-0.3747,0.6879,0.0451,0.0,-0.6429,0.7254,0.0825,0.7187,0.0759,0.1261,-0.5168,0.6259,-0.017,0.6643,0.0215,0.6103,-0.0326,0.695,0.0522,0.7333,0.0905 +Achievements/place_table,0.9786,0.9683,-0.0103,0.9853,0.0067,0.0,-0.9786,0.9571,-0.0214,0.9643,-0.0143,0.9433,-0.0353,0.9778,-0.0008,0.4788,-0.4998,0.0,-0.9786,0.9853,0.0067,0.0,-0.9786,0.9704,-0.0082,0.8014,-0.1772,0.8561,-0.1225,0.8623,-0.1163,0.9504,-0.0282,0.0,-0.9786,0.9507,-0.0279,0.9922,0.0136,0.4538,-0.5248,0.9712,-0.0073,0.958,-0.0205,0.9779,-0.0006,0.9645,-0.014,1.0,0.0214 +Achievements/wake_up,0.0929,0.119,0.0262,0.0735,-0.0193,0.0,-0.0929,0.0786,-0.0143,0.1,0.0071,0.0709,-0.0219,0.0963,0.0034,0.0061,-0.0868,0.0,-0.0929,0.1324,0.0395,0.0,-0.0929,0.1481,0.0553,0.2199,0.127,0.1583,0.0654,0.2391,0.1463,0.1064,0.0135,0.0,-0.0929,0.1197,0.0269,0.0547,-0.0382,0.0,-0.0929,0.0863,-0.0065,0.0699,-0.0229,0.0809,-0.012,0.0496,-0.0432,0.1185,0.0257 diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/tables/achievement_summary.tex b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/tables/achievement_summary.tex new file mode 100644 index 0000000000000000000000000000000000000000..4e2eee67b80df22d3ee934921d6f8c62ee4f57d7 --- /dev/null +++ b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/tables/achievement_summary.tex @@ -0,0 +1,33 @@ +\begin{table}[htbp] +\centering +\caption{Per-achievement final unlock rates and delta vs pretrained.} +\label{tab:achievements} +\begin{tabular}{lrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrr} +\toprule +\textbf{Achievement} & \textbf{Pretrained} & \textbf{action\_diversity} & \textbf{delta\_action\_diversity} & \textbf{advantage\_clip} & \textbf{delta\_advantage\_clip} & \textbf{attention\_only} & \textbf{delta\_attention\_only} & \textbf{baseline\_rl} & \textbf{delta\_baseline\_rl} & \textbf{bc\_wins} & \textbf{delta\_bc\_wins} & \textbf{entropy\_bonus} & \textbf{delta\_entropy\_bonus} & \textbf{ewc} & \textbf{delta\_ewc} & \textbf{ffn\_only} & \textbf{delta\_ffn\_only} & \textbf{frozen\_backbone} & \textbf{delta\_frozen\_backbone} & \textbf{gradient\_surgery} & \textbf{delta\_gradient\_surgery} & \textbf{head\_only} & \textbf{delta\_head\_only} & \textbf{kl\_penalty} & \textbf{delta\_kl\_penalty} & \textbf{layer\_ablation\_top1} & \textbf{delta\_layer\_ablation\_top1} & \textbf{layer\_ablation\_top2} & \textbf{delta\_layer\_ablation\_top2} & \textbf{layer\_ablation\_top3} & \textbf{delta\_layer\_ablation\_top3} & \textbf{llrd} & \textbf{delta\_llrd} & \textbf{lora} & \textbf{delta\_lora} & \textbf{low\_t} & \textbf{delta\_low\_t} & \textbf{mixed\_replay} & \textbf{delta\_mixed\_replay} & \textbf{normalized\_adv} & \textbf{delta\_normalized\_adv} & \textbf{reward\_filtering} & \textbf{delta\_reward\_filtering} & \textbf{reward\_model} & \textbf{delta\_reward\_model} & \textbf{running\_stats} & \textbf{delta\_running\_stats} & \textbf{t\_curriculum} & \textbf{delta\_t\_curriculum} & \textbf{trust\_region\_kl} & \textbf{delta\_trust\_region\_kl} \\ +\midrule +Achievements/collect\_coal & 0.4857 & 0.5079 & 0.0222 & 0.4338 & -0.0519 & 0.0000 & -0.4857 & 0.4429 & -0.0429 & 0.4643 & -0.0214 & 0.4255 & -0.0602 & 0.5407 & 0.0550 & 0.0303 & -0.4554 & 0.0000 & -0.4857 & 0.4118 & -0.0739 & 0.0000 & -0.4857 & 0.4815 & -0.0042 & 0.1702 & -0.3155 & 0.2806 & -0.2051 & 0.1812 & -0.3046 & 0.4468 & -0.0389 & 0.0000 & -0.4857 & 0.4085 & -0.0773 & 0.5391 & 0.0533 & 0.0840 & -0.4017 & 0.4317 & -0.0541 & 0.4615 & -0.0242 & 0.5074 & 0.0216 & 0.4397 & -0.0460 & 0.4074 & -0.0783 \\ +Achievements/collect\_diamond & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0074 & 0.0074 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 \\ +Achievements/collect\_drink & 0.3857 & 0.4286 & 0.0429 & 0.4044 & 0.0187 & 0.0000 & -0.3857 & 0.3857 & 0.0000 & 0.3714 & -0.0143 & 0.3333 & -0.0524 & 0.4370 & 0.0513 & 0.2848 & -0.1009 & 0.0000 & -0.3857 & 0.4191 & 0.0334 & 0.0000 & -0.3857 & 0.3259 & -0.0598 & 0.3333 & -0.0524 & 0.4245 & 0.0387 & 0.4203 & 0.0346 & 0.4043 & 0.0185 & 0.0000 & -0.3857 & 0.3944 & 0.0087 & 0.4141 & 0.0283 & 0.1849 & -0.2008 & 0.3453 & -0.0404 & 0.3427 & -0.0431 & 0.3676 & -0.0181 & 0.3972 & 0.0114 & 0.4000 & 0.0143 \\ +Achievements/collect\_iron & 0.3357 & 0.3095 & -0.0262 & 0.3309 & -0.0048 & 0.0000 & -0.3357 & 0.3786 & 0.0429 & 0.3571 & 0.0214 & 0.2979 & -0.0378 & 0.3556 & 0.0198 & 0.0000 & -0.3357 & 0.0000 & -0.3357 & 0.3088 & -0.0269 & 0.0000 & -0.3357 & 0.3704 & 0.0347 & 0.0213 & -0.3144 & 0.1151 & -0.2206 & 0.0652 & -0.2705 & 0.2979 & -0.0378 & 0.0000 & -0.3357 & 0.2676 & -0.0681 & 0.4062 & 0.0705 & 0.0084 & -0.3273 & 0.4029 & 0.0672 & 0.3427 & 0.0069 & 0.4118 & 0.0761 & 0.3546 & 0.0189 & 0.2815 & -0.0542 \\ +Achievements/collect\_sapling & 0.9071 & 0.9127 & 0.0056 & 0.9265 & 0.0193 & 0.0000 & -0.9071 & 0.8643 & -0.0429 & 0.8571 & -0.0500 & 0.9362 & 0.0290 & 0.9037 & -0.0034 & 0.5576 & -0.3496 & 0.0000 & -0.9071 & 0.9191 & 0.0120 & 0.0000 & -0.9071 & 0.9333 & 0.0262 & 0.7801 & -0.1270 & 0.8993 & -0.0079 & 0.8623 & -0.0448 & 0.8936 & -0.0135 & 0.0000 & -0.9071 & 0.9296 & 0.0224 & 0.9141 & 0.0069 & 1.0000 & 0.0929 & 0.9065 & -0.0007 & 0.9021 & -0.0050 & 0.8897 & -0.0174 & 0.9149 & 0.0078 & 0.8889 & -0.0183 \\ +Achievements/collect\_stone & 0.9000 & 0.8651 & -0.0349 & 0.8824 & -0.0176 & 0.0000 & -0.9000 & 0.8929 & -0.0071 & 0.8929 & -0.0071 & 0.8936 & -0.0064 & 0.9037 & 0.0037 & 0.1576 & -0.7424 & 0.0000 & -0.9000 & 0.8603 & -0.0397 & 0.0000 & -0.9000 & 0.8815 & -0.0185 & 0.5106 & -0.3894 & 0.5899 & -0.3101 & 0.6522 & -0.2478 & 0.8652 & -0.0348 & 0.0000 & -0.9000 & 0.8873 & -0.0127 & 0.9141 & 0.0141 & 0.3025 & -0.5975 & 0.9065 & 0.0065 & 0.8951 & -0.0049 & 0.9044 & 0.0044 & 0.8865 & -0.0135 & 0.9333 & 0.0333 \\ +Achievements/collect\_wood & 1.0000 & 0.9921 & -0.0079 & 0.9926 & -0.0074 & 0.0000 & -1.0000 & 0.9929 & -0.0071 & 0.9929 & -0.0071 & 0.9858 & -0.0142 & 1.0000 & 0.0000 & 0.9212 & -0.0788 & 0.0000 & -1.0000 & 1.0000 & 0.0000 & 0.0000 & -1.0000 & 1.0000 & 0.0000 & 1.0000 & 0.0000 & 1.0000 & 0.0000 & 0.9928 & -0.0072 & 0.9929 & -0.0071 & 0.0000 & -1.0000 & 0.9859 & -0.0141 & 1.0000 & 0.0000 & 0.9580 & -0.0420 & 0.9928 & -0.0072 & 1.0000 & 0.0000 & 0.9926 & -0.0074 & 0.9929 & -0.0071 & 1.0000 & 0.0000 \\ +Achievements/defeat\_skeleton & 0.1571 & 0.2540 & 0.0968 & 0.2279 & 0.0708 & 0.0000 & -0.1571 & 0.2143 & 0.0571 & 0.2071 & 0.0500 & 0.1844 & 0.0273 & 0.2000 & 0.0429 & 0.0242 & -0.1329 & 0.0000 & -0.1571 & 0.2132 & 0.0561 & 0.0000 & -0.1571 & 0.2296 & 0.0725 & 0.0780 & -0.0791 & 0.0791 & -0.0780 & 0.1087 & -0.0484 & 0.0993 & -0.0579 & 0.0000 & -0.1571 & 0.1972 & 0.0400 & 0.2266 & 0.0694 & 0.0924 & -0.0647 & 0.2446 & 0.0875 & 0.2028 & 0.0457 & 0.1838 & 0.0267 & 0.2553 & 0.0982 & 0.2296 & 0.0725 \\ +Achievements/defeat\_zombie & 0.6214 & 0.6508 & 0.0294 & 0.5956 & -0.0258 & 0.0000 & -0.6214 & 0.7214 & 0.1000 & 0.5714 & -0.0500 & 0.6950 & 0.0736 & 0.6815 & 0.0601 & 0.2606 & -0.3608 & 0.0000 & -0.6214 & 0.6912 & 0.0697 & 0.0000 & -0.6214 & 0.6519 & 0.0304 & 0.5887 & -0.0328 & 0.5971 & -0.0243 & 0.6884 & 0.0670 & 0.6879 & 0.0665 & 0.0000 & -0.6214 & 0.6761 & 0.0546 & 0.6484 & 0.0270 & 0.6134 & -0.0080 & 0.6906 & 0.0692 & 0.6643 & 0.0429 & 0.6397 & 0.0183 & 0.6454 & 0.0240 & 0.6815 & 0.0601 \\ +Achievements/eat\_cow & 0.3357 & 0.4603 & 0.1246 & 0.3603 & 0.0246 & 0.0000 & -0.3357 & 0.3429 & 0.0071 & 0.3786 & 0.0429 & 0.3830 & 0.0473 & 0.3333 & -0.0024 & 0.1152 & -0.2206 & 0.0000 & -0.3357 & 0.3750 & 0.0393 & 0.0000 & -0.3357 & 0.3926 & 0.0569 & 0.2057 & -0.1300 & 0.2446 & -0.0911 & 0.2971 & -0.0386 & 0.2908 & -0.0449 & 0.0000 & -0.3357 & 0.3732 & 0.0375 & 0.4219 & 0.0862 & 0.2437 & -0.0920 & 0.3022 & -0.0336 & 0.3427 & 0.0069 & 0.3750 & 0.0393 & 0.3050 & -0.0307 & 0.3630 & 0.0272 \\ +Achievements/eat\_plant & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 \\ +Achievements/make\_iron\_pickaxe & 0.0214 & 0.0238 & 0.0024 & 0.0221 & 0.0006 & 0.0000 & -0.0214 & 0.0214 & 0.0000 & 0.0000 & -0.0214 & 0.0284 & 0.0069 & 0.0370 & 0.0156 & 0.0000 & -0.0214 & 0.0000 & -0.0214 & 0.0074 & -0.0141 & 0.0000 & -0.0214 & 0.0222 & 0.0008 & 0.0000 & -0.0214 & 0.0144 & -0.0070 & 0.0072 & -0.0142 & 0.0213 & -0.0002 & 0.0000 & -0.0214 & 0.0141 & -0.0073 & 0.0078 & -0.0136 & 0.0000 & -0.0214 & 0.0504 & 0.0289 & 0.0140 & -0.0074 & 0.0368 & 0.0153 & 0.0213 & -0.0002 & 0.0000 & -0.0214 \\ +Achievements/make\_iron\_sword & 0.1214 & 0.1349 & 0.0135 & 0.1250 & 0.0036 & 0.0000 & -0.1214 & 0.1571 & 0.0357 & 0.1571 & 0.0357 & 0.1206 & -0.0009 & 0.1259 & 0.0045 & 0.0000 & -0.1214 & 0.0000 & -0.1214 & 0.0956 & -0.0258 & 0.0000 & -0.1214 & 0.1778 & 0.0563 & 0.0000 & -0.1214 & 0.0288 & -0.0927 & 0.0000 & -0.1214 & 0.0922 & -0.0292 & 0.0000 & -0.1214 & 0.0845 & -0.0369 & 0.1875 & 0.0661 & 0.0000 & -0.1214 & 0.1367 & 0.0153 & 0.1538 & 0.0324 & 0.1471 & 0.0256 & 0.1277 & 0.0062 & 0.0815 & -0.0399 \\ +Achievements/make\_stone\_pickaxe & 0.6500 & 0.6746 & 0.0246 & 0.6029 & -0.0471 & 0.0000 & -0.6500 & 0.6714 & 0.0214 & 0.6929 & 0.0429 & 0.6525 & 0.0025 & 0.7407 & 0.0907 & 0.0242 & -0.6258 & 0.0000 & -0.6500 & 0.6838 & 0.0338 & 0.0000 & -0.6500 & 0.6815 & 0.0315 & 0.1418 & -0.5082 & 0.3525 & -0.2975 & 0.2971 & -0.3529 & 0.6525 & 0.0025 & 0.0000 & -0.6500 & 0.6408 & -0.0092 & 0.6875 & 0.0375 & 0.0252 & -0.6248 & 0.7122 & 0.0622 & 0.6993 & 0.0493 & 0.6838 & 0.0338 & 0.6454 & -0.0046 & 0.7111 & 0.0611 \\ +Achievements/make\_stone\_sword & 0.7286 & 0.7302 & 0.0016 & 0.7353 & 0.0067 & 0.0000 & -0.7286 & 0.7214 & -0.0071 & 0.7500 & 0.0214 & 0.7660 & 0.0374 & 0.7259 & -0.0026 & 0.0545 & -0.6740 & 0.0000 & -0.7286 & 0.6765 & -0.0521 & 0.0000 & -0.7286 & 0.7185 & -0.0101 & 0.2482 & -0.4803 & 0.3669 & -0.3617 & 0.3696 & -0.3590 & 0.7447 & 0.0161 & 0.0000 & -0.7286 & 0.7465 & 0.0179 & 0.7656 & 0.0371 & 0.0672 & -0.6613 & 0.7986 & 0.0700 & 0.6923 & -0.0363 & 0.7279 & -0.0006 & 0.7021 & -0.0264 & 0.7630 & 0.0344 \\ +Achievements/make\_wood\_pickaxe & 0.9357 & 0.9444 & 0.0087 & 0.9265 & -0.0092 & 0.0000 & -0.9357 & 0.9286 & -0.0071 & 0.9571 & 0.0214 & 0.9362 & 0.0005 & 0.9556 & 0.0198 & 0.3636 & -0.5721 & 0.0000 & -0.9357 & 0.9412 & 0.0055 & 0.0000 & -0.9357 & 0.9481 & 0.0124 & 0.6950 & -0.2407 & 0.7554 & -0.1803 & 0.7971 & -0.1386 & 0.9220 & -0.0137 & 0.0000 & -0.9357 & 0.9155 & -0.0202 & 0.9297 & -0.0060 & 0.3445 & -0.5912 & 0.9353 & -0.0005 & 0.9301 & -0.0056 & 0.9412 & 0.0055 & 0.9220 & -0.0137 & 0.9481 & 0.0124 \\ +Achievements/make\_wood\_sword & 0.6714 & 0.6587 & -0.0127 & 0.6765 & 0.0050 & 0.0000 & -0.6714 & 0.6786 & 0.0071 & 0.7357 & 0.0643 & 0.6879 & 0.0165 & 0.6667 & -0.0048 & 0.3091 & -0.3623 & 0.0000 & -0.6714 & 0.7132 & 0.0418 & 0.0000 & -0.6714 & 0.6889 & 0.0175 & 0.6596 & -0.0119 & 0.7554 & 0.0840 & 0.7971 & 0.1257 & 0.6738 & 0.0023 & 0.0000 & -0.6714 & 0.7254 & 0.0539 & 0.6875 & 0.0161 & 0.1176 & -0.5538 & 0.6978 & 0.0264 & 0.6154 & -0.0560 & 0.7500 & 0.0786 & 0.6738 & 0.0023 & 0.7481 & 0.0767 \\ +Achievements/place\_furnace & 0.8000 & 0.7778 & -0.0222 & 0.8309 & 0.0309 & 0.0000 & -0.8000 & 0.8214 & 0.0214 & 0.8214 & 0.0214 & 0.8156 & 0.0156 & 0.8148 & 0.0148 & 0.0303 & -0.7697 & 0.0000 & -0.8000 & 0.7574 & -0.0426 & 0.0000 & -0.8000 & 0.8000 & -0.0000 & 0.1986 & -0.6014 & 0.4388 & -0.3612 & 0.4565 & -0.3435 & 0.7872 & -0.0128 & 0.0000 & -0.8000 & 0.7958 & -0.0042 & 0.7969 & -0.0031 & 0.2017 & -0.5983 & 0.8417 & 0.0417 & 0.7902 & -0.0098 & 0.8088 & 0.0088 & 0.8298 & 0.0298 & 0.8000 & -0.0000 \\ +Achievements/place\_plant & 0.5929 & 0.6508 & 0.0579 & 0.7059 & 0.1130 & 0.0000 & -0.5929 & 0.5786 & -0.0143 & 0.6286 & 0.0357 & 0.6950 & 0.1022 & 0.6519 & 0.0590 & 0.0242 & -0.5686 & 0.0000 & -0.5929 & 0.6324 & 0.0395 & 0.0000 & -0.5929 & 0.6593 & 0.0664 & 0.0426 & -0.5503 & 0.0791 & -0.5137 & 0.1014 & -0.4914 & 0.6879 & 0.0951 & 0.0000 & -0.5929 & 0.6901 & 0.0973 & 0.6719 & 0.0790 & 0.9412 & 0.3483 & 0.6906 & 0.0978 & 0.7063 & 0.1134 & 0.6250 & 0.0321 & 0.6667 & 0.0738 & 0.6519 & 0.0590 \\ +Achievements/place\_stone & 0.6429 & 0.6429 & 0.0000 & 0.7279 & 0.0851 & 0.0000 & -0.6429 & 0.7143 & 0.0714 & 0.7071 & 0.0643 & 0.6950 & 0.0522 & 0.7185 & 0.0757 & 0.0485 & -0.5944 & 0.0000 & -0.6429 & 0.6397 & -0.0032 & 0.0000 & -0.6429 & 0.6444 & 0.0016 & 0.2270 & -0.4159 & 0.2806 & -0.3623 & 0.2681 & -0.3747 & 0.6879 & 0.0451 & 0.0000 & -0.6429 & 0.7254 & 0.0825 & 0.7187 & 0.0759 & 0.1261 & -0.5168 & 0.6259 & -0.0170 & 0.6643 & 0.0215 & 0.6103 & -0.0326 & 0.6950 & 0.0522 & 0.7333 & 0.0905 \\ +Achievements/place\_table & 0.9786 & 0.9683 & -0.0103 & 0.9853 & 0.0067 & 0.0000 & -0.9786 & 0.9571 & -0.0214 & 0.9643 & -0.0143 & 0.9433 & -0.0353 & 0.9778 & -0.0008 & 0.4788 & -0.4998 & 0.0000 & -0.9786 & 0.9853 & 0.0067 & 0.0000 & -0.9786 & 0.9704 & -0.0082 & 0.8014 & -0.1772 & 0.8561 & -0.1225 & 0.8623 & -0.1163 & 0.9504 & -0.0282 & 0.0000 & -0.9786 & 0.9507 & -0.0279 & 0.9922 & 0.0136 & 0.4538 & -0.5248 & 0.9712 & -0.0073 & 0.9580 & -0.0205 & 0.9779 & -0.0006 & 0.9645 & -0.0140 & 1.0000 & 0.0214 \\ +Achievements/wake\_up & 0.0929 & 0.1190 & 0.0262 & 0.0735 & -0.0193 & 0.0000 & -0.0929 & 0.0786 & -0.0143 & 0.1000 & 0.0071 & 0.0709 & -0.0219 & 0.0963 & 0.0034 & 0.0061 & -0.0868 & 0.0000 & -0.0929 & 0.1324 & 0.0395 & 0.0000 & -0.0929 & 0.1481 & 0.0553 & 0.2199 & 0.1270 & 0.1583 & 0.0654 & 0.2391 & 0.1463 & 0.1064 & 0.0135 & 0.0000 & -0.0929 & 0.1197 & 0.0269 & 0.0547 & -0.0382 & 0.0000 & -0.0929 & 0.0863 & -0.0065 & 0.0699 & -0.0229 & 0.0809 & -0.0120 & 0.0496 & -0.0432 & 0.1185 & 0.0257 \\ +\bottomrule +\end{tabular} +\end{table} \ No newline at end of file diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/tables/forgetting_analysis.csv b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/tables/forgetting_analysis.csv new file mode 100644 index 0000000000000000000000000000000000000000..dad9791fe403844255f18bdb5b5b7fe38e1bf881 --- /dev/null +++ b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/tables/forgetting_analysis.csv @@ -0,0 +1,26 @@ +Method,First_Collapse_Iter,Min_Score,Recovery_Score,Recovered +baseline_rl,never,9.7303,10.8216,N/A +kl_penalty,never,9.8106,10.988,N/A +ewc,never,9.5767,10.7862,N/A +llrd,never,10.0296,10.7272,N/A +lora,25,-0.9283,-0.9103,N +mixed_replay,never,10.0754,10.9151,N/A +trust_region_kl,never,9.8786,10.6206,N/A +t_curriculum,never,10.1262,10.9168,N/A +entropy_bonus,never,9.8287,10.919,N/A +gradient_surgery,never,9.8682,10.8241,N/A +advantage_clip,never,9.737,10.7793,N/A +normalized_adv,250,5.3672,5.0739,N +bc_wins,never,9.8543,10.8676,N/A +low_t,never,9.9204,10.6598,N/A +frozen_backbone,75,-0.9247,0.2878,N +head_only,75,-0.9247,0.2743,N +attention_only,75,-0.9247,0.5447,N +ffn_only,200,3.0321,1.8791,N +layer_ablation_top1,175,6.2978,4.0468,N +layer_ablation_top2,225,7.2379,6.3875,N +layer_ablation_top3,175,7.6156,6.0164,Y +reward_filtering,never,9.8508,11.0502,N/A +running_stats,never,10.0158,10.9616,N/A +action_diversity,never,9.9356,11.0773,N/A +reward_model,never,9.9416,10.8205,N/A diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/tables/forgetting_analysis.tex b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/tables/forgetting_analysis.tex new file mode 100644 index 0000000000000000000000000000000000000000..e5efff8ed7de725f9451c5f3386ff45801fc9907 --- /dev/null +++ b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/tables/forgetting_analysis.tex @@ -0,0 +1,36 @@ +\begin{table}[htbp] +\centering +\caption{Catastrophic forgetting timeline.} +\label{tab:forgetting} +\begin{tabular}{lrrrr} +\toprule +\textbf{Method} & \textbf{First\_Collapse\_Iter} & \textbf{Min\_Score} & \textbf{Recovery\_Score} & \textbf{Recovered} \\ +\midrule +baseline\_rl & never & 9.7303 & 10.8216 & N/A \\ +kl\_penalty & never & 9.8106 & 10.9880 & N/A \\ +ewc & never & 9.5767 & 10.7862 & N/A \\ +llrd & never & 10.0296 & 10.7272 & N/A \\ +lora & 25 & -0.9283 & -0.9103 & N \\ +mixed\_replay & never & 10.0754 & 10.9151 & N/A \\ +trust\_region\_kl & never & 9.8786 & 10.6206 & N/A \\ +t\_curriculum & never & 10.1262 & 10.9168 & N/A \\ +entropy\_bonus & never & 9.8287 & 10.9190 & N/A \\ +gradient\_surgery & never & 9.8682 & 10.8241 & N/A \\ +advantage\_clip & never & 9.7370 & 10.7793 & N/A \\ +normalized\_adv & 250 & 5.3672 & 5.0739 & N \\ +bc\_wins & never & 9.8543 & 10.8676 & N/A \\ +low\_t & never & 9.9204 & 10.6598 & N/A \\ +frozen\_backbone & 75 & -0.9247 & 0.2878 & N \\ +head\_only & 75 & -0.9247 & 0.2743 & N \\ +attention\_only & 75 & -0.9247 & 0.5447 & N \\ +ffn\_only & 200 & 3.0321 & 1.8791 & N \\ +layer\_ablation\_top1 & 175 & 6.2978 & 4.0468 & N \\ +layer\_ablation\_top2 & 225 & 7.2379 & 6.3875 & N \\ +layer\_ablation\_top3 & 175 & 7.6156 & 6.0164 & Y \\ +reward\_filtering & never & 9.8508 & 11.0502 & N/A \\ +running\_stats & never & 10.0158 & 10.9616 & N/A \\ +action\_diversity & never & 9.9356 & 11.0773 & N/A \\ +reward\_model & never & 9.9416 & 10.8205 & N/A \\ +\bottomrule +\end{tabular} +\end{table} \ No newline at end of file diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/tables/gradient_analysis.csv b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/tables/gradient_analysis.csv new file mode 100644 index 0000000000000000000000000000000000000000..ad36c2fe857541bcd1eafa2f712b994ced69751a --- /dev/null +++ b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/tables/gradient_analysis.csv @@ -0,0 +1,26 @@ +Method,Mean_Grad_Align,Final_Grad_Align,Trend,Mean_KL_Drift,Final_KL_Drift +trust_region_kl,0.0555,0.1171,up,0.059783,0.064425 +kl_penalty,0.0206,0.0969,up,0.14607,0.167528 +mixed_replay,-0.0176,0.0927,up,0.158821,0.173193 +entropy_bonus,0.0139,0.0891,up,0.16366,0.189308 +running_stats,0.023,0.0808,down,0.180406,0.21801 +reward_filtering,0.0208,0.0775,up,0.114046,0.151835 +lora,0.0827,0.077,down,594365.340527,1679591.0 +bc_wins,0.0092,0.0702,down,0.19145,0.219727 +llrd,0.0125,0.068,up,0.137441,0.188783 +advantage_clip,-0.0014,0.0667,up,0.189775,0.219086 +gradient_surgery,0.0153,0.065,up,0.159825,0.209463 +baseline_rl,0.0153,0.0645,up,0.159842,0.2096 +action_diversity,0.0153,0.0644,up,0.159893,0.209855 +t_curriculum,0.0152,0.0495,up,0.197279,0.225474 +reward_model,0.0169,0.0449,down,0.171694,0.222742 +normalized_adv,0.0111,0.035,up,33.724305,63.2407 +layer_ablation_top3,0.0474,0.0302,down,12.019569,16.208744 +low_t,0.0246,0.0278,down,0.13292,0.187926 +attention_only,0.0171,0.0192,up,110.250885,90.191254 +frozen_backbone,0.0152,0.0109,down,125.046066,115.5261 +head_only,0.0151,0.0109,down,125.37802,115.826111 +ffn_only,0.0466,0.0048,down,15.351683,23.891375 +layer_ablation_top1,0.0475,-0.0028,down,11.511555,13.324247 +layer_ablation_top2,0.0461,-0.004,down,11.360045,9.672604 +ewc,0.0251,-0.0246,down,0.153398,0.194954 diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/tables/gradient_analysis.tex b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/tables/gradient_analysis.tex new file mode 100644 index 0000000000000000000000000000000000000000..df7e4d325bf6336985d9ac4e235d24e039c23549 --- /dev/null +++ b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/tables/gradient_analysis.tex @@ -0,0 +1,36 @@ +\begin{table}[htbp] +\centering +\caption{Gradient alignment and representation drift analysis.} +\label{tab:gradient} +\begin{tabular}{lrrrrr} +\toprule +\textbf{Method} & \textbf{Mean\_Grad\_Align} & \textbf{Final\_Grad\_Align} & \textbf{Trend} & \textbf{Mean\_KL\_Drift} & \textbf{Final\_KL\_Drift} \\ +\midrule +trust\_region\_kl & 0.0555 & 0.1171 & up & 0.0598 & 0.0644 \\ +kl\_penalty & 0.0206 & 0.0969 & up & 0.1461 & 0.1675 \\ +mixed\_replay & -0.0176 & 0.0927 & up & 0.1588 & 0.1732 \\ +entropy\_bonus & 0.0139 & 0.0891 & up & 0.1637 & 0.1893 \\ +running\_stats & 0.0230 & 0.0808 & down & 0.1804 & 0.2180 \\ +reward\_filtering & 0.0208 & 0.0775 & up & 0.1140 & 0.1518 \\ +lora & 0.0827 & 0.0770 & down & 594365.3405 & 1679591.0000 \\ +bc\_wins & 0.0092 & 0.0702 & down & 0.1915 & 0.2197 \\ +llrd & 0.0125 & 0.0680 & up & 0.1374 & 0.1888 \\ +advantage\_clip & -0.0014 & 0.0667 & up & 0.1898 & 0.2191 \\ +gradient\_surgery & 0.0153 & 0.0650 & up & 0.1598 & 0.2095 \\ +baseline\_rl & 0.0153 & 0.0645 & up & 0.1598 & 0.2096 \\ +action\_diversity & 0.0153 & 0.0644 & up & 0.1599 & 0.2099 \\ +t\_curriculum & 0.0152 & 0.0495 & up & 0.1973 & 0.2255 \\ +reward\_model & 0.0169 & 0.0449 & down & 0.1717 & 0.2227 \\ +normalized\_adv & 0.0111 & 0.0350 & up & 33.7243 & 63.2407 \\ +layer\_ablation\_top3 & 0.0474 & 0.0302 & down & 12.0196 & 16.2087 \\ +low\_t & 0.0246 & 0.0278 & down & 0.1329 & 0.1879 \\ +attention\_only & 0.0171 & 0.0192 & up & 110.2509 & 90.1913 \\ +frozen\_backbone & 0.0152 & 0.0109 & down & 125.0461 & 115.5261 \\ +head\_only & 0.0151 & 0.0109 & down & 125.3780 & 115.8261 \\ +ffn\_only & 0.0466 & 0.0048 & down & 15.3517 & 23.8914 \\ +layer\_ablation\_top1 & 0.0475 & -0.0028 & down & 11.5116 & 13.3242 \\ +layer\_ablation\_top2 & 0.0461 & -0.0040 & down & 11.3600 & 9.6726 \\ +ewc & 0.0251 & -0.0246 & down & 0.1534 & 0.1950 \\ +\bottomrule +\end{tabular} +\end{table} \ No newline at end of file diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/tables/group_summary.csv b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/tables/group_summary.csv new file mode 100644 index 0000000000000000000000000000000000000000..8cdde77ea4294628a48d3d8d73054c32d6e33782 --- /dev/null +++ b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/tables/group_summary.csv @@ -0,0 +1,6 @@ +Group,N,Mean,Best,Worst,StdDev +Baseline,1,10.8216,10.8216,10.8216,0.0 +A,6,8.8545,10.988,-0.9103,4.3686 +B,7,10.0058,10.919,5.0739,2.0151 +C,7,2.7767,6.3875,0.2743,2.4897 +D,4,10.9774,11.0773,10.8205,0.1002 diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/tables/group_summary.tex b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/tables/group_summary.tex new file mode 100644 index 0000000000000000000000000000000000000000..0f34449eb7426fb30c3a1dc4dad7a40c6e004903 --- /dev/null +++ b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/tables/group_summary.tex @@ -0,0 +1,16 @@ +\begin{table}[htbp] +\centering +\caption{Group summary statistics.} +\label{tab:group_summary} +\begin{tabular}{lrrrrr} +\toprule +\textbf{Group} & \textbf{N} & \textbf{Mean} & \textbf{Best} & \textbf{Worst} & \textbf{StdDev} \\ +\midrule +Baseline & 1 & 10.8216 & 10.8216 & 10.8216 & 0.0000 \\ +A & 6 & 8.8545 & 10.9880 & -0.9103 & 4.3686 \\ +B & 7 & 10.0058 & 10.9190 & 5.0739 & 2.0151 \\ +C & 7 & 2.7767 & 6.3875 & 0.2743 & 2.4897 \\ +D & 4 & 10.9774 & 11.0773 & 10.8205 & 0.1002 \\ +\bottomrule +\end{tabular} +\end{table} \ No newline at end of file diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/tables/hypothesis_verdict.csv b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/tables/hypothesis_verdict.csv new file mode 100644 index 0000000000000000000000000000000000000000..da483f3622b7d1eeaac2269a8a0cda97175c0403 --- /dev/null +++ b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/tables/hypothesis_verdict.csv @@ -0,0 +1,26 @@ +Ablation,Group,Hypothesis,Result,Conclusion +baseline_rl,Baseline,Diagnoses whether the RL signal alone causes collapse,IMPROVEMENT,Hypothesis SUPPORTED — this intervention helps +kl_penalty,A,If this helps: catastrophic forgetting is the primary cause; soft regularisation...,IMPROVEMENT,Hypothesis SUPPORTED — this intervention helps +ewc,A,If EWC helps: forgetting pretrained representations is the proximate cause,IMPROVEMENT,Hypothesis SUPPORTED — this intervention helps +llrd,A,If LLRD helps: deep gradient flow into early layers corrupts representations,IMPROVEMENT,Hypothesis SUPPORTED — this intervention helps +lora,A,If LoRA works: too many unconstrained degrees of freedom cause collapse,COLLAPSE,Hypothesis REFUTED — intervention did not prevent collapse +mixed_replay,A,If mixed replay helps: online data distribution alone is too corrupted,IMPROVEMENT,Hypothesis SUPPORTED — this intervention helps +trust_region_kl,A,If hard constraint helps: soft KL is insufficient — a hard boundary is needed,IMPROVEMENT,Hypothesis SUPPORTED — this intervention helps +t_curriculum,B,If curriculum helps: ordering of learning signals matters,IMPROVEMENT,Hypothesis SUPPORTED — this intervention helps +entropy_bonus,B,If entropy bonus helps: collapse is mode-collapse; not a gradient problem,IMPROVEMENT,Hypothesis SUPPORTED — this intervention helps +gradient_surgery,B,If PCGrad helps: gradients are conflicting and resolvable by projection,IMPROVEMENT,Hypothesis SUPPORTED — this intervention helps +advantage_clip,B,If clipping helps: large advantage magnitudes destabilise training,IMPROVEMENT,Hypothesis SUPPORTED — this intervention helps +normalized_adv,B,If std normalisation helps: simple mean normalisation is too loose,COLLAPSE,Hypothesis REFUTED — intervention did not prevent collapse +bc_wins,B,If BC on wins helps: the return weighting is the specific cause,IMPROVEMENT,Hypothesis SUPPORTED — this intervention helps +low_t,B,If low-t helps: high-t (coarse-structure) gradients are biased,IMPROVEMENT,Hypothesis SUPPORTED — this intervention helps +frozen_backbone,C,If frozen backbone helps: deep gradient flow into backbone causes collapse,COLLAPSE,Hypothesis REFUTED — intervention did not prevent collapse +head_only,C,If head-only works: backbone representations are fine; only decision boundary ne...,COLLAPSE,Hypothesis REFUTED — intervention did not prevent collapse +attention_only,C,"If attention-only works: model needs routing updates, not feature updates",COLLAPSE,Hypothesis REFUTED — intervention did not prevent collapse +ffn_only,C,If FFN-only works: stored knowledge (FFN as memory) needs updating; not attentio...,COLLAPSE,Hypothesis REFUTED — intervention did not prevent collapse +layer_ablation_top1,C,Minimal unfrozen depth needed; collapse depth correlates with gradient flow dept...,COLLAPSE,Hypothesis REFUTED — intervention did not prevent collapse +layer_ablation_top2,C,Minimal unfrozen depth needed; collapse depth correlates with gradient flow dept...,COLLAPSE,Hypothesis REFUTED — intervention did not prevent collapse +layer_ablation_top3,C,Minimal unfrozen depth needed; collapse depth correlates with gradient flow dept...,COLLAPSE,Hypothesis REFUTED — intervention did not prevent collapse +reward_filtering,D,If filtering helps: noisy/low-return data poisons gradients,IMPROVEMENT,Hypothesis SUPPORTED — this intervention helps +running_stats,D,If running stats help: batch normalisation is too noisy for small batches,IMPROVEMENT,Hypothesis SUPPORTED — this intervention helps +action_diversity,D,If diversity filtering helps: degenerate PPO plans corrupt training,IMPROVEMENT,Hypothesis SUPPORTED — this intervention helps +reward_model,D,If reward model helps: raw returns are too sparse; learned model smooths signal,IMPROVEMENT,Hypothesis SUPPORTED — this intervention helps diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/tables/hypothesis_verdict.tex b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/tables/hypothesis_verdict.tex new file mode 100644 index 0000000000000000000000000000000000000000..006475690d5e58594869b4f5d012224683034199 --- /dev/null +++ b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/tables/hypothesis_verdict.tex @@ -0,0 +1,36 @@ +\begin{table}[htbp] +\centering +\caption{Hypothesis verdict per ablation.} +\label{tab:hypothesis} +\begin{tabular}{lrrrr} +\toprule +\textbf{Ablation} & \textbf{Group} & \textbf{Hypothesis} & \textbf{Result} & \textbf{Conclusion} \\ +\midrule +baseline\_rl & Baseline & Diagnoses whether the RL signal alone causes collapse & IMPROVEMENT & Hypothesis SUPPORTED — this intervention helps \\ +kl\_penalty & A & If this helps: catastrophic forgetting is the primary cause; soft regularisation... & IMPROVEMENT & Hypothesis SUPPORTED — this intervention helps \\ +ewc & A & If EWC helps: forgetting pretrained representations is the proximate cause & IMPROVEMENT & Hypothesis SUPPORTED — this intervention helps \\ +llrd & A & If LLRD helps: deep gradient flow into early layers corrupts representations & IMPROVEMENT & Hypothesis SUPPORTED — this intervention helps \\ +lora & A & If LoRA works: too many unconstrained degrees of freedom cause collapse & COLLAPSE & Hypothesis REFUTED — intervention did not prevent collapse \\ +mixed\_replay & A & If mixed replay helps: online data distribution alone is too corrupted & IMPROVEMENT & Hypothesis SUPPORTED — this intervention helps \\ +trust\_region\_kl & A & If hard constraint helps: soft KL is insufficient — a hard boundary is needed & IMPROVEMENT & Hypothesis SUPPORTED — this intervention helps \\ +t\_curriculum & B & If curriculum helps: ordering of learning signals matters & IMPROVEMENT & Hypothesis SUPPORTED — this intervention helps \\ +entropy\_bonus & B & If entropy bonus helps: collapse is mode-collapse; not a gradient problem & IMPROVEMENT & Hypothesis SUPPORTED — this intervention helps \\ +gradient\_surgery & B & If PCGrad helps: gradients are conflicting and resolvable by projection & IMPROVEMENT & Hypothesis SUPPORTED — this intervention helps \\ +advantage\_clip & B & If clipping helps: large advantage magnitudes destabilise training & IMPROVEMENT & Hypothesis SUPPORTED — this intervention helps \\ +normalized\_adv & B & If std normalisation helps: simple mean normalisation is too loose & COLLAPSE & Hypothesis REFUTED — intervention did not prevent collapse \\ +bc\_wins & B & If BC on wins helps: the return weighting is the specific cause & IMPROVEMENT & Hypothesis SUPPORTED — this intervention helps \\ +low\_t & B & If low-t helps: high-t (coarse-structure) gradients are biased & IMPROVEMENT & Hypothesis SUPPORTED — this intervention helps \\ +frozen\_backbone & C & If frozen backbone helps: deep gradient flow into backbone causes collapse & COLLAPSE & Hypothesis REFUTED — intervention did not prevent collapse \\ +head\_only & C & If head-only works: backbone representations are fine; only decision boundary ne... & COLLAPSE & Hypothesis REFUTED — intervention did not prevent collapse \\ +attention\_only & C & If attention-only works: model needs routing updates, not feature updates & COLLAPSE & Hypothesis REFUTED — intervention did not prevent collapse \\ +ffn\_only & C & If FFN-only works: stored knowledge (FFN as memory) needs updating; not attentio... & COLLAPSE & Hypothesis REFUTED — intervention did not prevent collapse \\ +layer\_ablation\_top1 & C & Minimal unfrozen depth needed; collapse depth correlates with gradient flow dept... & COLLAPSE & Hypothesis REFUTED — intervention did not prevent collapse \\ +layer\_ablation\_top2 & C & Minimal unfrozen depth needed; collapse depth correlates with gradient flow dept... & COLLAPSE & Hypothesis REFUTED — intervention did not prevent collapse \\ +layer\_ablation\_top3 & C & Minimal unfrozen depth needed; collapse depth correlates with gradient flow dept... & COLLAPSE & Hypothesis REFUTED — intervention did not prevent collapse \\ +reward\_filtering & D & If filtering helps: noisy/low-return data poisons gradients & IMPROVEMENT & Hypothesis SUPPORTED — this intervention helps \\ +running\_stats & D & If running stats help: batch normalisation is too noisy for small batches & IMPROVEMENT & Hypothesis SUPPORTED — this intervention helps \\ +action\_diversity & D & If diversity filtering helps: degenerate PPO plans corrupt training & IMPROVEMENT & Hypothesis SUPPORTED — this intervention helps \\ +reward\_model & D & If reward model helps: raw returns are too sparse; learned model smooths signal & IMPROVEMENT & Hypothesis SUPPORTED — this intervention helps \\ +\bottomrule +\end{tabular} +\end{table} \ No newline at end of file diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/tables/main_results.csv b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/tables/main_results.csv new file mode 100644 index 0000000000000000000000000000000000000000..721939fa2b87eef390cc81f7499fdb61204b85e3 --- /dev/null +++ b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/tables/main_results.csv @@ -0,0 +1,26 @@ +Method,Group,Final_Score,Delta_vs_Pretrained,Delta_vs_Baseline_RL,Verdict +action_diversity,D,11.0773,0.5487,0.2557,IMPROVEMENT +reward_filtering,D,11.0502,0.5216,0.2286,IMPROVEMENT +kl_penalty,A,10.988,0.4594,0.1664,IMPROVEMENT +running_stats,D,10.9616,0.433,0.14,IMPROVEMENT +entropy_bonus,B,10.919,0.3904,0.0974,IMPROVEMENT +t_curriculum,B,10.9168,0.3882,0.0952,IMPROVEMENT +mixed_replay,A,10.9151,0.3865,0.0935,IMPROVEMENT +bc_wins,B,10.8676,0.339,0.046,IMPROVEMENT +gradient_surgery,B,10.8241,0.2955,0.0025,IMPROVEMENT +baseline_rl,Baseline,10.8216,0.293,0.0,IMPROVEMENT +reward_model,D,10.8205,0.2919,-0.0011,IMPROVEMENT +ewc,A,10.7862,0.2577,-0.0353,IMPROVEMENT +advantage_clip,B,10.7793,0.2507,-0.0423,IMPROVEMENT +llrd,A,10.7272,0.1986,-0.0944,IMPROVEMENT +low_t,B,10.6598,0.1312,-0.1618,IMPROVEMENT +trust_region_kl,A,10.6206,0.092,-0.201,IMPROVEMENT +layer_ablation_top2,C,6.3875,-4.1411,-4.4341,COLLAPSE +layer_ablation_top3,C,6.0164,-4.5122,-4.8052,COLLAPSE +normalized_adv,B,5.0739,-5.4547,-5.7477,COLLAPSE +layer_ablation_top1,C,4.0468,-6.4818,-6.7748,COLLAPSE +ffn_only,C,1.8791,-8.6494,-8.9424,COLLAPSE +attention_only,C,0.5447,-9.9838,-10.2768,COLLAPSE +frozen_backbone,C,0.2878,-10.2408,-10.5338,COLLAPSE +head_only,C,0.2743,-10.2543,-10.5473,COLLAPSE +lora,A,-0.9103,-11.4389,-11.7319,COLLAPSE diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/tables/main_results.tex b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/tables/main_results.tex new file mode 100644 index 0000000000000000000000000000000000000000..d80b09daed497b209dbb19e541a009300337b4e7 --- /dev/null +++ b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/tables/main_results.tex @@ -0,0 +1,36 @@ +\begin{table}[htbp] +\centering +\caption{Main ablation results.} +\label{tab:main_results} +\begin{tabular}{lrrrrr} +\toprule +\textbf{Method} & \textbf{Group} & \textbf{Final\_Score} & \textbf{Delta\_vs\_Pretrained} & \textbf{Delta\_vs\_Baseline\_RL} & \textbf{Verdict} \\ +\midrule +action\_diversity & D & 11.0773 & 0.5487 & 0.2557 & IMPROVEMENT \\ +reward\_filtering & D & 11.0502 & 0.5216 & 0.2286 & IMPROVEMENT \\ +kl\_penalty & A & 10.9880 & 0.4594 & 0.1664 & IMPROVEMENT \\ +running\_stats & D & 10.9616 & 0.4330 & 0.1400 & IMPROVEMENT \\ +entropy\_bonus & B & 10.9190 & 0.3904 & 0.0974 & IMPROVEMENT \\ +t\_curriculum & B & 10.9168 & 0.3882 & 0.0952 & IMPROVEMENT \\ +mixed\_replay & A & 10.9151 & 0.3865 & 0.0935 & IMPROVEMENT \\ +bc\_wins & B & 10.8676 & 0.3390 & 0.0460 & IMPROVEMENT \\ +gradient\_surgery & B & 10.8241 & 0.2955 & 0.0025 & IMPROVEMENT \\ +baseline\_rl & Baseline & 10.8216 & 0.2930 & 0.0000 & IMPROVEMENT \\ +reward\_model & D & 10.8205 & 0.2919 & -0.0011 & IMPROVEMENT \\ +ewc & A & 10.7862 & 0.2577 & -0.0353 & IMPROVEMENT \\ +advantage\_clip & B & 10.7793 & 0.2507 & -0.0423 & IMPROVEMENT \\ +llrd & A & 10.7272 & 0.1986 & -0.0944 & IMPROVEMENT \\ +low\_t & B & 10.6598 & 0.1312 & -0.1618 & IMPROVEMENT \\ +trust\_region\_kl & A & 10.6206 & 0.0920 & -0.2010 & IMPROVEMENT \\ +layer\_ablation\_top2 & C & 6.3875 & -4.1411 & -4.4341 & COLLAPSE \\ +layer\_ablation\_top3 & C & 6.0164 & -4.5122 & -4.8052 & COLLAPSE \\ +normalized\_adv & B & 5.0739 & -5.4547 & -5.7477 & COLLAPSE \\ +layer\_ablation\_top1 & C & 4.0468 & -6.4818 & -6.7748 & COLLAPSE \\ +ffn\_only & C & 1.8791 & -8.6494 & -8.9424 & COLLAPSE \\ +attention\_only & C & 0.5447 & -9.9838 & -10.2768 & COLLAPSE \\ +frozen\_backbone & C & 0.2878 & -10.2408 & -10.5338 & COLLAPSE \\ +head\_only & C & 0.2743 & -10.2543 & -10.5473 & COLLAPSE \\ +lora & A & -0.9103 & -11.4389 & -11.7319 & COLLAPSE \\ +\bottomrule +\end{tabular} +\end{table} \ No newline at end of file diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/tables/per_env.csv b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/tables/per_env.csv new file mode 100644 index 0000000000000000000000000000000000000000..2e2027d5d3ce9b158c5a0f050b710caf5b6a662f --- /dev/null +++ b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/tables/per_env.csv @@ -0,0 +1,27 @@ +Method,Achievements/collect_coal,Achievements/collect_diamond,Achievements/collect_drink,Achievements/collect_iron,Achievements/collect_sapling,Achievements/collect_stone,Achievements/collect_wood,Achievements/defeat_skeleton,Achievements/defeat_zombie,Achievements/eat_cow,Achievements/eat_plant,Achievements/make_iron_pickaxe,Achievements/make_iron_sword,Achievements/make_stone_pickaxe,Achievements/make_stone_sword,Achievements/make_wood_pickaxe,Achievements/make_wood_sword,Achievements/place_furnace,Achievements/place_plant,Achievements/place_stone,Achievements/place_table,Achievements/wake_up +pretrained,0.4857,0.0,0.3857,0.3357,0.9071,0.9,1.0,0.1571,0.6214,0.3357,0.0,0.0214,0.1214,0.65,0.7286,0.9357,0.6714,0.8,0.5929,0.6429,0.9786,0.0929 +action_diversity,0.5079,0.0,0.4286,0.3095,0.9127,0.8651,0.9921,0.254,0.6508,0.4603,0.0,0.0238,0.1349,0.6746,0.7302,0.9444,0.6587,0.7778,0.6508,0.6429,0.9683,0.119 +advantage_clip,0.4338,0.0,0.4044,0.3309,0.9265,0.8824,0.9926,0.2279,0.5956,0.3603,0.0,0.0221,0.125,0.6029,0.7353,0.9265,0.6765,0.8309,0.7059,0.7279,0.9853,0.0735 +attention_only,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0 +baseline_rl,0.4429,0.0,0.3857,0.3786,0.8643,0.8929,0.9929,0.2143,0.7214,0.3429,0.0,0.0214,0.1571,0.6714,0.7214,0.9286,0.6786,0.8214,0.5786,0.7143,0.9571,0.0786 +bc_wins,0.4643,0.0,0.3714,0.3571,0.8571,0.8929,0.9929,0.2071,0.5714,0.3786,0.0,0.0,0.1571,0.6929,0.75,0.9571,0.7357,0.8214,0.6286,0.7071,0.9643,0.1 +entropy_bonus,0.4255,0.0,0.3333,0.2979,0.9362,0.8936,0.9858,0.1844,0.695,0.383,0.0,0.0284,0.1206,0.6525,0.766,0.9362,0.6879,0.8156,0.695,0.695,0.9433,0.0709 +ewc,0.5407,0.0074,0.437,0.3556,0.9037,0.9037,1.0,0.2,0.6815,0.3333,0.0,0.037,0.1259,0.7407,0.7259,0.9556,0.6667,0.8148,0.6519,0.7185,0.9778,0.0963 +ffn_only,0.0303,0.0,0.2848,0.0,0.5576,0.1576,0.9212,0.0242,0.2606,0.1152,0.0,0.0,0.0,0.0242,0.0545,0.3636,0.3091,0.0303,0.0242,0.0485,0.4788,0.0061 +frozen_backbone,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0 +gradient_surgery,0.4118,0.0,0.4191,0.3088,0.9191,0.8603,1.0,0.2132,0.6912,0.375,0.0,0.0074,0.0956,0.6838,0.6765,0.9412,0.7132,0.7574,0.6324,0.6397,0.9853,0.1324 +head_only,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0 +kl_penalty,0.4815,0.0,0.3259,0.3704,0.9333,0.8815,1.0,0.2296,0.6519,0.3926,0.0,0.0222,0.1778,0.6815,0.7185,0.9481,0.6889,0.8,0.6593,0.6444,0.9704,0.1481 +layer_ablation_top1,0.1702,0.0,0.3333,0.0213,0.7801,0.5106,1.0,0.078,0.5887,0.2057,0.0,0.0,0.0,0.1418,0.2482,0.695,0.6596,0.1986,0.0426,0.227,0.8014,0.2199 +layer_ablation_top2,0.2806,0.0,0.4245,0.1151,0.8993,0.5899,1.0,0.0791,0.5971,0.2446,0.0,0.0144,0.0288,0.3525,0.3669,0.7554,0.7554,0.4388,0.0791,0.2806,0.8561,0.1583 +layer_ablation_top3,0.1812,0.0,0.4203,0.0652,0.8623,0.6522,0.9928,0.1087,0.6884,0.2971,0.0,0.0072,0.0,0.2971,0.3696,0.7971,0.7971,0.4565,0.1014,0.2681,0.8623,0.2391 +llrd,0.4468,0.0,0.4043,0.2979,0.8936,0.8652,0.9929,0.0993,0.6879,0.2908,0.0,0.0213,0.0922,0.6525,0.7447,0.922,0.6738,0.7872,0.6879,0.6879,0.9504,0.1064 +lora,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0 +low_t,0.4085,0.0,0.3944,0.2676,0.9296,0.8873,0.9859,0.1972,0.6761,0.3732,0.0,0.0141,0.0845,0.6408,0.7465,0.9155,0.7254,0.7958,0.6901,0.7254,0.9507,0.1197 +mixed_replay,0.5391,0.0,0.4141,0.4062,0.9141,0.9141,1.0,0.2266,0.6484,0.4219,0.0,0.0078,0.1875,0.6875,0.7656,0.9297,0.6875,0.7969,0.6719,0.7187,0.9922,0.0547 +normalized_adv,0.084,0.0,0.1849,0.0084,1.0,0.3025,0.958,0.0924,0.6134,0.2437,0.0,0.0,0.0,0.0252,0.0672,0.3445,0.1176,0.2017,0.9412,0.1261,0.4538,0.0 +reward_filtering,0.4317,0.0,0.3453,0.4029,0.9065,0.9065,0.9928,0.2446,0.6906,0.3022,0.0,0.0504,0.1367,0.7122,0.7986,0.9353,0.6978,0.8417,0.6906,0.6259,0.9712,0.0863 +reward_model,0.4615,0.0,0.3427,0.3427,0.9021,0.8951,1.0,0.2028,0.6643,0.3427,0.0,0.014,0.1538,0.6993,0.6923,0.9301,0.6154,0.7902,0.7063,0.6643,0.958,0.0699 +running_stats,0.5074,0.0,0.3676,0.4118,0.8897,0.9044,0.9926,0.1838,0.6397,0.375,0.0,0.0368,0.1471,0.6838,0.7279,0.9412,0.75,0.8088,0.625,0.6103,0.9779,0.0809 +t_curriculum,0.4397,0.0,0.3972,0.3546,0.9149,0.8865,0.9929,0.2553,0.6454,0.305,0.0,0.0213,0.1277,0.6454,0.7021,0.922,0.6738,0.8298,0.6667,0.695,0.9645,0.0496 +trust_region_kl,0.4074,0.0,0.4,0.2815,0.8889,0.9333,1.0,0.2296,0.6815,0.363,0.0,0.0,0.0815,0.7111,0.763,0.9481,0.7481,0.8,0.6519,0.7333,1.0,0.1185 diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/tables/per_env.tex b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/tables/per_env.tex new file mode 100644 index 0000000000000000000000000000000000000000..602c028fd2909525d26c154b7a1327dd780d38e2 --- /dev/null +++ b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/tables/per_env.tex @@ -0,0 +1,37 @@ +\begin{table}[htbp] +\centering +\caption{Per-environment (per-achievement) win rates at final eval.} +\label{tab:per_env} +\begin{tabular}{lrrrrrrrrrrrrrrrrrrrrrr} +\toprule +\textbf{Method} & \textbf{Achievements/collect\_coal} & \textbf{Achievements/collect\_diamond} & \textbf{Achievements/collect\_drink} & \textbf{Achievements/collect\_iron} & \textbf{Achievements/collect\_sapling} & \textbf{Achievements/collect\_stone} & \textbf{Achievements/collect\_wood} & \textbf{Achievements/defeat\_skeleton} & \textbf{Achievements/defeat\_zombie} & \textbf{Achievements/eat\_cow} & \textbf{Achievements/eat\_plant} & \textbf{Achievements/make\_iron\_pickaxe} & \textbf{Achievements/make\_iron\_sword} & \textbf{Achievements/make\_stone\_pickaxe} & \textbf{Achievements/make\_stone\_sword} & \textbf{Achievements/make\_wood\_pickaxe} & \textbf{Achievements/make\_wood\_sword} & \textbf{Achievements/place\_furnace} & \textbf{Achievements/place\_plant} & \textbf{Achievements/place\_stone} & \textbf{Achievements/place\_table} & \textbf{Achievements/wake\_up} \\ +\midrule +pretrained & 0.4857 & 0.0000 & 0.3857 & 0.3357 & 0.9071 & 0.9000 & 1.0000 & 0.1571 & 0.6214 & 0.3357 & 0.0000 & 0.0214 & 0.1214 & 0.6500 & 0.7286 & 0.9357 & 0.6714 & 0.8000 & 0.5929 & 0.6429 & 0.9786 & 0.0929 \\ +action\_diversity & 0.5079 & 0.0000 & 0.4286 & 0.3095 & 0.9127 & 0.8651 & 0.9921 & 0.2540 & 0.6508 & 0.4603 & 0.0000 & 0.0238 & 0.1349 & 0.6746 & 0.7302 & 0.9444 & 0.6587 & 0.7778 & 0.6508 & 0.6429 & 0.9683 & 0.1190 \\ +advantage\_clip & 0.4338 & 0.0000 & 0.4044 & 0.3309 & 0.9265 & 0.8824 & 0.9926 & 0.2279 & 0.5956 & 0.3603 & 0.0000 & 0.0221 & 0.1250 & 0.6029 & 0.7353 & 0.9265 & 0.6765 & 0.8309 & 0.7059 & 0.7279 & 0.9853 & 0.0735 \\ +attention\_only & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 \\ +baseline\_rl & 0.4429 & 0.0000 & 0.3857 & 0.3786 & 0.8643 & 0.8929 & 0.9929 & 0.2143 & 0.7214 & 0.3429 & 0.0000 & 0.0214 & 0.1571 & 0.6714 & 0.7214 & 0.9286 & 0.6786 & 0.8214 & 0.5786 & 0.7143 & 0.9571 & 0.0786 \\ +bc\_wins & 0.4643 & 0.0000 & 0.3714 & 0.3571 & 0.8571 & 0.8929 & 0.9929 & 0.2071 & 0.5714 & 0.3786 & 0.0000 & 0.0000 & 0.1571 & 0.6929 & 0.7500 & 0.9571 & 0.7357 & 0.8214 & 0.6286 & 0.7071 & 0.9643 & 0.1000 \\ +entropy\_bonus & 0.4255 & 0.0000 & 0.3333 & 0.2979 & 0.9362 & 0.8936 & 0.9858 & 0.1844 & 0.6950 & 0.3830 & 0.0000 & 0.0284 & 0.1206 & 0.6525 & 0.7660 & 0.9362 & 0.6879 & 0.8156 & 0.6950 & 0.6950 & 0.9433 & 0.0709 \\ +ewc & 0.5407 & 0.0074 & 0.4370 & 0.3556 & 0.9037 & 0.9037 & 1.0000 & 0.2000 & 0.6815 & 0.3333 & 0.0000 & 0.0370 & 0.1259 & 0.7407 & 0.7259 & 0.9556 & 0.6667 & 0.8148 & 0.6519 & 0.7185 & 0.9778 & 0.0963 \\ +ffn\_only & 0.0303 & 0.0000 & 0.2848 & 0.0000 & 0.5576 & 0.1576 & 0.9212 & 0.0242 & 0.2606 & 0.1152 & 0.0000 & 0.0000 & 0.0000 & 0.0242 & 0.0545 & 0.3636 & 0.3091 & 0.0303 & 0.0242 & 0.0485 & 0.4788 & 0.0061 \\ +frozen\_backbone & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 \\ +gradient\_surgery & 0.4118 & 0.0000 & 0.4191 & 0.3088 & 0.9191 & 0.8603 & 1.0000 & 0.2132 & 0.6912 & 0.3750 & 0.0000 & 0.0074 & 0.0956 & 0.6838 & 0.6765 & 0.9412 & 0.7132 & 0.7574 & 0.6324 & 0.6397 & 0.9853 & 0.1324 \\ +head\_only & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 \\ +kl\_penalty & 0.4815 & 0.0000 & 0.3259 & 0.3704 & 0.9333 & 0.8815 & 1.0000 & 0.2296 & 0.6519 & 0.3926 & 0.0000 & 0.0222 & 0.1778 & 0.6815 & 0.7185 & 0.9481 & 0.6889 & 0.8000 & 0.6593 & 0.6444 & 0.9704 & 0.1481 \\ +layer\_ablation\_top1 & 0.1702 & 0.0000 & 0.3333 & 0.0213 & 0.7801 & 0.5106 & 1.0000 & 0.0780 & 0.5887 & 0.2057 & 0.0000 & 0.0000 & 0.0000 & 0.1418 & 0.2482 & 0.6950 & 0.6596 & 0.1986 & 0.0426 & 0.2270 & 0.8014 & 0.2199 \\ +layer\_ablation\_top2 & 0.2806 & 0.0000 & 0.4245 & 0.1151 & 0.8993 & 0.5899 & 1.0000 & 0.0791 & 0.5971 & 0.2446 & 0.0000 & 0.0144 & 0.0288 & 0.3525 & 0.3669 & 0.7554 & 0.7554 & 0.4388 & 0.0791 & 0.2806 & 0.8561 & 0.1583 \\ +layer\_ablation\_top3 & 0.1812 & 0.0000 & 0.4203 & 0.0652 & 0.8623 & 0.6522 & 0.9928 & 0.1087 & 0.6884 & 0.2971 & 0.0000 & 0.0072 & 0.0000 & 0.2971 & 0.3696 & 0.7971 & 0.7971 & 0.4565 & 0.1014 & 0.2681 & 0.8623 & 0.2391 \\ +llrd & 0.4468 & 0.0000 & 0.4043 & 0.2979 & 0.8936 & 0.8652 & 0.9929 & 0.0993 & 0.6879 & 0.2908 & 0.0000 & 0.0213 & 0.0922 & 0.6525 & 0.7447 & 0.9220 & 0.6738 & 0.7872 & 0.6879 & 0.6879 & 0.9504 & 0.1064 \\ +lora & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 & 0.0000 \\ +low\_t & 0.4085 & 0.0000 & 0.3944 & 0.2676 & 0.9296 & 0.8873 & 0.9859 & 0.1972 & 0.6761 & 0.3732 & 0.0000 & 0.0141 & 0.0845 & 0.6408 & 0.7465 & 0.9155 & 0.7254 & 0.7958 & 0.6901 & 0.7254 & 0.9507 & 0.1197 \\ +mixed\_replay & 0.5391 & 0.0000 & 0.4141 & 0.4062 & 0.9141 & 0.9141 & 1.0000 & 0.2266 & 0.6484 & 0.4219 & 0.0000 & 0.0078 & 0.1875 & 0.6875 & 0.7656 & 0.9297 & 0.6875 & 0.7969 & 0.6719 & 0.7187 & 0.9922 & 0.0547 \\ +normalized\_adv & 0.0840 & 0.0000 & 0.1849 & 0.0084 & 1.0000 & 0.3025 & 0.9580 & 0.0924 & 0.6134 & 0.2437 & 0.0000 & 0.0000 & 0.0000 & 0.0252 & 0.0672 & 0.3445 & 0.1176 & 0.2017 & 0.9412 & 0.1261 & 0.4538 & 0.0000 \\ +reward\_filtering & 0.4317 & 0.0000 & 0.3453 & 0.4029 & 0.9065 & 0.9065 & 0.9928 & 0.2446 & 0.6906 & 0.3022 & 0.0000 & 0.0504 & 0.1367 & 0.7122 & 0.7986 & 0.9353 & 0.6978 & 0.8417 & 0.6906 & 0.6259 & 0.9712 & 0.0863 \\ +reward\_model & 0.4615 & 0.0000 & 0.3427 & 0.3427 & 0.9021 & 0.8951 & 1.0000 & 0.2028 & 0.6643 & 0.3427 & 0.0000 & 0.0140 & 0.1538 & 0.6993 & 0.6923 & 0.9301 & 0.6154 & 0.7902 & 0.7063 & 0.6643 & 0.9580 & 0.0699 \\ +running\_stats & 0.5074 & 0.0000 & 0.3676 & 0.4118 & 0.8897 & 0.9044 & 0.9926 & 0.1838 & 0.6397 & 0.3750 & 0.0000 & 0.0368 & 0.1471 & 0.6838 & 0.7279 & 0.9412 & 0.7500 & 0.8088 & 0.6250 & 0.6103 & 0.9779 & 0.0809 \\ +t\_curriculum & 0.4397 & 0.0000 & 0.3972 & 0.3546 & 0.9149 & 0.8865 & 0.9929 & 0.2553 & 0.6454 & 0.3050 & 0.0000 & 0.0213 & 0.1277 & 0.6454 & 0.7021 & 0.9220 & 0.6738 & 0.8298 & 0.6667 & 0.6950 & 0.9645 & 0.0496 \\ +trust\_region\_kl & 0.4074 & 0.0000 & 0.4000 & 0.2815 & 0.8889 & 0.9333 & 1.0000 & 0.2296 & 0.6815 & 0.3630 & 0.0000 & 0.0000 & 0.0815 & 0.7111 & 0.7630 & 0.9481 & 0.7481 & 0.8000 & 0.6519 & 0.7333 & 1.0000 & 0.1185 \\ +\bottomrule +\end{tabular} +\end{table} \ No newline at end of file diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/tables/repr_drift.csv b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/tables/repr_drift.csv new file mode 100644 index 0000000000000000000000000000000000000000..728d521c522d389f76e48aac0b5d306c9d1c2998 --- /dev/null +++ b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/tables/repr_drift.csv @@ -0,0 +1,26 @@ +Method,KL_mean,KL_low_t,KL_mid_t,KL_high_t +trust_region_kl,0.064425,0.079172,0.064176,0.046738 +reward_filtering,0.151835,0.164997,0.151723,0.143443 +kl_penalty,0.167528,0.182938,0.169007,0.164755 +mixed_replay,0.173193,0.17947,0.169166,0.165786 +low_t,0.187926,0.196443,0.180127,0.193411 +llrd,0.188783,0.194401,0.183651,0.181313 +entropy_bonus,0.189308,0.198052,0.186509,0.196232 +ewc,0.194954,0.214022,0.190065,0.183393 +gradient_surgery,0.209463,0.216535,0.204606,0.200818 +baseline_rl,0.2096,0.216757,0.204743,0.200695 +action_diversity,0.209855,0.217063,0.205008,0.201086 +running_stats,0.21801,0.223517,0.214218,0.215939 +advantage_clip,0.219086,0.226308,0.216805,0.217097 +bc_wins,0.219727,0.227315,0.217366,0.216954 +reward_model,0.222742,0.231557,0.217179,0.216113 +t_curriculum,0.225474,0.223155,0.21883,0.244676 +layer_ablation_top2,9.672604,8.604239,9.558239,10.714105 +layer_ablation_top1,13.324247,13.311382,13.21699,13.340172 +layer_ablation_top3,16.208744,16.479584,15.960442,16.609873 +ffn_only,23.891375,22.415035,23.823528,25.39563 +normalized_adv,63.2407,63.106873,64.224091,61.898434 +attention_only,90.191254,90.413528,90.065102,90.144936 +frozen_backbone,115.5261,115.571564,115.430725,115.766418 +head_only,115.826111,115.743759,115.716293,116.180679 +lora,1679591.0,1673381.25,1672479.625,1676199.75 diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/tables/repr_drift.tex b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/tables/repr_drift.tex new file mode 100644 index 0000000000000000000000000000000000000000..2415f42618eb3625e9c19f8a925806e9d3961984 --- /dev/null +++ b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/tables/repr_drift.tex @@ -0,0 +1,36 @@ +\begin{table}[htbp] +\centering +\caption{Representation drift (KL divergence) at final iteration.} +\label{tab:repr_drift} +\begin{tabular}{lrrrr} +\toprule +\textbf{Method} & \textbf{KL\_mean} & \textbf{KL\_low\_t} & \textbf{KL\_mid\_t} & \textbf{KL\_high\_t} \\ +\midrule +trust\_region\_kl & 0.0644 & 0.0792 & 0.0642 & 0.0467 \\ +reward\_filtering & 0.1518 & 0.1650 & 0.1517 & 0.1434 \\ +kl\_penalty & 0.1675 & 0.1829 & 0.1690 & 0.1648 \\ +mixed\_replay & 0.1732 & 0.1795 & 0.1692 & 0.1658 \\ +low\_t & 0.1879 & 0.1964 & 0.1801 & 0.1934 \\ +llrd & 0.1888 & 0.1944 & 0.1837 & 0.1813 \\ +entropy\_bonus & 0.1893 & 0.1981 & 0.1865 & 0.1962 \\ +ewc & 0.1950 & 0.2140 & 0.1901 & 0.1834 \\ +gradient\_surgery & 0.2095 & 0.2165 & 0.2046 & 0.2008 \\ +baseline\_rl & 0.2096 & 0.2168 & 0.2047 & 0.2007 \\ +action\_diversity & 0.2099 & 0.2171 & 0.2050 & 0.2011 \\ +running\_stats & 0.2180 & 0.2235 & 0.2142 & 0.2159 \\ +advantage\_clip & 0.2191 & 0.2263 & 0.2168 & 0.2171 \\ +bc\_wins & 0.2197 & 0.2273 & 0.2174 & 0.2170 \\ +reward\_model & 0.2227 & 0.2316 & 0.2172 & 0.2161 \\ +t\_curriculum & 0.2255 & 0.2232 & 0.2188 & 0.2447 \\ +layer\_ablation\_top2 & 9.6726 & 8.6042 & 9.5582 & 10.7141 \\ +layer\_ablation\_top1 & 13.3242 & 13.3114 & 13.2170 & 13.3402 \\ +layer\_ablation\_top3 & 16.2087 & 16.4796 & 15.9604 & 16.6099 \\ +ffn\_only & 23.8914 & 22.4150 & 23.8235 & 25.3956 \\ +normalized\_adv & 63.2407 & 63.1069 & 64.2241 & 61.8984 \\ +attention\_only & 90.1913 & 90.4135 & 90.0651 & 90.1449 \\ +frozen\_backbone & 115.5261 & 115.5716 & 115.4307 & 115.7664 \\ +head\_only & 115.8261 & 115.7438 & 115.7163 & 116.1807 \\ +lora & 1679591.0000 & 1673381.2500 & 1672479.6250 & 1676199.7500 \\ +\bottomrule +\end{tabular} +\end{table} \ No newline at end of file diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/tables/t_distribution.csv b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/tables/t_distribution.csv new file mode 100644 index 0000000000000000000000000000000000000000..6d6dcf3dc7f5e371e9ef2d2ca54105b2ac60d7b6 --- /dev/null +++ b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/tables/t_distribution.csv @@ -0,0 +1,26 @@ +Method,HighLow_Ratio,LowHigh_Cos_Sim,Dominant_Regime +baseline_rl,0.08,0.0888,low-t +kl_penalty,0.092,0.0543,low-t +ewc,0.08,0.0666,low-t +llrd,0.076,0.0818,low-t +lora,0.328,0.9819,low-t +mixed_replay,0.081,-0.0058,low-t +trust_region_kl,0.099,0.06,low-t +t_curriculum,0.098,0.0614,low-t +entropy_bonus,0.092,0.0534,low-t +gradient_surgery,0.08,0.0881,low-t +advantage_clip,0.078,0.0912,low-t +normalized_adv,0.303,0.6614,low-t +bc_wins,0.081,0.0876,low-t +low_t,0.103,0.0652,low-t +frozen_backbone,0.313,0.9871,low-t +head_only,0.314,0.9874,low-t +attention_only,0.307,0.9822,low-t +ffn_only,0.305,0.9569,low-t +layer_ablation_top1,0.302,0.9718,low-t +layer_ablation_top2,0.326,0.9374,low-t +layer_ablation_top3,0.325,0.8966,low-t +reward_filtering,0.075,0.0609,low-t +running_stats,0.083,0.0936,low-t +action_diversity,0.08,0.0895,low-t +reward_model,0.085,0.0771,low-t diff --git a/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/tables/t_distribution.tex b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/tables/t_distribution.tex new file mode 100644 index 0000000000000000000000000000000000000000..77c802804bfc029bae98d328bb5862fc77b0f208 --- /dev/null +++ b/experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/tables/t_distribution.tex @@ -0,0 +1,36 @@ +\begin{table}[htbp] +\centering +\caption{Timestep distribution analysis.} +\label{tab:t_dist} +\begin{tabular}{lrrr} +\toprule +\textbf{Method} & \textbf{HighLow\_Ratio} & \textbf{LowHigh\_Cos\_Sim} & \textbf{Dominant\_Regime} \\ +\midrule +baseline\_rl & 0.0800 & 0.0888 & low-t \\ +kl\_penalty & 0.0920 & 0.0543 & low-t \\ +ewc & 0.0800 & 0.0666 & low-t \\ +llrd & 0.0760 & 0.0818 & low-t \\ +lora & 0.3280 & 0.9819 & low-t \\ +mixed\_replay & 0.0810 & -0.0058 & low-t \\ +trust\_region\_kl & 0.0990 & 0.0600 & low-t \\ +t\_curriculum & 0.0980 & 0.0614 & low-t \\ +entropy\_bonus & 0.0920 & 0.0534 & low-t \\ +gradient\_surgery & 0.0800 & 0.0881 & low-t \\ +advantage\_clip & 0.0780 & 0.0912 & low-t \\ +normalized\_adv & 0.3030 & 0.6614 & low-t \\ +bc\_wins & 0.0810 & 0.0876 & low-t \\ +low\_t & 0.1030 & 0.0652 & low-t \\ +frozen\_backbone & 0.3130 & 0.9871 & low-t \\ +head\_only & 0.3140 & 0.9874 & low-t \\ +attention\_only & 0.3070 & 0.9822 & low-t \\ +ffn\_only & 0.3050 & 0.9569 & low-t \\ +layer\_ablation\_top1 & 0.3020 & 0.9718 & low-t \\ +layer\_ablation\_top2 & 0.3260 & 0.9374 & low-t \\ +layer\_ablation\_top3 & 0.3250 & 0.8966 & low-t \\ +reward\_filtering & 0.0750 & 0.0609 & low-t \\ +running\_stats & 0.0830 & 0.0936 & low-t \\ +action\_diversity & 0.0800 & 0.0895 & low-t \\ +reward\_model & 0.0850 & 0.0771 & low-t \\ +\bottomrule +\end{tabular} +\end{table} \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000000000000000000000000000000000000..caa86d14bc3ca77c7861e17ffd67e016452e5942 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,25 @@ +[project] +name = "craftax-remdm-planner" +version = "0.1.0" +description = "Add your description here" +readme = "README.md" +requires-python = ">=3.12" +dependencies = [ + "chex>=0.1.91", + "craftax>=1.5.0", + "distrax>=0.1.7", + "flax>=0.12.6", + "huggingface-hub>=1.9.1", + "jax>=0.9.2", + "matplotlib>=3.10.8", + "numpy>=2.4.4", + "optax>=0.2.8", + "orbax>=0.1.9", + "orjson>=3.11.8", + "polars>=1.39.3", + "pyyaml>=6.0.3", + "wandb>=0.25.1", +] + +[project.optional-dependencies] +cuda = ["jax[cuda13]>=0.9.2"] diff --git a/src/diffusion/__init__.py b/src/diffusion/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/diffusion/forward.py b/src/diffusion/forward.py new file mode 100644 index 0000000000000000000000000000000000000000..ba91b733c9162d2261dc6adef6fd790159c288be --- /dev/null +++ b/src/diffusion/forward.py @@ -0,0 +1,28 @@ +"""Forward process: q(z_t | x_0) by independent per-token masking.""" + +from __future__ import annotations + +import jax +import jax.numpy as jnp + + +def forward_process( + rng: jax.Array, + x_0: jnp.ndarray, + alpha_t: jnp.ndarray, + mask_id: int, +) -> jnp.ndarray: + """Sample z_t ~ q(z_t | x_0). Each token stays with prob alpha_t, else MASK. + + Args: + rng: PRNG key. + x_0: [B, H] int32, clean actions. + alpha_t: [B] or scalar, retention probability. + mask_id: MASK token index (= num_actions). + + Returns: + z_t: [B, H] int32. + """ + keep = jax.random.uniform(rng, shape=x_0.shape) + alpha_t = jnp.reshape(alpha_t, (-1, 1)) + return jnp.where(keep < alpha_t, x_0, jnp.array(mask_id, dtype=x_0.dtype)) diff --git a/src/diffusion/loss.py b/src/diffusion/loss.py new file mode 100644 index 0000000000000000000000000000000000000000..0bfad20c465eee300694fbffed5a58e9df43a63a --- /dev/null +++ b/src/diffusion/loss.py @@ -0,0 +1,122 @@ +"""MDLM ELBO loss for masked discrete diffusion training.""" + +from __future__ import annotations +from typing import Any, Callable, Optional + +import jax +import jax.numpy as jnp + +from .forward import forward_process +from .schedules import ScheduleFn + +ModelApplyFn = Callable[ + [Any, jnp.ndarray, jnp.ndarray, jnp.ndarray, Optional[Any]], jnp.ndarray +] + +_MAX_WEIGHT: float = 1000.0 +_EPS: float = 1e-5 + + +def compute_loss( + model_apply: ModelApplyFn, + params: Any, + rng: jax.Array, + x_0: jnp.ndarray, + obs: jnp.ndarray, + valid: jnp.ndarray, + num_actions: int, + schedule_fn: ScheduleFn, + schedule_deriv_fn: ScheduleFn, + sigma_t: float = 0.0, + label_smoothing: float = 0.0, + advantages: Optional[jnp.ndarray] = None, + t_min: float | jax.Array = _EPS, + t_max: float | jax.Array = 1.0, +) -> tuple[jnp.ndarray, dict[str, jnp.ndarray]]: + """Continuous-time ELBO loss on masked positions only. + + Args: + model_apply: fn(params, obs, z_t, t, rng) -> logits [B, H, V]. + params: Model parameters. + rng: PRNG key. + x_0: [B, H] int32, ground-truth actions. + obs: [B, obs_dim] float32, observations. + valid: [B] bool/float, whether each sample is valid. + num_actions: Size of real action vocabulary. + schedule_fn: alpha(t). + schedule_deriv_fn: d(alpha)/dt (analytic). + sigma_t: Remasking correction for ELBO weight (0 = standard MDLM). + label_smoothing: Smoothing epsilon (0 = exact ELBO targets). + advantages: Optional [B] per-sample weights. + t_min: Lower bound for uniform t sampling (default: _EPS). + t_max: Upper bound for uniform t sampling (default: 1.0). + + Returns: + (loss, info_dict). + """ + B = x_0.shape[0] + mask_id = num_actions + rng, t_rng, mask_rng, drop_rng = jax.random.split(rng, 4) + + # Sample t ~ U(t_min, t_max). Defaults give full ELBO; narrow range for ablations. + t = jax.random.uniform(t_rng, (B,), minval=t_min, maxval=t_max) + alpha_t = schedule_fn(t) + + # Analytic loss weight: w(t) = (1 - sigma) * (-d(alpha)/dt) / (1 - alpha(t)) + neg_alpha_dot = -schedule_deriv_fn(t) # positive quantity + weight = (1.0 - sigma_t) * neg_alpha_dot / jnp.maximum(1.0 - alpha_t, _EPS) + weight = jnp.minimum(weight, _MAX_WEIGHT) + + # Forward noise + z_t = forward_process(mask_rng, x_0, alpha_t, mask_id) + + # Model prediction + logits = model_apply(params, obs, z_t, t, drop_rng) # [B, H, V] + + # Cross-entropy on valid masked positions + is_masked = (z_t == mask_id).astype(jnp.float32) # [B, H] + valid_masked = is_masked * valid[:, None].astype(jnp.float32) # [B, H] + + targets = jax.nn.one_hot(x_0, num_actions) + if label_smoothing > 0: + targets = (1.0 - label_smoothing) * targets + label_smoothing / num_actions + + log_probs = jax.nn.log_softmax(logits, axis=-1) + ce = -jnp.sum(targets * log_probs, axis=-1) # [B, H] + + n_masked = jnp.maximum(valid_masked.sum(axis=-1), 1.0) # [B] + per_sample = weight * (ce * valid_masked).sum(axis=-1) / n_masked + + if advantages is not None: + per_sample = per_sample * jax.lax.stop_gradient(advantages) + + loss = jnp.mean(per_sample) + + # Diagnostics + preds = jnp.argmax(logits, axis=-1) + correct = (preds == x_0).astype(jnp.float32) + acc = jnp.sum(correct * valid_masked) / jnp.maximum(valid_masked.sum(), 1.0) + + t_bins = jnp.array([0.33, 0.66]) + t_lo = (t < t_bins[0])[:, None] + t_mi = ((t >= t_bins[0]) & (t <= t_bins[1]))[:, None] + t_hi = (t > t_bins[1])[:, None] + + def _binned_acc(mask): + m = valid_masked * mask + return jnp.sum(correct * m) / jnp.maximum(m.sum(), 1.0) + + info = { + "loss": loss, + "unweighted_loss": jnp.mean((ce * valid_masked).sum(axis=-1) / n_masked), + "mean_t": jnp.mean(t), + "frac_masked": jnp.mean(is_masked), + "accuracy": acc, + "acc_t_low": _binned_acc(t_lo), + "acc_t_mid": _binned_acc(t_mi), + "acc_t_high": _binned_acc(t_hi), + } + if advantages is not None: + info["adv_mean"] = jnp.mean(advantages) + info["adv_std"] = jnp.std(advantages) + return loss, info diff --git a/src/diffusion/sampling.py b/src/diffusion/sampling.py new file mode 100644 index 0000000000000000000000000000000000000000..d1fb3532041f93bd4595bc16d19f764b568651df --- /dev/null +++ b/src/diffusion/sampling.py @@ -0,0 +1,312 @@ +"""Reverse diffusion sampling with ReMDM remasking (Wang et al.). + +Strategies (Section 4.1): + rescale: sigma = eta * sigma_max + cap: sigma = min(eta, sigma_max) + conf: per-token confidence-based remasking + +Loop mode (Section 4.2, Algorithm 3): + Phase 1: standard MDLM decode, t in [1, t_on] + Phase 2: constant alpha(t_on), remasking active + Phase 3: standard MDLM decode, t in [t_off, 0] +""" + +from __future__ import annotations +from typing import Any, Callable, Optional + +import jax +import jax.numpy as jnp + +from .schedules import ScheduleFn + +ModelApplyFn = Callable[ + [Any, jnp.ndarray, jnp.ndarray, jnp.ndarray, Optional[Any]], jnp.ndarray +] + + +# --------------------------------------------------------------------------- +# Remasking sigma computation +# --------------------------------------------------------------------------- + +def _sigma_max(alpha_t: jnp.ndarray, alpha_s: jnp.ndarray) -> jnp.ndarray: + """sigma_max = min(1, (1 - alpha_s) / alpha_t). [Eq. 7]""" + return jnp.minimum(1.0, (1.0 - alpha_s) / jnp.maximum(alpha_t, 1e-8)) + + +def sigma_rescale(alpha_t, alpha_s, eta): + return eta * _sigma_max(alpha_t, alpha_s) + + +def sigma_cap(alpha_t, alpha_s, eta): + return jnp.minimum(eta, _sigma_max(alpha_t, alpha_s)) + + +def sigma_conf(alpha_t, alpha_s, eta, psi, is_unmasked): + """Per-token confidence remasking. Safe against all-masked rows.""" + base = eta * _sigma_max(alpha_t, alpha_s) + any_unmasked = jnp.any(is_unmasked, axis=-1, keepdims=True) + neg_psi = jnp.where(is_unmasked, -psi, -jnp.inf) + safe_neg_psi = jnp.where(any_unmasked, neg_psi, 0.0) + eta_conf = jax.nn.softmax(safe_neg_psi, axis=-1) + return jnp.where(is_unmasked, eta_conf * base, 0.0) + + +_SIGMA_FNS = { + "rescale": lambda a_t, a_s, eta, *_: sigma_rescale(a_t, a_s, eta), + "cap": lambda a_t, a_s, eta, *_: sigma_cap(a_t, a_s, eta), + "conf": lambda a_t, a_s, eta, psi, unm: sigma_conf(a_t, a_s, eta, psi, unm), +} + + +# --------------------------------------------------------------------------- +# Decoding helpers +# --------------------------------------------------------------------------- + +def _nucleus_sample(rng, logits, top_p): + """Top-p sampling from [B, H, V] logits -> [B, H] int32.""" + probs = jax.nn.softmax(logits, axis=-1) + idx = jnp.argsort(-probs, axis=-1) + sorted_p = jnp.take_along_axis(probs, idx, axis=-1) + cum = jnp.cumsum(sorted_p, axis=-1) + + cutoff = cum - sorted_p + sorted_p = jnp.where(cutoff >= top_p, 0.0, sorted_p) + sorted_p = sorted_p / jnp.maximum(sorted_p.sum(axis=-1, keepdims=True), 1e-12) + + B, H, V = logits.shape + flat = sorted_p.reshape(B * H, V) + tokens = jax.random.categorical(rng, jnp.log(flat + 1e-12)).reshape(B, H) + return jnp.take_along_axis(idx, tokens[..., None], axis=-1).squeeze(-1) + + +def _decode(rng, logits, temperature, top_p): + """Sample tokens from logits. Argmax only when temperature <= 0.""" + if top_p is not None: + return _nucleus_sample(rng, logits / jnp.maximum(temperature, 1e-8), top_p) + if temperature > 1e-8: + B, H, V = logits.shape + scaled = logits / temperature + return jax.random.categorical(rng, scaled.reshape(-1, V)).reshape(B, H) + return jnp.argmax(logits, axis=-1) + + +# --------------------------------------------------------------------------- +# Reverse sampling +# --------------------------------------------------------------------------- + +def sample_plan( + model_apply: ModelApplyFn, + params: Any, + rng: jax.Array, + obs: jnp.ndarray, + num_actions: int, + plan_horizon: int, + num_steps: int, + schedule_fn: ScheduleFn, + remask_strategy: str = "cap", + eta: float = 0.5, + use_loop: bool = False, + t_on: float = 0.55, + t_off: float = 0.05, + temperature: float = 1.0, + top_p: Optional[float] = None, +) -> jnp.ndarray: + """Generate an action plan via reverse diffusion with ReMDM remasking. + + Returns: + actions: [B, H] int32. + """ + B = obs.shape[0] + mask_id = num_actions + mask_val = jnp.array(mask_id, dtype=jnp.int32) + + if remask_strategy not in _SIGMA_FNS: + raise ValueError(f"Unknown strategy {remask_strategy!r}. Options: {list(_SIGMA_FNS)}") + get_sigma = _SIGMA_FNS[remask_strategy] + + # Phase allocation for loop mode + if use_loop: + f1, f3 = 1.0 - t_on, t_off + denom = f1 + f3 + (t_on - t_off) + n1 = max(int(round(num_steps * f1 / denom)), 1) + n3 = max(int(round(num_steps * f3 / denom)), 1) + n2 = max(num_steps - n1 - n3, 1) + else: + n1, n2, n3 = num_steps, 0, 0 + + alpha_loop = schedule_fn(jnp.array(t_on)) + + z_init = jnp.full((B, plan_horizon), mask_id, dtype=jnp.int32) + psi_init = jnp.full((B, plan_horizon), jnp.inf) + + # ------------------------------------------------------------------ + # Core denoising step (ReMDM Eq. 6) + # ------------------------------------------------------------------ + def _step(carry, _unused, t_val, alpha_t, alpha_s, sigma_on): + z, rng, psi = carry + rng, s_rng, u_rng, r_rng = jax.random.split(rng, 4) + + t_inp = jnp.full((B,), t_val) + logits = model_apply(params, obs, z, t_inp, None) + x_hat = _decode(s_rng, logits, temperature, top_p) + + is_masked = z == mask_id + is_unmasked = ~is_masked + + sigma = get_sigma(alpha_t, alpha_s, eta, psi, is_unmasked) + sigma = jnp.broadcast_to(sigma, z.shape) + sigma = jnp.where(sigma_on, sigma, 0.0) + + # Masked -> unmask probability + denom = jnp.maximum(1.0 - alpha_t, 1e-8) + p_unmask = jnp.clip((alpha_s - (1.0 - sigma) * alpha_t) / denom, 0.0, 1.0) + + do_unmask = is_masked & (jax.random.uniform(u_rng, z.shape) < p_unmask) + do_remask = is_unmasked & (jax.random.uniform(r_rng, z.shape) < sigma) + + z_new = jnp.where(do_unmask, x_hat, z) + z_new = jnp.where(do_remask, mask_val, z_new) + + # Update confidence history + probs = jax.nn.softmax(logits, axis=-1) + decode_prob = jnp.take_along_axis(probs, x_hat[..., None], axis=-1).squeeze(-1) + psi_new = jnp.where(do_unmask, decode_prob, psi) + psi_new = jnp.where(do_remask, jnp.inf, psi_new) + + return (z_new, rng, psi_new), None + + # ------------------------------------------------------------------ + # Phase functions + # ------------------------------------------------------------------ + def _phase1_step(carry, idx): + t = 1.0 - idx * (1.0 - t_on) / n1 + s = jnp.maximum(1.0 - (idx + 1) * (1.0 - t_on) / n1, t_on) + return _step(carry, idx, t, schedule_fn(t), schedule_fn(s), False) + + def _phase2_step(carry, idx): + return _step(carry, idx, t_on, alpha_loop, alpha_loop, True) + + def _phase3_step(carry, idx): + t = t_off - idx * t_off / n3 + s = jnp.maximum(t_off - (idx + 1) * t_off / n3, 0.0) + return _step(carry, idx, t, schedule_fn(t), schedule_fn(s), False) + + def _simple_step(carry, idx): + t = (num_steps - idx) / num_steps + s = jnp.maximum((num_steps - idx - 1) / num_steps, 0.0) + return _step(carry, idx, t, schedule_fn(t), schedule_fn(s), True) + + # ------------------------------------------------------------------ + # Run + # ------------------------------------------------------------------ + carry = (z_init, rng, psi_init) + + if use_loop: + carry, _ = jax.lax.scan(_phase1_step, carry, jnp.arange(n1)) + carry, _ = jax.lax.scan(_phase2_step, carry, jnp.arange(n2)) + if n3 > 0: + carry, _ = jax.lax.scan(_phase3_step, carry, jnp.arange(n3)) + else: + carry, _ = jax.lax.scan(_simple_step, carry, jnp.arange(num_steps)) + + z_final = carry[0] + + # Final greedy cleanup for any remaining masks + final_logits = model_apply(params, obs, z_final, jnp.zeros((B,)), None) + fallback = jnp.argmax(final_logits, axis=-1) + return jnp.where(z_final == mask_id, fallback, z_final) + + +# --------------------------------------------------------------------------- +# Inpainting sampler (MPC / historical prefix) +# --------------------------------------------------------------------------- + +def sample_plan_inpainting( + apply_fn: ModelApplyFn, + params: Any, + rng: jax.Array, + obs: jnp.ndarray, + history: jnp.ndarray, + hist_len: jnp.ndarray, + num_actions: int, + plan_horizon: int, + diffusion_steps: int, + temperature: float, + top_p: Optional[float], +) -> jnp.ndarray: + """Diffusion sampling with a locked historical prefix (inpainting). + + Positions ``0 .. hist_len[b] - 1`` are fixed to the values in ``history`` + for each batch element ``b``; the remainder are diffused freely. + + Args: + apply_fn: Model apply closure (eval mode, no dropout). + params: Model parameter pytree. + rng: PRNG key. + obs: ``[B, obs_dim]`` conditioning observations. + history: ``[B, plan_horizon]`` int32 prefix of executed actions. + hist_len: ``[B]`` int32 number of valid prefix tokens per element. + num_actions: Size of the real action vocabulary (mask token = ``num_actions``). + plan_horizon: Total sequence length. + diffusion_steps: Number of denoising iterations. + temperature: Softmax temperature for token sampling. + top_p: Nucleus-sampling threshold; ``None`` disables nucleus filtering. + + Returns: + ``[B, plan_horizon]`` int32 completed action plan. + """ + B = obs.shape[0] + mask_id = num_actions + + def _step(carry, step): + seq, rng = carry + rng, model_rng, sample_rng, remask_rng = jax.random.split(rng, 4) + + ratio = step / diffusion_steps + t_tensor = jnp.full((B,), 1.0 - ratio) + logits = apply_fn(params, obs, seq, t_tensor, model_rng) / jnp.maximum(temperature, 1e-8) + + # Optional nucleus filtering + if top_p is not None: + probs = jax.nn.softmax(logits, axis=-1) + sorted_idx = jnp.argsort(-probs, axis=-1) + sorted_p = jnp.take_along_axis(probs, sorted_idx, axis=-1) + cutoff = jnp.cumsum(sorted_p, axis=-1) - sorted_p + inv_idx = jnp.argsort(sorted_idx, axis=-1) + nucleus_mask = jnp.take_along_axis(cutoff >= top_p, inv_idx, axis=-1) + logits = jnp.where(nucleus_mask, -jnp.inf, logits) + + preds = jax.random.categorical(sample_rng, logits, axis=-1) + conf = jnp.take_along_axis( + jax.nn.softmax(logits, axis=-1), preds[..., None], axis=-1, + ).squeeze(-1) + + # Keep top-(ratio * H) most confident predictions unmasked + num_unmask = jnp.maximum(1, (plan_horizon * ratio).astype(jnp.int32)) + sorted_conf = jnp.sort(conf, axis=-1)[..., ::-1] + thresh = sorted_conf[jnp.arange(B), num_unmask - 1] + seq_new = jnp.where(conf < thresh[:, None], mask_id, preds) + + # Light ReMDM-style remasking + remask_prob = 0.15 * (1.0 - ratio) + do_remask = ( + (jax.random.uniform(remask_rng, seq_new.shape) < remask_prob) + & (seq_new != mask_id) + ) + seq_new = jnp.where(do_remask, mask_id, seq_new) + + # Lock historical prefix + pos = jnp.broadcast_to(jnp.arange(plan_horizon)[None, :], (B, plan_horizon)) + seq_new = jnp.where(pos < hist_len[:, None], history, seq_new) + + return (seq_new, rng), None + + # Initialise: history locked, remainder fully masked + init_seq = jnp.full((B, plan_horizon), mask_id, dtype=jnp.int32) + pos = jnp.broadcast_to(jnp.arange(plan_horizon)[None, :], (B, plan_horizon)) + init_seq = jnp.where(pos < hist_len[:, None], history, init_seq) + + (final_seq, _), _ = jax.lax.scan( + _step, (init_seq, rng), jnp.arange(1, diffusion_steps + 1), + ) + return final_seq diff --git a/src/diffusion/schedules.py b/src/diffusion/schedules.py new file mode 100644 index 0000000000000000000000000000000000000000..f7e815322cf1d55428d0981b959963c106706ef2 --- /dev/null +++ b/src/diffusion/schedules.py @@ -0,0 +1,35 @@ +"""Noise schedules for masked discrete diffusion. + +alpha(t) is the retention probability: alpha(0)=1 (clean), alpha(1)=0 (fully masked). +""" + +from __future__ import annotations +from typing import Callable + +import jax.numpy as jnp + +ScheduleFn = Callable[[jnp.ndarray], jnp.ndarray] + + +def linear_schedule(t: jnp.ndarray) -> jnp.ndarray: + """alpha(t) = 1 - t. Default in MDLM / ReMDM.""" + return 1.0 - t + + +def linear_schedule_deriv(t: jnp.ndarray) -> jnp.ndarray: + return jnp.full_like(t, -1.0) + + +def cosine_schedule(t: jnp.ndarray) -> jnp.ndarray: + """alpha(t) = cos(pi * t / 2).""" + return jnp.cos(t * jnp.pi / 2.0) + + +def cosine_schedule_deriv(t: jnp.ndarray) -> jnp.ndarray: + return -(jnp.pi / 2.0) * jnp.sin(t * jnp.pi / 2.0) + + +SCHEDULE_MAP: dict[str, tuple[ScheduleFn, ScheduleFn]] = { + "linear": (linear_schedule, linear_schedule_deriv), + "cosine": (cosine_schedule, cosine_schedule_deriv), +} diff --git a/src/envs/__init__.py b/src/envs/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..7b7835ea3776d9f016570b56fc9723a55b0f9920 --- /dev/null +++ b/src/envs/__init__.py @@ -0,0 +1,15 @@ +"""Environment wrappers for the ReMDM planner.""" + +from src.envs.wrappers import ( + DiscreteTokenizationWrapper, + OfflineTrajectoryWrapper, + PlannerWrapper, + SequenceHistoryWrapper, +) + +__all__ = [ + "DiscreteTokenizationWrapper", + "OfflineTrajectoryWrapper", + "PlannerWrapper", + "SequenceHistoryWrapper", +] diff --git a/src/envs/wrappers.py b/src/envs/wrappers.py new file mode 100644 index 0000000000000000000000000000000000000000..f61102c268d864a9af60d3a8024cce92dcf83b78 --- /dev/null +++ b/src/envs/wrappers.py @@ -0,0 +1,403 @@ +"""Project-specific Gymnax wrappers for the ReMDM planner.""" + +from __future__ import annotations + +from functools import partial +from typing import Any, Callable, Tuple, Union + +import chex +import jax +import jax.numpy as jnp +from flax import struct + +from Craftax_Baselines.wrappers import GymnaxWrapper + + +# ============================================================================= +# SequenceHistoryWrapper +# ============================================================================= + + +@struct.dataclass +class SequenceHistoryState: + env_state: Any + obs_history: chex.Array # [history_len, *obs_shape] + act_history: chex.Array # [history_len] int32 + + +class SequenceHistoryWrapper(GymnaxWrapper): + """Augments env state with a sliding window of past observations and actions. + + After each step the histories satisfy:: + + obs_history[-1] = current observation + act_history[i] = action taken from obs_history[i] to reach obs_history[i+1] + + The wrapper returns the current observation unchanged; the sequence context is + accessed via ``state.obs_history`` and ``state.act_history`` in the training loop. + + Place this as the **innermost** wrapper (before AutoReset / LogWrapper) so that + episode boundaries trigger a proper history reset via the auto-reset mechanism. + + Args: + env: Single Gymnax environment. + history_len: Number of past timesteps to keep (including current). + obs_shape: Shape of a single observation, e.g. ``(obs_dim,)``. + """ + + def __init__(self, env: Any, history_len: int, obs_shape: Tuple[int, ...]) -> None: + super().__init__(env) + self.history_len = history_len + self.obs_shape = obs_shape + + @partial(jax.jit, static_argnums=(0, 2)) + def reset( + self, key: chex.PRNGKey, params: Any = None + ) -> Tuple[chex.Array, SequenceHistoryState]: + obs, env_state = self._env.reset(key, params) + obs_history = jnp.tile( + obs[None], [self.history_len] + [1] * len(self.obs_shape) + ) + act_history = jnp.zeros(self.history_len, dtype=jnp.int32) + state = SequenceHistoryState( + env_state=env_state, + obs_history=obs_history, + act_history=act_history, + ) + return obs, state + + @partial(jax.jit, static_argnums=(0, 4)) + def step( + self, + key: chex.PRNGKey, + state: SequenceHistoryState, + action: Union[int, float], + params: Any = None, + ) -> Tuple[chex.Array, SequenceHistoryState, chex.Array, chex.Array, Any]: + obs, env_state, reward, done, info = self._env.step( + key, state.env_state, action, params + ) + act_history = jnp.roll(state.act_history, -1, axis=0).at[-1].set(action) + obs_history = jnp.roll(state.obs_history, -1, axis=0).at[-1].set(obs) + new_state = SequenceHistoryState( + env_state=env_state, + obs_history=obs_history, + act_history=act_history, + ) + return obs, new_state, reward, done, info + + +# ============================================================================= +# DiscreteTokenizationWrapper +# ============================================================================= + + +class DiscreteTokenizationWrapper(GymnaxWrapper): + """Quantizes continuous observations into discrete token indices. + + Each observation element is mapped to one of ``n_bins`` integer tokens using + uniform binning between ``obs_min`` and ``obs_max``. + + The returned observation dtype is int32 with values in ``[0, n_bins - 1]``. + + Args: + env: Gymnax environment (or wrapper). + n_bins: Number of discrete bins per observation element. + obs_min: Per-element lower bound, shape matching the observation. + obs_max: Per-element upper bound, shape matching the observation. + """ + + def __init__( + self, + env: Any, + n_bins: int, + obs_min: jnp.ndarray, + obs_max: jnp.ndarray, + ) -> None: + super().__init__(env) + self.n_bins = n_bins + self.obs_min = obs_min + self.obs_max = obs_max + + def _tokenize(self, obs: chex.Array) -> chex.Array: + obs_clipped = jnp.clip(obs, self.obs_min, self.obs_max) + normalized = (obs_clipped - self.obs_min) / ( + self.obs_max - self.obs_min + 1e-8 + ) + tokens = jnp.floor(normalized * self.n_bins).astype(jnp.int32) + return jnp.clip(tokens, 0, self.n_bins - 1) + + @partial(jax.jit, static_argnums=(0, 2)) + def reset( + self, key: chex.PRNGKey, params: Any = None + ) -> Tuple[chex.Array, Any]: + obs, state = self._env.reset(key, params) + return self._tokenize(obs), state + + @partial(jax.jit, static_argnums=(0, 4)) + def step( + self, + key: chex.PRNGKey, + state: Any, + action: Union[int, float], + params: Any = None, + ) -> Tuple[chex.Array, Any, chex.Array, chex.Array, Any]: + obs, state, reward, done, info = self._env.step(key, state, action, params) + return self._tokenize(obs), state, reward, done, info + + +# ============================================================================= +# PlannerWrapper +# ============================================================================= + + +@struct.dataclass +class PlannerState: + env_state: Any + current_plan: chex.Array # [num_envs, plan_horizon] int32 + plan_step: int + + +class PlannerWrapper(GymnaxWrapper): + """Manages the plan / replan cycle for a discrete diffusion planner. + + Expected wrapper stack (inner -> outer):: + + env -> LogWrapper -> AutoResetEnvWrapper -> BatchEnvWrapper + -> PlannerWrapper + + The ``planner_apply_fn`` must have the signature:: + + fn(rng, model_params, obs) -> jnp.ndarray # [num_envs, plan_horizon] int32 + + Args: + env: Batched Gymnax environment (already handles num_envs). + num_envs: Number of parallel environments. + plan_horizon: Total number of actions the diffusion model outputs. + replan_every: Steps to execute before requesting a new plan (<= plan_horizon). + planner_apply_fn: Callable that invokes the diffusion model. + """ + + def __init__( + self, + env: Any, + num_envs: int, + plan_horizon: int, + replan_every: int, + planner_apply_fn: Callable[..., jnp.ndarray], + ) -> None: + super().__init__(env) + if replan_every > plan_horizon: + raise ValueError( + f"replan_every ({replan_every}) must be <= plan_horizon ({plan_horizon})" + ) + self.num_envs = num_envs + self.plan_horizon = plan_horizon + self.replan_every = replan_every + self.planner_apply_fn = planner_apply_fn + + @partial(jax.jit, static_argnums=(0, 2)) + def reset( + self, key: chex.PRNGKey, params: Any = None + ) -> Tuple[chex.Array, PlannerState]: + obs, env_state = self._env.reset(key, params) + current_plan = jnp.zeros( + (self.num_envs, self.plan_horizon), dtype=jnp.int32 + ) + state = PlannerState( + env_state=env_state, + current_plan=current_plan, + plan_step=0, + ) + return obs, state + + @partial(jax.jit, static_argnums=(0,)) + def step( + self, + key: chex.PRNGKey, + state: PlannerState, + last_obs: chex.Array, + model_params: Any, + env_params: Any = None, + ) -> Tuple[chex.Array, PlannerState, chex.Array, chex.Array, chex.Array, Any]: + """Step the environment using the diffusion plan. + + Args: + key: PRNG key. + state: Current PlannerState. + last_obs: Most recent batched observation [num_envs, *obs_shape]. + model_params: Parameters passed to planner_apply_fn. + env_params: Optional Gymnax environment params. + + Returns: + (obs, state, action, reward, done, info) + """ + key, plan_key, step_key = jax.random.split(key, 3) + + current_plan = jax.lax.cond( + state.plan_step == 0, + lambda operand: self.planner_apply_fn(*operand), + lambda operand: state.current_plan, + (plan_key, model_params, last_obs), + ) + + action = current_plan[:, state.plan_step] + + obs, env_state, reward, done, info = self._env.step( + step_key, state.env_state, action, env_params + ) + + new_plan_step = (state.plan_step + 1) % self.replan_every + new_state = PlannerState( + env_state=env_state, + current_plan=current_plan, + plan_step=new_plan_step, + ) + return obs, new_state, action, reward, done, info + + +# ============================================================================= +# OfflineTrajectoryWrapper +# ============================================================================= + + +@struct.dataclass +class TrajectoryBufferState: + env_state: Any + last_obs: Any # [*obs_shape] + buf_obs: Any # [max_size, *obs_shape] + buf_act: Any # [max_size] int32 + buf_reward: Any # [max_size] float32 + buf_done: Any # [max_size] bool + buf_next_obs: Any # [max_size, *obs_shape] + write_idx: Any # int32, wraps at max_size + num_valid: Any # int32, capped at max_size + + +class OfflineTrajectoryWrapper(GymnaxWrapper): + """Accumulates transitions into a fixed-size circular replay buffer. + + The buffer overwrites the oldest entries once full. Use ``sample_sequences`` + to draw contiguous subsequences for training a sequence model like ReMDM. + + Designed for a single environment; compose with ``BatchEnvWrapper`` *outside* + this wrapper to collect from multiple envs simultaneously. + + Args: + env: Single Gymnax environment (or wrapper). + max_size: Maximum number of transitions to store. + obs_shape: Shape of a single observation, e.g. ``(obs_dim,)``. + """ + + def __init__( + self, env: Any, max_size: int, obs_shape: Tuple[int, ...] + ) -> None: + super().__init__(env) + self.max_size = max_size + self.obs_shape = obs_shape + + def _empty_buffer( + self, env_state: Any, first_obs: chex.Array + ) -> TrajectoryBufferState: + return TrajectoryBufferState( + env_state=env_state, + last_obs=first_obs, + buf_obs=jnp.zeros( + (self.max_size, *self.obs_shape), dtype=jnp.float32 + ), + buf_act=jnp.zeros(self.max_size, dtype=jnp.int32), + buf_reward=jnp.zeros(self.max_size, dtype=jnp.float32), + buf_done=jnp.zeros(self.max_size, dtype=jnp.bool_), + buf_next_obs=jnp.zeros( + (self.max_size, *self.obs_shape), dtype=jnp.float32 + ), + write_idx=jnp.int32(0), + num_valid=jnp.int32(0), + ) + + @partial(jax.jit, static_argnums=(0, 2)) + def reset( + self, key: chex.PRNGKey, params: Any = None + ) -> Tuple[chex.Array, TrajectoryBufferState]: + obs, env_state = self._env.reset(key, params) + state = self._empty_buffer(env_state, obs) + return obs, state + + @partial(jax.jit, static_argnums=(0, 4)) + def step( + self, + key: chex.PRNGKey, + state: TrajectoryBufferState, + action: Union[int, float], + params: Any = None, + ) -> Tuple[chex.Array, TrajectoryBufferState, chex.Array, chex.Array, Any]: + obs, env_state, reward, done, info = self._env.step( + key, state.env_state, action, params + ) + + idx = state.write_idx % self.max_size + buf_obs = state.buf_obs.at[idx].set(state.last_obs) + buf_act = state.buf_act.at[idx].set(action) + buf_reward = state.buf_reward.at[idx].set(reward) + buf_done = state.buf_done.at[idx].set(done) + buf_next_obs = state.buf_next_obs.at[idx].set(obs) + + # Wrap write_idx at max_size to prevent unbounded growth / int32 overflow + new_write_idx = (state.write_idx + 1) % self.max_size + is_full = state.num_valid >= self.max_size + new_num_valid = jnp.where(is_full, self.max_size, state.num_valid + 1) + + new_state = TrajectoryBufferState( + env_state=env_state, + last_obs=obs, + buf_obs=buf_obs, + buf_act=buf_act, + buf_reward=buf_reward, + buf_done=buf_done, + buf_next_obs=buf_next_obs, + write_idx=new_write_idx, + num_valid=new_num_valid, + ) + return obs, new_state, reward, done, info + + @partial(jax.jit, static_argnums=(0, 3, 4)) + def sample_sequences( + self, + rng: chex.PRNGKey, + state: TrajectoryBufferState, + n_samples: int, + seq_len: int, + ) -> Tuple[ + chex.Array, chex.Array, chex.Array, chex.Array, chex.Array + ]: + """Sample ``n_samples`` contiguous subsequences of length ``seq_len``. + + Precondition: ``state.num_valid >= seq_len``. + + Returns: + obs [n_samples, seq_len, *obs_shape] + act [n_samples, seq_len] + reward [n_samples, seq_len] + done [n_samples, seq_len] + next_obs [n_samples, seq_len, *obs_shape] + """ + max_start = jnp.maximum(state.num_valid - seq_len, 1) + start_indices = jax.random.randint( + rng, shape=(n_samples,), minval=0, maxval=max_start + ) + + def gather_seq( + start_idx: jnp.ndarray, + ) -> Tuple[ + chex.Array, chex.Array, chex.Array, chex.Array, chex.Array + ]: + indices = (start_idx + jnp.arange(seq_len)) % self.max_size + return ( + state.buf_obs[indices], + state.buf_act[indices], + state.buf_reward[indices], + state.buf_done[indices], + state.buf_next_obs[indices], + ) + + return jax.vmap(gather_seq)(start_indices) diff --git a/src/models/__init__.py b/src/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/models/denoiser.py b/src/models/denoiser.py new file mode 100644 index 0000000000000000000000000000000000000000..972e122f5cd95d5317d3fb1c8062ed167a9a7221 --- /dev/null +++ b/src/models/denoiser.py @@ -0,0 +1,125 @@ +"""Denoising Transformer for masked discrete diffusion planning. + +Architecture: obs MLP encoder + sinusoidal time embedding + bidirectional +transformer. Two prefix tokens (obs, time) precede the action sequence. +""" + +from __future__ import annotations + +import numpy as np +import jax.numpy as jnp +import flax.linen as nn +from flax.linen.initializers import constant, orthogonal + +_INIT = orthogonal(np.sqrt(2)) +_INIT_SMALL = orthogonal(0.01) +_BIAS = constant(0.0) + + +class SinusoidalPosEmbed(nn.Module): + """Sinusoidal embedding for continuous timesteps or integer positions.""" + + dim: int + + @nn.compact + def __call__(self, x: jnp.ndarray) -> jnp.ndarray: + half = self.dim // 2 + freqs = jnp.exp(-jnp.log(10_000.0) * jnp.arange(half) / half) + angles = x[..., None] * freqs + emb = jnp.concatenate([jnp.sin(angles), jnp.cos(angles)], axis=-1) + if self.dim % 2 == 1: + emb = jnp.concatenate([emb, jnp.zeros_like(emb[..., :1])], axis=-1) + return emb + + +class TransformerBlock(nn.Module): + """Pre-norm transformer: LN -> MHA -> res -> LN -> FFN -> res.""" + + d_model: int + n_heads: int + d_ff: int + dropout_rate: float = 0.1 + deterministic: bool = True + + @nn.compact + def __call__(self, x: jnp.ndarray) -> jnp.ndarray: + h = nn.LayerNorm()(x) + h = nn.MultiHeadDotProductAttention( + num_heads=self.n_heads, kernel_init=_INIT, deterministic=self.deterministic, + )(h, h) + h = nn.Dropout(rate=self.dropout_rate, deterministic=self.deterministic)(h) + x = x + h + + h = nn.LayerNorm()(x) + h = nn.Dense(self.d_ff, kernel_init=_INIT, bias_init=_BIAS)(h) + h = nn.gelu(h) + h = nn.Dense(self.d_model, kernel_init=_INIT, bias_init=_BIAS)(h) + h = nn.Dropout(rate=self.dropout_rate, deterministic=self.deterministic)(h) + return x + h + + +class DenoisingTransformer(nn.Module): + """Denoising transformer for masked discrete diffusion planning. + + Input: (obs [B, D], noisy_actions [B, H], timestep [B]) + Output: logits [B, H, num_actions] (no MASK logit). + """ + + num_actions: int + plan_horizon: int + d_model: int = 256 + n_heads: int = 4 + n_layers: int = 4 + d_ff: int = 512 + obs_encoder_layers: int = 2 + obs_encoder_width: int = 512 + dropout_rate: float = 0.1 + + @nn.compact + def __call__( + self, + obs: jnp.ndarray, + noisy_actions: jnp.ndarray, + timestep: jnp.ndarray, + deterministic: bool = True, + ) -> jnp.ndarray: + B = obs.shape[0] + vocab = self.num_actions + 1 # +1 for MASK token + + # Observation encoder + h = nn.Dense(self.obs_encoder_width, kernel_init=_INIT, bias_init=_BIAS)(obs) + h = nn.LayerNorm()(h) + h = nn.relu(h) + for _ in range(self.obs_encoder_layers - 1): + h = nn.Dense(self.obs_encoder_width, kernel_init=_INIT, bias_init=_BIAS)(h) + h = nn.relu(h) + obs_tok = nn.Dense(self.d_model, kernel_init=_INIT, bias_init=_BIAS)(h)[:, None, :] + + # Time embedding + t = timestep.reshape(B) + t_emb = SinusoidalPosEmbed(self.d_model)(t) + t_emb = nn.Dense(self.d_model, kernel_init=_INIT, bias_init=_BIAS)(t_emb) + t_emb = nn.gelu(t_emb) + t_tok = nn.Dense(self.d_model, kernel_init=_INIT, bias_init=_BIAS)(t_emb)[:, None, :] + + # Action token embedding + act_emb = nn.Embed(num_embeddings=vocab, features=self.d_model)(noisy_actions) + + # Assemble sequence: [obs, time, actions] + seq = jnp.concatenate([obs_tok, t_tok, act_emb], axis=1) + seq_len = 2 + self.plan_horizon + pos_emb = SinusoidalPosEmbed(self.d_model)(jnp.arange(seq_len)) + seq = seq + pos_emb[None, :, :] + + # Transformer + for _ in range(self.n_layers): + seq = TransformerBlock( + d_model=self.d_model, n_heads=self.n_heads, d_ff=self.d_ff, + dropout_rate=self.dropout_rate, deterministic=deterministic, + )(seq) + seq = nn.LayerNorm()(seq) + + # Output logits over real actions (skip 2 prefix tokens) + return nn.Dense(self.num_actions, kernel_init=_INIT_SMALL, bias_init=_BIAS)( + seq[:, 2:, :] + ) diff --git a/src/planners/__init__.py b/src/planners/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/planners/collect.py b/src/planners/collect.py new file mode 100644 index 0000000000000000000000000000000000000000..7996d982285bca2214ef0d9a3096ff245351aa7d --- /dev/null +++ b/src/planners/collect.py @@ -0,0 +1,75 @@ +"""Collect offline trajectories from a trained PPO agent.""" + +from __future__ import annotations + +import pathlib +from typing import Any + +import jax +import jax.numpy as jnp +import numpy as np + +from .ppo import load_ppo_agent +from .env import make_env + + +def collect_offline_data(config: dict[str, Any]) -> None: + """Roll out a PPO agent and save (obs, actions, rewards, dones) to disk. + + Args: + config: Upper-cased hyperparameter dict. Must contain + ``PPO_CHECKPOINT_PATH``, ``OFFLINE_DATA_PATH``, + ``COLLECT_NUM_STEPS``, and ``COLLECT_NUM_ENVS``. + """ + assert config.get("PPO_CHECKPOINT_PATH"), ( + "--ppo_checkpoint_path is required for --mode collect." + ) + + num_envs: int = config["COLLECT_NUM_ENVS"] + num_iters: int = config["COLLECT_NUM_STEPS"] // num_envs + + env_w, env_params = make_env(config, num_envs) + num_actions = env_w.action_space(env_params).n + obs_dim = env_w.observation_space(env_params).shape[0] + + ppo = load_ppo_agent( + config["PPO_CHECKPOINT_PATH"], num_actions, obs_dim, + config.get("LAYER_SIZE", 512), + model_type=config.get("PPO_MODEL_TYPE", "ppo_rnn"), + config=config, num_envs=num_envs, + ) + + rng = jax.random.PRNGKey(config["SEED"]) + rng, env_rng, collect_rng = jax.random.split(rng, 3) + obs, env_state = env_w.reset(env_rng, env_params) + done = jnp.zeros(num_envs, dtype=bool) + hidden = ppo.init_hidden(num_envs) + + def _step(carry, _): + rng, es, obs, done, hs = carry + rng, act_rng, step_rng = jax.random.split(rng, 3) + action, new_hs = ppo.act(obs, done, hs, act_rng, + temperature=config.get("COLLECT_TEMPERATURE", 1.0)) + obs_next, es, reward, done_next, _ = env_w.step(step_rng, es, action, env_params) + return (rng, es, obs_next, done_next, new_hs), (obs, action, reward, done) + + rollout_fn = jax.jit(lambda c: jax.lax.scan(_step, c, None, length=num_iters)) + _, (obs_arr, act_arr, rew_arr, done_arr) = rollout_fn( + (collect_rng, env_state, obs, done, hidden), + ) + + # [steps, envs, ...] -> [envs, steps, ...] + obs_np = np.array(obs_arr).transpose(1, 0, 2) + act_np = np.array(act_arr).transpose(1, 0) + rew_np = np.array(rew_arr).transpose(1, 0) + done_np = np.array(done_arr).transpose(1, 0) + + out_path = config["OFFLINE_DATA_PATH"] + pathlib.Path(out_path).parent.mkdir(parents=True, exist_ok=True) + np.savez(out_path, obs=obs_np, actions=act_np, rewards=rew_np, dones=done_np) + total = obs_np.shape[0] * obs_np.shape[1] + print(f"Saved {obs_np.shape[0]}×{obs_np.shape[1]} transitions ({total:,}) to '{out_path}'") + + +def run_collect(config: dict[str, Any]) -> None: + collect_offline_data(config) diff --git a/src/planners/common.py b/src/planners/common.py new file mode 100644 index 0000000000000000000000000000000000000000..bd52d45925cc654d53e39287e351f7177da6eac6 --- /dev/null +++ b/src/planners/common.py @@ -0,0 +1,458 @@ +"""Shared gradient-step factory, validation rollout, and action diagnostics. + +Both :mod:`src.planners.offline` and :mod:`src.planners.online` use identical +gradient update and validation logic. Centralising it here eliminates +duplication. +""" + +from __future__ import annotations + +from typing import Any, Callable + +import jax +import jax.numpy as jnp +import optax + +from src.diffusion.loss import compute_loss +from src.diffusion.sampling import sample_plan +from src.diffusion.schedules import ScheduleFn + + +def resolve_num_updates(config: dict[str, Any], mode: str) -> None: + """Resolve ``NUM_UPDATES`` from env-frame-denominated config keys. + + Mutates ``config`` in place. After this call the runners can read + ``NUM_UPDATES`` (and ``OFFLINE_TOTAL_TIMESTEPS`` / + ``ONLINE_TOTAL_TIMESTEPS`` depending on mode) without worrying about + whether the user specified the env-frame or update-count form. + + Resolution priority: + + =========== ============================================================ + Mode Priority (highest first) + =========== ============================================================ + ``offline`` ``OFFLINE_TOTAL_TIMESTEPS`` > ``OFFLINE_NUM_UPDATES`` + ``online`` ``ONLINE_TOTAL_TIMESTEPS`` > ``ONLINE_NUM_UPDATES`` + =========== ============================================================ + + Env-frame keys are preferred because they are invariant under + ``num_envs`` changes — the same value yields the same total environment + experience regardless of hardware sizing, which makes cross-hardware + fairness studies (e.g. UCL 4096-env vs QMUL 96-env) trivially fair + without manual scaling. + + The function is idempotent: calling it twice with the same config has + the same effect as calling it once. + + Args: + config: Upper-cased config dict. Must contain ``NUM_STEPS`` and + ``NUM_ENVS``. + mode: Either ``"offline"`` or ``"online"``. + + Raises: + ValueError: If neither the env-frame nor the update-count form is + set for the given mode, or if ``mode`` is unknown. + """ + frames_per_update = int(config["NUM_STEPS"]) * int(config["NUM_ENVS"]) + + if mode == "offline": + ts_key, nu_key = "OFFLINE_TOTAL_TIMESTEPS", "OFFLINE_NUM_UPDATES" + elif mode == "online": + ts_key, nu_key = "ONLINE_TOTAL_TIMESTEPS", "ONLINE_NUM_UPDATES" + else: + raise ValueError( + f"Unknown mode: {mode!r}; expected 'offline' or 'online'." + ) + + ts = config.get(ts_key) + nu = config.get(nu_key) + # float() first to accept YAML scientific notation parsed as string + # (PyYAML 1.1 only auto-coerces "3.0e+8", not "3e8" or "3.0e8"). + if ts is not None: + num_updates = max(1, int(float(ts)) // frames_per_update) + elif nu: + num_updates = int(float(nu)) + else: + raise ValueError( + f"{mode.capitalize()} mode requires either " + f"{ts_key.lower()!r} (env frames, preferred) or " + f"{nu_key.lower()!r} to be set." + ) + + config["NUM_UPDATES"] = num_updates + # Re-snap so downstream consumers (run names, SPS, checkpoint IDs) + # see the exact integer multiple actually trained. + config[ts_key] = num_updates * frames_per_update + + +def resolve_scaled_hyperparams(config: dict[str, Any], mode: str) -> None: + """Resolve env-frame-denominated hyperparameters into update-step form. + + Mutates ``config`` in place. PRIMARY (env-frame) keys override LEGACY + (update-step) keys when set, mirroring the + :func:`resolve_num_updates` pattern. When the PRIMARY key is ``None`` + the LEGACY value passes through unchanged, preserving full + backward compatibility with configs that predate this resolver. + + Resolution table + ================ + + +-----------------------------+------------------------+----------+ + | PRIMARY (env-frame) | LEGACY (update-step) | Mode | + +=============================+========================+==========+ + | ``LR_WARMUP_FRAMES`` | ``LR_WARMUP_STEPS`` | both | + +-----------------------------+------------------------+----------+ + | ``VAL_INTERVAL_FRAMES`` | ``VAL_INTERVAL`` | both | + +-----------------------------+------------------------+----------+ + | ``DAGGER_BETA_FINAL`` | ``DAGGER_BETA_DECAY`` | online | + +-----------------------------+------------------------+----------+ + | ``DAGGER_BUFFER_CYCLES`` | ``DAGGER_BUFFER_MAX`` | online | + +-----------------------------+------------------------+----------+ + + Why env-frame units + ------------------- + Env-frame values are invariant under ``num_envs`` changes, so the + same config trains the same effective experiment on any GPU. The + update-step legacy keys had to be hand-derived per hardware tier, + which was both error-prone and obscured the conceptual quantity + (e.g. *final beta*, not *per-update decay constant*). + + The conversion for ``DAGGER_BETA_FINAL`` requires ``NUM_UPDATES``, + so this function MUST be called after :func:`resolve_num_updates` + when resolving online mode. + + Idempotent: calling this twice is equivalent to calling it once. + + Args: + config: Upper-cased config dict. Must contain ``NUM_STEPS`` and + ``NUM_ENVS``. + mode: Either ``"offline"`` or ``"online"``. + + Raises: + ValueError: If ``DAGGER_BETA_FINAL`` is set in online mode but + ``NUM_UPDATES`` has not been resolved yet. + """ + fpu = int(config["NUM_STEPS"]) * int(config["NUM_ENVS"]) + + # float() first to accept YAML scientific notation parsed as string + # (PyYAML 1.1 only auto-coerces "3.0e+8", not "3e8" or "3.0e8"). + # ── Mode-agnostic ──────────────────────────────────────────────── + warmup_frames = config.get("LR_WARMUP_FRAMES") + if warmup_frames is not None: + config["LR_WARMUP_STEPS"] = int(float(warmup_frames)) // fpu + + val_frames = config.get("VAL_INTERVAL_FRAMES") + if val_frames is not None: + config["VAL_INTERVAL"] = max(1, int(float(val_frames)) // fpu) + + # ── Online-only ────────────────────────────────────────────────── + if mode != "online": + return + + beta_final = config.get("DAGGER_BETA_FINAL") + if beta_final is not None: + num_updates = config.get("NUM_UPDATES") + if num_updates is None: + raise ValueError( + "DAGGER_BETA_FINAL requires NUM_UPDATES to be resolved " + "first; call resolve_num_updates() before " + "resolve_scaled_hyperparams()." + ) + beta_init = float(config.get("DAGGER_BETA_INIT", 1.0)) + # final = init * decay^N => decay = (final / init) ** (1 / N) + config["DAGGER_BETA_DECAY"] = ( + float(beta_final) / beta_init + ) ** (1.0 / int(num_updates)) + + buffer_cycles = config.get("DAGGER_BUFFER_CYCLES") + if buffer_cycles is not None: + config["DAGGER_BUFFER_MAX"] = max(1, int(round(float(buffer_cycles) * fpu))) + + +def print_config_snapshot(config: dict[str, Any], mode: str) -> None: + """Print a structured banner of training-critical hyperparameters. + + Surfaces fairness-critical, schedule, and architecture parameters at + the start of every offline/online run so cross-hardware comparisons + can be sanity-checked at a glance. Must be called AFTER + :func:`resolve_num_updates` and :func:`resolve_scaled_hyperparams` + so the printed values reflect what training will actually use. + + Args: + config: Upper-cased config dict (post-resolver). + mode: Either ``"offline"`` or ``"online"``. + """ + fpu = int(config["NUM_STEPS"]) * int(config["NUM_ENVS"]) + num_updates = int(config["NUM_UPDATES"]) + minibatch = fpu // int(config["NUM_MINIBATCHES"]) + ts_key = f"{mode.upper()}_TOTAL_TIMESTEPS" + total_frames = int(config[ts_key]) + + bar = "=" * 72 + title = f"{mode.upper()} training — config snapshot" + print(f"\n{bar}\n {title}\n{bar}") + print(f" env_name : {config['ENV_NAME']}") + print(f" seed : {config['SEED']}") + + print(" -- Rollout / hardware --") + print(f" num_envs = {config['NUM_ENVS']}") + print(f" num_steps = {config['NUM_STEPS']}") + print(f" fpu (envs*steps) = {fpu}") + print(f" num_minibatches = {config['NUM_MINIBATCHES']} (minibatch={minibatch})") + print(f" update_epochs = {config['UPDATE_EPOCHS']}") + print(f" num_repeats = {config.get('NUM_REPEATS', 1)}") + + print(" -- Schedule --") + print(f" {ts_key.lower():<24} = {total_frames:,} (~{total_frames/1e6:.1f}M frames)") + print(f" {'num_updates':<24} = {num_updates:,}") + warmup = int(config.get("LR_WARMUP_STEPS", 0)) + print(f" {'lr':<24} = {float(config['LR']):.2e}") + print(f" {'lr_warmup_steps':<24} = {warmup} (~{warmup * fpu / 1e6:.2f}M frames)") + print(f" {'max_grad_norm':<24} = {config.get('MAX_GRAD_NORM', 1.0)}") + + if mode == "online": + beta_init = float(config.get("DAGGER_BETA_INIT", 1.0)) + beta_decay = float(config["DAGGER_BETA_DECAY"]) + final_beta = beta_init * beta_decay ** num_updates + buffer_max = int(config["DAGGER_BUFFER_MAX"]) + cycles = buffer_max / fpu + # Mirrors the n_train_passes default in run_online: drawn fresh per + # update, capped at samples_per_update for memory. + plan_h = int(config["PLAN_HORIZON"]) + samples_per_update = int(config["NUM_ENVS"]) * ( + int(config["NUM_STEPS"]) - plan_h + 1 + ) + n_passes = config.get("DAGGER_TRAIN_PASSES") or max( + 1, buffer_max // max(1, samples_per_update) + ) + expert_det = bool(config.get("DAGGER_EXPERT_DETERMINISTIC", True)) + total_grad_steps = ( + num_updates * int(n_passes) + * int(config["UPDATE_EPOCHS"]) * int(config["NUM_MINIBATCHES"]) + ) + passes_tag = "auto" if config.get("DAGGER_TRAIN_PASSES") is None else "override" + print(" -- DAgger --") + print(f" {'dagger_beta_init':<24} = {beta_init}") + print(f" {'dagger_beta_decay':<24} = {beta_decay:.10f}") + print(f" {'final beta':<24} = {final_beta:.4f} (init * decay^N)") + print(f" {'dagger_buffer_max':<24} = {buffer_max:,} (~{cycles:.2f} update cycles)") + print(f" {'samples_per_update':<24} = {samples_per_update:,}") + print(f" {'dagger_train_passes':<24} = {n_passes} ({passes_tag})") + print(f" {'dagger_expert_determ':<24} = {expert_det}") + print(f" {'total_grad_steps':<24} = {total_grad_steps:,}") + else: + total_grad_steps = ( + num_updates * int(config["UPDATE_EPOCHS"]) * int(config["NUM_MINIBATCHES"]) + ) + print(f" {'total_grad_steps':<24} = {total_grad_steps:,}") + + val_int = int(config.get("VAL_INTERVAL", 0)) + print(" -- Validation --") + print(f" val_interval = {val_int} updates (~{val_int * fpu / 1e6:.2f}M frames)") + print(f" val_diffusion_steps = {config.get('VAL_DIFFUSION_STEPS')}") + print(f" val_replan_every = {config.get('VAL_REPLAN_EVERY')}") + print(f" val_steps = {config.get('VAL_STEPS')}") + + print(" -- Diffusion model --") + print( + f" d_model/n_heads/n_layers/d_ff = " + f"{config['D_MODEL']}/{config['N_HEADS']}/{config['N_LAYERS']}/{config['D_FF']}" + ) + print(f" plan_horizon = {config['PLAN_HORIZON']}") + print(f" diffusion_steps = {config['DIFFUSION_STEPS']}") + print(f" remask_strategy = {config.get('REMASK_STRATEGY')} eta={config.get('ETA')}") + print( + f" sampling: temp={config.get('TEMPERATURE')} top_p={config.get('TOP_P')} " + f"loop={config.get('USE_LOOP')} t_on/t_off={config.get('T_ON')}/{config.get('T_OFF')}" + ) + print(f"{bar}\n", flush=True) + + +def _action_stats( + acts: jnp.ndarray, + num_actions: int, + valid: jnp.ndarray, +) -> dict[str, jnp.ndarray]: + """Compute action-distribution entropy and unique-action fraction over valid windows. + + Args: + acts: ``[B, H]`` int32 action sequences. + num_actions: Size of the real action vocabulary. + valid: ``[B]`` bool mask; invalid samples are excluded from counts. + + Returns: + Dict with ``action_entropy`` and ``action_unique_frac``. + """ + mask = jnp.broadcast_to(valid[:, None], acts.shape).reshape(-1) + flat = jnp.where(mask, acts.reshape(-1), num_actions + 1) + counts = jnp.bincount(flat, length=num_actions).astype(jnp.float32) + probs = counts / jnp.maximum(counts.sum(), 1.0) + entropy = -jnp.sum(probs * jnp.log(jnp.where(probs > 0, probs, 1.0))) + return { + "action_entropy": entropy, + "action_unique_frac": jnp.sum(probs > 0).astype(jnp.float32) / num_actions, + } + + +def make_grad_step( + apply_train: Callable, + num_actions: int, + schedule_fn: ScheduleFn, + schedule_deriv_fn: ScheduleFn, + sigma_t: float, + label_smoothing: float, +) -> Callable: + """Return a jittable gradient update function. + + Args: + apply_train: Model apply function with dropout enabled. + num_actions: Size of the action vocabulary. + schedule_fn: alpha(t) noise schedule. + schedule_deriv_fn: d(alpha)/dt analytic derivative. + sigma_t: ReMDM remasking strength during training. + label_smoothing: Cross-entropy label smoothing epsilon. + + Returns: + A ``step(state, acts, obs, valid, rng, advantages) -> (state, metrics)`` + function ready for use inside ``jax.lax.scan``. + """ + + def _loss_fn( + params: Any, + acts: jnp.ndarray, + obs: jnp.ndarray, + valid: jnp.ndarray, + rng: jax.Array, + advantages: jnp.ndarray, + ) -> tuple[jnp.ndarray, dict]: + return compute_loss( + apply_train, params, rng, acts, obs, valid, + num_actions, schedule_fn, schedule_deriv_fn, + sigma_t=sigma_t, label_smoothing=label_smoothing, + advantages=advantages, + ) + + def step( + state: Any, + acts: jnp.ndarray, + obs: jnp.ndarray, + valid: jnp.ndarray, + rng: jax.Array, + advantages: jnp.ndarray, + ) -> tuple[Any, dict]: + """Single gradient update step. + + Args: + state: Current ``TrainState``. + acts: ``[B, H]`` int32 action sequences. + obs: ``[B, obs_dim]`` float32 observations. + valid: ``[B]`` bool validity mask (episode-boundary filter). + rng: PRNG key for dropout and noise sampling. + advantages: ``[B]`` float per-sample weights applied before loss reduction. + + Returns: + Updated ``TrainState`` and a metrics dict. + """ + (_, info), grads = jax.value_and_grad(_loss_fn, has_aux=True)( + state.params, acts, obs, valid, rng, advantages, + ) + state = state.apply_gradients(grads=grads) + info["grad_norm"] = optax.tree.norm(grads) + info.update(_action_stats(acts, num_actions, valid)) + return state, info + + return step + + +def make_validate( + env: Any, + env_params: Any, + apply_eval: Callable, + num_actions: int, + plan_horizon: int, + schedule_fn: ScheduleFn, + config: dict[str, Any], + val_replan_every: int, + n_val_cycles: int, +) -> Callable: + """Return a ``validate(state, rng) -> dict`` closure for periodic eval. + + The closure runs a held-out rollout using the diffusion model's current + parameters and returns metrics under the ``val/`` namespace. + + Args: + env: Batched Gymnax environment. + env_params: Gymnax environment params. + apply_eval: Model apply function (eval mode, no dropout). + num_actions: Size of the action vocabulary. + plan_horizon: Action plan length H. + schedule_fn: alpha(t) noise schedule. + config: Training config dict (read-only). + val_replan_every: Env steps executed per diffusion plan during validation. + n_val_cycles: Number of plan-execute cycles per validation rollout. + + Returns: + A ``validate(state, rng) -> {str: jnp.ndarray}`` closure. + """ + + def validate(state: Any, rng: jax.Array) -> dict[str, jnp.ndarray]: + """Run a validation rollout and return ``val/`` metrics. + + Args: + state: Current ``TrainState`` (only ``.params`` is used). + rng: PRNG key. + + Returns: + Dict with ``val/`` prefixed metric keys. + """ + rng, val_rng = jax.random.split(rng) + val_obs, val_env_state = env.reset(val_rng, env_params) + + def _val_cycle(carry, _): + vs, vo, rng = carry + rng, p_rng = jax.random.split(rng) + plan = sample_plan( + apply_eval, + state.params, + p_rng, + vo, + num_actions, + plan_horizon, + num_steps=config.get("VAL_DIFFUSION_STEPS", 50), + schedule_fn=schedule_fn, + remask_strategy=config.get("REMASK_STRATEGY", "rescale"), + eta=config.get("ETA", 0.5), + use_loop=config.get("USE_LOOP", True), + t_on=config.get("T_ON", 0.7), + t_off=config.get("T_OFF", 0.3), + temperature=config.get("TEMPERATURE", 0.5), + top_p=config.get("TOP_P", 0.95), + ) # [num_envs, plan_horizon] + + def _exec_step(inner_carry, step_i): + vs_i, vo_i, r = inner_carry + r, s_rng = jax.random.split(r) + vo_next, vs_next, _, _, info = env.step( + s_rng, vs_i, plan[:, step_i], env_params, + ) + return (vs_next, vo_next, r), info + + (vs, vo, rng), step_infos = jax.lax.scan( + _exec_step, (vs, vo, rng), jnp.arange(val_replan_every), + ) + return (vs, vo, rng), step_infos + + _, cycle_infos = jax.lax.scan( + _val_cycle, (val_env_state, val_obs, rng), None, n_val_cycles, + ) + infos = jax.tree.map( + lambda x: x.reshape(-1, *x.shape[2:]), cycle_infos, + ) + returned = infos["returned_episode"] + metrics = jax.tree.map( + lambda x: (x * returned).sum() / (returned.sum() + 1e-8), + infos, + ) + return {f"val/{k}": v for k, v in metrics.items()} + + return validate diff --git a/src/planners/env.py b/src/planners/env.py new file mode 100644 index 0000000000000000000000000000000000000000..ac6f69b50038c44bd4b61268a6c43f845da31470 --- /dev/null +++ b/src/planners/env.py @@ -0,0 +1,48 @@ +"""Craftax environment construction and trajectory data structures.""" + +from __future__ import annotations + +import jax.numpy as jnp +from craftax.craftax_env import make_craftax_env_from_name +from typing import NamedTuple + +from Craftax_Baselines.wrappers import ( + AutoResetEnvWrapper, + BatchEnvWrapper, + LogWrapper, + OptimisticResetVecEnvWrapper, +) + + +class Transition(NamedTuple): + done: jnp.ndarray + action: jnp.ndarray + reward: jnp.ndarray + obs: jnp.ndarray + info: dict + + +def make_env(config: dict, num_envs: int): + """Build a wrapped Craftax environment. + + Args: + config: Upper-cased config dict with ``ENV_NAME``, + ``USE_OPTIMISTIC_RESETS``, ``OPTIMISTIC_RESET_RATIO``. + num_envs: Number of parallel environments. + + Returns: + Tuple of ``(env, env_params)``. + """ + env = make_craftax_env_from_name(config["ENV_NAME"], not config["USE_OPTIMISTIC_RESETS"]) + env_params = env.default_params + env = LogWrapper(env) + if config["USE_OPTIMISTIC_RESETS"]: + env = OptimisticResetVecEnvWrapper( + env, + num_envs=num_envs, + reset_ratio=min(config["OPTIMISTIC_RESET_RATIO"], num_envs), + ) + else: + env = AutoResetEnvWrapper(env) + env = BatchEnvWrapper(env, num_envs=num_envs) + return env, env_params diff --git a/src/planners/inference.py b/src/planners/inference.py new file mode 100644 index 0000000000000000000000000000000000000000..06e59e43a4771f38819e1a07b7ba2c30dbc35739 --- /dev/null +++ b/src/planners/inference.py @@ -0,0 +1,142 @@ +"""Evaluation: run a trained diffusion planner with MPC + historical inpainting.""" + +from __future__ import annotations + +import time +from typing import Any + +import jax +import jax.numpy as jnp +import numpy as np +import wandb +from craftax.craftax_env import make_craftax_env_from_name +from craftax.craftax.constants import Achievement as FullCraftaxAchievements +from craftax.craftax_classic.constants import Achievement as ClassicAchievements + +from src.diffusion.sampling import sample_plan_inpainting +from .model import build_model, load_checkpoint, make_apply_fns + + +# --------------------------------------------------------------------------- +# Entry point +# --------------------------------------------------------------------------- + +def run_inference(config: dict[str, Any]) -> None: + env_name = config["ENV_NAME"] + env = make_craftax_env_from_name(env_name, auto_reset=True) + env_params = env.default_params + num_actions = env.action_space(env_params).n + obs_dim = env.observation_space(env_params).shape[0] + config["NUM_ACTIONS"] = num_actions + + num_envs = config.get("EVAL_NUM_ENVS", 32) + plan_horizon = config["PLAN_HORIZON"] + diffusion_steps = config.get("DIFFUSION_STEPS_EVAL", 10) + temperature = config.get("TEMPERATURE", 0.5) + top_p = config.get("TOP_P", 0.95) + eval_steps = int(float(config.get("EVAL_STEPS", 10000))) + + model = build_model(config, num_actions) + apply_eval, _ = make_apply_fns(model) + + rng = jax.random.PRNGKey(config["SEED"]) + rng, ckpt_rng = jax.random.split(rng) + model_params = load_checkpoint(model, ckpt_rng, obs_dim, plan_horizon, config["CHECKPOINT_PATH"]) + env_indices = jnp.arange(num_envs) + + @jax.jit + def mpc_step(carry, _step_idx): + obs, state, rng, history, hist_len = carry + rng, plan_rng, env_rng = jax.random.split(rng, 3) + + # Reset history when plan is exhausted + seq_full = hist_len >= plan_horizon + hist_len = jnp.where(seq_full, 0, hist_len) + history = jnp.where(seq_full[:, None], num_actions, history) + + plan = sample_plan_inpainting( + apply_eval, model_params, plan_rng, obs, + history, hist_len, num_actions, plan_horizon, + diffusion_steps, temperature, top_p, + ) + + action = jnp.take_along_axis(plan, hist_len[:, None], axis=-1).squeeze(-1) + history = history.at[env_indices, hist_len].set(action) + hist_len = hist_len + 1 + + obs_next, state_next, reward, done, info = jax.vmap(env.step, in_axes=(0, 0, 0, None))( + jax.random.split(env_rng, num_envs), state, action, env_params, + ) + + hist_len = jnp.where(done, 0, hist_len) + history = jnp.where(done[:, None], num_actions, history) + return (obs_next, state_next, rng, history, hist_len), (reward, done, state_next.achievements) + + print(f"\nRunning {num_envs} agents in {env_name} for {eval_steps} steps...") + + rng, env_rng = jax.random.split(rng) + obs, state = jax.vmap(env.reset, in_axes=(0, None))( + jax.random.split(env_rng, num_envs), env_params, + ) + history = jnp.full((num_envs, plan_horizon), num_actions, dtype=jnp.int32) + hist_len = jnp.zeros(num_envs, dtype=jnp.int32) + + t0 = time.time() + _, (rewards, dones, achievements) = jax.lax.scan( + mpc_step, (obs, state, rng, history, hist_len), jnp.arange(eval_steps), + ) + elapsed = time.time() - t0 + + # First-episode extraction (strict single-life evaluation) + rewards_np = np.array(rewards) + dones_np = np.array(dones) + ach_np = np.array(achievements) + + ep_rewards = np.zeros(num_envs) + ep_ach = np.zeros((num_envs, ach_np.shape[2])) + ep_lengths = np.zeros(num_envs, dtype=int) + + for i in range(num_envs): + death = np.where(dones_np[:, i])[0] + end = death[0] if len(death) > 0 else eval_steps - 1 + ep_rewards[i] = rewards_np[:end + 1, i].sum() + ep_ach[i] = ach_np[:end + 1, i].max(axis=0) + ep_lengths[i] = end + 1 + + pct = ep_ach.mean(axis=0) * 100.0 + + # Report + print(f"\n{'=' * 50}") + print(f"EVALUATION COMPLETE ({elapsed:.1f}s)") + print(f"{'=' * 50}") + print(f"Average Score: {ep_rewards.mean():.1f} | Best: {ep_rewards.max():.1f}") + + ach_cls = ClassicAchievements if "Classic" in env_name else FullCraftaxAchievements + ach_names = [(a.name.replace("_", " ").title(), a.name.lower()) for a in ach_cls] + valid = [i for i, p in enumerate(pct) if p > 0] + top_idx = max(valid) if valid else 5 + + for i in range(top_idx + 1): + name, _ = ach_names[i] + count = int(pct[i] / 100.0 * num_envs) + icon = "+" if count > 0 else "-" + print(f" [{icon}] {name}: {count}/{num_envs}") + print(f"{'=' * 50}") + + if config.get("USE_WANDB", True): + wandb.init( + project=config.get("WANDB_PROJECT", "remdm-craftax"), + name=f"Eval-T{temperature}-P{top_p}", + config=config, job_type="evaluation", + ) + summary = {"eval/average_score": float(ep_rewards.mean())} + for i in range(top_idx + 1): + summary[f"eval/achievements/{ach_names[i][1]}"] = pct[i] + wandb.log(summary) + + table = wandb.Table(columns=["Agent", "Score", "Achievements", "Lifespan"]) + unlocked = ep_ach.sum(axis=-1) + for i in range(num_envs): + table.add_data(f"Agent {i + 1}", float(ep_rewards[i]), int(unlocked[i]), int(ep_lengths[i])) + wandb.log({"Individual Results": table}) + wandb.finish() diff --git a/src/planners/logging.py b/src/planners/logging.py new file mode 100644 index 0000000000000000000000000000000000000000..f6c824ce42baf6aa0aa916fdc6d154dcc6e8ea31 --- /dev/null +++ b/src/planners/logging.py @@ -0,0 +1,209 @@ +"""Centralised W&B logging utilities for ReMDM training loops. + +Design principles +----------------- +* ``build_log_dict`` is a pure function — no side effects, fully testable. +* ``make_wandb_callback`` is a factory that returns a closure suitable for + ``jax.debug.callback``. All timing state is local to the closure; there is + no module-level global state. +* Both ``train.py`` and ``online.py`` import the same symbols, keeping all + metric naming and aggregation logic in one place. + +Metric namespacing +------------------ +``diffusion/`` — ELBO loss, accuracy, and noise-level diagnostics from ``compute_loss``. +``train/`` — data quality, action distribution, throughput. +``env/`` — episode returns and per-achievement unlock rates (training envs). +``val/`` — same as ``env/`` but from the held-out validation rollout + (only emitted when ``step_idx % val_interval == 0``). +``dagger/`` — DAgger-specific metrics (online training only). +""" + +from __future__ import annotations + +import time +from typing import Any + +import wandb + + +def init_wandb( + config: dict[str, Any], + name: str, + *, + resume_run_id: str | None = None, +) -> None: + """Initialise a W&B run, optionally resuming an existing one. + + Args: + config: Training config dict (used for ``project``, ``entity``, + and logged as run config). + name: Human-readable run name. + resume_run_id: If provided, attaches to an existing W&B run via + ``wandb.init(id=..., resume="must")``. The run must + already exist. + """ + kwargs: dict[str, Any] = { + "project": config.get("WANDB_PROJECT", "remdm-craftax"), + "entity": config.get("WANDB_ENTITY"), + "config": config, + } + if resume_run_id is not None: + kwargs["id"] = resume_run_id + kwargs["resume"] = "must" + else: + kwargs["name"] = name + + wandb.init(**kwargs) + + +# Keys emitted by ``src.diffusion.loss.compute_loss`` info dict. +_DIFFUSION_KEYS: tuple[str, ...] = ( + "loss", + "unweighted_loss", + "accuracy", + "acc_t_low", + "acc_t_mid", + "acc_t_high", + "frac_masked", + "mean_t", + "grad_norm", +) + +# Keys added locally by training loops. +_TRAIN_KEYS: tuple[str, ...] = ( + "action_entropy", + "action_unique_frac", + "valid_frac", + "mean_return_weight", +) + +# Keys specific to online DAgger training. +_DAGGER_KEYS: tuple[str, ...] = ( + "beta", + "reward_mean", + "buffer_fill", + "valid_frac", + "best_val_return", +) + + +def build_log_dict( + metric: dict[str, Any], + step_idx: int, + val_interval: int, + *, + is_online: bool = False, + sps: float | None = None, +) -> dict[str, float]: + """Build a flat W&B-ready log dict from a merged training metric dict. + + Args: + metric: Merged metric dict from the current update step. + step_idx: Integer update step index. + val_interval: How often (in steps) validation runs occur. + is_online: If ``True``, emit DAgger-specific keys under ``dagger/``. + sps: Pre-computed steps-per-second; omitted when ``None``. + + Returns: + Flat ``{str: float}`` dict suitable for ``wandb.log``. + """ + log: dict[str, float] = {} + is_val_step = (step_idx % val_interval == 0) + + for k in _DIFFUSION_KEYS: + if k in metric: + log[f"diffusion/{k}"] = float(metric[k]) + + for k in _TRAIN_KEYS: + if k in metric: + log[f"train/{k}"] = float(metric[k]) + + if is_online: + for k in _DAGGER_KEYS: + if k in metric: + log[f"dagger/{k}"] = float(metric[k]) + + if "returned_episode_returns" in metric: + log["env/episode_return"] = float(metric["returned_episode_returns"]) + if "returned_episode_lengths" in metric: + log["env/episode_length"] = float(metric["returned_episode_lengths"]) + + # Per-achievement breakdown + aggregate score (Craftax reports as %, divide by 100). + achieve_total = 0.0 + for k, v in metric.items(): + if "achievement" in k.lower() and not k.startswith("val/"): + log[f"env/achieve/{k}"] = float(v) + achieve_total += float(v) / 100.0 + log["env/achievements"] = achieve_total + + # Validation metrics — only emitted on val steps to avoid polluting charts with zeros. + if is_val_step: + val_achieve_total = 0.0 + for k, v in metric.items(): + if not k.startswith("val/"): + continue + inner = k[4:] # strip leading "val/" + if "achievement" in inner.lower(): + log[f"val/achieve/{inner}"] = float(v) + val_achieve_total += float(v) / 100.0 + elif inner == "returned_episode_returns": + log["val/episode_return"] = float(v) + elif inner == "returned_episode_lengths": + log["val/episode_length"] = float(v) + log["val/achievements"] = val_achieve_total + + if sps is not None: + log["train/sps"] = sps + + return log + + +def make_wandb_callback( + config: dict[str, Any], + *, + steps_per_update: int | None, + val_interval: int, + is_online: bool = False, +) -> Any: + """Return a host-side logging closure for ``jax.debug.callback``. + + The closure tracks wall-clock time between successive calls to compute + steps-per-second. All state is local to the closure; there is no + module-level mutable state. + + SPS is not reported on ``step_idx == 0`` (JIT compilation overhead) or + when ``steps_per_update`` is ``None`` (e.g. data-replay mode where no + environment frames are consumed). + + Args: + config: Training config dict (read-only; only consulted for + ``USE_WANDB`` — callers are expected to guard). + steps_per_update: Environment frames consumed per update step. Pass + ``None`` to disable ``train/sps`` logging entirely + (e.g. when training from pre-collected data files). + val_interval: Frequency (in steps) at which validation runs occur. + is_online: If ``True``, emit DAgger keys under ``dagger/``. + + Returns: + A callable ``log_fn(metric, step_idx) -> None`` for + ``jax.debug.callback``. + """ + _t: list[float] = [time.time()] + + def log_fn(metric: dict[str, Any], step_idx: int) -> None: + now = time.time() + dt = now - _t[0] + _t[0] = now + + sps: float | None = ( + steps_per_update / dt + if steps_per_update is not None and int(step_idx) > 0 and dt > 1e-6 + else None + ) + log = build_log_dict( + metric, int(step_idx), val_interval, is_online=is_online, sps=sps, + ) + wandb.log(log, step=int(step_idx)) + + return log_fn diff --git a/src/planners/model.py b/src/planners/model.py new file mode 100644 index 0000000000000000000000000000000000000000..72afa40ab71779d449a827792303cb4eab44e1a6 --- /dev/null +++ b/src/planners/model.py @@ -0,0 +1,325 @@ +"""Diffusion model lifecycle: construction, parameter init, checkpoint I/O, and apply closures.""" + +from __future__ import annotations + +import json +import logging +from pathlib import Path +from typing import Any, Callable, Union + +import jax +import jax.numpy as jnp +import numpy as np +import optax +import orbax.checkpoint as ocp +from flax.training.train_state import TrainState + +from src.models.denoiser import DenoisingTransformer + +logger = logging.getLogger(__name__) + +_METADATA_FILENAME = "resume_metadata.json" + + +def build_model(config: dict, num_actions: int) -> DenoisingTransformer: + """Construct a :class:`DenoisingTransformer` from a config dict. + + Args: + config: Upper-cased config dict with architecture hyperparameters. + num_actions: Size of the discrete action vocabulary. + + Returns: + An uninitialised :class:`DenoisingTransformer` instance. + """ + return DenoisingTransformer( + num_actions=num_actions, + plan_horizon=config["PLAN_HORIZON"], + d_model=config.get("D_MODEL", 256), + n_heads=config.get("N_HEADS", 4), + n_layers=config.get("N_LAYERS", 4), + d_ff=config.get("D_FF", 512), + obs_encoder_layers=config.get("OBS_ENCODER_LAYERS", 2), + obs_encoder_width=config.get("OBS_ENCODER_WIDTH", 512), + dropout_rate=config.get("DROPOUT_RATE", 0.1), + ) + + +def init_params( + model: DenoisingTransformer, + rng: jax.Array, + obs_dim: int, + plan_horizon: int, +) -> Any: + """Initialize model parameters with dummy inputs. + + Args: + model: Flax module to initialise. + rng: PRNG key. + obs_dim: Observation dimensionality. + plan_horizon: Number of action steps in a plan. + + Returns: + Initialised parameter pytree. + """ + return model.init( + rng, + jnp.zeros((1, obs_dim)), + jnp.zeros((1, plan_horizon), dtype=jnp.int32), + jnp.zeros((1,)), + ) + + +def resolve_checkpoint_path( + path: str, + download_dir: str | None = None, +) -> str: + """Resolve a checkpoint path, downloading from W&B if it is an artifact reference. + + Paths prefixed with ``wandb:`` are treated as W&B artifact references + (e.g. ``wandb:entity/project/name:version``) and downloaded locally + before returning the filesystem path. + + Args: + path: Local filesystem path or ``wandb:``-prefixed artifact + reference. + download_dir: Root directory for downloaded artifacts. When ``None``, + falls back to the wandb default (``./artifacts/``). + + Returns: + Local filesystem path to the checkpoint directory. + """ + if not path.startswith("wandb:"): + return str(Path(path).resolve()) + + import wandb + + artifact_ref = path.removeprefix("wandb:") + api = wandb.Api() + artifact = api.artifact(artifact_ref) + local_path = ( + artifact.download(root=download_dir) if download_dir else artifact.download() + ) + print(f"Downloaded W&B artifact '{artifact_ref}' -> '{local_path}'") + return local_path + + +def load_checkpoint( + model: DenoisingTransformer, + rng: jax.Array, + obs_dim: int, + plan_horizon: int, + path: str, +) -> Any: + """Load diffusion model parameters from an Orbax checkpoint. + + Args: + model: Flax module (used to build the abstract state structure). + rng: PRNG key for dummy initialisation. + obs_dim: Observation dimensionality. + plan_horizon: Number of action steps in a plan. + path: Path to the Orbax checkpoint directory. + + Returns: + Restored parameter pytree. + + Raises: + FileNotFoundError: If the checkpoint directory contains no saved steps. + """ + path = str(Path(path).resolve()) + params = init_params(model, rng, obs_dim, plan_horizon) + abstract_state = create_train_state(model=model, params=params, lr=1e-4, max_grad_norm=1.0) + + with ocp.CheckpointManager(path) as mgr: + step = mgr.latest_step() + if step is None: + raise FileNotFoundError(f"No checkpoint at {path}") + restored_state = mgr.restore( + step, + args=ocp.args.StandardRestore(item=abstract_state), + ) + + print(f"Loaded diffusion checkpoint from '{path}' (step {step})") + return restored_state.params + + +def create_train_state( + model: DenoisingTransformer, + params: Any, + lr: Union[float, Callable[[int], float]], + max_grad_norm: float, +) -> TrainState: + """Create a :class:`TrainState` with gradient clipping and Adam. + + Args: + model: Flax module (used only to bind ``apply_fn``). + params: Initialised parameter pytree. + lr: Constant learning rate or an optax schedule + (any callable ``step -> lr``). + max_grad_norm: Global gradient clipping threshold. + + Returns: + A Flax ``TrainState`` ready for ``apply_gradients``. + """ + tx = optax.chain(optax.clip_by_global_norm(max_grad_norm), optax.adam(lr, eps=1e-5)) + return TrainState.create(apply_fn=model.apply, params=params, tx=tx) + + +def make_apply_fns( + model: DenoisingTransformer, +) -> tuple[Callable, Callable]: + """Return ``(apply_eval, apply_train)`` closures matching ``ModelApplyFn``. + + Args: + model: Flax module. + + Returns: + Tuple of ``(apply_eval, apply_train)`` where ``apply_train`` enables + dropout via ``rngs={"dropout": rng}``. + """ + + def apply_eval(params: Any, obs: jnp.ndarray, z_t: jnp.ndarray, t: jnp.ndarray, _rng=None): + return model.apply(params, obs, z_t, t) + + def apply_train(params: Any, obs: jnp.ndarray, z_t: jnp.ndarray, t: jnp.ndarray, rng=None): + return model.apply( + params, obs, z_t, t, + deterministic=False, + rngs={"dropout": rng} if rng is not None else {}, + ) + + return apply_eval, apply_train + + +# --------------------------------------------------------------------------- +# Checkpoint metadata sidecar +# --------------------------------------------------------------------------- + + +class _NumpyEncoder(json.JSONEncoder): + """JSON encoder that handles numpy scalar types.""" + + def default(self, o: Any) -> Any: + """Serialize numpy scalars to native Python types. + + Args: + o: Object to serialize. + + Returns: + JSON-serializable object. + """ + if isinstance(o, np.integer): + return int(o) + if isinstance(o, np.floating): + return float(o) + if isinstance(o, np.ndarray): + return o.tolist() + return super().default(o) + + +def save_checkpoint_metadata( + checkpoint_dir: str, + mode: str, + update_step: int, + total_gradient_steps: int, + wandb_run_id: str | None, + config: dict[str, Any], +) -> None: + """Write a JSON metadata sidecar alongside an Orbax checkpoint. + + Args: + checkpoint_dir: Root directory of the Orbax checkpoint manager. + mode: Training mode (``"offline"`` or ``"online"``). + update_step: Final update step index. + total_gradient_steps: Total gradient steps completed. + wandb_run_id: Current W&B run ID, or ``None``. + config: Full training config snapshot. + """ + metadata = { + "mode": mode, + "update_step": int(update_step), + "total_gradient_steps_completed": int(total_gradient_steps), + "wandb_run_id": wandb_run_id, + "config_snapshot": config, + } + path = Path(checkpoint_dir) / _METADATA_FILENAME + with open(path, "w") as f: + json.dump(metadata, f, indent=2, cls=_NumpyEncoder) + print(f"Saved checkpoint metadata to {path}") + + +def load_checkpoint_metadata( + checkpoint_dir: str, +) -> dict[str, Any] | None: + """Read the JSON metadata sidecar from a checkpoint directory. + + Args: + checkpoint_dir: Root directory of the Orbax checkpoint manager. + + Returns: + Parsed metadata dict, or ``None`` if the sidecar does not exist + (backward-compatible with checkpoints created before this feature). + """ + path = Path(checkpoint_dir) / _METADATA_FILENAME + if not path.exists(): + return None + with open(path) as f: + return json.load(f) + + +def load_checkpoint_for_resume( + model: DenoisingTransformer, + rng: jax.Array, + obs_dim: int, + plan_horizon: int, + path: str, + lr_schedule: Union[float, Callable[[int], float]], + max_grad_norm: float, +) -> TrainState: + """Load a full ``TrainState`` (params + optimizer state) for resume. + + Unlike :func:`load_checkpoint` which returns only params, this function + restores the complete ``TrainState`` including Adam moments so that + training can continue seamlessly. + + The ``lr_schedule`` and ``max_grad_norm`` must match the optimizer chain + structure used when the checkpoint was saved (same chain composition, + possibly different schedule values). + + Args: + model: Flax module (used to build the abstract state). + rng: PRNG key for dummy initialisation. + obs_dim: Observation dimensionality. + plan_horizon: Number of action steps in a plan. + path: Path to the Orbax checkpoint directory. + lr_schedule: Learning rate or schedule matching the current run's + optimizer (must produce the same ``opt_state`` structure). + max_grad_norm: Global gradient clipping threshold. + + Returns: + Restored ``TrainState`` with params, opt_state, and step from the + checkpoint. The caller should call ``.replace(step=...)`` to set the + correct LR offset for the resumed run. + + Raises: + FileNotFoundError: If the checkpoint directory contains no saved steps. + """ + path = str(Path(path).resolve()) + params = init_params(model, rng, obs_dim, plan_horizon) + abstract_state = create_train_state(model, params, lr_schedule, max_grad_norm) + + with ocp.CheckpointManager(path) as mgr: + step = mgr.latest_step() + if step is None: + raise FileNotFoundError( + f"No checkpoint found at {path}" + ) + restored_state = mgr.restore( + step, + args=ocp.args.StandardRestore(item=abstract_state), + ) + + print( + f"Loaded full TrainState for resume from '{path}' " + f"(step {step}, opt_state step {restored_state.step})" + ) + return restored_state diff --git a/src/planners/offline.py b/src/planners/offline.py new file mode 100644 index 0000000000000000000000000000000000000000..3963bab4ccfe68d5ff15bfa7a5b46d0e5f299306 --- /dev/null +++ b/src/planners/offline.py @@ -0,0 +1,357 @@ +"""Training loop: environment rollout -> diffusion window extraction -> gradient updates.""" + +from __future__ import annotations + +import os +import time +from typing import Any + +import jax +import jax.numpy as jnp +import optax +import orbax.checkpoint as ocp +import wandb + +from src.diffusion.schedules import SCHEDULE_MAP +from .common import ( + make_grad_step, + make_validate, + print_config_snapshot, + resolve_num_updates, + resolve_scaled_hyperparams, +) +from .env import Transition, make_env +from .model import ( + build_model, + init_params, + create_train_state, + load_checkpoint_for_resume, + make_apply_fns, + save_checkpoint_metadata, +) +from .ppo import PPOAgent, build_ppo_network, load_ppo_params +from .logging import init_wandb, make_wandb_callback + + +# --------------------------------------------------------------------------- +# make_train +# --------------------------------------------------------------------------- + +def make_train(config: dict[str, Any]): + """Build the offline diffusion training closure. + + All environment construction, model instantiation, and static pre-computation + happen here (outside the returned ``train`` closure) so they are not repeated + across ``jax.vmap`` replicas or JIT retraces. + + Args: + config: Upper-cased hyperparameter dict (see ``configs/defaults.yaml``). + + Returns: + A ``train(rng) -> dict`` closure that is safe to JIT and vmap. + """ + num_steps = config["NUM_STEPS"] + num_envs = config["NUM_ENVS"] + plan_horizon = config["PLAN_HORIZON"] + val_interval = config.get("VAL_INTERVAL", 50) + val_replan_every = config.get("VAL_REPLAN_EVERY", 4) + val_steps = config.get("VAL_STEPS", 128) + n_val_cycles = val_steps // val_replan_every + valid_per_rollout = num_steps - plan_horizon + 1 + num_samples = num_envs * valid_per_rollout + return_weight_cap = config.get("RETURN_WEIGHT_CAP", 5.0) + + # NUM_UPDATES and OFFLINE_TOTAL_TIMESTEPS are resolved in + # run_offline_diffusion before wandb.init so the run name can use + # OFFLINE_TOTAL_TIMESTEPS. We assume both are present here. + assert num_samples % config["NUM_MINIBATCHES"] == 0, ( + f"{num_samples} samples not divisible by {config['NUM_MINIBATCHES']} minibatches" + ) + config["MINIBATCH_SIZE"] = num_samples // config["NUM_MINIBATCHES"] + + # Environment + env, env_params = make_env(config, num_envs) + + num_actions = env.action_space(env_params).n + obs_shape = env.observation_space(env_params).shape + obs_dim = obs_shape[0] + + # PPO collector + model_type = config["PPO_MODEL_TYPE"] + ppo_net = build_ppo_network(model_type, num_actions, config["LAYER_SIZE"], config) + ppo_params = load_ppo_params( + config["PPO_CHECKPOINT_PATH"], ppo_net, model_type, num_envs, obs_shape, config["LAYER_SIZE"], + ) + ppo = PPOAgent(ppo_net, ppo_params, model_type, config["LAYER_SIZE"]) + + # Noise schedule + schedule_fn, schedule_deriv_fn = SCHEDULE_MAP[config["DIFFUSION_SCHEDULE"]] + + # Diffusion model — pure Flax dataclass, no randomness, safe to build once. + net = build_model(config, num_actions) + apply_eval, apply_train = make_apply_fns(net) + grad_step = make_grad_step( + apply_train, num_actions, schedule_fn, schedule_deriv_fn, + config.get("TRAIN_SIGMA", 0.0), config.get("LABEL_SMOOTHING", 0.0), + ) + + # Cosine LR decay over total gradient steps with optional linear warm-up. + total_grad_steps = config["NUM_UPDATES"] * config["UPDATE_EPOCHS"] * config["NUM_MINIBATCHES"] + warmup_steps = config.get("LR_WARMUP_STEPS", 0) + lr_schedule = ( + optax.warmup_cosine_decay_schedule( + init_value=0.0, + peak_value=config["LR"], + warmup_steps=warmup_steps, + decay_steps=total_grad_steps, + end_value=config["LR"] * 0.1, + ) + if warmup_steps > 0 + else optax.cosine_decay_schedule( + init_value=config["LR"], + decay_steps=total_grad_steps, + alpha=0.1, + ) + ) + + # Resume checkpoint (loaded outside JIT, captured by train closure) ------ + resume_step = config.get("RESUME_STEP") or 0 + resume_state = None + if config.get("RESUME_CHECKPOINT_PATH"): + resume_state = load_checkpoint_for_resume( + net, + jax.random.PRNGKey(0), + obs_dim, + plan_horizon, + config["RESUME_CHECKPOINT_PATH"], + lr_schedule, + config["MAX_GRAD_NORM"], + ) + # Set the optimizer step counter so the LR schedule picks up at the + # correct position. The schedule is indexed by gradient step, which + # equals update_step * update_epochs * num_minibatches. + target_opt_step = resume_step * config["UPDATE_EPOCHS"] * config["NUM_MINIBATCHES"] + resume_state = resume_state.replace(step=target_opt_step) + + scan_length = config["NUM_UPDATES"] - resume_step + + # W&B callback — one closure shared across vmap replicas (timing is per-call). + _wandb_log = ( + make_wandb_callback( + config, + steps_per_update=num_steps * num_envs, + val_interval=val_interval, + ) + if config["USE_WANDB"] else None + ) + + def train(rng: jax.Array) -> dict[str, Any]: + """JIT/vmap-compatible training loop. + + Args: + rng: JAX PRNG key (one per vmap replica). + + Returns: + Dict with ``runner_state`` (final scan carry) and ``metrics`` (all update metrics). + """ + rng, init_rng, env_rng = jax.random.split(rng, 3) + if resume_state is not None: + state = resume_state + else: + params = init_params(net, init_rng, obs_dim, plan_horizon) + state = create_train_state(net, params, lr_schedule, config["MAX_GRAD_NORM"]) + + obsv, env_state = env.reset(env_rng, env_params) + init_hstate = ppo.init_hidden(num_envs) + + # Shared validation closure (see common.py) + _validate = make_validate( + env, env_params, apply_eval, num_actions, + plan_horizon, schedule_fn, config, + val_replan_every, n_val_cycles, + ) + + # ------------------------------------------------------------------ + # Update step + # ------------------------------------------------------------------ + def _update_step(runner, _): + state, env_state, last_obs, last_done, hstate, rng, step_idx = runner + + # --- Trajectory collection (state excluded from carry: not modified here) --- + def _env_step(carry, _): + es, obs, done, hs, rng = carry + rng, act_rng, step_rng = jax.random.split(rng, 3) + action, new_hs = ppo.act( + obs, done, hs, act_rng, temperature=config.get("COLLECT_TEMPERATURE", 1.0), + ) + new_obs, es, reward, new_done, info = env.step(step_rng, es, action, env_params) + t = Transition(done=done, action=action, reward=reward, obs=obs, info=info) + return (es, new_obs, new_done, new_hs, rng), t + + (env_state, last_obs, last_done, hstate, rng), traj = jax.lax.scan( + _env_step, (env_state, last_obs, last_done, hstate, rng), None, num_steps, + ) + + # --- Diffusion window extraction --- + def _window(t_idx): + obs_t = traj.obs[t_idx] + acts = jax.lax.dynamic_slice(traj.action, (t_idx, 0), (plan_horizon, num_envs)) + # traj.done[t] marks a reset *before* step t, so traj.done[t_idx] + # only tells us obs_t is an episode-start — it does NOT invalidate the + # window. We check done flags strictly *inside* the action sequence. + dones = jax.lax.dynamic_slice( + traj.done, (t_idx + 1, 0), (plan_horizon - 1, num_envs), + ) + valid = ~jnp.any(dones, axis=0) + + rew_seq = jax.lax.dynamic_slice(traj.reward, (t_idx, 0), (plan_horizon, num_envs)) + window_return = jnp.sum(rew_seq, axis=0) # [num_envs] + + return obs_t, jnp.swapaxes(acts, 0, 1), valid, window_return + + obs_w, act_w, valid_w, returns_w = jax.vmap(_window)(jnp.arange(valid_per_rollout)) + + flat_obs = obs_w.reshape(-1, obs_dim) + flat_acts = act_w.reshape(-1, plan_horizon) + flat_valid = valid_w.reshape(-1) # bool: episode-boundary filter + + # Return-weighted advantages: normalise by batch mean, clip to [0.1, cap]. + # Passed as per-sample multipliers into compute_loss *after* per-position + # normalisation, so the weight correctly scales each sample's contribution. + flat_returns = returns_w.reshape(-1) + flat_returns_clipped = jnp.clip(flat_returns, 0.0, None) + return_weights = flat_returns_clipped / (jnp.mean(flat_returns_clipped) + 1e-8) + return_weights = jnp.clip(return_weights, 0.1, return_weight_cap) + + dataset = (flat_obs, flat_acts, flat_valid, return_weights) + + # --- Minibatch SGD over UPDATE_EPOCHS epochs --- + def _epoch(epoch_state, _): + state, ds, rng = epoch_state + rng, perm_rng = jax.random.split(rng) + perm = jax.random.permutation(perm_rng, num_samples) + shuffled = jax.tree.map(lambda x: jnp.take(x, perm, axis=0), ds) + batches = jax.tree.map( + lambda x: x.reshape(config["NUM_MINIBATCHES"], -1, *x.shape[1:]), shuffled, + ) + + def _mb(carry, batch): + st, rng = carry + rng, loss_rng = jax.random.split(rng) + obs_b, act_b, val_b, adv_b = batch + st, metrics = grad_step(st, act_b, obs_b, val_b, loss_rng, adv_b) + return (st, rng), metrics + + (state, rng), metrics = jax.lax.scan(_mb, (state, rng), batches) + return (state, ds, rng), metrics + + (state, _, rng), loss_info = jax.lax.scan( + _epoch, (state, dataset, rng), None, config["UPDATE_EPOCHS"], + ) + + # --- Metrics --- + metric = jax.tree.map(jnp.mean, loss_info) + returned = traj.info["returned_episode"] + env_metrics = jax.tree.map( + lambda x: (x * returned).sum() / (returned.sum() + 1e-8), traj.info, + ) + metric.update(env_metrics) + metric["valid_frac"] = jnp.mean(flat_valid.astype(jnp.float32)) + metric["mean_return_weight"] = jnp.mean(return_weights) + + # --- Periodic validation --- + rng, val_rng = jax.random.split(rng) + dummy = jax.tree.map( + jnp.zeros_like, {f"val/{k}": v for k, v in env_metrics.items()}, + ) + val_metrics = jax.lax.cond( + step_idx % val_interval == 0, + lambda: _validate(state, val_rng), + lambda: dummy, + ) + metric.update(val_metrics) + + if _wandb_log is not None: + jax.debug.callback(_wandb_log, metric, step_idx) + + runner = (state, env_state, last_obs, last_done, hstate, rng, step_idx + 1) + return runner, metric + + rng, run_rng = jax.random.split(rng) + runner_init = ( + state, env_state, obsv, jnp.zeros(num_envs, dtype=bool), + init_hstate, run_rng, resume_step, + ) + runner_final, metrics = jax.lax.scan(_update_step, runner_init, None, scan_length) + return {"runner_state": runner_final, "metrics": metrics} + + return train + + +# --------------------------------------------------------------------------- +# Entry point +# --------------------------------------------------------------------------- + +def run_offline_diffusion(config): + """Configure, compile, and run offline diffusion training. + + Args: + config: Mixed-case hyperparameter dict from ``defaults.yaml`` / CLI merge. + Keys are upper-cased on entry. + """ + config = {k.upper(): v for k, v in config.items()} + + # OFFLINE_TOTAL_TIMESTEPS (env frames) is the hardware-portable source of + # truth: invariant under num_envs changes, so the same config trains the + # same amount of environment experience on any GPU. OFFLINE_NUM_UPDATES + # is kept as a legacy fallback for configs that prefer the update form. + resolve_num_updates(config, "offline") + # Translate env-frame-denominated hyperparameters (LR_WARMUP_FRAMES, + # VAL_INTERVAL_FRAMES) into their update-step legacy keys. + resolve_scaled_hyperparams(config, "offline") + print_config_snapshot(config, "offline") + + if config["USE_WANDB"]: + init_wandb( + config, + name=f"{config['ENV_NAME']}-OfflineDiffusion-BC-{int(config['OFFLINE_TOTAL_TIMESTEPS'] // 1e6)}M", + resume_run_id=config.get("RESUME_WANDB_RUN_ID"), + ) + + rng = jax.random.PRNGKey(config["SEED"]) + rngs = jax.random.split(rng, config["NUM_REPEATS"]) + + train_fn = jax.jit(jax.vmap(make_train(config))) + + t0 = time.time() + out = train_fn(rngs) + elapsed = time.time() - t0 + print(f"Time: {elapsed:.1f}s SPS: {config['OFFLINE_TOTAL_TIMESTEPS'] / elapsed:.0f}") + + if config["USE_WANDB"] and config["SAVE_POLICY"]: + train_states = out["runner_state"][0] + train_state = jax.tree.map(lambda x: x[0], train_states) + path = os.path.join(wandb.run.dir, "policies") + with ocp.CheckpointManager(path, options=ocp.CheckpointManagerOptions(max_to_keep=1)) as mgr: + mgr.save(int(config["OFFLINE_TOTAL_TIMESTEPS"]), args=ocp.args.StandardSave(train_state)) + print(f"Saved policy to {path}") + + num_updates = config["NUM_UPDATES"] + save_checkpoint_metadata( + path, + mode="offline", + update_step=num_updates, + total_gradient_steps=num_updates * config["UPDATE_EPOCHS"] * config["NUM_MINIBATCHES"], + wandb_run_id=wandb.run.id if wandb.run else None, + config=config, + ) + + artifact = wandb.Artifact( + name=f"{config['ENV_NAME']}-policy", + type="model", + metadata=config + ) + artifact.add_dir(path) + wandb.log_artifact(artifact) + + print("Uploaded policy artifact to wandb") diff --git a/src/planners/online.py b/src/planners/online.py new file mode 100644 index 0000000000000000000000000000000000000000..cd16d40f80f09bd3ae0d7a80a3ff7ce7e908d689 --- /dev/null +++ b/src/planners/online.py @@ -0,0 +1,759 @@ +"""Online DAgger training: roll out learner, label with expert, aggregate, update. + +Implements Algorithm 3.1 from Ross et al. (2011) — 'A Reduction of Imitation +Learning and Structured Prediction to No-Regret Online Learning'. + +Each DAgger iteration: + 1. Roll out the mixed policy (beta * expert + (1-beta) * learner). + 2. At every visited state, query the expert for target actions. + 3. Aggregate (obs, expert_plan) pairs into a growing replay buffer. + 4. Train the diffusion model on the full buffer with BC loss (MDLM ELBO). + beta decays exponentially so the learner's own policy dominates rollouts. +""" + +from __future__ import annotations + +import os +import time +from typing import Any, NamedTuple + +import jax +import jax.numpy as jnp +import optax +import orbax.checkpoint as ocp +import wandb +from flax.training.train_state import TrainState + +from src.diffusion.sampling import sample_plan +from src.diffusion.schedules import SCHEDULE_MAP + +from .common import ( + make_grad_step, + make_validate, + print_config_snapshot, + resolve_num_updates, + resolve_scaled_hyperparams, +) +from .env import make_env +from .model import ( + build_model, + init_params, + load_checkpoint, + load_checkpoint_for_resume, + create_train_state, + make_apply_fns, + save_checkpoint_metadata, +) +from .ppo import PPOAgent, load_ppo_agent +from .logging import init_wandb, make_wandb_callback + + +# --------------------------------------------------------------------------- +# Scan carry +# --------------------------------------------------------------------------- + + +class DAggerCarry(NamedTuple): + """Carry state for the outer DAgger update scan.""" + + train_state: Any + env_state: Any + obs: jnp.ndarray # [E, obs_dim] + rng: jax.Array + step_idx: int + ppo_hs: Any # [E, layer_size] or None (non-RNN) + prev_done: jnp.ndarray # [E] bool + buf_obs: jnp.ndarray # [max_buf, obs_dim] + buf_plans: jnp.ndarray # [max_buf, plan_horizon] int32 + buf_valid: jnp.ndarray # [max_buf] float32 + buf_write_idx: int + buf_fill: int + best_params: Any # copy of params with highest val return + best_val_return: jnp.ndarray # scalar, -inf initially + + +# --------------------------------------------------------------------------- +# make_train_dagger +# --------------------------------------------------------------------------- + + +def make_train_dagger(config: dict[str, Any]): + """Build the DAgger train closure. + + All environment construction, model instantiation, and static + pre-computation happen here (outside the returned ``train`` closure) so + they are not repeated across ``jax.vmap`` replicas or JIT retraces. + + Args: + config: Upper-cased hyperparameter dict (see ``configs/defaults.yaml``). + + Returns: + A ``train(rng) -> dict`` closure that is safe to JIT and vmap. + """ + num_envs = config["NUM_ENVS"] + plan_horizon = config["PLAN_HORIZON"] + num_updates = config["NUM_UPDATES"] + update_epochs = config["UPDATE_EPOCHS"] + num_minibatches = config["NUM_MINIBATCHES"] + diffusion_steps = config["DIFFUSION_STEPS"] + num_steps = config["NUM_STEPS"] + + # Validation config + val_interval = config.get("VAL_INTERVAL", 50) + val_replan_every = config.get("VAL_REPLAN_EVERY", 4) + val_steps = config.get("VAL_STEPS", 128) + n_val_cycles = val_steps // val_replan_every + + # Environment ---------------------------------------------------------- + env, env_params = make_env(config, num_envs) + num_actions = env.action_space(env_params).n + obs_shape = env.observation_space(env_params).shape + obs_dim = obs_shape[0] + + # Expert (PPO) — required for DAgger ----------------------------------- + assert config.get("PPO_CHECKPOINT_PATH"), ( + "DAgger requires an expert policy; set PPO_CHECKPOINT_PATH." + ) + ppo: PPOAgent = load_ppo_agent( + config["PPO_CHECKPOINT_PATH"], + num_actions, + obs_dim, + config.get("LAYER_SIZE", 512), + config.get("PPO_MODEL_TYPE", "ppo_rnn"), + config, + num_envs=num_envs, + ) + + # Schedule ------------------------------------------------------------- + schedule_fn, schedule_deriv_fn = SCHEDULE_MAP[config["DIFFUSION_SCHEDULE"]] + + # Diffusion model / apply fns ------------------------------------------ + model = build_model(config, num_actions) + apply_eval, apply_train = make_apply_fns(model) + grad_step = make_grad_step( + apply_train, + num_actions, + schedule_fn, + schedule_deriv_fn, + config.get("TRAIN_SIGMA", 0.0), + config.get("LABEL_SMOOTHING", 0.0), + ) + + # Pretrained checkpoint ------------------------------------------------ + pretrained_params = None + if config.get("OFFLINE_CHECKPOINT_PATH"): + _tmp_rng = jax.random.PRNGKey(0) + pretrained_params = load_checkpoint( + model, + _tmp_rng, + obs_dim, + plan_horizon, + config["OFFLINE_CHECKPOINT_PATH"], + ) + + # B3: roll out num_steps env transitions across n_cycles plans, then + # extract sliding windows the same way offline.py does — yields + # W = num_steps - plan_horizon + 1 windows per env per update instead + # of only one window per cycle (~16x denser for the default config). + assert num_steps % plan_horizon == 0, ( + f"NUM_STEPS ({num_steps}) must be divisible by" + f" PLAN_HORIZON ({plan_horizon})" + ) + assert num_steps >= plan_horizon, ( + f"NUM_STEPS ({num_steps}) must be >= PLAN_HORIZON ({plan_horizon})" + ) + n_cycles = num_steps // plan_horizon + valid_per_rollout = num_steps - plan_horizon + 1 + samples_per_update = num_envs * valid_per_rollout + + assert samples_per_update % num_minibatches == 0, ( + f"samples_per_update ({samples_per_update}) not divisible by" + f" num_minibatches ({num_minibatches})" + ) + + # B1: circular replay buffer sizing. + # Theoretical max = num_updates * samples_per_update; cap to stay in + # GPU memory. Each sample stores obs (float32) + plan (int32) + valid. + max_buffer_size = min( + num_updates * samples_per_update, + config.get("DAGGER_BUFFER_MAX", 100_000), + ) + assert samples_per_update <= max_buffer_size, ( + f"samples_per_update ({samples_per_update}) exceeds" + f" max_buffer_size ({max_buffer_size}); raise DAGGER_BUFFER_MAX" + f" or shrink NUM_ENVS / NUM_STEPS" + ) + + # B1: multi-pass training over the buffer. Each pass redraws a + # fresh sample of size ``samples_per_update`` from the filled portion + # of D. Default = 1 so DAgger does exactly the same gradient work + # per update as offline BC (``update_epochs * num_minibatches`` grad + # steps over a sample of size ``samples_per_update``) — fair compute + # comparison. After the Bug 3 sliding-window fix, ``samples_per_update`` + # is already 16x larger than the legacy cycle-only count, so a single + # pass already yields ~34% buffer coverage and cumulative coverage + # across a few updates approaches 100%. Set DAGGER_TRAIN_PASSES > 1 + # explicitly to trade BC fairness for higher per-update D coverage. + _default_passes = 1 + n_train_passes = config.get("DAGGER_TRAIN_PASSES") or _default_passes + + # B2: deterministic-by-default expert. Sampling from ``pi.logits`` + # injects label noise — two queries to the same state can return + # different actions — which breaks DAgger's assumption of a fixed + # expert mapping s -> a*. Argmax keeps labels consistent. + expert_deterministic = config.get("DAGGER_EXPERT_DETERMINISTIC", True) + + # Beta schedule: probability of using expert for rollout actions. + # beta_i = beta_init * beta_decay^i -> 0 as i -> inf. + beta_init = config.get("DAGGER_BETA_INIT", 1.0) + beta_decay = config.get("DAGGER_BETA_DECAY", 0.95) + + # I8: cosine LR schedule matching offline training. Stretched to + # cover all gradient steps across train passes (B1). + total_grad_steps = ( + num_updates * n_train_passes * update_epochs * num_minibatches + ) + warmup_steps = config.get("LR_WARMUP_STEPS", 0) + lr_schedule = ( + optax.warmup_cosine_decay_schedule( + init_value=0.0, + peak_value=config["LR"], + warmup_steps=warmup_steps, + decay_steps=total_grad_steps, + end_value=config["LR"] * 0.1, + ) + if warmup_steps > 0 + else optax.cosine_decay_schedule( + init_value=config["LR"], + decay_steps=total_grad_steps, + alpha=0.1, + ) + ) + + # Resume checkpoint (loaded outside JIT, captured by train closure) ------ + resume_step = config.get("RESUME_STEP") or 0 + resume_state = None + if config.get("RESUME_CHECKPOINT_PATH"): + resume_state = load_checkpoint_for_resume( + model, + jax.random.PRNGKey(0), + obs_dim, + plan_horizon, + config["RESUME_CHECKPOINT_PATH"], + lr_schedule, + config["MAX_GRAD_NORM"], + ) + target_opt_step = ( + resume_step * n_train_passes * update_epochs * num_minibatches + ) + resume_state = resume_state.replace(step=target_opt_step) + + scan_length = num_updates - resume_step + + # W&B ------------------------------------------------------------------ + _wandb_log = ( + make_wandb_callback( + config, + steps_per_update=num_envs * num_steps, + val_interval=val_interval, + is_online=True, + ) + if config.get("USE_WANDB") + else None + ) + + # ------------------------------------------------------------------ + # Train closure + # ------------------------------------------------------------------ + + def train(rng: jax.Array) -> dict[str, Any]: + """JIT/vmap-compatible DAgger training loop. + + Args: + rng: JAX PRNG key (one per vmap replica). + + Returns: + Dict with ``runner_state`` (final DAggerCarry) and ``metrics``. + """ + rng, init_rng, env_rng = jax.random.split(rng, 3) + + if resume_state is not None: + state = resume_state + params = resume_state.params + elif pretrained_params is not None: + params = pretrained_params + state = create_train_state( + model, params, lr_schedule, config["MAX_GRAD_NORM"], + ) + else: + params = init_params(model, init_rng, obs_dim, plan_horizon) + state = create_train_state( + model, params, lr_schedule, config["MAX_GRAD_NORM"], + ) + + obs, env_state = env.reset(env_rng, env_params) + + # Pre-allocate DAgger replay buffer + buf_obs = jnp.zeros((max_buffer_size, obs_dim)) + buf_plans = jnp.zeros( + (max_buffer_size, plan_horizon), dtype=jnp.int32, + ) + buf_valid = jnp.zeros(max_buffer_size) + + # Shared validation closure (see common.py) + _validate = make_validate( + env, env_params, apply_eval, num_actions, + plan_horizon, schedule_fn, config, + val_replan_every, n_val_cycles, + ) + + # ---------------------------------------------------------- + # _update_step (one DAgger iteration) + # ---------------------------------------------------------- + def _update_step(carry: DAggerCarry, _): + ( + state, + env_state, + obs, + rng, + step_idx, + ppo_hs, + prev_done, + buf_obs, + buf_plans, + buf_valid, + buf_write_idx, + buf_fill, + best_params, + best_val_return, + ) = carry + + # Beta decays each update: expert -> learner + beta = beta_init * jnp.power(beta_decay, step_idx) + + # --- Roll out with mixed policy, collect expert labels - + def _plan_and_execute(outer_carry, _): + es, cur_obs, rng, hs, p_done = outer_carry + rng, plan_rng, sim_rng = jax.random.split(rng, 3) + + # Learner plan from the current diffusion policy + learner_plan = sample_plan( + apply_eval, + state.params, + plan_rng, + cur_obs, + num_actions, + plan_horizon, + diffusion_steps, + schedule_fn, + config.get("REMASK_STRATEGY", "rescale"), + config.get("ETA", 0.5), + config.get("USE_LOOP", True), + config.get("T_ON", 0.7), + config.get("T_OFF", 0.3), + config.get("TEMPERATURE", 1.0), + config.get("TOP_P", 0.95), + ) # [E, H] + + # B3: simulate plan_horizon steps, recording the visited + # obs at every state alongside the expert action. The + # outer code then extracts sliding windows from the full + # per-step trace, mirroring offline.py. + def _sim_step(c, step_i): + st, o, r, hs, p_done = c + r, s_rng, mix_rng, ppo_rng = jax.random.split( + r, 4, + ) + + # Expert action with the correct done flag so the + # PPO RNN hidden state resets on episode boundaries. + pi, new_hs = ppo.get_pi(o, p_done, hs) + if expert_deterministic: + # B2: argmax keeps the expert mapping s -> a + # deterministic, removing label noise from D. + expert_act = jnp.argmax( + pi.logits, axis=-1, + ).squeeze(0) + else: + expert_act = jax.random.categorical( + ppo_rng, pi.logits, + ).squeeze(0) + + # Learner action from the plan + learner_act = learner_plan[:, step_i] + + # Mixed execution: prob beta -> expert, else learner + use_expert = jax.random.bernoulli( + mix_rng, beta, shape=(num_envs,), + ) + exec_act = jnp.where( + use_expert, expert_act, learner_act, + ) + + o_next, st, rew, done, info = env.step( + s_rng, st, exec_act, env_params, + ) + # Yield the visited obs ``o`` (not ``o_next``) so the + # paired (obs_t, expert_act_t) is consistent. + return (st, o_next, r, new_hs, done), ( + o, + expert_act, + rew, + done, + info, + ) + + final_c, ( + obs_seq, expert_acts, rews, dones, infos, + ) = jax.lax.scan( + _sim_step, + (es, cur_obs, sim_rng, hs, p_done), + jnp.arange(plan_horizon), + ) + # obs_seq: [H, E, obs_dim] + # expert_acts: [H, E] + # dones: [H, E] + es_next, obs_next, _, hs_next, done_next = final_c + + return (es_next, obs_next, rng, hs_next, done_next), ( + obs_seq, + expert_acts, + rews, + dones, + infos, + ) + + # I7: ppo_hs and prev_done persist across cycles and + # updates via the scan carry. + (env_state, obs, rng, ppo_hs, prev_done), traj = ( + jax.lax.scan( + _plan_and_execute, + (env_state, obs, rng, ppo_hs, prev_done), + None, + n_cycles, + ) + ) + # traj_obs: [C, H, E, obs_dim] + # traj_expert_acts: [C, H, E] + # traj_dones: [C, H, E] + ( + traj_obs, + traj_expert_acts, + traj_rew, + traj_dones, + all_infos, + ) = traj + + # B3: concatenate cycles into one [T, E, ...] rollout. Cycles + # are contiguous in time so a flat reshape preserves order. + T = num_steps + obs_t = traj_obs.reshape(T, num_envs, obs_dim) # [T, E, D] + acts_t = traj_expert_acts.reshape(T, num_envs) # [T, E] + dones_t = traj_dones.reshape(T, num_envs) # [T, E] + + # B3 + B5: sliding-window extraction matching offline.py. + # Window (t, e) is valid iff actions [t..t+H-1] all came from + # one episode, i.e. dones[t..t+H-2] are all False. An episode + # boundary on the *last* action is allowed — the window's + # action sequence is still a coherent trajectory. + def _window(t_idx): + obs_w = obs_t[t_idx] # [E, D] + acts_w = jax.lax.dynamic_slice( + acts_t, (t_idx, 0), (plan_horizon, num_envs), + ) # [H, E] + dones_w = jax.lax.dynamic_slice( + dones_t, (t_idx, 0), (plan_horizon - 1, num_envs), + ) # [H-1, E] + valid = ~jnp.any(dones_w, axis=0) # [E] + return obs_w, jnp.swapaxes(acts_w, 0, 1), valid + + obs_w, act_w, valid_w = jax.vmap(_window)( + jnp.arange(valid_per_rollout), + ) + # obs_w: [W, E, D] + # act_w: [W, E, H] + # valid_w: [W, E] + flat_obs = obs_w.reshape(-1, obs_dim) + flat_plans = act_w.reshape(-1, plan_horizon) + flat_valid = valid_w.reshape(-1).astype(jnp.float32) + + # B1: write new samples into circular replay buffer + write_indices = ( + buf_write_idx + jnp.arange(samples_per_update) + ) % max_buffer_size + buf_obs = buf_obs.at[write_indices].set(flat_obs) + buf_plans = buf_plans.at[write_indices].set(flat_plans) + buf_valid = buf_valid.at[write_indices].set(flat_valid) + buf_write_idx = ( + buf_write_idx + samples_per_update + ) % max_buffer_size + buf_fill = jnp.minimum( + buf_fill + samples_per_update, max_buffer_size, + ) + + # B1: multi-pass training over the aggregated buffer. Each + # pass redraws a fresh sample of size ``samples_per_update`` + # from the filled portion of the buffer; with the default + # ``n_train_passes`` ≈ ⌊|D|/B⌋ one update covers ~|D| + # examples once the buffer is full, restoring DAgger's + # "train on all of D" requirement without inflating gather + # memory beyond a single per-pass batch. + def _pass(pass_state, _): + state, rng = pass_state + rng, sample_rng = jax.random.split(rng) + buf_indices = jax.random.randint( + sample_rng, (samples_per_update,), 0, buf_fill, + ) + dataset = ( + buf_obs[buf_indices], + buf_plans[buf_indices], + buf_valid[buf_indices], + ) + + def _epoch(epoch_state, _): + state, ds, rng = epoch_state + rng, perm_rng = jax.random.split(rng) + perm = jax.random.permutation( + perm_rng, samples_per_update, + ) + shuffled = jax.tree.map( + lambda x: jnp.take(x, perm, axis=0), ds, + ) + batches = jax.tree.map( + lambda x: x.reshape( + num_minibatches, -1, *x.shape[1:], + ), + shuffled, + ) + + def _mb(mb_carry, batch): + st, rng = mb_carry + rng, loss_rng = jax.random.split(rng) + obs_b, act_b, val_b = batch + adv_b = jnp.ones(act_b.shape[0]) + st, metrics = grad_step( + st, + act_b, + obs_b, + val_b, + loss_rng, + advantages=adv_b, + ) + return (st, rng), metrics + + (state, rng), metrics = jax.lax.scan( + _mb, (state, rng), batches, + ) + return (state, ds, rng), metrics + + (state, _, rng), loss_info = jax.lax.scan( + _epoch, (state, dataset, rng), None, update_epochs, + ) + return (state, rng), loss_info + + (state, rng), loss_info = jax.lax.scan( + _pass, (state, rng), None, n_train_passes, + ) + + # --- Metrics ------------------------------------------ + metric = jax.tree.map(jnp.mean, loss_info) + returned = all_infos["returned_episode"] + env_metrics = jax.tree.map( + lambda x: (x * returned).sum() + / (returned.sum() + 1e-8), + all_infos, + ) + metric.update(env_metrics) + metric["beta"] = beta + metric["reward_mean"] = jnp.mean(traj_rew) + metric["buffer_fill"] = buf_fill.astype(jnp.float32) + metric["valid_frac"] = jnp.mean(flat_valid) + + # --- Periodic validation ------------------------------ + rng, val_rng = jax.random.split(rng) + dummy = jax.tree.map( + jnp.zeros_like, + {f"val/{k}": v for k, v in env_metrics.items()}, + ) + val_metrics = jax.lax.cond( + step_idx % val_interval == 0, + lambda: _validate(state, val_rng), + lambda: dummy, + ) + metric.update(val_metrics) + + # Best-model tracking: update when validation improves. + val_ret = val_metrics.get( + "val/returned_episode_returns", + jnp.array(-jnp.inf), + ) + is_val_step = step_idx % val_interval == 0 + improved = is_val_step & (val_ret > best_val_return) + best_params = jax.tree.map( + lambda b, c: jnp.where(improved, c, b), + best_params, + state.params, + ) + best_val_return = jnp.where( + improved, val_ret, best_val_return, + ) + metric["best_val_return"] = best_val_return + + if _wandb_log is not None: + jax.debug.callback(_wandb_log, metric, step_idx) + + new_carry = DAggerCarry( + train_state=state, + env_state=env_state, + obs=obs, + rng=rng, + step_idx=step_idx + 1, + ppo_hs=ppo_hs, + prev_done=prev_done, + buf_obs=buf_obs, + buf_plans=buf_plans, + buf_valid=buf_valid, + buf_write_idx=buf_write_idx, + buf_fill=buf_fill, + best_params=best_params, + best_val_return=best_val_return, + ) + return new_carry, metric + + # --- Outer scan ------------------------------------------- + rng, run_rng = jax.random.split(rng) + runner_init = DAggerCarry( + train_state=state, + env_state=env_state, + obs=obs, + rng=run_rng, + step_idx=resume_step, + ppo_hs=ppo.init_hidden(num_envs), + prev_done=jnp.zeros(num_envs, dtype=bool), + buf_obs=buf_obs, + buf_plans=buf_plans, + buf_valid=buf_valid, + buf_write_idx=jnp.int32(0), + buf_fill=jnp.int32(0), + best_params=params, + best_val_return=jnp.array(-jnp.inf), + ) + runner_final, metrics = jax.lax.scan( + _update_step, runner_init, None, scan_length, + ) + return { + "runner_state": runner_final, + "metrics": metrics, + "best_params": runner_final.best_params, + } + + return train + + +# --------------------------------------------------------------------------- +# Entry point +# --------------------------------------------------------------------------- + + +def run_online(config: dict[str, Any]) -> None: + """Configure, compile, and run DAgger online training. + + Args: + config: Mixed-case hyperparameter dict from ``defaults.yaml`` / CLI. + """ + config = {k.upper(): v for k, v in config.items()} + + # ONLINE_TOTAL_TIMESTEPS (env frames) is the hardware-portable source of + # truth: invariant under num_envs changes, so the same config trains the + # same amount of environment experience on any GPU. ONLINE_NUM_UPDATES is + # kept as a legacy fallback for configs that prefer the update form. + resolve_num_updates(config, "online") + # Translate env-frame-denominated hyperparameters (LR_WARMUP_FRAMES, + # VAL_INTERVAL_FRAMES, DAGGER_BETA_FINAL, DAGGER_BUFFER_CYCLES) into + # their update-step legacy keys. Must run AFTER resolve_num_updates + # because DAGGER_BETA_FINAL needs NUM_UPDATES. + resolve_scaled_hyperparams(config, "online") + print_config_snapshot(config, "online") + + if config.get("USE_WANDB"): + init_wandb( + config, + name=f"{config['ENV_NAME']}-OnlineDiffusion-DAgger-{int(config['ONLINE_TOTAL_TIMESTEPS'] // 1e6)}M", + resume_run_id=config.get("RESUME_WANDB_RUN_ID"), + ) + + rng = jax.random.PRNGKey(config["SEED"]) + rngs = jax.random.split(rng, config.get("NUM_REPEATS", 1)) + + train_fn = jax.jit(jax.vmap(make_train_dagger(config))) + + t0 = time.time() + out = train_fn(rngs) + elapsed = time.time() - t0 + print(f"Time: {elapsed:.1f}s SPS: {config['ONLINE_TOTAL_TIMESTEPS'] / elapsed:.0f}") + + if config.get("USE_WANDB") and config.get("SAVE_POLICY"): + # Final checkpoint (last iteration params) + train_states = out["runner_state"].train_state + train_state = jax.tree.map(lambda x: x[0], train_states) + path = os.path.join(wandb.run.dir, "policies") + with ocp.CheckpointManager( + path, + options=ocp.CheckpointManagerOptions(max_to_keep=1), + ) as mgr: + mgr.save( + int(config["NUM_UPDATES"]), + args=ocp.args.StandardSave(train_state), + ) + print(f"Saved final policy to {path}") + + num_updates = config["NUM_UPDATES"] + save_checkpoint_metadata( + path, + mode="online", + update_step=num_updates, + total_gradient_steps=num_updates * config["UPDATE_EPOCHS"] * config["NUM_MINIBATCHES"], + wandb_run_id=wandb.run.id if wandb.run else None, + config=config, + ) + + artifact = wandb.Artifact( + name=f"{config['ENV_NAME']}-policy", + type="model", + metadata=config, + ) + artifact.add_dir(path) + wandb.log_artifact(artifact) + + # Best checkpoint (highest validation return). + # Wrap in a dummy TrainState so the Orbax structure matches + # the final checkpoint — load_checkpoint expects TrainState. + best_params = jax.tree.map( + lambda x: x[0], out["best_params"], + ) + tx = optax.chain( + optax.clip_by_global_norm(config["MAX_GRAD_NORM"]), + optax.adam(config["LR"], eps=1e-5), + ) + best_state = TrainState.create( + apply_fn=lambda *a: None, + params=best_params, + tx=tx, + ) + best_path = os.path.join(wandb.run.dir, "policies_best") + with ocp.CheckpointManager( + best_path, + options=ocp.CheckpointManagerOptions(max_to_keep=1), + ) as mgr: + mgr.save(0, args=ocp.args.StandardSave(best_state)) + print(f"Saved best policy to {best_path}") + + best_artifact = wandb.Artifact( + name=f"{config['ENV_NAME']}-policy-best", + type="model", + metadata=config, + ) + best_artifact.add_dir(best_path) + wandb.log_artifact(best_artifact) + + print("Uploaded final + best policy artifacts to wandb") diff --git a/src/planners/ppo.py b/src/planners/ppo.py new file mode 100644 index 0000000000000000000000000000000000000000..fda1d481f92d51290f3549710975038b17969b88 --- /dev/null +++ b/src/planners/ppo.py @@ -0,0 +1,188 @@ +"""PPO agent adapter and checkpoint loading utilities.""" + +from __future__ import annotations + +from pathlib import Path +from typing import Any + +import jax +import jax.numpy as jnp +import numpy as np +import orbax.checkpoint as ocp + +from Craftax_Baselines.ppo import ActorCritic +from Craftax_Baselines.ppo_rnn import ActorCriticRNN +from Craftax_Baselines.ppo_rnd import ActorCriticRND + + +def load_ppo_params( + path: str, + network: Any, + model_type: str, + num_envs: int, + obs_shape: tuple, + layer_size: int = 512, +) -> Any: + """Restore PPO parameters from an Orbax checkpoint. + + Args: + path: Path to the Orbax checkpoint directory. + network: Instantiated Flax network (used only for structure). + model_type: One of ``"ppo_rnn"``, ``"ppo_rnd"``, or ``"ppo"``. + num_envs: Number of parallel environments (affects RNN init shape). + obs_shape: Observation shape tuple. + layer_size: Hidden layer size (for RNN hidden state init). + + Returns: + Restored parameter pytree. + """ + path = str(Path(path).resolve()) + rng = jax.random.PRNGKey(0) + if model_type == "ppo_rnn": + init_x = (jnp.zeros((1, num_envs, *obs_shape)), jnp.zeros((1, num_envs))) + abstract = network.init(rng, jnp.zeros((num_envs, layer_size)), init_x) + else: + abstract = network.init(rng, jnp.zeros((1, *obs_shape))) + + with ocp.CheckpointManager(path) as mgr: + step = mgr.latest_step() + if step is None: + raise FileNotFoundError(f"No checkpoint at {path}") + restored = mgr.restore( + step, + args=ocp.args.PyTreeRestore(item={"params": abstract}, partial_restore=True), + ) + print(f"Loaded {model_type.upper()} checkpoint from '{path}' (step {step})") + return restored["params"] + + +def build_ppo_network(model_type: str, num_actions: int, layer_size: int, config: dict) -> Any: + """Instantiate the correct PPO architecture. + + Args: + model_type: One of ``"ppo_rnn"``, ``"ppo_rnd"``, or ``"ppo"``. + num_actions: Size of the discrete action space. + layer_size: Hidden layer width. + config: Training config (forwarded to ``ActorCriticRNN``). + + Returns: + Flax module instance. + """ + model_type = model_type.lower() + if model_type == "ppo_rnn": + return ActorCriticRNN(num_actions, config=config) + if model_type == "ppo_rnd": + return ActorCriticRND(num_actions, layer_size) + return ActorCritic(num_actions, layer_size) + + +def load_ppo_agent( + path: str, + num_actions: int, + obs_dim: int, + layer_size: int, + model_type: str, + config: dict, + num_envs: int = 1, +) -> "PPOAgent": + """Build network, load params, and return a :class:`PPOAgent`. + + Args: + path: Path to the Orbax checkpoint directory. + num_actions: Size of the discrete action space. + obs_dim: Observation vector dimensionality. + layer_size: Hidden layer width. + model_type: One of ``"ppo_rnn"``, ``"ppo_rnd"``, or ``"ppo"``. + config: Training config dict. + num_envs: Number of parallel environments. + + Returns: + A fully initialised :class:`PPOAgent`. + """ + net = build_ppo_network(model_type, num_actions, layer_size, config) + params = load_ppo_params(path, net, model_type, num_envs, (obs_dim,), layer_size) + return PPOAgent(net, params, model_type, layer_size) + + +class PPOAgent: + """Uniform interface over PPO-RNN / PPO / PPO-RND for action collection. + + Args: + network: Flax actor-critic module. + params: Loaded parameter pytree. + model_type: One of ``"ppo_rnn"``, ``"ppo_rnd"``, or ``"ppo"``. + layer_size: Hidden layer width (used for RNN hidden-state shape). + """ + + def __init__(self, network: Any, params: Any, model_type: str, layer_size: int = 512) -> None: + self.network = network + self.params = params + self.model_type = model_type.lower() + self.layer_size = layer_size + + def init_hidden(self, batch_size: int) -> jnp.ndarray | None: + """Return a zero hidden state for RNN models, else ``None``.""" + if self.model_type == "ppo_rnn": + return jnp.zeros((batch_size, self.layer_size)) + return None + + def act( + self, + obs: jnp.ndarray, + done: jnp.ndarray, + hidden: jnp.ndarray | None, + rng: jax.Array, + temperature: float = 1.0, + ) -> tuple[jnp.ndarray, jnp.ndarray | None]: + """Sample an action. + + Args: + obs: Observation array ``[B, obs_dim]``. + done: Episode-done flags ``[B]``. + hidden: RNN hidden state (``None`` for non-RNN models). + rng: PRNG key. + temperature: Softmax temperature for sampling. + + Returns: + ``(action, new_hidden)`` tuple. + """ + if self.model_type == "ppo_rnn": + ac_in = (obs[np.newaxis, :], done[np.newaxis, :]) + new_hidden, pi, _ = self.network.apply(self.params, hidden, ac_in) + elif self.model_type == "ppo_rnd": + pi, _, _ = self.network.apply(self.params, obs) + new_hidden = hidden + else: + pi, _ = self.network.apply(self.params, obs) + new_hidden = hidden + + action = jax.random.categorical(rng, pi.logits / temperature) + if self.model_type == "ppo_rnn": + action = action.squeeze(0) + return action, new_hidden + + def get_pi( + self, + obs: jnp.ndarray, + done: jnp.ndarray | None = None, + hidden: jnp.ndarray | None = None, + ) -> tuple[Any, jnp.ndarray | None]: + """Return the policy distribution (used in DAgger expert labelling). + + Args: + obs: Observation array ``[B, obs_dim]``. + done: Episode-done flags ``[B]`` (required for RNN models). + hidden: RNN hidden state. + + Returns: + ``(pi, new_hidden)`` tuple. + """ + if self.model_type == "ppo_rnn": + ac_in = (obs[np.newaxis, :], done[np.newaxis, :]) + new_hidden, pi, _ = self.network.apply(self.params, hidden, ac_in) + return pi, new_hidden + if self.model_type == "ppo_rnd": + pi, _, _ = self.network.apply(self.params, obs) + return pi, hidden + pi, _ = self.network.apply(self.params, obs) + return pi, hidden