Upload COMP0258 demo bundle (code + diffusion/PPO checkpoints + ablation assets)
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- .gitattributes +118 -0
- Craftax_Baselines/.gitignore +169 -0
- Craftax_Baselines/.pre-commit-config.yaml +6 -0
- Craftax_Baselines/Dockerfile +41 -0
- Craftax_Baselines/LICENSE +19 -0
- Craftax_Baselines/README.md +46 -0
- Craftax_Baselines/analysis/__init__.py +0 -0
- Craftax_Baselines/analysis/view_ppo_agent.py +151 -0
- Craftax_Baselines/build.sh +10 -0
- Craftax_Baselines/images/logo.png +0 -0
- Craftax_Baselines/logz/__init__.py +0 -0
- Craftax_Baselines/logz/batch_logging.py +115 -0
- Craftax_Baselines/models/__init__.py +0 -0
- Craftax_Baselines/models/actor_critic.py +256 -0
- Craftax_Baselines/models/icm.py +72 -0
- Craftax_Baselines/models/rnd.py +120 -0
- Craftax_Baselines/ppo.py +733 -0
- Craftax_Baselines/ppo_rnd.py +680 -0
- Craftax_Baselines/ppo_rnn.py +542 -0
- Craftax_Baselines/requirements.txt +16 -0
- Craftax_Baselines/run_docker.sh +24 -0
- Craftax_Baselines/wrappers.py +200 -0
- README.md +547 -0
- checkpoints/offline/Craftax-Classic-Symbolic-v1-OfflineDiffusion-BC-100M/100000000/_CHECKPOINT_METADATA +1 -0
- checkpoints/offline/Craftax-Classic-Symbolic-v1-OfflineDiffusion-BC-100M/100000000/default/_METADATA +0 -0
- checkpoints/offline/Craftax-Classic-Symbolic-v1-OfflineDiffusion-BC-100M/100000000/default/_sharding +1 -0
- checkpoints/offline/Craftax-Classic-Symbolic-v1-OfflineDiffusion-BC-100M/100000000/default/array_metadatas/process_0 +1 -0
- checkpoints/offline/Craftax-Classic-Symbolic-v1-OfflineDiffusion-BC-100M/100000000/default/d/021af9ba431a3072f4819480f91b83af +0 -0
- checkpoints/offline/Craftax-Classic-Symbolic-v1-OfflineDiffusion-BC-100M/100000000/default/manifest.ocdbt +0 -0
- checkpoints/offline/Craftax-Classic-Symbolic-v1-OfflineDiffusion-BC-100M/100000000/default/ocdbt.process_0/d/0cf9a08a9722f9b8a0b7f007da7c1e92 +3 -0
- checkpoints/offline/Craftax-Classic-Symbolic-v1-OfflineDiffusion-BC-100M/100000000/default/ocdbt.process_0/d/1968eb861d84503c0e805cffdd77528a +0 -0
- checkpoints/offline/Craftax-Classic-Symbolic-v1-OfflineDiffusion-BC-100M/100000000/default/ocdbt.process_0/d/aec20934c03229d1bd9651c955e59d84 +0 -0
- checkpoints/offline/Craftax-Classic-Symbolic-v1-OfflineDiffusion-BC-100M/100000000/default/ocdbt.process_0/d/d32b0ea672fe7a9b86b8e62e7c20dbaf +3 -0
- checkpoints/offline/Craftax-Classic-Symbolic-v1-OfflineDiffusion-BC-100M/100000000/default/ocdbt.process_0/manifest.ocdbt +0 -0
- checkpoints/offline/Craftax-Classic-Symbolic-v1-OfflineDiffusion-BC-100M/resume_metadata.json +68 -0
- checkpoints/online/Craftax-Classic-Symbolic-v1-OnlineDiffusion-DAgger-50M/50000000/_CHECKPOINT_METADATA +1 -0
- checkpoints/online/Craftax-Classic-Symbolic-v1-OnlineDiffusion-DAgger-50M/50000000/default/_METADATA +0 -0
- checkpoints/online/Craftax-Classic-Symbolic-v1-OnlineDiffusion-DAgger-50M/50000000/default/_sharding +1 -0
- checkpoints/online/Craftax-Classic-Symbolic-v1-OnlineDiffusion-DAgger-50M/50000000/default/array_metadatas/process_0 +1 -0
- checkpoints/online/Craftax-Classic-Symbolic-v1-OnlineDiffusion-DAgger-50M/50000000/default/d/63ff4b6b75238977cfc360704c224d86 +0 -0
- checkpoints/online/Craftax-Classic-Symbolic-v1-OnlineDiffusion-DAgger-50M/50000000/default/manifest.ocdbt +0 -0
- checkpoints/online/Craftax-Classic-Symbolic-v1-OnlineDiffusion-DAgger-50M/50000000/default/ocdbt.process_0/d/63a3ac9c870d5f7eb2b74967758ee043 +0 -0
- checkpoints/online/Craftax-Classic-Symbolic-v1-OnlineDiffusion-DAgger-50M/50000000/default/ocdbt.process_0/d/c3b086748e0ea04233c3638a3994fa30 +0 -0
- checkpoints/online/Craftax-Classic-Symbolic-v1-OnlineDiffusion-DAgger-50M/50000000/default/ocdbt.process_0/d/db1925df3ee2d3c92bea0a9878efa5fc +0 -0
- checkpoints/online/Craftax-Classic-Symbolic-v1-OnlineDiffusion-DAgger-50M/50000000/default/ocdbt.process_0/d/e25dea5d414404e637e55db20175c620 +0 -0
- checkpoints/online/Craftax-Classic-Symbolic-v1-OnlineDiffusion-DAgger-50M/50000000/default/ocdbt.process_0/d/fd4cdc0c7be44d4f518c6ccca7ad654a +3 -0
- checkpoints/online/Craftax-Classic-Symbolic-v1-OnlineDiffusion-DAgger-50M/50000000/default/ocdbt.process_0/manifest.ocdbt +0 -0
- checkpoints/ppo_agents/Craftax-Classic-Symbolic-v1-PPO_RNN-1000M/1000000000/_CHECKPOINT_METADATA +1 -0
- checkpoints/ppo_agents/Craftax-Classic-Symbolic-v1-PPO_RNN-1000M/1000000000/default/_METADATA +1 -0
- checkpoints/ppo_agents/Craftax-Classic-Symbolic-v1-PPO_RNN-1000M/1000000000/default/_sharding +1 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,121 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
checkpoints/offline/Craftax-Classic-Symbolic-v1-OfflineDiffusion-BC-100M/100000000/default/ocdbt.process_0/d/0cf9a08a9722f9b8a0b7f007da7c1e92 filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
checkpoints/offline/Craftax-Classic-Symbolic-v1-OfflineDiffusion-BC-100M/100000000/default/ocdbt.process_0/d/d32b0ea672fe7a9b86b8e62e7c20dbaf filter=lfs diff=lfs merge=lfs -text
|
| 38 |
+
checkpoints/online/Craftax-Classic-Symbolic-v1-OnlineDiffusion-DAgger-50M/50000000/default/ocdbt.process_0/d/fd4cdc0c7be44d4f518c6ccca7ad654a filter=lfs diff=lfs merge=lfs -text
|
| 39 |
+
checkpoints/ppo_agents/Craftax-Classic-Symbolic-v1-PPO_RNN-1000M/1000000000/default/ocdbt.process_0/d/369457b7c6608f1adf28eb88024d6b91 filter=lfs diff=lfs merge=lfs -text
|
| 40 |
+
checkpoints/ppo_agents/Craftax-Classic-Symbolic-v1-PPO_RNN-1000M/1000000000/default/ocdbt.process_0/d/6bb110a840637eede93e25f5356236f9 filter=lfs diff=lfs merge=lfs -text
|
| 41 |
+
checkpoints/ppo_agents/Craftax-Symbolic-v1-PPO_RNN-1000M/1000000000/default/ocdbt.process_0/d/e5a4020f50167115120fe5dac41c20fb filter=lfs diff=lfs merge=lfs -text
|
| 42 |
+
experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/achievement_breakdown.png filter=lfs diff=lfs merge=lfs -text
|
| 43 |
+
experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/achievement_collapse_action_diversity.png filter=lfs diff=lfs merge=lfs -text
|
| 44 |
+
experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/achievement_collapse_advantage_clip.png filter=lfs diff=lfs merge=lfs -text
|
| 45 |
+
experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/achievement_collapse_attention_only.png filter=lfs diff=lfs merge=lfs -text
|
| 46 |
+
experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/achievement_collapse_baseline_rl.png filter=lfs diff=lfs merge=lfs -text
|
| 47 |
+
experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/achievement_collapse_bc_wins.png filter=lfs diff=lfs merge=lfs -text
|
| 48 |
+
experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/achievement_collapse_entropy_bonus.png filter=lfs diff=lfs merge=lfs -text
|
| 49 |
+
experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/achievement_collapse_ewc.png filter=lfs diff=lfs merge=lfs -text
|
| 50 |
+
experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/achievement_collapse_ffn_only.png filter=lfs diff=lfs merge=lfs -text
|
| 51 |
+
experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/achievement_collapse_frozen_backbone.png filter=lfs diff=lfs merge=lfs -text
|
| 52 |
+
experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/achievement_collapse_gradient_surgery.png filter=lfs diff=lfs merge=lfs -text
|
| 53 |
+
experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/achievement_collapse_head_only.png filter=lfs diff=lfs merge=lfs -text
|
| 54 |
+
experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/achievement_collapse_kl_penalty.png filter=lfs diff=lfs merge=lfs -text
|
| 55 |
+
experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/achievement_collapse_layer_ablation_top1.png filter=lfs diff=lfs merge=lfs -text
|
| 56 |
+
experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/achievement_collapse_layer_ablation_top2.png filter=lfs diff=lfs merge=lfs -text
|
| 57 |
+
experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/achievement_collapse_layer_ablation_top3.png filter=lfs diff=lfs merge=lfs -text
|
| 58 |
+
experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/achievement_collapse_llrd.png filter=lfs diff=lfs merge=lfs -text
|
| 59 |
+
experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/achievement_collapse_lora.png filter=lfs diff=lfs merge=lfs -text
|
| 60 |
+
experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/achievement_collapse_low_t.png filter=lfs diff=lfs merge=lfs -text
|
| 61 |
+
experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/achievement_collapse_mixed_replay.png filter=lfs diff=lfs merge=lfs -text
|
| 62 |
+
experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/achievement_collapse_normalized_adv.png filter=lfs diff=lfs merge=lfs -text
|
| 63 |
+
experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/achievement_collapse_reward_filtering.png filter=lfs diff=lfs merge=lfs -text
|
| 64 |
+
experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/achievement_collapse_reward_model.png filter=lfs diff=lfs merge=lfs -text
|
| 65 |
+
experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/achievement_collapse_running_stats.png filter=lfs diff=lfs merge=lfs -text
|
| 66 |
+
experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/achievement_collapse_t_curriculum.png filter=lfs diff=lfs merge=lfs -text
|
| 67 |
+
experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/achievement_collapse_trust_region_kl.png filter=lfs diff=lfs merge=lfs -text
|
| 68 |
+
experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/js_divergence_comparison.png filter=lfs diff=lfs merge=lfs -text
|
| 69 |
+
experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/cka_similarity.png filter=lfs diff=lfs merge=lfs -text
|
| 70 |
+
experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/curves_action_diversity.png filter=lfs diff=lfs merge=lfs -text
|
| 71 |
+
experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/curves_advantage_clip.png filter=lfs diff=lfs merge=lfs -text
|
| 72 |
+
experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/curves_attention_only.png filter=lfs diff=lfs merge=lfs -text
|
| 73 |
+
experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/curves_baseline_rl.png filter=lfs diff=lfs merge=lfs -text
|
| 74 |
+
experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/curves_bc_wins.png filter=lfs diff=lfs merge=lfs -text
|
| 75 |
+
experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/curves_entropy_bonus.png filter=lfs diff=lfs merge=lfs -text
|
| 76 |
+
experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/curves_ewc.png filter=lfs diff=lfs merge=lfs -text
|
| 77 |
+
experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/curves_ffn_only.png filter=lfs diff=lfs merge=lfs -text
|
| 78 |
+
experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/curves_frozen_backbone.png filter=lfs diff=lfs merge=lfs -text
|
| 79 |
+
experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/curves_gradient_surgery.png filter=lfs diff=lfs merge=lfs -text
|
| 80 |
+
experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/curves_head_only.png filter=lfs diff=lfs merge=lfs -text
|
| 81 |
+
experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/curves_kl_penalty.png filter=lfs diff=lfs merge=lfs -text
|
| 82 |
+
experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/curves_layer_ablation_top1.png filter=lfs diff=lfs merge=lfs -text
|
| 83 |
+
experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/curves_layer_ablation_top2.png filter=lfs diff=lfs merge=lfs -text
|
| 84 |
+
experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/curves_layer_ablation_top3.png filter=lfs diff=lfs merge=lfs -text
|
| 85 |
+
experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/curves_llrd.png filter=lfs diff=lfs merge=lfs -text
|
| 86 |
+
experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/curves_lora.png filter=lfs diff=lfs merge=lfs -text
|
| 87 |
+
experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/curves_low_t.png filter=lfs diff=lfs merge=lfs -text
|
| 88 |
+
experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/curves_mixed_replay.png filter=lfs diff=lfs merge=lfs -text
|
| 89 |
+
experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/curves_normalized_adv.png filter=lfs diff=lfs merge=lfs -text
|
| 90 |
+
experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/curves_reward_filtering.png filter=lfs diff=lfs merge=lfs -text
|
| 91 |
+
experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/curves_reward_model.png filter=lfs diff=lfs merge=lfs -text
|
| 92 |
+
experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/curves_running_stats.png filter=lfs diff=lfs merge=lfs -text
|
| 93 |
+
experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/curves_t_curriculum.png filter=lfs diff=lfs merge=lfs -text
|
| 94 |
+
experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/curves_trust_region_kl.png filter=lfs diff=lfs merge=lfs -text
|
| 95 |
+
experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/diagnosis_decision_tree.png filter=lfs diff=lfs merge=lfs -text
|
| 96 |
+
experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/eval_scores_over_training.png filter=lfs diff=lfs merge=lfs -text
|
| 97 |
+
experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/final_score_comparison.png filter=lfs diff=lfs merge=lfs -text
|
| 98 |
+
experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/gradient_alignment.png filter=lfs diff=lfs merge=lfs -text
|
| 99 |
+
experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/gradient_conflict_map.png filter=lfs diff=lfs merge=lfs -text
|
| 100 |
+
experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/per_layer_grad_heatmap_action_diversity.png filter=lfs diff=lfs merge=lfs -text
|
| 101 |
+
experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/per_layer_grad_heatmap_advantage_clip.png filter=lfs diff=lfs merge=lfs -text
|
| 102 |
+
experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/per_layer_grad_heatmap_attention_only.png filter=lfs diff=lfs merge=lfs -text
|
| 103 |
+
experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/per_layer_grad_heatmap_baseline_rl.png filter=lfs diff=lfs merge=lfs -text
|
| 104 |
+
experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/per_layer_grad_heatmap_bc_wins.png filter=lfs diff=lfs merge=lfs -text
|
| 105 |
+
experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/per_layer_grad_heatmap_entropy_bonus.png filter=lfs diff=lfs merge=lfs -text
|
| 106 |
+
experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/per_layer_grad_heatmap_ewc.png filter=lfs diff=lfs merge=lfs -text
|
| 107 |
+
experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/per_layer_grad_heatmap_ffn_only.png filter=lfs diff=lfs merge=lfs -text
|
| 108 |
+
experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/per_layer_grad_heatmap_frozen_backbone.png filter=lfs diff=lfs merge=lfs -text
|
| 109 |
+
experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/per_layer_grad_heatmap_gradient_surgery.png filter=lfs diff=lfs merge=lfs -text
|
| 110 |
+
experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/per_layer_grad_heatmap_head_only.png filter=lfs diff=lfs merge=lfs -text
|
| 111 |
+
experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/per_layer_grad_heatmap_kl_penalty.png filter=lfs diff=lfs merge=lfs -text
|
| 112 |
+
experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/per_layer_grad_heatmap_layer_ablation_top1.png filter=lfs diff=lfs merge=lfs -text
|
| 113 |
+
experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/per_layer_grad_heatmap_layer_ablation_top2.png filter=lfs diff=lfs merge=lfs -text
|
| 114 |
+
experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/per_layer_grad_heatmap_layer_ablation_top3.png filter=lfs diff=lfs merge=lfs -text
|
| 115 |
+
experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/per_layer_grad_heatmap_llrd.png filter=lfs diff=lfs merge=lfs -text
|
| 116 |
+
experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/per_layer_grad_heatmap_lora.png filter=lfs diff=lfs merge=lfs -text
|
| 117 |
+
experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/per_layer_grad_heatmap_low_t.png filter=lfs diff=lfs merge=lfs -text
|
| 118 |
+
experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/per_layer_grad_heatmap_mixed_replay.png filter=lfs diff=lfs merge=lfs -text
|
| 119 |
+
experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/per_layer_grad_heatmap_normalized_adv.png filter=lfs diff=lfs merge=lfs -text
|
| 120 |
+
experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/per_layer_grad_heatmap_reward_filtering.png filter=lfs diff=lfs merge=lfs -text
|
| 121 |
+
experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/per_layer_grad_heatmap_reward_model.png filter=lfs diff=lfs merge=lfs -text
|
| 122 |
+
experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/per_layer_grad_heatmap_running_stats.png filter=lfs diff=lfs merge=lfs -text
|
| 123 |
+
experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/per_layer_grad_heatmap_t_curriculum.png filter=lfs diff=lfs merge=lfs -text
|
| 124 |
+
experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/per_layer_grad_heatmap_trust_region_kl.png filter=lfs diff=lfs merge=lfs -text
|
| 125 |
+
experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/representation_drift.png filter=lfs diff=lfs merge=lfs -text
|
| 126 |
+
experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/score_delta_over_baseline_rl.png filter=lfs diff=lfs merge=lfs -text
|
| 127 |
+
experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/t_bin_grad_norms_action_diversity.png filter=lfs diff=lfs merge=lfs -text
|
| 128 |
+
experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/t_bin_grad_norms_advantage_clip.png filter=lfs diff=lfs merge=lfs -text
|
| 129 |
+
experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/t_bin_grad_norms_attention_only.png filter=lfs diff=lfs merge=lfs -text
|
| 130 |
+
experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/t_bin_grad_norms_baseline_rl.png filter=lfs diff=lfs merge=lfs -text
|
| 131 |
+
experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/t_bin_grad_norms_bc_wins.png filter=lfs diff=lfs merge=lfs -text
|
| 132 |
+
experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/t_bin_grad_norms_entropy_bonus.png filter=lfs diff=lfs merge=lfs -text
|
| 133 |
+
experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/t_bin_grad_norms_ewc.png filter=lfs diff=lfs merge=lfs -text
|
| 134 |
+
experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/t_bin_grad_norms_ffn_only.png filter=lfs diff=lfs merge=lfs -text
|
| 135 |
+
experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/t_bin_grad_norms_frozen_backbone.png filter=lfs diff=lfs merge=lfs -text
|
| 136 |
+
experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/t_bin_grad_norms_gradient_surgery.png filter=lfs diff=lfs merge=lfs -text
|
| 137 |
+
experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/t_bin_grad_norms_head_only.png filter=lfs diff=lfs merge=lfs -text
|
| 138 |
+
experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/t_bin_grad_norms_kl_penalty.png filter=lfs diff=lfs merge=lfs -text
|
| 139 |
+
experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/t_bin_grad_norms_layer_ablation_top1.png filter=lfs diff=lfs merge=lfs -text
|
| 140 |
+
experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/t_bin_grad_norms_layer_ablation_top2.png filter=lfs diff=lfs merge=lfs -text
|
| 141 |
+
experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/t_bin_grad_norms_layer_ablation_top3.png filter=lfs diff=lfs merge=lfs -text
|
| 142 |
+
experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/t_bin_grad_norms_llrd.png filter=lfs diff=lfs merge=lfs -text
|
| 143 |
+
experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/t_bin_grad_norms_low_t.png filter=lfs diff=lfs merge=lfs -text
|
| 144 |
+
experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/t_bin_grad_norms_mixed_replay.png filter=lfs diff=lfs merge=lfs -text
|
| 145 |
+
experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/t_bin_grad_norms_normalized_adv.png filter=lfs diff=lfs merge=lfs -text
|
| 146 |
+
experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/t_bin_grad_norms_reward_filtering.png filter=lfs diff=lfs merge=lfs -text
|
| 147 |
+
experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/t_bin_grad_norms_reward_model.png filter=lfs diff=lfs merge=lfs -text
|
| 148 |
+
experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/t_bin_grad_norms_running_stats.png filter=lfs diff=lfs merge=lfs -text
|
| 149 |
+
experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/t_bin_grad_norms_t_curriculum.png filter=lfs diff=lfs merge=lfs -text
|
| 150 |
+
experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/t_bin_grad_norms_trust_region_kl.png filter=lfs diff=lfs merge=lfs -text
|
| 151 |
+
experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/t_bin_norms_heatmap.png filter=lfs diff=lfs merge=lfs -text
|
| 152 |
+
experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/t_distribution_analysis.png filter=lfs diff=lfs merge=lfs -text
|
| 153 |
+
experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/win_rate_and_effective_batch_size.png filter=lfs diff=lfs merge=lfs -text
|
Craftax_Baselines/.gitignore
ADDED
|
@@ -0,0 +1,169 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
tmp/
|
| 2 |
+
wandb/
|
| 3 |
+
res/
|
| 4 |
+
runs/
|
| 5 |
+
|
| 6 |
+
play_data
|
| 7 |
+
|
| 8 |
+
# Byte-compiled / optimized / DLL files
|
| 9 |
+
__pycache__/
|
| 10 |
+
*.py[cod]
|
| 11 |
+
*$py.class
|
| 12 |
+
|
| 13 |
+
# C extensions
|
| 14 |
+
*.so
|
| 15 |
+
|
| 16 |
+
# Distribution / packaging
|
| 17 |
+
.Python
|
| 18 |
+
build/
|
| 19 |
+
develop-eggs/
|
| 20 |
+
dist/
|
| 21 |
+
downloads/
|
| 22 |
+
eggs/
|
| 23 |
+
.eggs/
|
| 24 |
+
lib/
|
| 25 |
+
lib64/
|
| 26 |
+
parts/
|
| 27 |
+
sdist/
|
| 28 |
+
var/
|
| 29 |
+
wheels/
|
| 30 |
+
share/python-wheels/
|
| 31 |
+
*.egg-info/
|
| 32 |
+
.installed.cfg
|
| 33 |
+
*.egg
|
| 34 |
+
MANIFEST
|
| 35 |
+
|
| 36 |
+
# PyInstaller
|
| 37 |
+
# Usually these files are written by a python script from a template
|
| 38 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
| 39 |
+
*.manifest
|
| 40 |
+
*.spec
|
| 41 |
+
|
| 42 |
+
# Installer logs
|
| 43 |
+
pip-log.txt
|
| 44 |
+
pip-delete-this-directory.txt
|
| 45 |
+
|
| 46 |
+
# Unit test / coverage reports
|
| 47 |
+
htmlcov/
|
| 48 |
+
.tox/
|
| 49 |
+
.nox/
|
| 50 |
+
.coverage
|
| 51 |
+
.coverage.*
|
| 52 |
+
.cache
|
| 53 |
+
nosetests.xml
|
| 54 |
+
coverage.xml
|
| 55 |
+
*.cover
|
| 56 |
+
*.py,cover
|
| 57 |
+
.hypothesis/
|
| 58 |
+
.pytest_cache/
|
| 59 |
+
cover/
|
| 60 |
+
|
| 61 |
+
# Translations
|
| 62 |
+
*.mo
|
| 63 |
+
*.pot
|
| 64 |
+
|
| 65 |
+
# Django stuff:
|
| 66 |
+
*.log
|
| 67 |
+
local_settings.py
|
| 68 |
+
db.sqlite3
|
| 69 |
+
db.sqlite3-journal
|
| 70 |
+
|
| 71 |
+
# Flask stuff:
|
| 72 |
+
instance/
|
| 73 |
+
.webassets-cache
|
| 74 |
+
|
| 75 |
+
# Scrapy stuff:
|
| 76 |
+
.scrapy
|
| 77 |
+
|
| 78 |
+
# Sphinx documentation
|
| 79 |
+
docs/_build/
|
| 80 |
+
|
| 81 |
+
# PyBuilder
|
| 82 |
+
.pybuilder/
|
| 83 |
+
target/
|
| 84 |
+
|
| 85 |
+
# Jupyter Notebook
|
| 86 |
+
.ipynb_checkpoints
|
| 87 |
+
|
| 88 |
+
# IPython
|
| 89 |
+
profile_default/
|
| 90 |
+
ipython_config.py
|
| 91 |
+
|
| 92 |
+
# pyenv
|
| 93 |
+
# For a library or package, you might want to ignore these files since the code is
|
| 94 |
+
# intended to run in multiple environments; otherwise, check them in:
|
| 95 |
+
# .python-version
|
| 96 |
+
|
| 97 |
+
# pipenv
|
| 98 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
| 99 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
| 100 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
| 101 |
+
# install all needed dependencies.
|
| 102 |
+
#Pipfile.lock
|
| 103 |
+
|
| 104 |
+
# poetry
|
| 105 |
+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
| 106 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
| 107 |
+
# commonly ignored for libraries.
|
| 108 |
+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
| 109 |
+
#poetry.lock
|
| 110 |
+
|
| 111 |
+
# pdm
|
| 112 |
+
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
| 113 |
+
#pdm.lock
|
| 114 |
+
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
| 115 |
+
# in version control.
|
| 116 |
+
# https://pdm.fming.dev/#use-with-ide
|
| 117 |
+
.pdm.toml
|
| 118 |
+
|
| 119 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
| 120 |
+
__pypackages__/
|
| 121 |
+
|
| 122 |
+
# Celery stuff
|
| 123 |
+
celerybeat-schedule
|
| 124 |
+
celerybeat.pid
|
| 125 |
+
|
| 126 |
+
# SageMath parsed files
|
| 127 |
+
*.sage.py
|
| 128 |
+
|
| 129 |
+
# Environments
|
| 130 |
+
.env
|
| 131 |
+
.venv
|
| 132 |
+
env/
|
| 133 |
+
venv/
|
| 134 |
+
ENV/
|
| 135 |
+
env.bak/
|
| 136 |
+
venv.bak/
|
| 137 |
+
|
| 138 |
+
# Spyder project settings
|
| 139 |
+
.spyderproject
|
| 140 |
+
.spyproject
|
| 141 |
+
|
| 142 |
+
# Rope project settings
|
| 143 |
+
.ropeproject
|
| 144 |
+
|
| 145 |
+
# mkdocs documentation
|
| 146 |
+
/site
|
| 147 |
+
|
| 148 |
+
# mypy
|
| 149 |
+
.mypy_cache/
|
| 150 |
+
.dmypy.json
|
| 151 |
+
dmypy.json
|
| 152 |
+
|
| 153 |
+
# Pyre type checker
|
| 154 |
+
.pyre/
|
| 155 |
+
|
| 156 |
+
# pytype static type analyzer
|
| 157 |
+
.pytype/
|
| 158 |
+
|
| 159 |
+
# Cython debug symbols
|
| 160 |
+
cython_debug/
|
| 161 |
+
|
| 162 |
+
# PyCharm
|
| 163 |
+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
| 164 |
+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
| 165 |
+
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
| 166 |
+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
| 167 |
+
.idea/
|
| 168 |
+
texture_cache.pbz2
|
| 169 |
+
texture_cache*.pbz2
|
Craftax_Baselines/.pre-commit-config.yaml
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
repos:
|
| 2 |
+
- repo: https://github.com/psf/black
|
| 3 |
+
rev: 22.3.0
|
| 4 |
+
hooks:
|
| 5 |
+
- id: black
|
| 6 |
+
language_version: python3
|
Craftax_Baselines/Dockerfile
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
FROM nvidia/cuda:12.1.0-cudnn8-devel-ubuntu22.04
|
| 2 |
+
|
| 3 |
+
ENV CUDA_PATH /usr/local/cuda
|
| 4 |
+
ENV CUDA_INCLUDE_PATH /usr/local/cuda/include
|
| 5 |
+
ENV CUDA_LIBRARY_PATH /usr/local/cuda/lib64
|
| 6 |
+
|
| 7 |
+
# Set timezone
|
| 8 |
+
ENV TZ=Europe/London DEBIAN_FRONTEND=noninteractive
|
| 9 |
+
|
| 10 |
+
# Add Python 3.8 to Ubuntu 22.04 and install dependencies
|
| 11 |
+
RUN apt update
|
| 12 |
+
RUN apt install -y software-properties-common && add-apt-repository ppa:deadsnakes/ppa
|
| 13 |
+
RUN apt install -y \
|
| 14 |
+
git \
|
| 15 |
+
python3.8 \
|
| 16 |
+
python3-pip \
|
| 17 |
+
python3.8-venv \
|
| 18 |
+
python3-setuptools \
|
| 19 |
+
python3-wheel
|
| 20 |
+
|
| 21 |
+
# Create local user
|
| 22 |
+
# https://jtreminio.com/blog/running-docker-containers-as-current-host-user/
|
| 23 |
+
ARG UID
|
| 24 |
+
ARG GID
|
| 25 |
+
RUN if [ ${UID:-0} -ne 0 ] && [ ${GID:-0} -ne 0 ]; then \
|
| 26 |
+
groupadd -g ${GID} duser &&\
|
| 27 |
+
useradd -l -u ${UID} -g duser duser &&\
|
| 28 |
+
install -d -m 0755 -o duser -g duser /home/duser &&\
|
| 29 |
+
chown --changes --silent --no-dereference --recursive ${UID}:${GID} /home/duser \
|
| 30 |
+
;fi
|
| 31 |
+
|
| 32 |
+
USER duser
|
| 33 |
+
WORKDIR /home/duser
|
| 34 |
+
|
| 35 |
+
# Install Python packages
|
| 36 |
+
ENV PATH="/home/duser/.local/bin:$PATH"
|
| 37 |
+
RUN python3 -m pip install --upgrade pip
|
| 38 |
+
ARG REQS
|
| 39 |
+
RUN pip install $REQS -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
|
| 40 |
+
|
| 41 |
+
WORKDIR /home/duser/Craftax
|
Craftax_Baselines/LICENSE
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Copyright (c) 2024 Michael Matthews
|
| 2 |
+
|
| 3 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 4 |
+
of this software and associated documentation files (the "Software"), to deal
|
| 5 |
+
in the Software without restriction, including without limitation the rights
|
| 6 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 7 |
+
copies of the Software, and to permit persons to whom the Software is
|
| 8 |
+
furnished to do so, subject to the following conditions:
|
| 9 |
+
|
| 10 |
+
The above copyright notice and this permission notice shall be included in all
|
| 11 |
+
copies or substantial portions of the Software.
|
| 12 |
+
|
| 13 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 14 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 15 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 16 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 17 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 18 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 19 |
+
SOFTWARE.
|
Craftax_Baselines/README.md
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<p align="center">
|
| 2 |
+
<img width="80%" src="https://raw.githubusercontent.com/MichaelTMatthews/Craftax_Baselines/main/images/logo.png" />
|
| 3 |
+
</p>
|
| 4 |
+
|
| 5 |
+
# Craftax Baselines
|
| 6 |
+
|
| 7 |
+
This repository contains the code for running the baselines from the [Craftax paper](https://arxiv.org/abs/2402.16801).
|
| 8 |
+
For packaging reasons, this is separate to the [main repository](https://github.com/MichaelTMatthews/Craftax/).
|
| 9 |
+
|
| 10 |
+
# Installation
|
| 11 |
+
```commandline
|
| 12 |
+
git clone https://github.com/MichaelTMatthews/Craftax_Baselines.git
|
| 13 |
+
cd Craftax_Baselines
|
| 14 |
+
pip install -r requirements.txt -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
|
| 15 |
+
pre-commit install
|
| 16 |
+
```
|
| 17 |
+
|
| 18 |
+
# Run Experiments
|
| 19 |
+
|
| 20 |
+
### PPO
|
| 21 |
+
```commandline
|
| 22 |
+
python ppo.py
|
| 23 |
+
```
|
| 24 |
+
|
| 25 |
+
### PPO-RNN
|
| 26 |
+
```commandline
|
| 27 |
+
python ppo_rnn.py
|
| 28 |
+
```
|
| 29 |
+
|
| 30 |
+
### ICM
|
| 31 |
+
```commandline
|
| 32 |
+
python ppo.py --train_icm
|
| 33 |
+
```
|
| 34 |
+
|
| 35 |
+
### E3B
|
| 36 |
+
```commandline
|
| 37 |
+
python ppo.py --train_icm --use_e3b --icm_reward_coeff 0
|
| 38 |
+
```
|
| 39 |
+
|
| 40 |
+
### RND
|
| 41 |
+
```commandline
|
| 42 |
+
python ppo_rnd.py
|
| 43 |
+
```
|
| 44 |
+
|
| 45 |
+
# Visualisation
|
| 46 |
+
You can save trained policies with the `--save_policy` flag. These can then be viewed with the `view_ppo_agent` script (pass in the path up to the `files` directory).
|
Craftax_Baselines/analysis/__init__.py
ADDED
|
File without changes
|
Craftax_Baselines/analysis/view_ppo_agent.py
ADDED
|
@@ -0,0 +1,151 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import os
|
| 3 |
+
import sys
|
| 4 |
+
|
| 5 |
+
import jax
|
| 6 |
+
import jax.numpy as jnp
|
| 7 |
+
import numpy as np
|
| 8 |
+
import optax
|
| 9 |
+
import yaml
|
| 10 |
+
from craftax.environment_base.wrappers import AutoResetEnvWrapper
|
| 11 |
+
from flax.training.train_state import TrainState
|
| 12 |
+
import orbax.checkpoint as ocp
|
| 13 |
+
|
| 14 |
+
from ..models.actor_critic import ActorCriticConv, ActorCritic
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def main(args):
|
| 18 |
+
|
| 19 |
+
with open(os.path.join(args.path, "config.yaml")) as f:
|
| 20 |
+
raw_config = yaml.load(f, Loader=yaml.Loader)
|
| 21 |
+
|
| 22 |
+
config = {}
|
| 23 |
+
for key, value in raw_config.items():
|
| 24 |
+
if isinstance(value, dict) and "value" in value:
|
| 25 |
+
config[key] = value["value"]
|
| 26 |
+
|
| 27 |
+
config["NUM_ENVS"] = 1
|
| 28 |
+
|
| 29 |
+
options = ocp.CheckpointManagerOptions(max_to_keep=1)
|
| 30 |
+
checkpoint_manager = ocp.CheckpointManager(
|
| 31 |
+
os.path.join(args.path, "policies"),
|
| 32 |
+
options=options
|
| 33 |
+
)
|
| 34 |
+
|
| 35 |
+
is_classic = False
|
| 36 |
+
|
| 37 |
+
if config["ENV_NAME"] == "Craftax-Symbolic-v1":
|
| 38 |
+
from craftax.craftax.envs.craftax_symbolic_env import CraftaxSymbolicEnv
|
| 39 |
+
from craftax.craftax.constants import Action
|
| 40 |
+
|
| 41 |
+
env = CraftaxSymbolicEnv(CraftaxSymbolicEnv.default_static_params())
|
| 42 |
+
network = ActorCritic(len(Action), config["LAYER_SIZE"])
|
| 43 |
+
elif config["ENV_NAME"] == "Craftax-Pixels-v1":
|
| 44 |
+
from craftax.craftax.envs.craftax_pixels_env import CraftaxPixelsEnv
|
| 45 |
+
from craftax.craftax.constants import Action
|
| 46 |
+
|
| 47 |
+
env = CraftaxPixelsEnv(CraftaxPixelsEnv.default_static_params())
|
| 48 |
+
network = ActorCriticConv(len(Action), config["LAYER_SIZE"])
|
| 49 |
+
elif config["ENV_NAME"] == "Craftax-Classic-Symbolic-v1":
|
| 50 |
+
from craftax.craftax_classic.envs.craftax_symbolic_env import (
|
| 51 |
+
CraftaxClassicSymbolicEnv,
|
| 52 |
+
)
|
| 53 |
+
from craftax.craftax_classic.constants import Action
|
| 54 |
+
|
| 55 |
+
env = CraftaxClassicSymbolicEnv(
|
| 56 |
+
CraftaxClassicSymbolicEnv.default_static_params()
|
| 57 |
+
)
|
| 58 |
+
network = ActorCritic(len(Action), config["LAYER_SIZE"])
|
| 59 |
+
is_classic = True
|
| 60 |
+
elif config["ENV_NAME"] == "Craftax-Classic-Pixels-v1":
|
| 61 |
+
from craftax.craftax_classic.envs.craftax_pixels_env import (
|
| 62 |
+
CraftaxClassicPixelsEnv,
|
| 63 |
+
)
|
| 64 |
+
from craftax.craftax_classic.constants import Action
|
| 65 |
+
|
| 66 |
+
env = CraftaxClassicPixelsEnv(CraftaxClassicPixelsEnv.default_static_params())
|
| 67 |
+
network = ActorCriticConv(len(Action), config["LAYER_SIZE"])
|
| 68 |
+
is_classic = True
|
| 69 |
+
else:
|
| 70 |
+
raise ValueError(f"Unknown env: {config['ENV_NAME']}")
|
| 71 |
+
|
| 72 |
+
env = AutoResetEnvWrapper(env)
|
| 73 |
+
env_params = env.default_params
|
| 74 |
+
|
| 75 |
+
init_x = jnp.zeros((config["NUM_ENVS"], *env.observation_space(env_params).shape))
|
| 76 |
+
|
| 77 |
+
rng = jax.random.PRNGKey(np.random.randint(2**31))
|
| 78 |
+
rng, _rng, __rng = jax.random.split(rng, 3)
|
| 79 |
+
network_params = network.init(_rng, init_x)
|
| 80 |
+
|
| 81 |
+
tx = optax.chain(
|
| 82 |
+
optax.clip_by_global_norm(config["MAX_GRAD_NORM"]),
|
| 83 |
+
optax.adam(config["LR"], eps=1e-5),
|
| 84 |
+
)
|
| 85 |
+
train_state = TrainState.create(
|
| 86 |
+
apply_fn=network.apply,
|
| 87 |
+
params=network_params,
|
| 88 |
+
tx=tx,
|
| 89 |
+
)
|
| 90 |
+
|
| 91 |
+
abstract_train_state = jax.eval_shape(lambda: train_state)
|
| 92 |
+
|
| 93 |
+
train_state = checkpoint_manager.restore(
|
| 94 |
+
config["TOTAL_TIMESTEPS"],
|
| 95 |
+
args=ocp.args.StandardRestore(abstract_train_state)
|
| 96 |
+
)
|
| 97 |
+
|
| 98 |
+
obs, env_state = env.reset(key=__rng)
|
| 99 |
+
done = 0
|
| 100 |
+
|
| 101 |
+
if is_classic:
|
| 102 |
+
from craftax.craftax_classic.play_craftax_classic import CraftaxRenderer
|
| 103 |
+
from craftax.craftax_classic.constants import Achievement
|
| 104 |
+
else:
|
| 105 |
+
from craftax.craftax.play_craftax import CraftaxRenderer
|
| 106 |
+
from craftax.craftax.constants import Achievement
|
| 107 |
+
|
| 108 |
+
renderer = CraftaxRenderer(env, env_params, pixel_render_size=1)
|
| 109 |
+
|
| 110 |
+
while not renderer.is_quit_requested():
|
| 111 |
+
done = np.array([done], dtype=bool)
|
| 112 |
+
obs = jnp.expand_dims(obs, axis=0)
|
| 113 |
+
|
| 114 |
+
pi, value = network.apply(train_state.params, obs)
|
| 115 |
+
rng, _rng = jax.random.split(rng)
|
| 116 |
+
action = pi.sample(seed=_rng)[0]
|
| 117 |
+
# action = jnp.argmax(pi.probs[0, 0])
|
| 118 |
+
|
| 119 |
+
if action is not None:
|
| 120 |
+
rng, _rng = jax.random.split(rng)
|
| 121 |
+
old_achievements = env_state.achievements
|
| 122 |
+
obs, env_state, reward, done, info = env.step(
|
| 123 |
+
_rng, env_state, action, env_params
|
| 124 |
+
)
|
| 125 |
+
new_achievements = env_state.achievements
|
| 126 |
+
print_new_achievements(Achievement, old_achievements, new_achievements)
|
| 127 |
+
if done:
|
| 128 |
+
print("\n")
|
| 129 |
+
renderer.render(env_state)
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
def print_new_achievements(achievements_cls, old_achievements, new_achievements):
|
| 133 |
+
for i in range(len(old_achievements)):
|
| 134 |
+
if old_achievements[i] == 0 and new_achievements[i] == 1:
|
| 135 |
+
print(f"{achievements_cls(i).name} ({new_achievements.sum()}/{22})")
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
if __name__ == "__main__":
|
| 139 |
+
parser = argparse.ArgumentParser()
|
| 140 |
+
parser.add_argument("--path", type=str)
|
| 141 |
+
parser.add_argument("--debug", action="store_true")
|
| 142 |
+
|
| 143 |
+
args, rest_args = parser.parse_known_args(sys.argv[1:])
|
| 144 |
+
if rest_args:
|
| 145 |
+
raise ValueError(f"Unknown args {rest_args}")
|
| 146 |
+
|
| 147 |
+
if args.debug:
|
| 148 |
+
with jax.disable_jit():
|
| 149 |
+
main(args)
|
| 150 |
+
else:
|
| 151 |
+
main(args)
|
Craftax_Baselines/build.sh
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
|
| 3 |
+
echo 'Building Dockerfile with image name craftax'
|
| 4 |
+
docker build \
|
| 5 |
+
--build-arg UID=$(id -u ${USER}) \
|
| 6 |
+
--build-arg GID=1234 \
|
| 7 |
+
--build-arg REQS="$(cat requirements.txt)" \
|
| 8 |
+
-t craftax_baselines \
|
| 9 |
+
--no-cache \
|
| 10 |
+
.
|
Craftax_Baselines/images/logo.png
ADDED
|
Craftax_Baselines/logz/__init__.py
ADDED
|
File without changes
|
Craftax_Baselines/logz/batch_logging.py
ADDED
|
@@ -0,0 +1,115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import time
|
| 2 |
+
|
| 3 |
+
import jax.numpy as jnp
|
| 4 |
+
import numpy as np
|
| 5 |
+
import wandb
|
| 6 |
+
|
| 7 |
+
batch_logs = {}
|
| 8 |
+
log_times = []
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def create_log_dict(info, config):
|
| 12 |
+
to_log = {
|
| 13 |
+
"episode_return": info["returned_episode_returns"],
|
| 14 |
+
"episode_length": info["returned_episode_lengths"],
|
| 15 |
+
}
|
| 16 |
+
|
| 17 |
+
diffusion_keys = [
|
| 18 |
+
"loss", "unweighted_loss", "accuracy", "mean_t",
|
| 19 |
+
"acc_t_low", "acc_t_mid", "acc_t_high", "grad_norm",
|
| 20 |
+
"action_entropy", "action_unique_frac"
|
| 21 |
+
]
|
| 22 |
+
for k in diffusion_keys:
|
| 23 |
+
if k in info:
|
| 24 |
+
to_log[f"diffusion/{k}"] = info[k]
|
| 25 |
+
|
| 26 |
+
sum_achievements = 0.0
|
| 27 |
+
sum_val_achievements = 0.0
|
| 28 |
+
has_val = False
|
| 29 |
+
|
| 30 |
+
for k, v in info.items():
|
| 31 |
+
if k.startswith("val/"):
|
| 32 |
+
has_val = True
|
| 33 |
+
to_log[k] = v
|
| 34 |
+
if "achievements" in k.lower() and k != "val/achievements":
|
| 35 |
+
sum_val_achievements += v / 100.0
|
| 36 |
+
elif "achievements" in k.lower():
|
| 37 |
+
to_log[k] = v
|
| 38 |
+
if k != "achievements":
|
| 39 |
+
sum_achievements += v / 100.0
|
| 40 |
+
|
| 41 |
+
to_log["achievements"] = sum_achievements
|
| 42 |
+
if has_val:
|
| 43 |
+
to_log["val/achievements"] = sum_val_achievements
|
| 44 |
+
|
| 45 |
+
if config.get("TRAIN_ICM") or config.get("USE_RND"):
|
| 46 |
+
to_log["intrinsic_reward"] = info.get("reward_i", 0.0)
|
| 47 |
+
to_log["extrinsic_reward"] = info.get("reward_e", 0.0)
|
| 48 |
+
|
| 49 |
+
if config.get("TRAIN_ICM"):
|
| 50 |
+
to_log["icm_inverse_loss"] = info.get("icm_inverse_loss", 0.0)
|
| 51 |
+
to_log["icm_forward_loss"] = info.get("icm_forward_loss", 0.0)
|
| 52 |
+
elif config.get("USE_RND"):
|
| 53 |
+
to_log["rnd_loss"] = info.get("rnd_loss", 0.0)
|
| 54 |
+
|
| 55 |
+
return to_log
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def batch_log(update_step, log, config):
|
| 59 |
+
update_step = int(update_step)
|
| 60 |
+
if update_step not in batch_logs:
|
| 61 |
+
batch_logs[update_step] = []
|
| 62 |
+
|
| 63 |
+
batch_logs[update_step].append(log)
|
| 64 |
+
|
| 65 |
+
if len(batch_logs[update_step]) == config.get("NUM_REPEATS", 1):
|
| 66 |
+
agg_logs = {}
|
| 67 |
+
for key in batch_logs[update_step][0]:
|
| 68 |
+
agg = []
|
| 69 |
+
if key in ["goal_heatmap"]:
|
| 70 |
+
agg = [batch_logs[update_step][0][key]]
|
| 71 |
+
else:
|
| 72 |
+
for i in range(config.get("NUM_REPEATS", 1)):
|
| 73 |
+
# Use .get() to prevent KeyErrors if repeats are out of sync
|
| 74 |
+
val = batch_logs[update_step][i].get(key, float("nan"))
|
| 75 |
+
if not jnp.isnan(val):
|
| 76 |
+
agg.append(val)
|
| 77 |
+
|
| 78 |
+
if len(agg) > 0:
|
| 79 |
+
if key in [
|
| 80 |
+
"episode_length",
|
| 81 |
+
"episode_return",
|
| 82 |
+
"exploration_bonus",
|
| 83 |
+
"e_mean",
|
| 84 |
+
"e_std",
|
| 85 |
+
"rnd_loss",
|
| 86 |
+
"diffusion/loss",
|
| 87 |
+
"diffusion/unweighted_loss",
|
| 88 |
+
"diffusion/accuracy",
|
| 89 |
+
"diffusion/acc_t_low",
|
| 90 |
+
"diffusion/acc_t_mid",
|
| 91 |
+
"diffusion/acc_t_high",
|
| 92 |
+
"diffusion/action_entropy",
|
| 93 |
+
"diffusion/grad_norm"
|
| 94 |
+
] or key.startswith("val/") or "achievement" in key.lower():
|
| 95 |
+
agg_logs[key] = np.mean(agg)
|
| 96 |
+
else:
|
| 97 |
+
agg_logs[key] = np.array(agg)
|
| 98 |
+
|
| 99 |
+
log_times.append(time.time())
|
| 100 |
+
|
| 101 |
+
if config.get("DEBUG"):
|
| 102 |
+
if len(log_times) == 1:
|
| 103 |
+
print("Started logging")
|
| 104 |
+
elif len(log_times) > 1:
|
| 105 |
+
dt = log_times[-1] - log_times[-2]
|
| 106 |
+
steps_between_updates = (
|
| 107 |
+
config["NUM_STEPS"] * config["NUM_ENVS"] * config.get("NUM_REPEATS", 1)
|
| 108 |
+
)
|
| 109 |
+
sps = steps_between_updates / dt
|
| 110 |
+
agg_logs["sps"] = sps
|
| 111 |
+
|
| 112 |
+
wandb.log(agg_logs)
|
| 113 |
+
|
| 114 |
+
# Clear buffer to prevent memory leaks
|
| 115 |
+
del batch_logs[update_step]
|
Craftax_Baselines/models/__init__.py
ADDED
|
File without changes
|
Craftax_Baselines/models/actor_critic.py
ADDED
|
@@ -0,0 +1,256 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import jax.numpy as jnp
|
| 2 |
+
import flax.linen as nn
|
| 3 |
+
import numpy as np
|
| 4 |
+
from flax.linen.initializers import constant, orthogonal
|
| 5 |
+
from typing import Sequence
|
| 6 |
+
|
| 7 |
+
import distrax
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class ActorCriticConvSymbolicCraftax(nn.Module):
|
| 11 |
+
action_dim: int
|
| 12 |
+
map_obs_shape: Sequence[int]
|
| 13 |
+
layer_width: int
|
| 14 |
+
|
| 15 |
+
@nn.compact
|
| 16 |
+
def __call__(self, obs):
|
| 17 |
+
# Split into map and flat obs
|
| 18 |
+
flat_map_obs_shape = (
|
| 19 |
+
self.map_obs_shape[0] * self.map_obs_shape[1] * self.map_obs_shape[2]
|
| 20 |
+
)
|
| 21 |
+
image_obs = obs[:, :flat_map_obs_shape]
|
| 22 |
+
image_dim = self.map_obs_shape
|
| 23 |
+
image_obs = image_obs.reshape((image_obs.shape[0], *image_dim))
|
| 24 |
+
|
| 25 |
+
flat_obs = obs[:, flat_map_obs_shape:]
|
| 26 |
+
|
| 27 |
+
# Convolutions on map
|
| 28 |
+
image_embedding = nn.Conv(features=32, kernel_size=(2, 2))(image_obs)
|
| 29 |
+
image_embedding = nn.relu(image_embedding)
|
| 30 |
+
image_embedding = nn.max_pool(
|
| 31 |
+
image_embedding, window_shape=(2, 2), strides=(1, 1)
|
| 32 |
+
)
|
| 33 |
+
image_embedding = nn.Conv(features=32, kernel_size=(2, 2))(image_embedding)
|
| 34 |
+
image_embedding = nn.relu(image_embedding)
|
| 35 |
+
image_embedding = nn.max_pool(
|
| 36 |
+
image_embedding, window_shape=(2, 2), strides=(1, 1)
|
| 37 |
+
)
|
| 38 |
+
image_embedding = image_embedding.reshape(image_embedding.shape[0], -1)
|
| 39 |
+
# image_embedding = jnp.concatenate([image_embedding, obs[:, : CraftaxEnv.get_flat_map_obs_shape()]], axis=-1)
|
| 40 |
+
|
| 41 |
+
# Combine embeddings
|
| 42 |
+
embedding = jnp.concatenate([image_embedding, flat_obs], axis=-1)
|
| 43 |
+
embedding = nn.Dense(
|
| 44 |
+
self.layer_width, kernel_init=orthogonal(2), bias_init=constant(0.0)
|
| 45 |
+
)(embedding)
|
| 46 |
+
embedding = nn.relu(embedding)
|
| 47 |
+
|
| 48 |
+
actor_mean = nn.Dense(
|
| 49 |
+
self.layer_width, kernel_init=orthogonal(2), bias_init=constant(0.0)
|
| 50 |
+
)(embedding)
|
| 51 |
+
actor_mean = nn.relu(actor_mean)
|
| 52 |
+
|
| 53 |
+
actor_mean = nn.Dense(
|
| 54 |
+
self.action_dim, kernel_init=orthogonal(0.01), bias_init=constant(0.0)
|
| 55 |
+
)(actor_mean)
|
| 56 |
+
actor_mean = nn.relu(actor_mean)
|
| 57 |
+
|
| 58 |
+
actor_mean = nn.Dense(
|
| 59 |
+
self.action_dim, kernel_init=orthogonal(0.01), bias_init=constant(0.0)
|
| 60 |
+
)(actor_mean)
|
| 61 |
+
|
| 62 |
+
pi = distrax.Categorical(logits=actor_mean)
|
| 63 |
+
|
| 64 |
+
critic = nn.Dense(
|
| 65 |
+
self.layer_width, kernel_init=orthogonal(2), bias_init=constant(0.0)
|
| 66 |
+
)(embedding)
|
| 67 |
+
critic = nn.relu(critic)
|
| 68 |
+
critic = nn.Dense(
|
| 69 |
+
self.layer_width, kernel_init=orthogonal(2), bias_init=constant(0.0)
|
| 70 |
+
)(critic)
|
| 71 |
+
critic = nn.relu(critic)
|
| 72 |
+
critic = nn.Dense(
|
| 73 |
+
self.layer_width, kernel_init=orthogonal(2), bias_init=constant(0.0)
|
| 74 |
+
)(critic)
|
| 75 |
+
critic = nn.relu(critic)
|
| 76 |
+
critic = nn.Dense(1, kernel_init=orthogonal(1.0), bias_init=constant(0.0))(
|
| 77 |
+
critic
|
| 78 |
+
)
|
| 79 |
+
|
| 80 |
+
return pi, jnp.squeeze(critic, axis=-1)
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
class ActorCriticConv(nn.Module):
|
| 84 |
+
action_dim: int
|
| 85 |
+
layer_width: int
|
| 86 |
+
activation: str = "tanh"
|
| 87 |
+
|
| 88 |
+
@nn.compact
|
| 89 |
+
def __call__(self, obs):
|
| 90 |
+
x = nn.Conv(features=32, kernel_size=(5, 5))(obs)
|
| 91 |
+
x = nn.relu(x)
|
| 92 |
+
x = nn.max_pool(x, window_shape=(3, 3), strides=(3, 3))
|
| 93 |
+
x = nn.Conv(features=32, kernel_size=(5, 5))(x)
|
| 94 |
+
x = nn.relu(x)
|
| 95 |
+
x = nn.max_pool(x, window_shape=(3, 3), strides=(3, 3))
|
| 96 |
+
x = nn.Conv(features=32, kernel_size=(5, 5))(x)
|
| 97 |
+
x = nn.relu(x)
|
| 98 |
+
x = nn.max_pool(x, window_shape=(3, 3), strides=(3, 3))
|
| 99 |
+
|
| 100 |
+
embedding = x.reshape(x.shape[0], -1)
|
| 101 |
+
|
| 102 |
+
actor_mean = nn.Dense(
|
| 103 |
+
self.layer_width, kernel_init=orthogonal(2), bias_init=constant(0.0)
|
| 104 |
+
)(embedding)
|
| 105 |
+
actor_mean = nn.relu(actor_mean)
|
| 106 |
+
|
| 107 |
+
actor_mean = nn.Dense(
|
| 108 |
+
self.action_dim, kernel_init=orthogonal(0.01), bias_init=constant(0.0)
|
| 109 |
+
)(actor_mean)
|
| 110 |
+
actor_mean = nn.relu(actor_mean)
|
| 111 |
+
|
| 112 |
+
actor_mean = nn.Dense(
|
| 113 |
+
self.action_dim, kernel_init=orthogonal(0.01), bias_init=constant(0.0)
|
| 114 |
+
)(actor_mean)
|
| 115 |
+
|
| 116 |
+
pi = distrax.Categorical(logits=actor_mean)
|
| 117 |
+
|
| 118 |
+
critic = nn.Dense(
|
| 119 |
+
self.layer_width, kernel_init=orthogonal(2), bias_init=constant(0.0)
|
| 120 |
+
)(embedding)
|
| 121 |
+
critic = nn.relu(critic)
|
| 122 |
+
critic = nn.Dense(1, kernel_init=orthogonal(1.0), bias_init=constant(0.0))(
|
| 123 |
+
critic
|
| 124 |
+
)
|
| 125 |
+
|
| 126 |
+
return pi, jnp.squeeze(critic, axis=-1)
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
class ActorCritic(nn.Module):
|
| 130 |
+
action_dim: int
|
| 131 |
+
layer_width: int
|
| 132 |
+
activation: str = "tanh"
|
| 133 |
+
|
| 134 |
+
@nn.compact
|
| 135 |
+
def __call__(self, x):
|
| 136 |
+
if self.activation == "relu":
|
| 137 |
+
activation = nn.relu
|
| 138 |
+
else:
|
| 139 |
+
activation = nn.tanh
|
| 140 |
+
|
| 141 |
+
actor_mean = nn.Dense(
|
| 142 |
+
self.layer_width,
|
| 143 |
+
kernel_init=orthogonal(np.sqrt(2)),
|
| 144 |
+
bias_init=constant(0.0),
|
| 145 |
+
)(x)
|
| 146 |
+
actor_mean = activation(actor_mean)
|
| 147 |
+
|
| 148 |
+
actor_mean = nn.Dense(
|
| 149 |
+
self.layer_width,
|
| 150 |
+
kernel_init=orthogonal(np.sqrt(2)),
|
| 151 |
+
bias_init=constant(0.0),
|
| 152 |
+
)(actor_mean)
|
| 153 |
+
actor_mean = activation(actor_mean)
|
| 154 |
+
|
| 155 |
+
actor_mean = nn.Dense(
|
| 156 |
+
self.layer_width,
|
| 157 |
+
kernel_init=orthogonal(np.sqrt(2)),
|
| 158 |
+
bias_init=constant(0.0),
|
| 159 |
+
)(actor_mean)
|
| 160 |
+
actor_mean = activation(actor_mean)
|
| 161 |
+
|
| 162 |
+
actor_mean = nn.Dense(
|
| 163 |
+
self.action_dim, kernel_init=orthogonal(0.01), bias_init=constant(0.0)
|
| 164 |
+
)(actor_mean)
|
| 165 |
+
pi = distrax.Categorical(logits=actor_mean)
|
| 166 |
+
|
| 167 |
+
critic = nn.Dense(
|
| 168 |
+
self.layer_width,
|
| 169 |
+
kernel_init=orthogonal(np.sqrt(2)),
|
| 170 |
+
bias_init=constant(0.0),
|
| 171 |
+
)(x)
|
| 172 |
+
critic = activation(critic)
|
| 173 |
+
|
| 174 |
+
critic = nn.Dense(
|
| 175 |
+
self.layer_width,
|
| 176 |
+
kernel_init=orthogonal(np.sqrt(2)),
|
| 177 |
+
bias_init=constant(0.0),
|
| 178 |
+
)(critic)
|
| 179 |
+
critic = activation(critic)
|
| 180 |
+
|
| 181 |
+
critic = nn.Dense(
|
| 182 |
+
self.layer_width,
|
| 183 |
+
kernel_init=orthogonal(np.sqrt(2)),
|
| 184 |
+
bias_init=constant(0.0),
|
| 185 |
+
)(critic)
|
| 186 |
+
critic = activation(critic)
|
| 187 |
+
|
| 188 |
+
critic = nn.Dense(1, kernel_init=orthogonal(1.0), bias_init=constant(0.0))(
|
| 189 |
+
critic
|
| 190 |
+
)
|
| 191 |
+
|
| 192 |
+
return pi, jnp.squeeze(critic, axis=-1)
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
class ActorCriticWithEmbedding(nn.Module):
|
| 196 |
+
action_dim: int
|
| 197 |
+
layer_width: int
|
| 198 |
+
activation: str = "tanh"
|
| 199 |
+
|
| 200 |
+
@nn.compact
|
| 201 |
+
def __call__(self, x):
|
| 202 |
+
if self.activation == "relu":
|
| 203 |
+
activation = nn.relu
|
| 204 |
+
else:
|
| 205 |
+
activation = nn.tanh
|
| 206 |
+
|
| 207 |
+
actor_emb = nn.Dense(
|
| 208 |
+
self.layer_width,
|
| 209 |
+
kernel_init=orthogonal(np.sqrt(2)),
|
| 210 |
+
bias_init=constant(0.0),
|
| 211 |
+
)(x)
|
| 212 |
+
actor_emb = activation(actor_emb)
|
| 213 |
+
|
| 214 |
+
actor_emb = nn.Dense(
|
| 215 |
+
self.layer_width,
|
| 216 |
+
kernel_init=orthogonal(np.sqrt(2)),
|
| 217 |
+
bias_init=constant(0.0),
|
| 218 |
+
)(actor_emb)
|
| 219 |
+
actor_emb = activation(actor_emb)
|
| 220 |
+
|
| 221 |
+
actor_emb = nn.Dense(
|
| 222 |
+
128, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0)
|
| 223 |
+
)(actor_emb)
|
| 224 |
+
actor_emb = activation(actor_emb)
|
| 225 |
+
|
| 226 |
+
actor_mean = nn.Dense(
|
| 227 |
+
self.action_dim, kernel_init=orthogonal(0.01), bias_init=constant(0.0)
|
| 228 |
+
)(actor_emb)
|
| 229 |
+
pi = distrax.Categorical(logits=actor_mean)
|
| 230 |
+
|
| 231 |
+
critic = nn.Dense(
|
| 232 |
+
self.layer_width,
|
| 233 |
+
kernel_init=orthogonal(np.sqrt(2)),
|
| 234 |
+
bias_init=constant(0.0),
|
| 235 |
+
)(x)
|
| 236 |
+
critic = activation(critic)
|
| 237 |
+
|
| 238 |
+
critic = nn.Dense(
|
| 239 |
+
self.layer_width,
|
| 240 |
+
kernel_init=orthogonal(np.sqrt(2)),
|
| 241 |
+
bias_init=constant(0.0),
|
| 242 |
+
)(critic)
|
| 243 |
+
critic = activation(critic)
|
| 244 |
+
|
| 245 |
+
critic = nn.Dense(
|
| 246 |
+
self.layer_width,
|
| 247 |
+
kernel_init=orthogonal(np.sqrt(2)),
|
| 248 |
+
bias_init=constant(0.0),
|
| 249 |
+
)(critic)
|
| 250 |
+
critic = activation(critic)
|
| 251 |
+
|
| 252 |
+
critic = nn.Dense(1, kernel_init=orthogonal(1.0), bias_init=constant(0.0))(
|
| 253 |
+
critic
|
| 254 |
+
)
|
| 255 |
+
|
| 256 |
+
return pi, jnp.squeeze(critic, axis=-1), actor_emb
|
Craftax_Baselines/models/icm.py
ADDED
|
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import jax
|
| 2 |
+
import jax.numpy as jnp
|
| 3 |
+
import flax.linen as nn
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class ICMEncoder(nn.Module):
|
| 7 |
+
layer_size: int
|
| 8 |
+
output_dim: int
|
| 9 |
+
num_layers: int
|
| 10 |
+
|
| 11 |
+
@nn.compact
|
| 12 |
+
def __call__(self, obs):
|
| 13 |
+
activation = nn.relu
|
| 14 |
+
|
| 15 |
+
# TODO Look at weight inits
|
| 16 |
+
|
| 17 |
+
emb = obs
|
| 18 |
+
for _ in range(self.num_layers):
|
| 19 |
+
emb = nn.Dense(
|
| 20 |
+
self.layer_size,
|
| 21 |
+
)(emb)
|
| 22 |
+
emb = activation(emb)
|
| 23 |
+
|
| 24 |
+
emb = nn.Dense(self.output_dim)(emb)
|
| 25 |
+
|
| 26 |
+
return emb
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
class ICMForward(nn.Module):
|
| 30 |
+
layer_size: int
|
| 31 |
+
output_dim: int
|
| 32 |
+
num_layers: int
|
| 33 |
+
num_actions: int
|
| 34 |
+
|
| 35 |
+
@nn.compact
|
| 36 |
+
def __call__(self, latent, action):
|
| 37 |
+
activation = nn.relu
|
| 38 |
+
|
| 39 |
+
action1h = jax.nn.one_hot(action, num_classes=self.num_actions)
|
| 40 |
+
emb = jnp.concatenate((latent, action1h), axis=-1)
|
| 41 |
+
for _ in range(self.num_layers):
|
| 42 |
+
emb = nn.Dense(
|
| 43 |
+
self.layer_size,
|
| 44 |
+
)(emb)
|
| 45 |
+
emb = activation(emb)
|
| 46 |
+
|
| 47 |
+
emb = nn.Dense(self.output_dim)(emb)
|
| 48 |
+
|
| 49 |
+
return emb
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
class ICMInverse(nn.Module):
|
| 53 |
+
layer_size: int
|
| 54 |
+
output_dim: int
|
| 55 |
+
num_layers: int
|
| 56 |
+
|
| 57 |
+
@nn.compact
|
| 58 |
+
def __call__(self, latent, next_latent):
|
| 59 |
+
activation = nn.relu
|
| 60 |
+
|
| 61 |
+
emb = jnp.concatenate((latent, next_latent), axis=-1)
|
| 62 |
+
for _ in range(self.num_layers):
|
| 63 |
+
emb = nn.Dense(
|
| 64 |
+
self.layer_size,
|
| 65 |
+
)(emb)
|
| 66 |
+
emb = activation(emb)
|
| 67 |
+
|
| 68 |
+
action_raw = nn.Dense(self.output_dim)(emb)
|
| 69 |
+
|
| 70 |
+
action_logits = jax.nn.log_softmax(action_raw)
|
| 71 |
+
|
| 72 |
+
return action_logits
|
Craftax_Baselines/models/rnd.py
ADDED
|
@@ -0,0 +1,120 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import jax.numpy as jnp
|
| 2 |
+
import flax.linen as nn
|
| 3 |
+
import numpy as np
|
| 4 |
+
from flax.linen.initializers import constant, orthogonal
|
| 5 |
+
|
| 6 |
+
import distrax
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class RNDNetwork(nn.Module):
|
| 10 |
+
layer_size: int
|
| 11 |
+
output_dim: int
|
| 12 |
+
num_layers: int
|
| 13 |
+
|
| 14 |
+
@nn.compact
|
| 15 |
+
def __call__(self, x):
|
| 16 |
+
activation = nn.relu
|
| 17 |
+
|
| 18 |
+
emb = x
|
| 19 |
+
for _ in range(self.num_layers):
|
| 20 |
+
emb = nn.Dense(
|
| 21 |
+
self.layer_size,
|
| 22 |
+
)(emb)
|
| 23 |
+
emb = activation(emb)
|
| 24 |
+
|
| 25 |
+
emb = nn.Dense(self.output_dim)(emb)
|
| 26 |
+
|
| 27 |
+
return emb
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
class ActorCriticRND(nn.Module):
|
| 31 |
+
action_dim: int
|
| 32 |
+
layer_width: int
|
| 33 |
+
activation: str = "tanh"
|
| 34 |
+
|
| 35 |
+
@nn.compact
|
| 36 |
+
def __call__(self, x):
|
| 37 |
+
if self.activation == "relu":
|
| 38 |
+
activation = nn.relu
|
| 39 |
+
else:
|
| 40 |
+
activation = nn.tanh
|
| 41 |
+
|
| 42 |
+
actor_mean = nn.Dense(
|
| 43 |
+
self.layer_width,
|
| 44 |
+
kernel_init=orthogonal(np.sqrt(2)),
|
| 45 |
+
bias_init=constant(0.0),
|
| 46 |
+
)(x)
|
| 47 |
+
actor_mean = activation(actor_mean)
|
| 48 |
+
|
| 49 |
+
actor_mean = nn.Dense(
|
| 50 |
+
self.layer_width,
|
| 51 |
+
kernel_init=orthogonal(np.sqrt(2)),
|
| 52 |
+
bias_init=constant(0.0),
|
| 53 |
+
)(actor_mean)
|
| 54 |
+
actor_mean = activation(actor_mean)
|
| 55 |
+
|
| 56 |
+
actor_mean = nn.Dense(
|
| 57 |
+
self.layer_width,
|
| 58 |
+
kernel_init=orthogonal(np.sqrt(2)),
|
| 59 |
+
bias_init=constant(0.0),
|
| 60 |
+
)(actor_mean)
|
| 61 |
+
actor_mean = activation(actor_mean)
|
| 62 |
+
|
| 63 |
+
actor_mean = nn.Dense(
|
| 64 |
+
self.action_dim, kernel_init=orthogonal(0.01), bias_init=constant(0.0)
|
| 65 |
+
)(actor_mean)
|
| 66 |
+
pi = distrax.Categorical(logits=actor_mean)
|
| 67 |
+
|
| 68 |
+
# Extrinsic reward
|
| 69 |
+
critic_e = nn.Dense(
|
| 70 |
+
self.layer_width,
|
| 71 |
+
kernel_init=orthogonal(np.sqrt(2)),
|
| 72 |
+
bias_init=constant(0.0),
|
| 73 |
+
)(x)
|
| 74 |
+
critic_e = activation(critic_e)
|
| 75 |
+
|
| 76 |
+
critic_e = nn.Dense(
|
| 77 |
+
self.layer_width,
|
| 78 |
+
kernel_init=orthogonal(np.sqrt(2)),
|
| 79 |
+
bias_init=constant(0.0),
|
| 80 |
+
)(critic_e)
|
| 81 |
+
critic_e = activation(critic_e)
|
| 82 |
+
|
| 83 |
+
critic_e = nn.Dense(
|
| 84 |
+
self.layer_width,
|
| 85 |
+
kernel_init=orthogonal(np.sqrt(2)),
|
| 86 |
+
bias_init=constant(0.0),
|
| 87 |
+
)(critic_e)
|
| 88 |
+
critic_e = activation(critic_e)
|
| 89 |
+
|
| 90 |
+
critic_e = nn.Dense(1, kernel_init=orthogonal(1.0), bias_init=constant(0.0))(
|
| 91 |
+
critic_e
|
| 92 |
+
)
|
| 93 |
+
|
| 94 |
+
# Intrinsic reward
|
| 95 |
+
critic_i = nn.Dense(
|
| 96 |
+
self.layer_width,
|
| 97 |
+
kernel_init=orthogonal(np.sqrt(2)),
|
| 98 |
+
bias_init=constant(0.0),
|
| 99 |
+
)(x)
|
| 100 |
+
critic_i = activation(critic_i)
|
| 101 |
+
|
| 102 |
+
critic_i = nn.Dense(
|
| 103 |
+
self.layer_width,
|
| 104 |
+
kernel_init=orthogonal(np.sqrt(2)),
|
| 105 |
+
bias_init=constant(0.0),
|
| 106 |
+
)(critic_i)
|
| 107 |
+
critic_i = activation(critic_i)
|
| 108 |
+
|
| 109 |
+
critic_i = nn.Dense(
|
| 110 |
+
self.layer_width,
|
| 111 |
+
kernel_init=orthogonal(np.sqrt(2)),
|
| 112 |
+
bias_init=constant(0.0),
|
| 113 |
+
)(critic_i)
|
| 114 |
+
critic_i = activation(critic_i)
|
| 115 |
+
|
| 116 |
+
critic_i = nn.Dense(1, kernel_init=orthogonal(1.0), bias_init=constant(0.0))(
|
| 117 |
+
critic_i
|
| 118 |
+
)
|
| 119 |
+
|
| 120 |
+
return pi, jnp.squeeze(critic_e, axis=-1), jnp.squeeze(critic_i, axis=-1)
|
Craftax_Baselines/ppo.py
ADDED
|
@@ -0,0 +1,733 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import os
|
| 3 |
+
import sys
|
| 4 |
+
import time
|
| 5 |
+
|
| 6 |
+
import jax
|
| 7 |
+
import jax.numpy as jnp
|
| 8 |
+
import numpy as np
|
| 9 |
+
import optax
|
| 10 |
+
from craftax.craftax_env import make_craftax_env_from_name
|
| 11 |
+
|
| 12 |
+
import wandb
|
| 13 |
+
from typing import NamedTuple
|
| 14 |
+
|
| 15 |
+
from flax.training.train_state import TrainState
|
| 16 |
+
import orbax.checkpoint as ocp
|
| 17 |
+
|
| 18 |
+
from logz.batch_logging import batch_log, create_log_dict
|
| 19 |
+
from models.actor_critic import (
|
| 20 |
+
ActorCritic,
|
| 21 |
+
ActorCriticConv,
|
| 22 |
+
)
|
| 23 |
+
from models.icm import ICMEncoder, ICMForward, ICMInverse
|
| 24 |
+
from wrappers import (
|
| 25 |
+
LogWrapper,
|
| 26 |
+
OptimisticResetVecEnvWrapper,
|
| 27 |
+
BatchEnvWrapper,
|
| 28 |
+
AutoResetEnvWrapper,
|
| 29 |
+
)
|
| 30 |
+
|
| 31 |
+
# Code adapted from the original implementation made by Chris Lu
|
| 32 |
+
# Original code located at https://github.com/luchris429/purejaxrl
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
class Transition(NamedTuple):
|
| 36 |
+
done: jnp.ndarray
|
| 37 |
+
action: jnp.ndarray
|
| 38 |
+
value: jnp.ndarray
|
| 39 |
+
reward_e: jnp.ndarray
|
| 40 |
+
reward_i: jnp.ndarray
|
| 41 |
+
reward: jnp.ndarray
|
| 42 |
+
log_prob: jnp.ndarray
|
| 43 |
+
obs: jnp.ndarray
|
| 44 |
+
next_obs: jnp.ndarray
|
| 45 |
+
info: jnp.ndarray
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def make_train(config):
|
| 49 |
+
config["NUM_UPDATES"] = (
|
| 50 |
+
config["TOTAL_TIMESTEPS"] // config["NUM_STEPS"] // config["NUM_ENVS"]
|
| 51 |
+
)
|
| 52 |
+
config["MINIBATCH_SIZE"] = (
|
| 53 |
+
config["NUM_ENVS"] * config["NUM_STEPS"] // config["NUM_MINIBATCHES"]
|
| 54 |
+
)
|
| 55 |
+
|
| 56 |
+
env = make_craftax_env_from_name(
|
| 57 |
+
config["ENV_NAME"], not config["USE_OPTIMISTIC_RESETS"]
|
| 58 |
+
)
|
| 59 |
+
env_params = env.default_params
|
| 60 |
+
|
| 61 |
+
env = LogWrapper(env)
|
| 62 |
+
if config["USE_OPTIMISTIC_RESETS"]:
|
| 63 |
+
env = OptimisticResetVecEnvWrapper(
|
| 64 |
+
env,
|
| 65 |
+
num_envs=config["NUM_ENVS"],
|
| 66 |
+
reset_ratio=min(config["OPTIMISTIC_RESET_RATIO"], config["NUM_ENVS"]),
|
| 67 |
+
)
|
| 68 |
+
else:
|
| 69 |
+
env = AutoResetEnvWrapper(env)
|
| 70 |
+
env = BatchEnvWrapper(env, num_envs=config["NUM_ENVS"])
|
| 71 |
+
|
| 72 |
+
def linear_schedule(count):
|
| 73 |
+
frac = (
|
| 74 |
+
1.0
|
| 75 |
+
- (count // (config["NUM_MINIBATCHES"] * config["UPDATE_EPOCHS"]))
|
| 76 |
+
/ config["NUM_UPDATES"]
|
| 77 |
+
)
|
| 78 |
+
return config["LR"] * frac
|
| 79 |
+
|
| 80 |
+
def train(rng):
|
| 81 |
+
# INIT NETWORK
|
| 82 |
+
if "Symbolic" in config["ENV_NAME"]:
|
| 83 |
+
network = ActorCritic(env.action_space(env_params).n, config["LAYER_SIZE"])
|
| 84 |
+
else:
|
| 85 |
+
network = ActorCriticConv(
|
| 86 |
+
env.action_space(env_params).n, config["LAYER_SIZE"]
|
| 87 |
+
)
|
| 88 |
+
|
| 89 |
+
rng, _rng = jax.random.split(rng)
|
| 90 |
+
init_x = jnp.zeros((1, *env.observation_space(env_params).shape))
|
| 91 |
+
network_params = network.init(_rng, init_x)
|
| 92 |
+
if config["ANNEAL_LR"]:
|
| 93 |
+
tx = optax.chain(
|
| 94 |
+
optax.clip_by_global_norm(config["MAX_GRAD_NORM"]),
|
| 95 |
+
optax.adam(learning_rate=linear_schedule, eps=1e-5),
|
| 96 |
+
)
|
| 97 |
+
else:
|
| 98 |
+
tx = optax.chain(
|
| 99 |
+
optax.clip_by_global_norm(config["MAX_GRAD_NORM"]),
|
| 100 |
+
optax.adam(config["LR"], eps=1e-5),
|
| 101 |
+
)
|
| 102 |
+
train_state = TrainState.create(
|
| 103 |
+
apply_fn=network.apply,
|
| 104 |
+
params=network_params,
|
| 105 |
+
tx=tx,
|
| 106 |
+
)
|
| 107 |
+
|
| 108 |
+
# Exploration state
|
| 109 |
+
ex_state = {
|
| 110 |
+
"icm_encoder": None,
|
| 111 |
+
"icm_forward": None,
|
| 112 |
+
"icm_inverse": None,
|
| 113 |
+
"e3b_matrix": None,
|
| 114 |
+
}
|
| 115 |
+
|
| 116 |
+
if config["TRAIN_ICM"]:
|
| 117 |
+
obs_shape = env.observation_space(env_params).shape
|
| 118 |
+
assert len(obs_shape) == 1, "Only configured for 1D observations"
|
| 119 |
+
obs_shape = obs_shape[0]
|
| 120 |
+
|
| 121 |
+
# Encoder
|
| 122 |
+
icm_encoder_network = ICMEncoder(
|
| 123 |
+
num_layers=3,
|
| 124 |
+
output_dim=config["ICM_LATENT_SIZE"],
|
| 125 |
+
layer_size=config["ICM_LAYER_SIZE"],
|
| 126 |
+
)
|
| 127 |
+
rng, _rng = jax.random.split(rng)
|
| 128 |
+
icm_encoder_network_params = icm_encoder_network.init(
|
| 129 |
+
_rng, jnp.zeros((1, obs_shape))
|
| 130 |
+
)
|
| 131 |
+
tx = optax.chain(
|
| 132 |
+
optax.clip_by_global_norm(config["MAX_GRAD_NORM"]),
|
| 133 |
+
optax.adam(config["ICM_LR"], eps=1e-5),
|
| 134 |
+
)
|
| 135 |
+
ex_state["icm_encoder"] = TrainState.create(
|
| 136 |
+
apply_fn=icm_encoder_network.apply,
|
| 137 |
+
params=icm_encoder_network_params,
|
| 138 |
+
tx=tx,
|
| 139 |
+
)
|
| 140 |
+
|
| 141 |
+
# Forward
|
| 142 |
+
icm_forward_network = ICMForward(
|
| 143 |
+
num_layers=3,
|
| 144 |
+
output_dim=config["ICM_LATENT_SIZE"],
|
| 145 |
+
layer_size=config["ICM_LAYER_SIZE"],
|
| 146 |
+
num_actions=env.num_actions,
|
| 147 |
+
)
|
| 148 |
+
rng, _rng = jax.random.split(rng)
|
| 149 |
+
icm_forward_network_params = icm_forward_network.init(
|
| 150 |
+
_rng, jnp.zeros((1, config["ICM_LATENT_SIZE"])), jnp.zeros((1,))
|
| 151 |
+
)
|
| 152 |
+
tx = optax.chain(
|
| 153 |
+
optax.clip_by_global_norm(config["MAX_GRAD_NORM"]),
|
| 154 |
+
optax.adam(config["ICM_LR"], eps=1e-5),
|
| 155 |
+
)
|
| 156 |
+
ex_state["icm_forward"] = TrainState.create(
|
| 157 |
+
apply_fn=icm_forward_network.apply,
|
| 158 |
+
params=icm_forward_network_params,
|
| 159 |
+
tx=tx,
|
| 160 |
+
)
|
| 161 |
+
|
| 162 |
+
# Inverse
|
| 163 |
+
icm_inverse_network = ICMInverse(
|
| 164 |
+
num_layers=3,
|
| 165 |
+
output_dim=env.num_actions,
|
| 166 |
+
layer_size=config["ICM_LAYER_SIZE"],
|
| 167 |
+
)
|
| 168 |
+
rng, _rng = jax.random.split(rng)
|
| 169 |
+
icm_inverse_network_params = icm_inverse_network.init(
|
| 170 |
+
_rng,
|
| 171 |
+
jnp.zeros((1, config["ICM_LATENT_SIZE"])),
|
| 172 |
+
jnp.zeros((1, config["ICM_LATENT_SIZE"])),
|
| 173 |
+
)
|
| 174 |
+
tx = optax.chain(
|
| 175 |
+
optax.clip_by_global_norm(config["MAX_GRAD_NORM"]),
|
| 176 |
+
optax.adam(config["ICM_LR"], eps=1e-5),
|
| 177 |
+
)
|
| 178 |
+
ex_state["icm_inverse"] = TrainState.create(
|
| 179 |
+
apply_fn=icm_inverse_network.apply,
|
| 180 |
+
params=icm_inverse_network_params,
|
| 181 |
+
tx=tx,
|
| 182 |
+
)
|
| 183 |
+
|
| 184 |
+
if config["USE_E3B"]:
|
| 185 |
+
ex_state["e3b_matrix"] = (
|
| 186 |
+
jnp.repeat(
|
| 187 |
+
jnp.expand_dims(
|
| 188 |
+
jnp.identity(config["ICM_LATENT_SIZE"]), axis=0
|
| 189 |
+
),
|
| 190 |
+
config["NUM_ENVS"],
|
| 191 |
+
axis=0,
|
| 192 |
+
)
|
| 193 |
+
/ config["E3B_LAMBDA"]
|
| 194 |
+
)
|
| 195 |
+
|
| 196 |
+
# INIT ENV
|
| 197 |
+
rng, _rng = jax.random.split(rng)
|
| 198 |
+
obsv, env_state = env.reset(_rng, env_params)
|
| 199 |
+
|
| 200 |
+
# TRAIN LOOP
|
| 201 |
+
def _update_step(runner_state, unused):
|
| 202 |
+
# COLLECT TRAJECTORIES
|
| 203 |
+
def _env_step(runner_state, unused):
|
| 204 |
+
(
|
| 205 |
+
train_state,
|
| 206 |
+
env_state,
|
| 207 |
+
last_obs,
|
| 208 |
+
ex_state,
|
| 209 |
+
rng,
|
| 210 |
+
update_step,
|
| 211 |
+
) = runner_state
|
| 212 |
+
|
| 213 |
+
# SELECT ACTION
|
| 214 |
+
rng, _rng = jax.random.split(rng)
|
| 215 |
+
pi, value = network.apply(train_state.params, last_obs)
|
| 216 |
+
action = pi.sample(seed=_rng)
|
| 217 |
+
log_prob = pi.log_prob(action)
|
| 218 |
+
|
| 219 |
+
# STEP ENV
|
| 220 |
+
rng, _rng = jax.random.split(rng)
|
| 221 |
+
obsv, env_state, reward_e, done, info = env.step(
|
| 222 |
+
_rng, env_state, action, env_params
|
| 223 |
+
)
|
| 224 |
+
|
| 225 |
+
reward_i = jnp.zeros(config["NUM_ENVS"])
|
| 226 |
+
|
| 227 |
+
if config["TRAIN_ICM"]:
|
| 228 |
+
latent_obs = ex_state["icm_encoder"].apply_fn(
|
| 229 |
+
ex_state["icm_encoder"].params, last_obs
|
| 230 |
+
)
|
| 231 |
+
latent_next_obs = ex_state["icm_encoder"].apply_fn(
|
| 232 |
+
ex_state["icm_encoder"].params, obsv
|
| 233 |
+
)
|
| 234 |
+
|
| 235 |
+
latent_next_obs_pred = ex_state["icm_forward"].apply_fn(
|
| 236 |
+
ex_state["icm_forward"].params, latent_obs, action
|
| 237 |
+
)
|
| 238 |
+
error = (latent_next_obs - latent_next_obs_pred) * (
|
| 239 |
+
1 - done[:, None]
|
| 240 |
+
)
|
| 241 |
+
mse = jnp.square(error).mean(axis=-1)
|
| 242 |
+
|
| 243 |
+
reward_i = mse * config["ICM_REWARD_COEFF"]
|
| 244 |
+
|
| 245 |
+
if config["USE_E3B"]:
|
| 246 |
+
# Embedding is (NUM_ENVS, 128)
|
| 247 |
+
# e3b_matrix is (NUM_ENVS, 128, 128)
|
| 248 |
+
us = jax.vmap(jnp.matmul)(ex_state["e3b_matrix"], latent_obs)
|
| 249 |
+
bs = jax.vmap(jnp.dot)(latent_obs, us)
|
| 250 |
+
|
| 251 |
+
def update_c(c, b, u):
|
| 252 |
+
return c - (1.0 / (1 + b)) * jnp.outer(u, u)
|
| 253 |
+
|
| 254 |
+
updated_cs = jax.vmap(update_c)(ex_state["e3b_matrix"], bs, us)
|
| 255 |
+
new_cs = (
|
| 256 |
+
jnp.repeat(
|
| 257 |
+
jnp.expand_dims(
|
| 258 |
+
jnp.identity(config["ICM_LATENT_SIZE"]), axis=0
|
| 259 |
+
),
|
| 260 |
+
config["NUM_ENVS"],
|
| 261 |
+
axis=0,
|
| 262 |
+
)
|
| 263 |
+
/ config["E3B_LAMBDA"]
|
| 264 |
+
)
|
| 265 |
+
ex_state["e3b_matrix"] = jnp.where(
|
| 266 |
+
done[:, None, None], new_cs, updated_cs
|
| 267 |
+
)
|
| 268 |
+
|
| 269 |
+
e3b_bonus = jnp.where(
|
| 270 |
+
done, jnp.zeros((config["NUM_ENVS"],)), bs
|
| 271 |
+
)
|
| 272 |
+
|
| 273 |
+
reward_i = e3b_bonus * config["E3B_REWARD_COEFF"]
|
| 274 |
+
|
| 275 |
+
reward = reward_e + reward_i
|
| 276 |
+
|
| 277 |
+
transition = Transition(
|
| 278 |
+
done=done,
|
| 279 |
+
action=action,
|
| 280 |
+
value=value,
|
| 281 |
+
reward=reward,
|
| 282 |
+
reward_i=reward_i,
|
| 283 |
+
reward_e=reward_e,
|
| 284 |
+
log_prob=log_prob,
|
| 285 |
+
obs=last_obs,
|
| 286 |
+
next_obs=obsv,
|
| 287 |
+
info=info,
|
| 288 |
+
)
|
| 289 |
+
runner_state = (
|
| 290 |
+
train_state,
|
| 291 |
+
env_state,
|
| 292 |
+
obsv,
|
| 293 |
+
ex_state,
|
| 294 |
+
rng,
|
| 295 |
+
update_step,
|
| 296 |
+
)
|
| 297 |
+
return runner_state, transition
|
| 298 |
+
|
| 299 |
+
runner_state, traj_batch = jax.lax.scan(
|
| 300 |
+
_env_step, runner_state, None, config["NUM_STEPS"]
|
| 301 |
+
)
|
| 302 |
+
|
| 303 |
+
# CALCULATE ADVANTAGE
|
| 304 |
+
(
|
| 305 |
+
train_state,
|
| 306 |
+
env_state,
|
| 307 |
+
last_obs,
|
| 308 |
+
ex_state,
|
| 309 |
+
rng,
|
| 310 |
+
update_step,
|
| 311 |
+
) = runner_state
|
| 312 |
+
_, last_val = network.apply(train_state.params, last_obs)
|
| 313 |
+
|
| 314 |
+
def _calculate_gae(traj_batch, last_val):
|
| 315 |
+
def _get_advantages(gae_and_next_value, transition):
|
| 316 |
+
gae, next_value = gae_and_next_value
|
| 317 |
+
done, value, reward = (
|
| 318 |
+
transition.done,
|
| 319 |
+
transition.value,
|
| 320 |
+
transition.reward,
|
| 321 |
+
)
|
| 322 |
+
delta = reward + config["GAMMA"] * next_value * (1 - done) - value
|
| 323 |
+
gae = (
|
| 324 |
+
delta
|
| 325 |
+
+ config["GAMMA"] * config["GAE_LAMBDA"] * (1 - done) * gae
|
| 326 |
+
)
|
| 327 |
+
return (gae, value), gae
|
| 328 |
+
|
| 329 |
+
_, advantages = jax.lax.scan(
|
| 330 |
+
_get_advantages,
|
| 331 |
+
(jnp.zeros_like(last_val), last_val),
|
| 332 |
+
traj_batch,
|
| 333 |
+
reverse=True,
|
| 334 |
+
unroll=16,
|
| 335 |
+
)
|
| 336 |
+
return advantages, advantages + traj_batch.value
|
| 337 |
+
|
| 338 |
+
advantages, targets = _calculate_gae(traj_batch, last_val)
|
| 339 |
+
|
| 340 |
+
# UPDATE NETWORK
|
| 341 |
+
def _update_epoch(update_state, unused):
|
| 342 |
+
def _update_minbatch(train_state, batch_info):
|
| 343 |
+
traj_batch, advantages, targets = batch_info
|
| 344 |
+
|
| 345 |
+
# Policy/value network
|
| 346 |
+
def _loss_fn(params, traj_batch, gae, targets):
|
| 347 |
+
# RERUN NETWORK
|
| 348 |
+
pi, value = network.apply(params, traj_batch.obs)
|
| 349 |
+
log_prob = pi.log_prob(traj_batch.action)
|
| 350 |
+
|
| 351 |
+
# CALCULATE VALUE LOSS
|
| 352 |
+
value_pred_clipped = traj_batch.value + (
|
| 353 |
+
value - traj_batch.value
|
| 354 |
+
).clip(-config["CLIP_EPS"], config["CLIP_EPS"])
|
| 355 |
+
value_losses = jnp.square(value - targets)
|
| 356 |
+
value_losses_clipped = jnp.square(value_pred_clipped - targets)
|
| 357 |
+
value_loss = (
|
| 358 |
+
0.5 * jnp.maximum(value_losses, value_losses_clipped).mean()
|
| 359 |
+
)
|
| 360 |
+
|
| 361 |
+
# CALCULATE ACTOR LOSS
|
| 362 |
+
ratio = jnp.exp(log_prob - traj_batch.log_prob)
|
| 363 |
+
gae = (gae - gae.mean()) / (gae.std() + 1e-8)
|
| 364 |
+
loss_actor1 = ratio * gae
|
| 365 |
+
loss_actor2 = (
|
| 366 |
+
jnp.clip(
|
| 367 |
+
ratio,
|
| 368 |
+
1.0 - config["CLIP_EPS"],
|
| 369 |
+
1.0 + config["CLIP_EPS"],
|
| 370 |
+
)
|
| 371 |
+
* gae
|
| 372 |
+
)
|
| 373 |
+
loss_actor = -jnp.minimum(loss_actor1, loss_actor2)
|
| 374 |
+
loss_actor = loss_actor.mean()
|
| 375 |
+
entropy = pi.entropy().mean()
|
| 376 |
+
|
| 377 |
+
total_loss = (
|
| 378 |
+
loss_actor
|
| 379 |
+
+ config["VF_COEF"] * value_loss
|
| 380 |
+
- config["ENT_COEF"] * entropy
|
| 381 |
+
)
|
| 382 |
+
return total_loss, (value_loss, loss_actor, entropy)
|
| 383 |
+
|
| 384 |
+
grad_fn = jax.value_and_grad(_loss_fn, has_aux=True)
|
| 385 |
+
total_loss, grads = grad_fn(
|
| 386 |
+
train_state.params, traj_batch, advantages, targets
|
| 387 |
+
)
|
| 388 |
+
train_state = train_state.apply_gradients(grads=grads)
|
| 389 |
+
|
| 390 |
+
losses = (total_loss, 0)
|
| 391 |
+
return train_state, losses
|
| 392 |
+
|
| 393 |
+
(
|
| 394 |
+
train_state,
|
| 395 |
+
traj_batch,
|
| 396 |
+
advantages,
|
| 397 |
+
targets,
|
| 398 |
+
rng,
|
| 399 |
+
) = update_state
|
| 400 |
+
rng, _rng = jax.random.split(rng)
|
| 401 |
+
batch_size = config["MINIBATCH_SIZE"] * config["NUM_MINIBATCHES"]
|
| 402 |
+
assert (
|
| 403 |
+
batch_size == config["NUM_STEPS"] * config["NUM_ENVS"]
|
| 404 |
+
), "batch size must be equal to number of steps * number of envs"
|
| 405 |
+
permutation = jax.random.permutation(_rng, batch_size)
|
| 406 |
+
batch = (traj_batch, advantages, targets)
|
| 407 |
+
batch = jax.tree.map(
|
| 408 |
+
lambda x: x.reshape((batch_size,) + x.shape[2:]), batch
|
| 409 |
+
)
|
| 410 |
+
shuffled_batch = jax.tree.map(
|
| 411 |
+
lambda x: jnp.take(x, permutation, axis=0), batch
|
| 412 |
+
)
|
| 413 |
+
minibatches = jax.tree.map(
|
| 414 |
+
lambda x: jnp.reshape(
|
| 415 |
+
x, [config["NUM_MINIBATCHES"], -1] + list(x.shape[1:])
|
| 416 |
+
),
|
| 417 |
+
shuffled_batch,
|
| 418 |
+
)
|
| 419 |
+
train_state, losses = jax.lax.scan(
|
| 420 |
+
_update_minbatch, train_state, minibatches
|
| 421 |
+
)
|
| 422 |
+
update_state = (
|
| 423 |
+
train_state,
|
| 424 |
+
traj_batch,
|
| 425 |
+
advantages,
|
| 426 |
+
targets,
|
| 427 |
+
rng,
|
| 428 |
+
)
|
| 429 |
+
return update_state, losses
|
| 430 |
+
|
| 431 |
+
update_state = (
|
| 432 |
+
train_state,
|
| 433 |
+
traj_batch,
|
| 434 |
+
advantages,
|
| 435 |
+
targets,
|
| 436 |
+
rng,
|
| 437 |
+
)
|
| 438 |
+
update_state, loss_info = jax.lax.scan(
|
| 439 |
+
_update_epoch, update_state, None, config["UPDATE_EPOCHS"]
|
| 440 |
+
)
|
| 441 |
+
|
| 442 |
+
train_state = update_state[0]
|
| 443 |
+
metric = jax.tree.map(
|
| 444 |
+
lambda x: (x * traj_batch.info["returned_episode"]).sum()
|
| 445 |
+
/ traj_batch.info["returned_episode"].sum(),
|
| 446 |
+
traj_batch.info,
|
| 447 |
+
)
|
| 448 |
+
|
| 449 |
+
rng = update_state[-1]
|
| 450 |
+
|
| 451 |
+
# UPDATE EXPLORATION STATE
|
| 452 |
+
def _update_ex_epoch(update_state, unused):
|
| 453 |
+
def _update_ex_minbatch(ex_state, traj_batch):
|
| 454 |
+
def _inverse_loss_fn(
|
| 455 |
+
icm_encoder_params, icm_inverse_params, traj_batch
|
| 456 |
+
):
|
| 457 |
+
latent_obs = ex_state["icm_encoder"].apply_fn(
|
| 458 |
+
icm_encoder_params, traj_batch.obs
|
| 459 |
+
)
|
| 460 |
+
latent_next_obs = ex_state["icm_encoder"].apply_fn(
|
| 461 |
+
icm_encoder_params, traj_batch.next_obs
|
| 462 |
+
)
|
| 463 |
+
|
| 464 |
+
action_pred_logits = ex_state["icm_inverse"].apply_fn(
|
| 465 |
+
icm_inverse_params, latent_obs, latent_next_obs
|
| 466 |
+
)
|
| 467 |
+
true_action = jax.nn.one_hot(
|
| 468 |
+
traj_batch.action, num_classes=action_pred_logits.shape[-1]
|
| 469 |
+
)
|
| 470 |
+
|
| 471 |
+
bce = -jnp.mean(
|
| 472 |
+
jnp.sum(
|
| 473 |
+
action_pred_logits
|
| 474 |
+
* true_action
|
| 475 |
+
* (1 - traj_batch.done[:, None]),
|
| 476 |
+
axis=1,
|
| 477 |
+
)
|
| 478 |
+
)
|
| 479 |
+
|
| 480 |
+
return bce * config["ICM_INVERSE_LOSS_COEF"]
|
| 481 |
+
|
| 482 |
+
inverse_grad_fn = jax.value_and_grad(
|
| 483 |
+
_inverse_loss_fn,
|
| 484 |
+
has_aux=False,
|
| 485 |
+
argnums=(
|
| 486 |
+
0,
|
| 487 |
+
1,
|
| 488 |
+
),
|
| 489 |
+
)
|
| 490 |
+
inverse_loss, grads = inverse_grad_fn(
|
| 491 |
+
ex_state["icm_encoder"].params,
|
| 492 |
+
ex_state["icm_inverse"].params,
|
| 493 |
+
traj_batch,
|
| 494 |
+
)
|
| 495 |
+
icm_encoder_grad, icm_inverse_grad = grads
|
| 496 |
+
ex_state["icm_encoder"] = ex_state["icm_encoder"].apply_gradients(
|
| 497 |
+
grads=icm_encoder_grad
|
| 498 |
+
)
|
| 499 |
+
ex_state["icm_inverse"] = ex_state["icm_inverse"].apply_gradients(
|
| 500 |
+
grads=icm_inverse_grad
|
| 501 |
+
)
|
| 502 |
+
|
| 503 |
+
def _forward_loss_fn(icm_forward_params, traj_batch):
|
| 504 |
+
latent_obs = ex_state["icm_encoder"].apply_fn(
|
| 505 |
+
ex_state["icm_encoder"].params, traj_batch.obs
|
| 506 |
+
)
|
| 507 |
+
latent_next_obs = ex_state["icm_encoder"].apply_fn(
|
| 508 |
+
ex_state["icm_encoder"].params, traj_batch.next_obs
|
| 509 |
+
)
|
| 510 |
+
|
| 511 |
+
latent_next_obs_pred = ex_state["icm_forward"].apply_fn(
|
| 512 |
+
icm_forward_params, latent_obs, traj_batch.action
|
| 513 |
+
)
|
| 514 |
+
|
| 515 |
+
error = (latent_next_obs - latent_next_obs_pred) * (
|
| 516 |
+
1 - traj_batch.done[:, None]
|
| 517 |
+
)
|
| 518 |
+
return (
|
| 519 |
+
jnp.square(error).mean() * config["ICM_FORWARD_LOSS_COEF"]
|
| 520 |
+
)
|
| 521 |
+
|
| 522 |
+
forward_grad_fn = jax.value_and_grad(
|
| 523 |
+
_forward_loss_fn, has_aux=False
|
| 524 |
+
)
|
| 525 |
+
forward_loss, icm_forward_grad = forward_grad_fn(
|
| 526 |
+
ex_state["icm_forward"].params, traj_batch
|
| 527 |
+
)
|
| 528 |
+
ex_state["icm_forward"] = ex_state["icm_forward"].apply_gradients(
|
| 529 |
+
grads=icm_forward_grad
|
| 530 |
+
)
|
| 531 |
+
|
| 532 |
+
losses = (inverse_loss, forward_loss)
|
| 533 |
+
return ex_state, losses
|
| 534 |
+
|
| 535 |
+
(ex_state, traj_batch, rng) = update_state
|
| 536 |
+
rng, _rng = jax.random.split(rng)
|
| 537 |
+
batch_size = config["MINIBATCH_SIZE"] * config["NUM_MINIBATCHES"]
|
| 538 |
+
assert (
|
| 539 |
+
batch_size == config["NUM_STEPS"] * config["NUM_ENVS"]
|
| 540 |
+
), "batch size must be equal to number of steps * number of envs"
|
| 541 |
+
permutation = jax.random.permutation(_rng, batch_size)
|
| 542 |
+
batch = jax.tree.map(
|
| 543 |
+
lambda x: x.reshape((batch_size,) + x.shape[2:]), traj_batch
|
| 544 |
+
)
|
| 545 |
+
shuffled_batch = jax.tree.map(
|
| 546 |
+
lambda x: jnp.take(x, permutation, axis=0), batch
|
| 547 |
+
)
|
| 548 |
+
minibatches = jax.tree.map(
|
| 549 |
+
lambda x: jnp.reshape(
|
| 550 |
+
x, [config["NUM_MINIBATCHES"], -1] + list(x.shape[1:])
|
| 551 |
+
),
|
| 552 |
+
shuffled_batch,
|
| 553 |
+
)
|
| 554 |
+
ex_state, losses = jax.lax.scan(
|
| 555 |
+
_update_ex_minbatch, ex_state, minibatches
|
| 556 |
+
)
|
| 557 |
+
update_state = (ex_state, traj_batch, rng)
|
| 558 |
+
return update_state, losses
|
| 559 |
+
|
| 560 |
+
if config["TRAIN_ICM"]:
|
| 561 |
+
ex_update_state = (ex_state, traj_batch, rng)
|
| 562 |
+
ex_update_state, ex_loss = jax.lax.scan(
|
| 563 |
+
_update_ex_epoch,
|
| 564 |
+
ex_update_state,
|
| 565 |
+
None,
|
| 566 |
+
config["EXPLORATION_UPDATE_EPOCHS"],
|
| 567 |
+
)
|
| 568 |
+
metric["icm_inverse_loss"] = ex_loss[0].mean()
|
| 569 |
+
metric["icm_forward_loss"] = ex_loss[1].mean()
|
| 570 |
+
metric["reward_i"] = traj_batch.reward_i.mean()
|
| 571 |
+
metric["reward_e"] = traj_batch.reward_e.mean()
|
| 572 |
+
|
| 573 |
+
ex_state = ex_update_state[0]
|
| 574 |
+
rng = ex_update_state[-1]
|
| 575 |
+
|
| 576 |
+
# wandb logging
|
| 577 |
+
if config["DEBUG"] and config["USE_WANDB"]:
|
| 578 |
+
|
| 579 |
+
def callback(metric, update_step):
|
| 580 |
+
to_log = create_log_dict(metric, config)
|
| 581 |
+
batch_log(update_step, to_log, config)
|
| 582 |
+
|
| 583 |
+
jax.debug.callback(
|
| 584 |
+
callback,
|
| 585 |
+
metric,
|
| 586 |
+
update_step,
|
| 587 |
+
)
|
| 588 |
+
|
| 589 |
+
runner_state = (
|
| 590 |
+
train_state,
|
| 591 |
+
env_state,
|
| 592 |
+
last_obs,
|
| 593 |
+
ex_state,
|
| 594 |
+
rng,
|
| 595 |
+
update_step + 1,
|
| 596 |
+
)
|
| 597 |
+
return runner_state, metric
|
| 598 |
+
|
| 599 |
+
rng, _rng = jax.random.split(rng)
|
| 600 |
+
runner_state = (
|
| 601 |
+
train_state,
|
| 602 |
+
env_state,
|
| 603 |
+
obsv,
|
| 604 |
+
ex_state,
|
| 605 |
+
_rng,
|
| 606 |
+
0,
|
| 607 |
+
)
|
| 608 |
+
runner_state, metric = jax.lax.scan(
|
| 609 |
+
_update_step, runner_state, None, config["NUM_UPDATES"]
|
| 610 |
+
)
|
| 611 |
+
return {"runner_state": runner_state} # , "info": metric}
|
| 612 |
+
|
| 613 |
+
return train
|
| 614 |
+
|
| 615 |
+
|
| 616 |
+
def run_ppo(config):
|
| 617 |
+
config = {k.upper(): v for k, v in config.__dict__.items()}
|
| 618 |
+
|
| 619 |
+
if config["USE_WANDB"]:
|
| 620 |
+
wandb.init(
|
| 621 |
+
project=config["WANDB_PROJECT"],
|
| 622 |
+
entity=config["WANDB_ENTITY"],
|
| 623 |
+
config=config,
|
| 624 |
+
name=config["ENV_NAME"]
|
| 625 |
+
+ "-"
|
| 626 |
+
+ str(int(config["TOTAL_TIMESTEPS"] // 1e6))
|
| 627 |
+
+ "M",
|
| 628 |
+
)
|
| 629 |
+
|
| 630 |
+
rng = jax.random.PRNGKey(config["SEED"])
|
| 631 |
+
rngs = jax.random.split(rng, config["NUM_REPEATS"])
|
| 632 |
+
|
| 633 |
+
train_jit = jax.jit(make_train(config))
|
| 634 |
+
train_vmap = jax.vmap(train_jit)
|
| 635 |
+
|
| 636 |
+
t0 = time.time()
|
| 637 |
+
out = train_vmap(rngs)
|
| 638 |
+
t1 = time.time()
|
| 639 |
+
print("Time to run experiment", t1 - t0)
|
| 640 |
+
print("SPS: ", config["TOTAL_TIMESTEPS"] / (t1 - t0))
|
| 641 |
+
|
| 642 |
+
if config["USE_WANDB"]:
|
| 643 |
+
|
| 644 |
+
def _save_network(rs_index, dir_name):
|
| 645 |
+
train_states = out["runner_state"][rs_index]
|
| 646 |
+
train_state = jax.tree.map(lambda x: x[0], train_states)
|
| 647 |
+
|
| 648 |
+
path = os.path.join(wandb.run.dir, dir_name)
|
| 649 |
+
options = ocp.CheckpointManagerOptions(max_to_keep=1)
|
| 650 |
+
|
| 651 |
+
with ocp.CheckpointManager(path, options=options) as checkpoint_manager:
|
| 652 |
+
checkpoint_manager.save(
|
| 653 |
+
int(config["TOTAL_TIMESTEPS"]),
|
| 654 |
+
args=ocp.args.StandardSave(train_state)
|
| 655 |
+
)
|
| 656 |
+
|
| 657 |
+
print(f"saved runner state to {path}")
|
| 658 |
+
|
| 659 |
+
if config["SAVE_POLICY"]:
|
| 660 |
+
_save_network(0, "policies")
|
| 661 |
+
|
| 662 |
+
|
| 663 |
+
if __name__ == "__main__":
|
| 664 |
+
parser = argparse.ArgumentParser()
|
| 665 |
+
parser.add_argument("--env_name", type=str, default="Craftax-Symbolic-v1")
|
| 666 |
+
parser.add_argument(
|
| 667 |
+
"--num_envs",
|
| 668 |
+
type=int,
|
| 669 |
+
default=1024,
|
| 670 |
+
)
|
| 671 |
+
parser.add_argument(
|
| 672 |
+
"--total_timesteps", type=lambda x: int(float(x)), default=1e9
|
| 673 |
+
) # Allow scientific notation
|
| 674 |
+
parser.add_argument("--lr", type=float, default=2e-4)
|
| 675 |
+
parser.add_argument("--num_steps", type=int, default=64)
|
| 676 |
+
parser.add_argument("--update_epochs", type=int, default=4)
|
| 677 |
+
parser.add_argument("--num_minibatches", type=int, default=8)
|
| 678 |
+
parser.add_argument("--gamma", type=float, default=0.99)
|
| 679 |
+
parser.add_argument("--gae_lambda", type=float, default=0.8)
|
| 680 |
+
parser.add_argument("--clip_eps", type=float, default=0.2)
|
| 681 |
+
parser.add_argument("--ent_coef", type=float, default=0.01)
|
| 682 |
+
parser.add_argument("--vf_coef", type=float, default=0.5)
|
| 683 |
+
parser.add_argument("--max_grad_norm", type=float, default=1.0)
|
| 684 |
+
parser.add_argument("--activation", type=str, default="tanh")
|
| 685 |
+
parser.add_argument(
|
| 686 |
+
"--anneal_lr", action=argparse.BooleanOptionalAction, default=True
|
| 687 |
+
)
|
| 688 |
+
parser.add_argument("--debug", action=argparse.BooleanOptionalAction, default=True)
|
| 689 |
+
parser.add_argument("--jit", action=argparse.BooleanOptionalAction, default=True)
|
| 690 |
+
parser.add_argument("--seed", type=int)
|
| 691 |
+
parser.add_argument(
|
| 692 |
+
"--use_wandb", action=argparse.BooleanOptionalAction, default=True
|
| 693 |
+
)
|
| 694 |
+
parser.add_argument("--save_policy", action="store_true")
|
| 695 |
+
parser.add_argument("--num_repeats", type=int, default=1)
|
| 696 |
+
parser.add_argument("--layer_size", type=int, default=512)
|
| 697 |
+
parser.add_argument("--wandb_project", type=str)
|
| 698 |
+
parser.add_argument("--wandb_entity", type=str)
|
| 699 |
+
parser.add_argument(
|
| 700 |
+
"--use_optimistic_resets", action=argparse.BooleanOptionalAction, default=True
|
| 701 |
+
)
|
| 702 |
+
parser.add_argument("--optimistic_reset_ratio", type=int, default=16)
|
| 703 |
+
|
| 704 |
+
# EXPLORATION
|
| 705 |
+
parser.add_argument("--exploration_update_epochs", type=int, default=4)
|
| 706 |
+
# ICM
|
| 707 |
+
parser.add_argument("--icm_reward_coeff", type=float, default=1.0)
|
| 708 |
+
parser.add_argument("--train_icm", action="store_true")
|
| 709 |
+
parser.add_argument("--icm_lr", type=float, default=3e-4)
|
| 710 |
+
parser.add_argument("--icm_forward_loss_coef", type=float, default=1.0)
|
| 711 |
+
parser.add_argument("--icm_inverse_loss_coef", type=float, default=1.0)
|
| 712 |
+
parser.add_argument("--icm_layer_size", type=int, default=256)
|
| 713 |
+
parser.add_argument("--icm_latent_size", type=int, default=32)
|
| 714 |
+
# E3B
|
| 715 |
+
parser.add_argument("--e3b_reward_coeff", type=float, default=1.0)
|
| 716 |
+
parser.add_argument("--use_e3b", action="store_true")
|
| 717 |
+
parser.add_argument("--e3b_lambda", type=float, default=0.1)
|
| 718 |
+
|
| 719 |
+
args, rest_args = parser.parse_known_args(sys.argv[1:])
|
| 720 |
+
if rest_args:
|
| 721 |
+
raise ValueError(f"Unknown args {rest_args}")
|
| 722 |
+
|
| 723 |
+
if args.use_e3b:
|
| 724 |
+
assert args.train_icm
|
| 725 |
+
assert args.icm_reward_coeff == 0
|
| 726 |
+
if args.seed is None:
|
| 727 |
+
args.seed = np.random.randint(2**31)
|
| 728 |
+
|
| 729 |
+
if args.jit:
|
| 730 |
+
run_ppo(args)
|
| 731 |
+
else:
|
| 732 |
+
with jax.disable_jit():
|
| 733 |
+
run_ppo(args)
|
Craftax_Baselines/ppo_rnd.py
ADDED
|
@@ -0,0 +1,680 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import os
|
| 3 |
+
import sys
|
| 4 |
+
import time
|
| 5 |
+
|
| 6 |
+
import jax
|
| 7 |
+
import jax.numpy as jnp
|
| 8 |
+
import numpy as np
|
| 9 |
+
import optax
|
| 10 |
+
from craftax.craftax_env import make_craftax_env_from_name
|
| 11 |
+
|
| 12 |
+
import wandb
|
| 13 |
+
from typing import NamedTuple
|
| 14 |
+
|
| 15 |
+
from flax.training.train_state import TrainState
|
| 16 |
+
import orbax.checkpoint as ocp
|
| 17 |
+
|
| 18 |
+
from logz.batch_logging import batch_log, create_log_dict
|
| 19 |
+
from wrappers import (
|
| 20 |
+
LogWrapper,
|
| 21 |
+
OptimisticResetVecEnvWrapper,
|
| 22 |
+
AutoResetEnvWrapper,
|
| 23 |
+
BatchEnvWrapper,
|
| 24 |
+
)
|
| 25 |
+
from models.rnd import RNDNetwork, ActorCriticRND
|
| 26 |
+
|
| 27 |
+
# Code adapted from the original implementation made by Chris Lu
|
| 28 |
+
# Original code located at https://github.com/luchris429/purejaxrl
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
class Transition(NamedTuple):
|
| 32 |
+
done: jnp.ndarray
|
| 33 |
+
action: jnp.ndarray
|
| 34 |
+
value_e: jnp.ndarray
|
| 35 |
+
value_i: jnp.ndarray
|
| 36 |
+
reward_e: jnp.ndarray
|
| 37 |
+
reward_i: jnp.ndarray
|
| 38 |
+
reward: jnp.ndarray
|
| 39 |
+
log_prob: jnp.ndarray
|
| 40 |
+
obs: jnp.ndarray
|
| 41 |
+
next_obs: jnp.ndarray
|
| 42 |
+
info: jnp.ndarray
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def make_train(config):
|
| 46 |
+
config["NUM_UPDATES"] = (
|
| 47 |
+
config["TOTAL_TIMESTEPS"] // config["NUM_STEPS"] // config["NUM_ENVS"]
|
| 48 |
+
)
|
| 49 |
+
config["MINIBATCH_SIZE"] = (
|
| 50 |
+
config["NUM_ENVS"] * config["NUM_STEPS"] // config["NUM_MINIBATCHES"]
|
| 51 |
+
)
|
| 52 |
+
|
| 53 |
+
env = make_craftax_env_from_name(
|
| 54 |
+
config["ENV_NAME"], not config["USE_OPTIMISTIC_RESETS"]
|
| 55 |
+
)
|
| 56 |
+
env_params = env.default_params
|
| 57 |
+
|
| 58 |
+
env = LogWrapper(env)
|
| 59 |
+
if config["USE_OPTIMISTIC_RESETS"]:
|
| 60 |
+
env = OptimisticResetVecEnvWrapper(
|
| 61 |
+
env,
|
| 62 |
+
num_envs=config["NUM_ENVS"],
|
| 63 |
+
reset_ratio=min(config["OPTIMISTIC_RESET_RATIO"], config["NUM_ENVS"]),
|
| 64 |
+
)
|
| 65 |
+
else:
|
| 66 |
+
env = AutoResetEnvWrapper(env)
|
| 67 |
+
env = BatchEnvWrapper(env, num_envs=config["NUM_ENVS"])
|
| 68 |
+
|
| 69 |
+
def linear_schedule(count):
|
| 70 |
+
frac = (
|
| 71 |
+
1.0
|
| 72 |
+
- (count // (config["NUM_MINIBATCHES"] * config["UPDATE_EPOCHS"]))
|
| 73 |
+
/ config["NUM_UPDATES"]
|
| 74 |
+
)
|
| 75 |
+
return config["LR"] * frac
|
| 76 |
+
|
| 77 |
+
def train(rng):
|
| 78 |
+
# INIT NETWORK
|
| 79 |
+
if "Symbolic" in config["ENV_NAME"]:
|
| 80 |
+
network = ActorCriticRND(
|
| 81 |
+
env.action_space(env_params).n, config["LAYER_SIZE"]
|
| 82 |
+
)
|
| 83 |
+
else:
|
| 84 |
+
raise ValueError
|
| 85 |
+
# network = ActorCriticConv(
|
| 86 |
+
# env.action_space(env_params).n, config["LAYER_SIZE"]
|
| 87 |
+
# )
|
| 88 |
+
|
| 89 |
+
rng, _rng = jax.random.split(rng)
|
| 90 |
+
init_x = jnp.zeros((1, *env.observation_space(env_params).shape))
|
| 91 |
+
network_params = network.init(_rng, init_x)
|
| 92 |
+
if config["ANNEAL_LR"]:
|
| 93 |
+
tx = optax.chain(
|
| 94 |
+
optax.clip_by_global_norm(config["MAX_GRAD_NORM"]),
|
| 95 |
+
optax.adam(learning_rate=linear_schedule, eps=1e-5),
|
| 96 |
+
)
|
| 97 |
+
else:
|
| 98 |
+
tx = optax.chain(
|
| 99 |
+
optax.clip_by_global_norm(config["MAX_GRAD_NORM"]),
|
| 100 |
+
optax.adam(config["LR"], eps=1e-5),
|
| 101 |
+
)
|
| 102 |
+
train_state = TrainState.create(
|
| 103 |
+
apply_fn=network.apply,
|
| 104 |
+
params=network_params,
|
| 105 |
+
tx=tx,
|
| 106 |
+
)
|
| 107 |
+
|
| 108 |
+
# Exploration state
|
| 109 |
+
ex_state = {
|
| 110 |
+
"rnd_model": None,
|
| 111 |
+
}
|
| 112 |
+
|
| 113 |
+
if config["USE_RND"]:
|
| 114 |
+
obs_shape = env.observation_space(env_params).shape
|
| 115 |
+
assert len(obs_shape) == 1, "Only configured for 1D observations"
|
| 116 |
+
obs_shape = obs_shape[0]
|
| 117 |
+
|
| 118 |
+
# Random network
|
| 119 |
+
rnd_random_network = RNDNetwork(
|
| 120 |
+
num_layers=3,
|
| 121 |
+
output_dim=config["RND_OUTPUT_SIZE"],
|
| 122 |
+
layer_size=config["RND_LAYER_SIZE"],
|
| 123 |
+
)
|
| 124 |
+
rng, _rng = jax.random.split(rng)
|
| 125 |
+
rnd_random_network_params = rnd_random_network.init(
|
| 126 |
+
_rng, jnp.zeros((1, obs_shape))
|
| 127 |
+
)
|
| 128 |
+
|
| 129 |
+
# Distillation Network
|
| 130 |
+
rnd_distillation_network = RNDNetwork(
|
| 131 |
+
num_layers=3,
|
| 132 |
+
output_dim=config["RND_OUTPUT_SIZE"],
|
| 133 |
+
layer_size=config["RND_LAYER_SIZE"],
|
| 134 |
+
)
|
| 135 |
+
rng, _rng = jax.random.split(rng)
|
| 136 |
+
rnd_distillation_network_params = rnd_distillation_network.init(
|
| 137 |
+
_rng, jnp.zeros((1, obs_shape))
|
| 138 |
+
)
|
| 139 |
+
tx = optax.chain(
|
| 140 |
+
optax.clip_by_global_norm(config["MAX_GRAD_NORM"]),
|
| 141 |
+
optax.adam(config["RND_LR"], eps=1e-5),
|
| 142 |
+
)
|
| 143 |
+
ex_state["rnd_distillation_network"] = TrainState.create(
|
| 144 |
+
apply_fn=rnd_distillation_network.apply,
|
| 145 |
+
params=rnd_distillation_network_params,
|
| 146 |
+
tx=tx,
|
| 147 |
+
)
|
| 148 |
+
|
| 149 |
+
# INIT ENV
|
| 150 |
+
rng, _rng = jax.random.split(rng)
|
| 151 |
+
obsv, env_state = env.reset(_rng, env_params)
|
| 152 |
+
|
| 153 |
+
# TRAIN LOOP
|
| 154 |
+
def _update_step(runner_state, unused):
|
| 155 |
+
# COLLECT TRAJECTORIES
|
| 156 |
+
def _env_step(runner_state, unused):
|
| 157 |
+
(
|
| 158 |
+
train_state,
|
| 159 |
+
env_state,
|
| 160 |
+
last_obs,
|
| 161 |
+
ex_state,
|
| 162 |
+
rng,
|
| 163 |
+
update_step,
|
| 164 |
+
) = runner_state
|
| 165 |
+
|
| 166 |
+
# SELECT ACTION
|
| 167 |
+
rng, _rng = jax.random.split(rng)
|
| 168 |
+
pi, value_e, value_i = network.apply(train_state.params, last_obs)
|
| 169 |
+
action = pi.sample(seed=_rng)
|
| 170 |
+
log_prob = pi.log_prob(action)
|
| 171 |
+
|
| 172 |
+
# STEP ENV
|
| 173 |
+
rng, _rng = jax.random.split(rng)
|
| 174 |
+
obsv, env_state, reward_e, done, info = env.step(
|
| 175 |
+
_rng, env_state, action, env_params
|
| 176 |
+
)
|
| 177 |
+
|
| 178 |
+
reward_i = jnp.zeros(config["NUM_ENVS"])
|
| 179 |
+
|
| 180 |
+
if config["USE_RND"]:
|
| 181 |
+
random_pred = rnd_random_network.apply(
|
| 182 |
+
rnd_random_network_params, obsv
|
| 183 |
+
)
|
| 184 |
+
|
| 185 |
+
distill_pred = ex_state["rnd_distillation_network"].apply_fn(
|
| 186 |
+
ex_state["rnd_distillation_network"].params, obsv
|
| 187 |
+
)
|
| 188 |
+
error = (random_pred - distill_pred) * (1 - done[:, None])
|
| 189 |
+
mse = jnp.square(error).mean(axis=-1)
|
| 190 |
+
|
| 191 |
+
reward_i = mse * config["RND_REWARD_COEFF"]
|
| 192 |
+
|
| 193 |
+
reward = reward_e + reward_i
|
| 194 |
+
|
| 195 |
+
transition = Transition(
|
| 196 |
+
done=done,
|
| 197 |
+
action=action,
|
| 198 |
+
value_e=value_e,
|
| 199 |
+
value_i=value_i,
|
| 200 |
+
reward=reward,
|
| 201 |
+
reward_i=reward_i,
|
| 202 |
+
reward_e=reward_e,
|
| 203 |
+
log_prob=log_prob,
|
| 204 |
+
obs=last_obs,
|
| 205 |
+
next_obs=obsv,
|
| 206 |
+
info=info,
|
| 207 |
+
)
|
| 208 |
+
runner_state = (
|
| 209 |
+
train_state,
|
| 210 |
+
env_state,
|
| 211 |
+
obsv,
|
| 212 |
+
ex_state,
|
| 213 |
+
rng,
|
| 214 |
+
update_step,
|
| 215 |
+
)
|
| 216 |
+
return runner_state, transition
|
| 217 |
+
|
| 218 |
+
runner_state, traj_batch = jax.lax.scan(
|
| 219 |
+
_env_step, runner_state, None, config["NUM_STEPS"]
|
| 220 |
+
)
|
| 221 |
+
|
| 222 |
+
# CALCULATE ADVANTAGE
|
| 223 |
+
(
|
| 224 |
+
train_state,
|
| 225 |
+
env_state,
|
| 226 |
+
last_obs,
|
| 227 |
+
ex_state,
|
| 228 |
+
rng,
|
| 229 |
+
update_step,
|
| 230 |
+
) = runner_state
|
| 231 |
+
_, last_val_e, last_val_i = network.apply(train_state.params, last_obs)
|
| 232 |
+
|
| 233 |
+
def _calculate_gae(traj_batch, last_val, is_extrinsic):
|
| 234 |
+
def _get_advantages(gae_and_next_value, transition):
|
| 235 |
+
gae, next_value, is_extrinsic = gae_and_next_value
|
| 236 |
+
done, value, reward = (
|
| 237 |
+
transition.done,
|
| 238 |
+
jax.lax.select(
|
| 239 |
+
is_extrinsic, transition.value_e, transition.value_i
|
| 240 |
+
),
|
| 241 |
+
jax.lax.select(
|
| 242 |
+
is_extrinsic, transition.reward_e, transition.reward_i
|
| 243 |
+
),
|
| 244 |
+
)
|
| 245 |
+
done = jnp.logical_and(
|
| 246 |
+
done, jnp.logical_or(config["RND_IS_EPISODIC"], is_extrinsic)
|
| 247 |
+
)
|
| 248 |
+
|
| 249 |
+
delta = reward + config["GAMMA"] * next_value * (1 - done) - value
|
| 250 |
+
gae = (
|
| 251 |
+
delta
|
| 252 |
+
+ config["GAMMA"] * config["GAE_LAMBDA"] * (1 - done) * gae
|
| 253 |
+
)
|
| 254 |
+
return (gae, value, is_extrinsic), gae
|
| 255 |
+
|
| 256 |
+
_, advantages = jax.lax.scan(
|
| 257 |
+
_get_advantages,
|
| 258 |
+
(jnp.zeros_like(last_val), last_val, is_extrinsic),
|
| 259 |
+
traj_batch,
|
| 260 |
+
reverse=True,
|
| 261 |
+
unroll=16,
|
| 262 |
+
)
|
| 263 |
+
return advantages, advantages + jax.lax.select(
|
| 264 |
+
is_extrinsic, traj_batch.value_e, traj_batch.value_i
|
| 265 |
+
)
|
| 266 |
+
|
| 267 |
+
advantages_e, targets_e = _calculate_gae(traj_batch, last_val_e, True)
|
| 268 |
+
advantages_i, targets_i = _calculate_gae(traj_batch, last_val_i, False)
|
| 269 |
+
|
| 270 |
+
# UPDATE NETWORK
|
| 271 |
+
def _update_epoch(update_state, unused):
|
| 272 |
+
def _update_minbatch(train_state, batch_info):
|
| 273 |
+
(
|
| 274 |
+
traj_batch,
|
| 275 |
+
advantages_e,
|
| 276 |
+
targets_e,
|
| 277 |
+
advantages_i,
|
| 278 |
+
targets_i,
|
| 279 |
+
) = batch_info
|
| 280 |
+
|
| 281 |
+
# Policy/value network
|
| 282 |
+
def _loss_fn(
|
| 283 |
+
params, traj_batch, gae_e, targets_e, gae_i, targets_i
|
| 284 |
+
):
|
| 285 |
+
# RERUN NETWORK
|
| 286 |
+
pi, value_e, value_i = network.apply(params, traj_batch.obs)
|
| 287 |
+
log_prob = pi.log_prob(traj_batch.action)
|
| 288 |
+
|
| 289 |
+
# CALCULATE EXTRINSIC VALUE LOSS
|
| 290 |
+
value_pred_clipped_e = traj_batch.value_e + (
|
| 291 |
+
value_e - traj_batch.value_e
|
| 292 |
+
).clip(-config["CLIP_EPS"], config["CLIP_EPS"])
|
| 293 |
+
value_losses_e = jnp.square(value_e - targets_e)
|
| 294 |
+
value_losses_clipped_e = jnp.square(
|
| 295 |
+
value_pred_clipped_e - targets_e
|
| 296 |
+
)
|
| 297 |
+
value_loss_e = (
|
| 298 |
+
0.5
|
| 299 |
+
* jnp.maximum(value_losses_e, value_losses_clipped_e).mean()
|
| 300 |
+
)
|
| 301 |
+
|
| 302 |
+
# CALCULATE INTRINSIC VALUE LOSS
|
| 303 |
+
value_pred_clipped_i = traj_batch.value_i + (
|
| 304 |
+
value_i - traj_batch.value_i
|
| 305 |
+
).clip(-config["CLIP_EPS"], config["CLIP_EPS"])
|
| 306 |
+
value_losses_i = jnp.square(value_i - targets_i)
|
| 307 |
+
value_losses_clipped_i = jnp.square(
|
| 308 |
+
value_pred_clipped_i - targets_i
|
| 309 |
+
)
|
| 310 |
+
value_loss_i = (
|
| 311 |
+
0.5
|
| 312 |
+
* jnp.maximum(value_losses_i, value_losses_clipped_i).mean()
|
| 313 |
+
)
|
| 314 |
+
|
| 315 |
+
# CALCULATE ACTOR LOSS
|
| 316 |
+
gae = gae_e
|
| 317 |
+
if config["USE_RND"]:
|
| 318 |
+
gae += gae_i * config["RND_GAE_COEFF"]
|
| 319 |
+
ratio = jnp.exp(log_prob - traj_batch.log_prob)
|
| 320 |
+
gae = (gae - gae.mean()) / (gae.std() + 1e-8)
|
| 321 |
+
loss_actor1 = ratio * gae
|
| 322 |
+
loss_actor2 = (
|
| 323 |
+
jnp.clip(
|
| 324 |
+
ratio,
|
| 325 |
+
1.0 - config["CLIP_EPS"],
|
| 326 |
+
1.0 + config["CLIP_EPS"],
|
| 327 |
+
)
|
| 328 |
+
* gae
|
| 329 |
+
)
|
| 330 |
+
loss_actor = -jnp.minimum(loss_actor1, loss_actor2)
|
| 331 |
+
loss_actor = loss_actor.mean()
|
| 332 |
+
entropy = pi.entropy().mean()
|
| 333 |
+
|
| 334 |
+
value_loss = value_loss_e
|
| 335 |
+
if config["USE_RND"]:
|
| 336 |
+
value_loss += value_loss_i
|
| 337 |
+
|
| 338 |
+
total_loss = (
|
| 339 |
+
loss_actor
|
| 340 |
+
+ config["VF_COEF"] * value_loss
|
| 341 |
+
- config["ENT_COEF"] * entropy
|
| 342 |
+
)
|
| 343 |
+
return total_loss, (
|
| 344 |
+
value_loss_e,
|
| 345 |
+
value_loss_i,
|
| 346 |
+
loss_actor,
|
| 347 |
+
entropy,
|
| 348 |
+
)
|
| 349 |
+
|
| 350 |
+
grad_fn = jax.value_and_grad(_loss_fn, has_aux=True)
|
| 351 |
+
total_loss, grads = grad_fn(
|
| 352 |
+
train_state.params,
|
| 353 |
+
traj_batch,
|
| 354 |
+
advantages_e,
|
| 355 |
+
targets_e,
|
| 356 |
+
advantages_i,
|
| 357 |
+
targets_i,
|
| 358 |
+
)
|
| 359 |
+
train_state = train_state.apply_gradients(grads=grads)
|
| 360 |
+
|
| 361 |
+
losses = (total_loss, 0)
|
| 362 |
+
return train_state, losses
|
| 363 |
+
|
| 364 |
+
(
|
| 365 |
+
train_state,
|
| 366 |
+
traj_batch,
|
| 367 |
+
advantages_e,
|
| 368 |
+
targets_e,
|
| 369 |
+
advantages_i,
|
| 370 |
+
targets_i,
|
| 371 |
+
rng,
|
| 372 |
+
) = update_state
|
| 373 |
+
rng, _rng = jax.random.split(rng)
|
| 374 |
+
batch_size = config["MINIBATCH_SIZE"] * config["NUM_MINIBATCHES"]
|
| 375 |
+
assert (
|
| 376 |
+
batch_size == config["NUM_STEPS"] * config["NUM_ENVS"]
|
| 377 |
+
), "batch size must be equal to number of steps * number of envs"
|
| 378 |
+
permutation = jax.random.permutation(_rng, batch_size)
|
| 379 |
+
batch = (
|
| 380 |
+
traj_batch,
|
| 381 |
+
advantages_e,
|
| 382 |
+
targets_e,
|
| 383 |
+
advantages_i,
|
| 384 |
+
targets_i,
|
| 385 |
+
)
|
| 386 |
+
batch = jax.tree.map(
|
| 387 |
+
lambda x: x.reshape((batch_size,) + x.shape[2:]), batch
|
| 388 |
+
)
|
| 389 |
+
shuffled_batch = jax.tree.map(
|
| 390 |
+
lambda x: jnp.take(x, permutation, axis=0), batch
|
| 391 |
+
)
|
| 392 |
+
minibatches = jax.tree.map(
|
| 393 |
+
lambda x: jnp.reshape(
|
| 394 |
+
x, [config["NUM_MINIBATCHES"], -1] + list(x.shape[1:])
|
| 395 |
+
),
|
| 396 |
+
shuffled_batch,
|
| 397 |
+
)
|
| 398 |
+
train_state, losses = jax.lax.scan(
|
| 399 |
+
_update_minbatch, train_state, minibatches
|
| 400 |
+
)
|
| 401 |
+
update_state = (
|
| 402 |
+
train_state,
|
| 403 |
+
traj_batch,
|
| 404 |
+
advantages_e,
|
| 405 |
+
targets_e,
|
| 406 |
+
advantages_i,
|
| 407 |
+
targets_i,
|
| 408 |
+
rng,
|
| 409 |
+
)
|
| 410 |
+
return update_state, losses
|
| 411 |
+
|
| 412 |
+
update_state = (
|
| 413 |
+
train_state,
|
| 414 |
+
traj_batch,
|
| 415 |
+
advantages_e,
|
| 416 |
+
targets_e,
|
| 417 |
+
advantages_i,
|
| 418 |
+
targets_i,
|
| 419 |
+
rng,
|
| 420 |
+
)
|
| 421 |
+
update_state, loss_info = jax.lax.scan(
|
| 422 |
+
_update_epoch, update_state, None, config["UPDATE_EPOCHS"]
|
| 423 |
+
)
|
| 424 |
+
|
| 425 |
+
train_state = update_state[0]
|
| 426 |
+
metric = jax.tree.map(
|
| 427 |
+
lambda x: (x * traj_batch.info["returned_episode"]).sum()
|
| 428 |
+
/ traj_batch.info["returned_episode"].sum(),
|
| 429 |
+
traj_batch.info,
|
| 430 |
+
)
|
| 431 |
+
|
| 432 |
+
rng = update_state[-1]
|
| 433 |
+
|
| 434 |
+
# UPDATE EXPLORATION STATE
|
| 435 |
+
def _update_ex_epoch(update_state, unused):
|
| 436 |
+
def _update_ex_minbatch(ex_state, traj_batch):
|
| 437 |
+
rnd_loss = 0
|
| 438 |
+
|
| 439 |
+
if config["USE_RND"]:
|
| 440 |
+
|
| 441 |
+
def _rnd_loss_fn(rnd_distillation_params, traj_batch):
|
| 442 |
+
random_network_out = rnd_random_network.apply(
|
| 443 |
+
rnd_random_network_params, traj_batch.next_obs
|
| 444 |
+
)
|
| 445 |
+
|
| 446 |
+
distillation_network_out = ex_state[
|
| 447 |
+
"rnd_distillation_network"
|
| 448 |
+
].apply_fn(rnd_distillation_params, traj_batch.next_obs)
|
| 449 |
+
|
| 450 |
+
error = (random_network_out - distillation_network_out) * (
|
| 451 |
+
1 - traj_batch.done[:, None]
|
| 452 |
+
)
|
| 453 |
+
return jnp.square(error).mean() * config["RND_LOSS_COEFF"]
|
| 454 |
+
|
| 455 |
+
rnd_grad_fn = jax.value_and_grad(_rnd_loss_fn, has_aux=False)
|
| 456 |
+
rnd_loss, rnd_grad = rnd_grad_fn(
|
| 457 |
+
ex_state["rnd_distillation_network"].params, traj_batch
|
| 458 |
+
)
|
| 459 |
+
ex_state["rnd_distillation_network"] = ex_state[
|
| 460 |
+
"rnd_distillation_network"
|
| 461 |
+
].apply_gradients(grads=rnd_grad)
|
| 462 |
+
|
| 463 |
+
losses = (rnd_loss,)
|
| 464 |
+
return ex_state, losses
|
| 465 |
+
|
| 466 |
+
(ex_state, traj_batch, rng) = update_state
|
| 467 |
+
rng, _rng = jax.random.split(rng)
|
| 468 |
+
batch_size = config["MINIBATCH_SIZE"] * config["NUM_MINIBATCHES"]
|
| 469 |
+
assert (
|
| 470 |
+
batch_size == config["NUM_STEPS"] * config["NUM_ENVS"]
|
| 471 |
+
), "batch size must be equal to number of steps * number of envs"
|
| 472 |
+
permutation = jax.random.permutation(_rng, batch_size)
|
| 473 |
+
batch = jax.tree.map(
|
| 474 |
+
lambda x: x.reshape((batch_size,) + x.shape[2:]), traj_batch
|
| 475 |
+
)
|
| 476 |
+
shuffled_batch = jax.tree.map(
|
| 477 |
+
lambda x: jnp.take(x, permutation, axis=0), batch
|
| 478 |
+
)
|
| 479 |
+
minibatches = jax.tree.map(
|
| 480 |
+
lambda x: jnp.reshape(
|
| 481 |
+
x, [config["NUM_MINIBATCHES"], -1] + list(x.shape[1:])
|
| 482 |
+
),
|
| 483 |
+
shuffled_batch,
|
| 484 |
+
)
|
| 485 |
+
ex_state, losses = jax.lax.scan(
|
| 486 |
+
_update_ex_minbatch, ex_state, minibatches
|
| 487 |
+
)
|
| 488 |
+
update_state = (ex_state, traj_batch, rng)
|
| 489 |
+
return update_state, losses
|
| 490 |
+
|
| 491 |
+
if config["USE_RND"]:
|
| 492 |
+
ex_update_state = (ex_state, traj_batch, rng)
|
| 493 |
+
ex_update_state, ex_loss = jax.lax.scan(
|
| 494 |
+
_update_ex_epoch,
|
| 495 |
+
ex_update_state,
|
| 496 |
+
None,
|
| 497 |
+
config["EXPLORATION_UPDATE_EPOCHS"],
|
| 498 |
+
)
|
| 499 |
+
metric["rnd_loss"] = ex_loss[0].mean()
|
| 500 |
+
metric["reward_i"] = traj_batch.reward_i.mean()
|
| 501 |
+
metric["reward_e"] = traj_batch.reward_e.mean()
|
| 502 |
+
|
| 503 |
+
ex_state = ex_update_state[0]
|
| 504 |
+
rng = ex_update_state[-1]
|
| 505 |
+
|
| 506 |
+
# wandb logging
|
| 507 |
+
if config["DEBUG"] and config["USE_WANDB"]:
|
| 508 |
+
|
| 509 |
+
def callback(
|
| 510 |
+
metric, update_step
|
| 511 |
+
): # , loss_info, traj_batch, ex_state, advantages_i, targets_i):
|
| 512 |
+
to_log = create_log_dict(metric, config)
|
| 513 |
+
batch_log(update_step, to_log, config)
|
| 514 |
+
|
| 515 |
+
jax.debug.callback(
|
| 516 |
+
callback,
|
| 517 |
+
metric,
|
| 518 |
+
update_step,
|
| 519 |
+
# loss_info, traj_batch, ex_state, advantages_i, targets_i
|
| 520 |
+
)
|
| 521 |
+
|
| 522 |
+
runner_state = (
|
| 523 |
+
train_state,
|
| 524 |
+
env_state,
|
| 525 |
+
last_obs,
|
| 526 |
+
ex_state,
|
| 527 |
+
rng,
|
| 528 |
+
update_step + 1,
|
| 529 |
+
)
|
| 530 |
+
return runner_state, metric
|
| 531 |
+
|
| 532 |
+
rng, _rng = jax.random.split(rng)
|
| 533 |
+
runner_state = (
|
| 534 |
+
train_state,
|
| 535 |
+
env_state,
|
| 536 |
+
obsv,
|
| 537 |
+
ex_state,
|
| 538 |
+
_rng,
|
| 539 |
+
0,
|
| 540 |
+
)
|
| 541 |
+
runner_state, metric = jax.lax.scan(
|
| 542 |
+
_update_step, runner_state, None, config["NUM_UPDATES"]
|
| 543 |
+
)
|
| 544 |
+
return {"runner_state": runner_state} # , "info": metric}
|
| 545 |
+
|
| 546 |
+
return train
|
| 547 |
+
|
| 548 |
+
|
| 549 |
+
def run_ppo(config):
|
| 550 |
+
config = {k.upper(): v for k, v in config.__dict__.items()}
|
| 551 |
+
|
| 552 |
+
if config["USE_WANDB"]:
|
| 553 |
+
wandb.init(
|
| 554 |
+
project=config["WANDB_PROJECT"],
|
| 555 |
+
entity=config["WANDB_ENTITY"],
|
| 556 |
+
config=config,
|
| 557 |
+
name=config["ENV_NAME"]
|
| 558 |
+
+ "-PPO_RND-"
|
| 559 |
+
+ str(int(config["TOTAL_TIMESTEPS"] // 1e6))
|
| 560 |
+
+ "M",
|
| 561 |
+
)
|
| 562 |
+
|
| 563 |
+
rng = jax.random.PRNGKey(config["SEED"])
|
| 564 |
+
rngs = jax.random.split(rng, config["NUM_REPEATS"])
|
| 565 |
+
|
| 566 |
+
train_jit = jax.jit(make_train(config))
|
| 567 |
+
train_vmap = jax.vmap(train_jit)
|
| 568 |
+
|
| 569 |
+
t0 = time.time()
|
| 570 |
+
out = train_vmap(rngs)
|
| 571 |
+
t1 = time.time()
|
| 572 |
+
print("Time to run experiment", t1 - t0)
|
| 573 |
+
print("SPS: ", config["TOTAL_TIMESTEPS"] / (t1 - t0))
|
| 574 |
+
# t1 = time.time()
|
| 575 |
+
# out = train_vmap(rngs)
|
| 576 |
+
# t2 = time.time()
|
| 577 |
+
# print("t2", t2 - t1)
|
| 578 |
+
# print("SPS2: ", config["TOTAL_TIMESTEPS"] / (t2 - t1))
|
| 579 |
+
|
| 580 |
+
if config["USE_WANDB"]:
|
| 581 |
+
# if config["DEBUG"] == "end":
|
| 582 |
+
# info = out["info"]
|
| 583 |
+
# for update in range(info["timestep"].shape[1]):
|
| 584 |
+
# if update % 10 == 0:
|
| 585 |
+
# for repeat in range(info["timestep"].shape[0]):
|
| 586 |
+
# update_info = jax.tree.map(lambda x: x[repeat, update], info)
|
| 587 |
+
# to_log = create_log_dict(update_info)
|
| 588 |
+
# batch_log(update, to_log, config)
|
| 589 |
+
#
|
| 590 |
+
# t2 = time.time()
|
| 591 |
+
# print("Time to log to wandb", t2 - t1)
|
| 592 |
+
|
| 593 |
+
def _save_network(rs_index, dir_name):
|
| 594 |
+
train_states = out["runner_state"][rs_index]
|
| 595 |
+
train_state = jax.tree.map(lambda x: x[0], train_states)
|
| 596 |
+
|
| 597 |
+
path = os.path.join(wandb.run.dir, dir_name)
|
| 598 |
+
options = ocp.CheckpointManagerOptions(max_to_keep=1)
|
| 599 |
+
|
| 600 |
+
with ocp.CheckpointManager(path, options=options) as checkpoint_manager:
|
| 601 |
+
checkpoint_manager.save(
|
| 602 |
+
int(config["TOTAL_TIMESTEPS"]),
|
| 603 |
+
args=ocp.args.StandardSave(train_state)
|
| 604 |
+
)
|
| 605 |
+
|
| 606 |
+
print(f"saved runner state to {path}")
|
| 607 |
+
|
| 608 |
+
if config["SAVE_POLICY"]:
|
| 609 |
+
_save_network(0, "policies")
|
| 610 |
+
|
| 611 |
+
|
| 612 |
+
if __name__ == "__main__":
|
| 613 |
+
parser = argparse.ArgumentParser()
|
| 614 |
+
parser.add_argument("--env_name", type=str, default="Craftax-Symbolic-v1")
|
| 615 |
+
parser.add_argument(
|
| 616 |
+
"--num_envs",
|
| 617 |
+
type=int,
|
| 618 |
+
default=1024,
|
| 619 |
+
)
|
| 620 |
+
parser.add_argument(
|
| 621 |
+
"--total_timesteps", type=lambda x: int(float(x)), default=1e9
|
| 622 |
+
) # Allow scientific notation
|
| 623 |
+
parser.add_argument("--lr", type=float, default=2e-4)
|
| 624 |
+
parser.add_argument("--num_steps", type=int, default=64)
|
| 625 |
+
parser.add_argument("--update_epochs", type=int, default=4)
|
| 626 |
+
parser.add_argument("--num_minibatches", type=int, default=8)
|
| 627 |
+
parser.add_argument("--gamma", type=float, default=0.99)
|
| 628 |
+
parser.add_argument("--gae_lambda", type=float, default=0.8)
|
| 629 |
+
parser.add_argument("--clip_eps", type=float, default=0.2)
|
| 630 |
+
parser.add_argument("--ent_coef", type=float, default=0.01)
|
| 631 |
+
parser.add_argument("--vf_coef", type=float, default=0.5)
|
| 632 |
+
parser.add_argument("--max_grad_norm", type=float, default=1.0)
|
| 633 |
+
parser.add_argument("--activation", type=str, default="tanh")
|
| 634 |
+
parser.add_argument(
|
| 635 |
+
"--anneal_lr", action=argparse.BooleanOptionalAction, default=True
|
| 636 |
+
)
|
| 637 |
+
parser.add_argument("--debug", action=argparse.BooleanOptionalAction, default=True)
|
| 638 |
+
parser.add_argument("--jit", action=argparse.BooleanOptionalAction, default=True)
|
| 639 |
+
parser.add_argument("--seed", type=int)
|
| 640 |
+
parser.add_argument(
|
| 641 |
+
"--use_wandb", action=argparse.BooleanOptionalAction, default=True
|
| 642 |
+
)
|
| 643 |
+
parser.add_argument("--save_policy", action="store_true")
|
| 644 |
+
parser.add_argument("--num_repeats", type=int, default=1)
|
| 645 |
+
parser.add_argument("--layer_size", type=int, default=512)
|
| 646 |
+
parser.add_argument("--wandb_project", type=str)
|
| 647 |
+
parser.add_argument("--wandb_entity", type=str)
|
| 648 |
+
parser.add_argument(
|
| 649 |
+
"--use_optimistic_resets", action=argparse.BooleanOptionalAction, default=True
|
| 650 |
+
)
|
| 651 |
+
parser.add_argument("--optimistic_reset_ratio", type=int, default=16)
|
| 652 |
+
|
| 653 |
+
# EXPLORATION
|
| 654 |
+
parser.add_argument("--exploration_update_epochs", type=int, default=1)
|
| 655 |
+
# RND
|
| 656 |
+
parser.add_argument(
|
| 657 |
+
"--use_rnd", action=argparse.BooleanOptionalAction, default=True
|
| 658 |
+
)
|
| 659 |
+
parser.add_argument("--rnd_layer_size", type=int, default=256)
|
| 660 |
+
parser.add_argument("--rnd_output_size", type=int, default=512)
|
| 661 |
+
parser.add_argument("--rnd_lr", type=float, default=3e-4)
|
| 662 |
+
parser.add_argument("--rnd_reward_coeff", type=float, default=1.0)
|
| 663 |
+
parser.add_argument("--rnd_loss_coeff", type=float, default=0.01)
|
| 664 |
+
parser.add_argument("--rnd_gae_coeff", type=float, default=0.01)
|
| 665 |
+
parser.add_argument(
|
| 666 |
+
"--rnd_is_episodic", action=argparse.BooleanOptionalAction, default=False
|
| 667 |
+
)
|
| 668 |
+
|
| 669 |
+
args, rest_args = parser.parse_known_args(sys.argv[1:])
|
| 670 |
+
if rest_args:
|
| 671 |
+
raise ValueError(f"Unknown args {rest_args}")
|
| 672 |
+
|
| 673 |
+
if args.seed is None:
|
| 674 |
+
args.seed = np.random.randint(2**31)
|
| 675 |
+
|
| 676 |
+
if args.jit:
|
| 677 |
+
run_ppo(args)
|
| 678 |
+
else:
|
| 679 |
+
with jax.disable_jit():
|
| 680 |
+
run_ppo(args)
|
Craftax_Baselines/ppo_rnn.py
ADDED
|
@@ -0,0 +1,542 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import os
|
| 3 |
+
import sys
|
| 4 |
+
|
| 5 |
+
import jax
|
| 6 |
+
import jax.numpy as jnp
|
| 7 |
+
import flax.linen as nn
|
| 8 |
+
import numpy as np
|
| 9 |
+
import optax
|
| 10 |
+
import time
|
| 11 |
+
|
| 12 |
+
import orbax.checkpoint as ocp
|
| 13 |
+
|
| 14 |
+
import wandb
|
| 15 |
+
from flax.linen.initializers import constant, orthogonal
|
| 16 |
+
from typing import NamedTuple, Dict
|
| 17 |
+
from flax.training.train_state import TrainState
|
| 18 |
+
import distrax
|
| 19 |
+
import functools
|
| 20 |
+
|
| 21 |
+
from wrappers import (
|
| 22 |
+
LogWrapper,
|
| 23 |
+
OptimisticResetVecEnvWrapper,
|
| 24 |
+
BatchEnvWrapper,
|
| 25 |
+
AutoResetEnvWrapper,
|
| 26 |
+
)
|
| 27 |
+
from logz.batch_logging import create_log_dict, batch_log
|
| 28 |
+
|
| 29 |
+
from craftax.craftax_env import make_craftax_env_from_name
|
| 30 |
+
|
| 31 |
+
# Code adapted from the original implementation made by Chris Lu
|
| 32 |
+
# Original code located at https://github.com/luchris429/purejaxrl
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
class ScannedRNN(nn.Module):
|
| 36 |
+
@functools.partial(
|
| 37 |
+
nn.scan,
|
| 38 |
+
variable_broadcast="params",
|
| 39 |
+
in_axes=0,
|
| 40 |
+
out_axes=0,
|
| 41 |
+
split_rngs={"params": False},
|
| 42 |
+
)
|
| 43 |
+
@nn.compact
|
| 44 |
+
def __call__(self, carry, x):
|
| 45 |
+
"""Applies the module."""
|
| 46 |
+
rnn_state = carry
|
| 47 |
+
ins, resets = x
|
| 48 |
+
rnn_state = jnp.where(
|
| 49 |
+
resets[:, np.newaxis],
|
| 50 |
+
self.initialize_carry(ins.shape[0], ins.shape[1]),
|
| 51 |
+
rnn_state,
|
| 52 |
+
)
|
| 53 |
+
new_rnn_state, y = nn.GRUCell(features=ins.shape[1])(rnn_state, ins)
|
| 54 |
+
return new_rnn_state, y
|
| 55 |
+
|
| 56 |
+
@staticmethod
|
| 57 |
+
def initialize_carry(batch_size, hidden_size):
|
| 58 |
+
# Use a dummy key since the default state init fn is just zeros.
|
| 59 |
+
cell = nn.GRUCell(features=hidden_size)
|
| 60 |
+
return cell.initialize_carry(jax.random.PRNGKey(0), (batch_size, hidden_size))
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
class ActorCriticRNN(nn.Module):
|
| 64 |
+
action_dim: int
|
| 65 |
+
config: Dict
|
| 66 |
+
|
| 67 |
+
@nn.compact
|
| 68 |
+
def __call__(self, hidden, x):
|
| 69 |
+
obs, dones = x
|
| 70 |
+
embedding = nn.Dense(
|
| 71 |
+
self.config["LAYER_SIZE"],
|
| 72 |
+
kernel_init=orthogonal(np.sqrt(2)),
|
| 73 |
+
bias_init=constant(0.0),
|
| 74 |
+
)(obs)
|
| 75 |
+
embedding = nn.relu(embedding)
|
| 76 |
+
|
| 77 |
+
rnn_in = (embedding, dones)
|
| 78 |
+
hidden, embedding = ScannedRNN()(hidden, rnn_in)
|
| 79 |
+
|
| 80 |
+
actor_mean = nn.Dense(
|
| 81 |
+
self.config["LAYER_SIZE"],
|
| 82 |
+
kernel_init=orthogonal(2),
|
| 83 |
+
bias_init=constant(0.0),
|
| 84 |
+
)(embedding)
|
| 85 |
+
actor_mean = nn.relu(actor_mean)
|
| 86 |
+
actor_mean = nn.Dense(
|
| 87 |
+
self.config["LAYER_SIZE"],
|
| 88 |
+
kernel_init=orthogonal(2),
|
| 89 |
+
bias_init=constant(0.0),
|
| 90 |
+
)(actor_mean)
|
| 91 |
+
actor_mean = nn.relu(actor_mean)
|
| 92 |
+
actor_mean = nn.Dense(
|
| 93 |
+
self.action_dim, kernel_init=orthogonal(0.01), bias_init=constant(0.0)
|
| 94 |
+
)(actor_mean)
|
| 95 |
+
|
| 96 |
+
pi = distrax.Categorical(logits=actor_mean)
|
| 97 |
+
|
| 98 |
+
critic = nn.Dense(
|
| 99 |
+
self.config["LAYER_SIZE"],
|
| 100 |
+
kernel_init=orthogonal(2),
|
| 101 |
+
bias_init=constant(0.0),
|
| 102 |
+
)(embedding)
|
| 103 |
+
critic = nn.relu(critic)
|
| 104 |
+
critic = nn.Dense(
|
| 105 |
+
self.config["LAYER_SIZE"],
|
| 106 |
+
kernel_init=orthogonal(2),
|
| 107 |
+
bias_init=constant(0.0),
|
| 108 |
+
)(critic)
|
| 109 |
+
critic = nn.relu(critic)
|
| 110 |
+
critic = nn.Dense(1, kernel_init=orthogonal(1.0), bias_init=constant(0.0))(
|
| 111 |
+
critic
|
| 112 |
+
)
|
| 113 |
+
|
| 114 |
+
return hidden, pi, jnp.squeeze(critic, axis=-1)
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
class Transition(NamedTuple):
|
| 118 |
+
done: jnp.ndarray
|
| 119 |
+
action: jnp.ndarray
|
| 120 |
+
value: jnp.ndarray
|
| 121 |
+
reward: jnp.ndarray
|
| 122 |
+
log_prob: jnp.ndarray
|
| 123 |
+
obs: jnp.ndarray
|
| 124 |
+
info: jnp.ndarray
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
def make_train(config):
|
| 128 |
+
config["NUM_UPDATES"] = (
|
| 129 |
+
config["TOTAL_TIMESTEPS"] // config["NUM_STEPS"] // config["NUM_ENVS"]
|
| 130 |
+
)
|
| 131 |
+
config["MINIBATCH_SIZE"] = (
|
| 132 |
+
config["NUM_ENVS"] * config["NUM_STEPS"] // config["NUM_MINIBATCHES"]
|
| 133 |
+
)
|
| 134 |
+
|
| 135 |
+
# Create environment
|
| 136 |
+
env = make_craftax_env_from_name(
|
| 137 |
+
config["ENV_NAME"], not config["USE_OPTIMISTIC_RESETS"]
|
| 138 |
+
)
|
| 139 |
+
env_params = env.default_params
|
| 140 |
+
|
| 141 |
+
# Wrap with some extra logging
|
| 142 |
+
env = LogWrapper(env)
|
| 143 |
+
|
| 144 |
+
# Wrap with a batcher, maybe using optimistic resets
|
| 145 |
+
if config["USE_OPTIMISTIC_RESETS"]:
|
| 146 |
+
env = OptimisticResetVecEnvWrapper(
|
| 147 |
+
env,
|
| 148 |
+
num_envs=config["NUM_ENVS"],
|
| 149 |
+
reset_ratio=min(config["OPTIMISTIC_RESET_RATIO"], config["NUM_ENVS"]),
|
| 150 |
+
)
|
| 151 |
+
else:
|
| 152 |
+
env = AutoResetEnvWrapper(env)
|
| 153 |
+
env = BatchEnvWrapper(env, num_envs=config["NUM_ENVS"])
|
| 154 |
+
|
| 155 |
+
def linear_schedule(count):
|
| 156 |
+
frac = (
|
| 157 |
+
1.0
|
| 158 |
+
- (count // (config["NUM_MINIBATCHES"] * config["UPDATE_EPOCHS"]))
|
| 159 |
+
/ config["NUM_UPDATES"]
|
| 160 |
+
)
|
| 161 |
+
return config["LR"] * frac
|
| 162 |
+
|
| 163 |
+
def train(rng):
|
| 164 |
+
# INIT NETWORK
|
| 165 |
+
network = ActorCriticRNN(env.action_space(env_params).n, config=config)
|
| 166 |
+
rng, _rng = jax.random.split(rng)
|
| 167 |
+
init_x = (
|
| 168 |
+
jnp.zeros(
|
| 169 |
+
(1, config["NUM_ENVS"], *env.observation_space(env_params).shape)
|
| 170 |
+
),
|
| 171 |
+
jnp.zeros((1, config["NUM_ENVS"])),
|
| 172 |
+
)
|
| 173 |
+
init_hstate = ScannedRNN.initialize_carry(
|
| 174 |
+
config["NUM_ENVS"], config["LAYER_SIZE"]
|
| 175 |
+
)
|
| 176 |
+
network_params = network.init(_rng, init_hstate, init_x)
|
| 177 |
+
if config["ANNEAL_LR"]:
|
| 178 |
+
tx = optax.chain(
|
| 179 |
+
optax.clip_by_global_norm(config["MAX_GRAD_NORM"]),
|
| 180 |
+
optax.adam(learning_rate=linear_schedule, eps=1e-5),
|
| 181 |
+
)
|
| 182 |
+
else:
|
| 183 |
+
tx = optax.chain(
|
| 184 |
+
optax.clip_by_global_norm(config["MAX_GRAD_NORM"]),
|
| 185 |
+
optax.adam(config["LR"], eps=1e-5),
|
| 186 |
+
)
|
| 187 |
+
train_state = TrainState.create(
|
| 188 |
+
apply_fn=network.apply,
|
| 189 |
+
params=network_params,
|
| 190 |
+
tx=tx,
|
| 191 |
+
)
|
| 192 |
+
|
| 193 |
+
# INIT ENV
|
| 194 |
+
rng, _rng = jax.random.split(rng)
|
| 195 |
+
obsv, env_state = env.reset(_rng, env_params)
|
| 196 |
+
init_hstate = ScannedRNN.initialize_carry(
|
| 197 |
+
config["NUM_ENVS"], config["LAYER_SIZE"]
|
| 198 |
+
)
|
| 199 |
+
|
| 200 |
+
# TRAIN LOOP
|
| 201 |
+
def _update_step(runner_state, unused):
|
| 202 |
+
# COLLECT TRAJECTORIES
|
| 203 |
+
def _env_step(runner_state, unused):
|
| 204 |
+
(
|
| 205 |
+
train_state,
|
| 206 |
+
env_state,
|
| 207 |
+
last_obs,
|
| 208 |
+
last_done,
|
| 209 |
+
hstate,
|
| 210 |
+
rng,
|
| 211 |
+
update_step,
|
| 212 |
+
) = runner_state
|
| 213 |
+
rng, _rng = jax.random.split(rng)
|
| 214 |
+
|
| 215 |
+
# SELECT ACTION
|
| 216 |
+
ac_in = (last_obs[np.newaxis, :], last_done[np.newaxis, :])
|
| 217 |
+
hstate, pi, value = network.apply(train_state.params, hstate, ac_in)
|
| 218 |
+
action = pi.sample(seed=_rng)
|
| 219 |
+
log_prob = pi.log_prob(action)
|
| 220 |
+
value, action, log_prob = (
|
| 221 |
+
value.squeeze(0),
|
| 222 |
+
action.squeeze(0),
|
| 223 |
+
log_prob.squeeze(0),
|
| 224 |
+
)
|
| 225 |
+
|
| 226 |
+
# STEP ENV
|
| 227 |
+
rng, _rng = jax.random.split(rng)
|
| 228 |
+
obsv, env_state, reward, done, info = env.step(
|
| 229 |
+
_rng, env_state, action, env_params
|
| 230 |
+
)
|
| 231 |
+
transition = Transition(
|
| 232 |
+
last_done, action, value, reward, log_prob, last_obs, info
|
| 233 |
+
)
|
| 234 |
+
runner_state = (
|
| 235 |
+
train_state,
|
| 236 |
+
env_state,
|
| 237 |
+
obsv,
|
| 238 |
+
done,
|
| 239 |
+
hstate,
|
| 240 |
+
rng,
|
| 241 |
+
update_step,
|
| 242 |
+
)
|
| 243 |
+
return runner_state, transition
|
| 244 |
+
|
| 245 |
+
initial_hstate = runner_state[-3]
|
| 246 |
+
runner_state, traj_batch = jax.lax.scan(
|
| 247 |
+
_env_step, runner_state, None, config["NUM_STEPS"]
|
| 248 |
+
)
|
| 249 |
+
|
| 250 |
+
# CALCULATE ADVANTAGE
|
| 251 |
+
(
|
| 252 |
+
train_state,
|
| 253 |
+
env_state,
|
| 254 |
+
last_obs,
|
| 255 |
+
last_done,
|
| 256 |
+
hstate,
|
| 257 |
+
rng,
|
| 258 |
+
update_step,
|
| 259 |
+
) = runner_state
|
| 260 |
+
ac_in = (last_obs[np.newaxis, :], last_done[np.newaxis, :])
|
| 261 |
+
_, _, last_val = network.apply(train_state.params, hstate, ac_in)
|
| 262 |
+
last_val = last_val.squeeze(0)
|
| 263 |
+
|
| 264 |
+
def _calculate_gae(traj_batch, last_val, last_done):
|
| 265 |
+
def _get_advantages(carry, transition):
|
| 266 |
+
gae, next_value, next_done = carry
|
| 267 |
+
done, value, reward = (
|
| 268 |
+
transition.done,
|
| 269 |
+
transition.value,
|
| 270 |
+
transition.reward,
|
| 271 |
+
)
|
| 272 |
+
delta = (
|
| 273 |
+
reward + config["GAMMA"] * next_value * (1 - next_done) - value
|
| 274 |
+
)
|
| 275 |
+
gae = (
|
| 276 |
+
delta
|
| 277 |
+
+ config["GAMMA"] * config["GAE_LAMBDA"] * (1 - next_done) * gae
|
| 278 |
+
)
|
| 279 |
+
return (gae, value, done), gae
|
| 280 |
+
|
| 281 |
+
_, advantages = jax.lax.scan(
|
| 282 |
+
_get_advantages,
|
| 283 |
+
(jnp.zeros_like(last_val), last_val, last_done),
|
| 284 |
+
traj_batch,
|
| 285 |
+
reverse=True,
|
| 286 |
+
unroll=16,
|
| 287 |
+
)
|
| 288 |
+
return advantages, advantages + traj_batch.value
|
| 289 |
+
|
| 290 |
+
advantages, targets = _calculate_gae(traj_batch, last_val, last_done)
|
| 291 |
+
|
| 292 |
+
# UPDATE NETWORK
|
| 293 |
+
def _update_epoch(update_state, unused):
|
| 294 |
+
def _update_minbatch(train_state, batch_info):
|
| 295 |
+
init_hstate, traj_batch, advantages, targets = batch_info
|
| 296 |
+
|
| 297 |
+
def _loss_fn(params, init_hstate, traj_batch, gae, targets):
|
| 298 |
+
# RERUN NETWORK
|
| 299 |
+
_, pi, value = network.apply(
|
| 300 |
+
params, init_hstate[0], (traj_batch.obs, traj_batch.done)
|
| 301 |
+
)
|
| 302 |
+
log_prob = pi.log_prob(traj_batch.action)
|
| 303 |
+
|
| 304 |
+
# CALCULATE VALUE LOSS
|
| 305 |
+
value_pred_clipped = traj_batch.value + (
|
| 306 |
+
value - traj_batch.value
|
| 307 |
+
).clip(-config["CLIP_EPS"], config["CLIP_EPS"])
|
| 308 |
+
value_losses = jnp.square(value - targets)
|
| 309 |
+
value_losses_clipped = jnp.square(value_pred_clipped - targets)
|
| 310 |
+
value_loss = (
|
| 311 |
+
0.5 * jnp.maximum(value_losses, value_losses_clipped).mean()
|
| 312 |
+
)
|
| 313 |
+
|
| 314 |
+
# CALCULATE ACTOR LOSS
|
| 315 |
+
ratio = jnp.exp(log_prob - traj_batch.log_prob)
|
| 316 |
+
gae = (gae - gae.mean()) / (gae.std() + 1e-8)
|
| 317 |
+
loss_actor1 = ratio * gae
|
| 318 |
+
loss_actor2 = (
|
| 319 |
+
jnp.clip(
|
| 320 |
+
ratio,
|
| 321 |
+
1.0 - config["CLIP_EPS"],
|
| 322 |
+
1.0 + config["CLIP_EPS"],
|
| 323 |
+
)
|
| 324 |
+
* gae
|
| 325 |
+
)
|
| 326 |
+
loss_actor = -jnp.minimum(loss_actor1, loss_actor2)
|
| 327 |
+
loss_actor = loss_actor.mean()
|
| 328 |
+
entropy = pi.entropy().mean()
|
| 329 |
+
|
| 330 |
+
total_loss = (
|
| 331 |
+
loss_actor
|
| 332 |
+
+ config["VF_COEF"] * value_loss
|
| 333 |
+
- config["ENT_COEF"] * entropy
|
| 334 |
+
)
|
| 335 |
+
return total_loss, (value_loss, loss_actor, entropy)
|
| 336 |
+
|
| 337 |
+
grad_fn = jax.value_and_grad(_loss_fn, has_aux=True)
|
| 338 |
+
total_loss, grads = grad_fn(
|
| 339 |
+
train_state.params, init_hstate, traj_batch, advantages, targets
|
| 340 |
+
)
|
| 341 |
+
train_state = train_state.apply_gradients(grads=grads)
|
| 342 |
+
return train_state, total_loss
|
| 343 |
+
|
| 344 |
+
(
|
| 345 |
+
train_state,
|
| 346 |
+
init_hstate,
|
| 347 |
+
traj_batch,
|
| 348 |
+
advantages,
|
| 349 |
+
targets,
|
| 350 |
+
rng,
|
| 351 |
+
) = update_state
|
| 352 |
+
|
| 353 |
+
rng, _rng = jax.random.split(rng)
|
| 354 |
+
permutation = jax.random.permutation(_rng, config["NUM_ENVS"])
|
| 355 |
+
batch = (init_hstate, traj_batch, advantages, targets)
|
| 356 |
+
|
| 357 |
+
shuffled_batch = jax.tree.map(
|
| 358 |
+
lambda x: jnp.take(x, permutation, axis=1), batch
|
| 359 |
+
)
|
| 360 |
+
|
| 361 |
+
minibatches = jax.tree.map(
|
| 362 |
+
lambda x: jnp.swapaxes(
|
| 363 |
+
jnp.reshape(
|
| 364 |
+
x,
|
| 365 |
+
[x.shape[0], config["NUM_MINIBATCHES"], -1]
|
| 366 |
+
+ list(x.shape[2:]),
|
| 367 |
+
),
|
| 368 |
+
1,
|
| 369 |
+
0,
|
| 370 |
+
),
|
| 371 |
+
shuffled_batch,
|
| 372 |
+
)
|
| 373 |
+
|
| 374 |
+
train_state, total_loss = jax.lax.scan(
|
| 375 |
+
_update_minbatch, train_state, minibatches
|
| 376 |
+
)
|
| 377 |
+
update_state = (
|
| 378 |
+
train_state,
|
| 379 |
+
init_hstate,
|
| 380 |
+
traj_batch,
|
| 381 |
+
advantages,
|
| 382 |
+
targets,
|
| 383 |
+
rng,
|
| 384 |
+
)
|
| 385 |
+
return update_state, total_loss
|
| 386 |
+
|
| 387 |
+
init_hstate = initial_hstate[None, :] # TBH
|
| 388 |
+
update_state = (
|
| 389 |
+
train_state,
|
| 390 |
+
init_hstate,
|
| 391 |
+
traj_batch,
|
| 392 |
+
advantages,
|
| 393 |
+
targets,
|
| 394 |
+
rng,
|
| 395 |
+
)
|
| 396 |
+
update_state, loss_info = jax.lax.scan(
|
| 397 |
+
_update_epoch, update_state, None, config["UPDATE_EPOCHS"]
|
| 398 |
+
)
|
| 399 |
+
train_state = update_state[0]
|
| 400 |
+
metric = jax.tree.map(
|
| 401 |
+
lambda x: (x * traj_batch.info["returned_episode"]).sum()
|
| 402 |
+
/ traj_batch.info["returned_episode"].sum(),
|
| 403 |
+
traj_batch.info,
|
| 404 |
+
)
|
| 405 |
+
rng = update_state[-1]
|
| 406 |
+
if config["DEBUG"] and config["USE_WANDB"]:
|
| 407 |
+
|
| 408 |
+
def callback(metric, update_step):
|
| 409 |
+
to_log = create_log_dict(metric, config)
|
| 410 |
+
batch_log(update_step, to_log, config)
|
| 411 |
+
|
| 412 |
+
jax.debug.callback(callback, metric, update_step)
|
| 413 |
+
|
| 414 |
+
runner_state = (
|
| 415 |
+
train_state,
|
| 416 |
+
env_state,
|
| 417 |
+
last_obs,
|
| 418 |
+
last_done,
|
| 419 |
+
hstate,
|
| 420 |
+
rng,
|
| 421 |
+
update_step + 1,
|
| 422 |
+
)
|
| 423 |
+
return runner_state, metric
|
| 424 |
+
|
| 425 |
+
rng, _rng = jax.random.split(rng)
|
| 426 |
+
runner_state = (
|
| 427 |
+
train_state,
|
| 428 |
+
env_state,
|
| 429 |
+
obsv,
|
| 430 |
+
jnp.zeros((config["NUM_ENVS"]), dtype=bool),
|
| 431 |
+
init_hstate,
|
| 432 |
+
_rng,
|
| 433 |
+
0,
|
| 434 |
+
)
|
| 435 |
+
runner_state, metric = jax.lax.scan(
|
| 436 |
+
_update_step, runner_state, None, config["NUM_UPDATES"]
|
| 437 |
+
)
|
| 438 |
+
return {"runner_state": runner_state, "metric": metric}
|
| 439 |
+
|
| 440 |
+
return train
|
| 441 |
+
|
| 442 |
+
|
| 443 |
+
def run_ppo(config):
|
| 444 |
+
config = {k.upper(): v for k, v in config.__dict__.items()}
|
| 445 |
+
|
| 446 |
+
if config["USE_WANDB"]:
|
| 447 |
+
wandb.init(
|
| 448 |
+
project=config["WANDB_PROJECT"],
|
| 449 |
+
entity=config["WANDB_ENTITY"],
|
| 450 |
+
config=config,
|
| 451 |
+
name=config["ENV_NAME"]
|
| 452 |
+
+ "-PPO_RNN-"
|
| 453 |
+
+ str(int(config["TOTAL_TIMESTEPS"] // 1e6))
|
| 454 |
+
+ "M",
|
| 455 |
+
)
|
| 456 |
+
|
| 457 |
+
rng = jax.random.PRNGKey(config["SEED"])
|
| 458 |
+
rngs = jax.random.split(rng, config["NUM_REPEATS"])
|
| 459 |
+
|
| 460 |
+
train_jit = jax.jit(make_train(config))
|
| 461 |
+
train_vmap = jax.vmap(train_jit)
|
| 462 |
+
|
| 463 |
+
t0 = time.time()
|
| 464 |
+
out = train_vmap(rngs)
|
| 465 |
+
t1 = time.time()
|
| 466 |
+
print("Time to run experiment", t1 - t0)
|
| 467 |
+
print("SPS: ", config["TOTAL_TIMESTEPS"] / (t1 - t0))
|
| 468 |
+
|
| 469 |
+
if config["USE_WANDB"]:
|
| 470 |
+
|
| 471 |
+
def _save_network(rs_index, dir_name):
|
| 472 |
+
train_states = out["runner_state"][rs_index]
|
| 473 |
+
train_state = jax.tree.map(lambda x: x[0], train_states)
|
| 474 |
+
|
| 475 |
+
path = os.path.join(wandb.run.dir, dir_name)
|
| 476 |
+
options = ocp.CheckpointManagerOptions(max_to_keep=1)
|
| 477 |
+
|
| 478 |
+
with ocp.CheckpointManager(path, options=options) as checkpoint_manager:
|
| 479 |
+
checkpoint_manager.save(
|
| 480 |
+
int(config["TOTAL_TIMESTEPS"]),
|
| 481 |
+
args=ocp.args.StandardSave(train_state)
|
| 482 |
+
)
|
| 483 |
+
|
| 484 |
+
print(f"saved runner state to {path}")
|
| 485 |
+
|
| 486 |
+
if config["SAVE_POLICY"]:
|
| 487 |
+
_save_network(0, "policies")
|
| 488 |
+
|
| 489 |
+
|
| 490 |
+
if __name__ == "__main__":
|
| 491 |
+
parser = argparse.ArgumentParser()
|
| 492 |
+
parser.add_argument("--env_name", type=str, default="Craftax-Symbolic-v1")
|
| 493 |
+
parser.add_argument(
|
| 494 |
+
"--num_envs",
|
| 495 |
+
type=int,
|
| 496 |
+
default=1024,
|
| 497 |
+
)
|
| 498 |
+
parser.add_argument("--total_timesteps", type=lambda x: int(float(x)), default=1e9)
|
| 499 |
+
parser.add_argument("--lr", type=float, default=2e-4)
|
| 500 |
+
parser.add_argument("--num_steps", type=int, default=64)
|
| 501 |
+
parser.add_argument("--update_epochs", type=int, default=4)
|
| 502 |
+
parser.add_argument("--num_minibatches", type=int, default=8)
|
| 503 |
+
parser.add_argument("--gamma", type=float, default=0.99)
|
| 504 |
+
parser.add_argument("--gae_lambda", type=float, default=0.8)
|
| 505 |
+
parser.add_argument("--clip_eps", type=float, default=0.2)
|
| 506 |
+
parser.add_argument("--ent_coef", type=float, default=0.01)
|
| 507 |
+
parser.add_argument("--vf_coef", type=float, default=0.5)
|
| 508 |
+
parser.add_argument("--max_grad_norm", type=float, default=1.0)
|
| 509 |
+
parser.add_argument("--activation", type=str, default="tanh")
|
| 510 |
+
parser.add_argument(
|
| 511 |
+
"--anneal_lr", action=argparse.BooleanOptionalAction, default=True
|
| 512 |
+
)
|
| 513 |
+
parser.add_argument("--debug", action=argparse.BooleanOptionalAction, default=True)
|
| 514 |
+
parser.add_argument("--jit", action=argparse.BooleanOptionalAction, default=True)
|
| 515 |
+
parser.add_argument("--seed", type=int, default=np.random.randint(2**31))
|
| 516 |
+
parser.add_argument(
|
| 517 |
+
"--use_wandb", action=argparse.BooleanOptionalAction, default=True
|
| 518 |
+
)
|
| 519 |
+
parser.add_argument(
|
| 520 |
+
"--save_policy", action=argparse.BooleanOptionalAction, default=False
|
| 521 |
+
)
|
| 522 |
+
parser.add_argument("--num_repeats", type=int, default=1)
|
| 523 |
+
parser.add_argument("--layer_size", type=int, default=512)
|
| 524 |
+
parser.add_argument("--wandb_project", type=str)
|
| 525 |
+
parser.add_argument("--wandb_entity", type=str)
|
| 526 |
+
parser.add_argument(
|
| 527 |
+
"--use_optimistic_resets", action=argparse.BooleanOptionalAction, default=True
|
| 528 |
+
)
|
| 529 |
+
parser.add_argument("--optimistic_reset_ratio", type=int, default=16)
|
| 530 |
+
|
| 531 |
+
args, rest_args = parser.parse_known_args(sys.argv[1:])
|
| 532 |
+
if rest_args:
|
| 533 |
+
raise ValueError(f"Unknown args {rest_args}")
|
| 534 |
+
|
| 535 |
+
if args.seed is None:
|
| 536 |
+
args.seed = np.random.randint(2**31)
|
| 537 |
+
|
| 538 |
+
if args.jit:
|
| 539 |
+
run_ppo(args)
|
| 540 |
+
else:
|
| 541 |
+
with jax.disable_jit():
|
| 542 |
+
run_ppo(args)
|
Craftax_Baselines/requirements.txt
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
jax[cuda12_pip]
|
| 2 |
+
distrax
|
| 3 |
+
optax
|
| 4 |
+
flax
|
| 5 |
+
numpy
|
| 6 |
+
black
|
| 7 |
+
pre-commit
|
| 8 |
+
argparse
|
| 9 |
+
wandb
|
| 10 |
+
orbax-checkpoint==0.5.0
|
| 11 |
+
pygame
|
| 12 |
+
gymnax
|
| 13 |
+
chex
|
| 14 |
+
matplotlib
|
| 15 |
+
imageio
|
| 16 |
+
craftax
|
Craftax_Baselines/run_docker.sh
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
WANDB_API_KEY=$(cat ./wandb_key)
|
| 3 |
+
# git pull
|
| 4 |
+
|
| 5 |
+
script_and_args="${@:2}"
|
| 6 |
+
if [ $1 == "all" ]; then
|
| 7 |
+
gpus="0 1 2 3 4 5 6 7"
|
| 8 |
+
else
|
| 9 |
+
gpus=$1
|
| 10 |
+
fi
|
| 11 |
+
|
| 12 |
+
for gpu in $gpus; do
|
| 13 |
+
echo "Launching container craftax_$gpu on GPU $gpu"
|
| 14 |
+
docker run \
|
| 15 |
+
--gpus device=$gpu \
|
| 16 |
+
-e WANDB_API_KEY=$WANDB_API_KEY \
|
| 17 |
+
-v $(pwd):/home/duser/Craftax \
|
| 18 |
+
--name craftax_$gpu \
|
| 19 |
+
--user $(id -u) \
|
| 20 |
+
--rm \
|
| 21 |
+
-d \
|
| 22 |
+
-t craftax_baselines \
|
| 23 |
+
/bin/bash -c "$script_and_args"
|
| 24 |
+
done
|
Craftax_Baselines/wrappers.py
ADDED
|
@@ -0,0 +1,200 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import jax
|
| 2 |
+
import jax.numpy as jnp
|
| 3 |
+
import chex
|
| 4 |
+
import numpy as np
|
| 5 |
+
from flax import struct
|
| 6 |
+
from functools import partial
|
| 7 |
+
from typing import Optional, Tuple, Union, Any
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class GymnaxWrapper(object):
|
| 11 |
+
"""Base class for Gymnax wrappers."""
|
| 12 |
+
|
| 13 |
+
def __init__(self, env):
|
| 14 |
+
self._env = env
|
| 15 |
+
|
| 16 |
+
# provide proxy access to regular attributes of wrapped object
|
| 17 |
+
def __getattr__(self, name):
|
| 18 |
+
return getattr(self._env, name)
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class BatchEnvWrapper(GymnaxWrapper):
|
| 22 |
+
"""Batches reset and step functions"""
|
| 23 |
+
|
| 24 |
+
def __init__(self, env, num_envs: int):
|
| 25 |
+
super().__init__(env)
|
| 26 |
+
|
| 27 |
+
self.num_envs = num_envs
|
| 28 |
+
|
| 29 |
+
self.reset_fn = jax.vmap(self._env.reset, in_axes=(0, None))
|
| 30 |
+
self.step_fn = jax.vmap(self._env.step, in_axes=(0, 0, 0, None))
|
| 31 |
+
|
| 32 |
+
@partial(jax.jit, static_argnums=(0, 2))
|
| 33 |
+
def reset(self, rng, params=None):
|
| 34 |
+
rng, _rng = jax.random.split(rng)
|
| 35 |
+
rngs = jax.random.split(_rng, self.num_envs)
|
| 36 |
+
obs, env_state = self.reset_fn(rngs, params)
|
| 37 |
+
return obs, env_state
|
| 38 |
+
|
| 39 |
+
@partial(jax.jit, static_argnums=(0, 4))
|
| 40 |
+
def step(self, rng, state, action, params=None):
|
| 41 |
+
rng, _rng = jax.random.split(rng)
|
| 42 |
+
rngs = jax.random.split(_rng, self.num_envs)
|
| 43 |
+
obs, state, reward, done, info = self.step_fn(rngs, state, action, params)
|
| 44 |
+
|
| 45 |
+
return obs, state, reward, done, info
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
class AutoResetEnvWrapper(GymnaxWrapper):
|
| 49 |
+
"""Provides standard auto-reset functionality, providing the same behaviour as Gymnax-default."""
|
| 50 |
+
|
| 51 |
+
def __init__(self, env):
|
| 52 |
+
super().__init__(env)
|
| 53 |
+
|
| 54 |
+
@partial(jax.jit, static_argnums=(0, 2))
|
| 55 |
+
def reset(self, key, params=None):
|
| 56 |
+
return self._env.reset(key, params)
|
| 57 |
+
|
| 58 |
+
@partial(jax.jit, static_argnums=(0, 4))
|
| 59 |
+
def step(self, rng, state, action, params=None):
|
| 60 |
+
|
| 61 |
+
rng, _rng = jax.random.split(rng)
|
| 62 |
+
obs_st, state_st, reward, done, info = self._env.step(
|
| 63 |
+
_rng, state, action, params
|
| 64 |
+
)
|
| 65 |
+
|
| 66 |
+
rng, _rng = jax.random.split(rng)
|
| 67 |
+
obs_re, state_re = self._env.reset(_rng, params)
|
| 68 |
+
|
| 69 |
+
# Auto-reset environment based on termination
|
| 70 |
+
def auto_reset(done, state_re, state_st, obs_re, obs_st):
|
| 71 |
+
state = jax.tree.map(
|
| 72 |
+
lambda x, y: jax.lax.select(done, x, y), state_re, state_st
|
| 73 |
+
)
|
| 74 |
+
obs = jax.lax.select(done, obs_re, obs_st)
|
| 75 |
+
|
| 76 |
+
return obs, state
|
| 77 |
+
|
| 78 |
+
obs, state = auto_reset(done, state_re, state_st, obs_re, obs_st)
|
| 79 |
+
|
| 80 |
+
return obs, state, reward, done, info
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
class OptimisticResetVecEnvWrapper(GymnaxWrapper):
|
| 84 |
+
"""
|
| 85 |
+
Provides efficient 'optimistic' resets.
|
| 86 |
+
The wrapper also necessarily handles the batching of environment steps and resetting.
|
| 87 |
+
reset_ratio: the number of environment workers per environment reset. Higher means more efficient but a higher
|
| 88 |
+
chance of duplicate resets.
|
| 89 |
+
"""
|
| 90 |
+
|
| 91 |
+
def __init__(self, env, num_envs: int, reset_ratio: int):
|
| 92 |
+
super().__init__(env)
|
| 93 |
+
|
| 94 |
+
self.num_envs = num_envs
|
| 95 |
+
self.reset_ratio = reset_ratio
|
| 96 |
+
assert (
|
| 97 |
+
num_envs % reset_ratio == 0
|
| 98 |
+
), "Reset ratio must perfectly divide num envs."
|
| 99 |
+
self.num_resets = self.num_envs // reset_ratio
|
| 100 |
+
|
| 101 |
+
self.reset_fn = jax.vmap(self._env.reset, in_axes=(0, None))
|
| 102 |
+
self.step_fn = jax.vmap(self._env.step, in_axes=(0, 0, 0, None))
|
| 103 |
+
|
| 104 |
+
@partial(jax.jit, static_argnums=(0, 2))
|
| 105 |
+
def reset(self, rng, params=None):
|
| 106 |
+
rng, _rng = jax.random.split(rng)
|
| 107 |
+
rngs = jax.random.split(_rng, self.num_envs)
|
| 108 |
+
obs, env_state = self.reset_fn(rngs, params)
|
| 109 |
+
return obs, env_state
|
| 110 |
+
|
| 111 |
+
@partial(jax.jit, static_argnums=(0, 4))
|
| 112 |
+
def step(self, rng, state, action, params=None):
|
| 113 |
+
|
| 114 |
+
rng, _rng = jax.random.split(rng)
|
| 115 |
+
rngs = jax.random.split(_rng, self.num_envs)
|
| 116 |
+
obs_st, state_st, reward, done, info = self.step_fn(rngs, state, action, params)
|
| 117 |
+
|
| 118 |
+
rng, _rng = jax.random.split(rng)
|
| 119 |
+
rngs = jax.random.split(_rng, self.num_resets)
|
| 120 |
+
obs_re, state_re = self.reset_fn(rngs, params)
|
| 121 |
+
|
| 122 |
+
rng, _rng = jax.random.split(rng)
|
| 123 |
+
reset_indexes = jnp.arange(self.num_resets).repeat(self.reset_ratio)
|
| 124 |
+
|
| 125 |
+
being_reset = jax.random.choice(
|
| 126 |
+
_rng,
|
| 127 |
+
jnp.arange(self.num_envs),
|
| 128 |
+
shape=(self.num_resets,),
|
| 129 |
+
p=done,
|
| 130 |
+
replace=False,
|
| 131 |
+
)
|
| 132 |
+
reset_indexes = reset_indexes.at[being_reset].set(jnp.arange(self.num_resets))
|
| 133 |
+
|
| 134 |
+
obs_re = obs_re[reset_indexes]
|
| 135 |
+
state_re = jax.tree.map(lambda x: x[reset_indexes], state_re)
|
| 136 |
+
|
| 137 |
+
# Auto-reset environment based on termination
|
| 138 |
+
def auto_reset(done, state_re, state_st, obs_re, obs_st):
|
| 139 |
+
state = jax.tree.map(
|
| 140 |
+
lambda x, y: jax.lax.select(done, x, y), state_re, state_st
|
| 141 |
+
)
|
| 142 |
+
obs = jax.lax.select(done, obs_re, obs_st)
|
| 143 |
+
|
| 144 |
+
return state, obs
|
| 145 |
+
|
| 146 |
+
state, obs = jax.vmap(auto_reset)(done, state_re, state_st, obs_re, obs_st)
|
| 147 |
+
|
| 148 |
+
return obs, state, reward, done, info
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
@struct.dataclass
|
| 152 |
+
class LogEnvState:
|
| 153 |
+
env_state: Any
|
| 154 |
+
episode_returns: float
|
| 155 |
+
episode_lengths: int
|
| 156 |
+
returned_episode_returns: float
|
| 157 |
+
returned_episode_lengths: int
|
| 158 |
+
timestep: int
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
class LogWrapper(GymnaxWrapper):
|
| 162 |
+
"""Log the episode returns and lengths."""
|
| 163 |
+
|
| 164 |
+
def __init__(self, env):
|
| 165 |
+
super().__init__(env)
|
| 166 |
+
|
| 167 |
+
@partial(jax.jit, static_argnums=(0, 2))
|
| 168 |
+
def reset(self, key: chex.PRNGKey, params=None):
|
| 169 |
+
obs, env_state = self._env.reset(key, params)
|
| 170 |
+
state = LogEnvState(env_state, 0.0, 0, 0.0, 0, 0)
|
| 171 |
+
return obs, state
|
| 172 |
+
|
| 173 |
+
@partial(jax.jit, static_argnums=(0, 4))
|
| 174 |
+
def step(
|
| 175 |
+
self,
|
| 176 |
+
key: chex.PRNGKey,
|
| 177 |
+
state,
|
| 178 |
+
action: Union[int, float],
|
| 179 |
+
params=None,
|
| 180 |
+
):
|
| 181 |
+
obs, env_state, reward, done, info = self._env.step(
|
| 182 |
+
key, state.env_state, action, params
|
| 183 |
+
)
|
| 184 |
+
new_episode_return = state.episode_returns + reward
|
| 185 |
+
new_episode_length = state.episode_lengths + 1
|
| 186 |
+
state = LogEnvState(
|
| 187 |
+
env_state=env_state,
|
| 188 |
+
episode_returns=new_episode_return * (1 - done),
|
| 189 |
+
episode_lengths=new_episode_length * (1 - done),
|
| 190 |
+
returned_episode_returns=state.returned_episode_returns * (1 - done)
|
| 191 |
+
+ new_episode_return * done,
|
| 192 |
+
returned_episode_lengths=state.returned_episode_lengths * (1 - done)
|
| 193 |
+
+ new_episode_length * done,
|
| 194 |
+
timestep=state.timestep + 1,
|
| 195 |
+
)
|
| 196 |
+
info["returned_episode_returns"] = state.returned_episode_returns
|
| 197 |
+
info["returned_episode_lengths"] = state.returned_episode_lengths
|
| 198 |
+
info["timestep"] = state.timestep
|
| 199 |
+
info["returned_episode"] = done
|
| 200 |
+
return obs, state, reward, done, info
|
README.md
ADDED
|
@@ -0,0 +1,547 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ReMDM Planner — Discrete Diffusion Planning on Craftax
|
| 2 |
+
|
| 3 |
+
A JAX implementation of **ReMDM** (Remasking Discrete Diffusion Model) for action-sequence planning in the [Craftax](https://github.com/MichaelTMatthews/Craftax) environment. A bidirectional transformer learns to generate action plans by iteratively denoising masked token sequences, conditioned on the current environment observation.
|
| 4 |
+
|
| 5 |
+
---
|
| 6 |
+
|
| 7 |
+
## Description
|
| 8 |
+
|
| 9 |
+
The planner starts from a fully-masked action sequence and iteratively unmasks tokens over `T` denoising steps, producing a `plan_horizon`-length plan. The ReMDM framework extends standard Masked Discrete Language Modelling (MDLM) with remasking strategies that allow committed tokens to be re-predicted, improving plan coherence.
|
| 10 |
+
|
| 11 |
+
Two independent training pipelines are available — **Offline BC** and **Online DAgger** — both supervised by a pre-trained PPO expert but otherwise separate. Neither depends on the other; the paper compares them head-to-head.
|
| 12 |
+
|
| 13 |
+
```
|
| 14 |
+
[Shared] Train PPO agent Craftax_Baselines/ppo_rnn.py | ppo_rnd.py
|
| 15 |
+
|
|
| 16 |
+
v checkpoint
|
| 17 |
+
┌───────┴────────┐
|
| 18 |
+
│ │
|
| 19 |
+
[Offline BC] [Online DAgger]
|
| 20 |
+
main.py main.py
|
| 21 |
+
--mode offline --mode online
|
| 22 |
+
(train on live (train from scratch;
|
| 23 |
+
PPO rollouts) mixed policy + expert
|
| 24 |
+
│ labels into replay buffer)
|
| 25 |
+
v v
|
| 26 |
+
checkpoint checkpoint
|
| 27 |
+
│ │
|
| 28 |
+
└───────┬────────┘
|
| 29 |
+
v
|
| 30 |
+
[Evaluate] main.py --mode inference --checkpoint_path ...
|
| 31 |
+
|
| 32 |
+
Optional: an offline BC checkpoint can warm-start DAgger
|
| 33 |
+
via --offline_checkpoint_path (not used in the paper).
|
| 34 |
+
|
| 35 |
+
[Offline BC] ──checkpoint──> [Online DAgger]
|
| 36 |
+
```
|
| 37 |
+
|
| 38 |
+
**Optional utility modes:**
|
| 39 |
+
```
|
| 40 |
+
[Collect] Save PPO rollouts to disk main.py --mode collect
|
| 41 |
+
[Smoke test] Quick end-to-end check main.py --mode smoke
|
| 42 |
+
```
|
| 43 |
+
|
| 44 |
+
---
|
| 45 |
+
|
| 46 |
+
## Installation
|
| 47 |
+
|
| 48 |
+
### Prerequisites (system-level)
|
| 49 |
+
|
| 50 |
+
`uv` manages Python packages only. The following must be installed at the OS level before
|
| 51 |
+
running on a GPU node — they are **not** in `pyproject.toml`:
|
| 52 |
+
|
| 53 |
+
- **CUDA 13** driver and toolkit (`libcuda.so`, `libcudnn`)
|
| 54 |
+
|
| 55 |
+
On HPC clusters these are typically loaded via `module load cuda/13.x`.
|
| 56 |
+
|
| 57 |
+
### 1. Create the virtual environment
|
| 58 |
+
|
| 59 |
+
```bash
|
| 60 |
+
# CPU-only (local development / macOS)
|
| 61 |
+
uv sync
|
| 62 |
+
|
| 63 |
+
# NVIDIA CUDA 13 (GPU node — Linux only)
|
| 64 |
+
uv sync --extra cuda
|
| 65 |
+
|
| 66 |
+
# Activate
|
| 67 |
+
source .venv/bin/activate
|
| 68 |
+
```
|
| 69 |
+
|
| 70 |
+
`uv sync` reads `pyproject.toml`, resolves a fully-reproducible lockfile (`uv.lock`),
|
| 71 |
+
and installs into `.venv/`. Commit `uv.lock` to pin the exact dependency graph.
|
| 72 |
+
|
| 73 |
+
### 2. Initialise the submodule
|
| 74 |
+
|
| 75 |
+
```bash
|
| 76 |
+
git submodule update --init --recursive
|
| 77 |
+
```
|
| 78 |
+
|
| 79 |
+
---
|
| 80 |
+
|
| 81 |
+
## Dependencies
|
| 82 |
+
|
| 83 |
+
| Package | Version | Role |
|
| 84 |
+
|---------|---------|------|
|
| 85 |
+
| `jax` | >=0.9.2 | JIT compilation and functional arrays |
|
| 86 |
+
| `flax` | >=0.12.6 | Neural network definitions |
|
| 87 |
+
| `optax` | >=0.2.8 | Adam optimiser and gradient clipping |
|
| 88 |
+
| `craftax` | >=1.5.0 | Procedurally-generated Minecraft-like environment |
|
| 89 |
+
| `chex` | >=0.1.91 | JAX testing and assertion utilities |
|
| 90 |
+
| `distrax` | >=0.1.7 | Probability distributions |
|
| 91 |
+
| `orbax` | >=0.1.9 | Model checkpointing |
|
| 92 |
+
| `wandb` | >=0.25.1 | Experiment logging |
|
| 93 |
+
| `numpy` | >=2.4.4 | Array operations |
|
| 94 |
+
| `matplotlib` | >=3.10.8 | Plotting |
|
| 95 |
+
| `polars` | >=1.39.3 | DataFrame analysis |
|
| 96 |
+
| `orjson` | >=3.11.8 | Fast JSON serialisation |
|
| 97 |
+
| `pyyaml` | >=6.0.3 | Config file parsing |
|
| 98 |
+
|
| 99 |
+
Full specification in `pyproject.toml`. Exact transitive pins are in `uv.lock`.
|
| 100 |
+
|
| 101 |
+
---
|
| 102 |
+
|
| 103 |
+
## Usage
|
| 104 |
+
|
| 105 |
+
All modes share the same entry point. Defaults are loaded from `configs/defaults.yaml`; any value can be overridden on the command line.
|
| 106 |
+
|
| 107 |
+
```bash
|
| 108 |
+
python main.py --mode <MODE> [--config PATH] [OVERRIDES...]
|
| 109 |
+
```
|
| 110 |
+
|
| 111 |
+
Pass `--no-jit` to disable JIT compilation (useful for debugging):
|
| 112 |
+
|
| 113 |
+
```bash
|
| 114 |
+
python main.py --mode offline --no-jit --num_envs 4
|
| 115 |
+
```
|
| 116 |
+
|
| 117 |
+
### Stage 1 — Train a PPO agent
|
| 118 |
+
|
| 119 |
+
PPO training is handled by the `Craftax_Baselines` submodule and produces the checkpoint consumed by all downstream stages.
|
| 120 |
+
|
| 121 |
+
```bash
|
| 122 |
+
cd Craftax_Baselines
|
| 123 |
+
|
| 124 |
+
# PPO with GRU hidden state (recommended)
|
| 125 |
+
python ppo_rnn.py \
|
| 126 |
+
--env_name Craftax-Classic-Symbolic-v1 \
|
| 127 |
+
--total_timesteps 500000000 \
|
| 128 |
+
--save_policy --use_wandb
|
| 129 |
+
|
| 130 |
+
# PPO with Random Network Distillation
|
| 131 |
+
python ppo_rnd.py \
|
| 132 |
+
--env_name Craftax-Classic-Symbolic-v1 \
|
| 133 |
+
--total_timesteps 500000000 \
|
| 134 |
+
--save_policy --use_wandb
|
| 135 |
+
|
| 136 |
+
cd ..
|
| 137 |
+
```
|
| 138 |
+
|
| 139 |
+
### Stage 2a — Collect trajectories to disk
|
| 140 |
+
|
| 141 |
+
Roll out the PPO checkpoint and save `(obs, actions, rewards, dones)` as a `.npz` file for reuse across multiple diffusion training runs.
|
| 142 |
+
|
| 143 |
+
```bash
|
| 144 |
+
python main.py --mode collect \
|
| 145 |
+
--ppo_checkpoint_path /path/to/ppo_checkpoint \
|
| 146 |
+
--offline_data_path data/trajectories.npz \
|
| 147 |
+
--collect_num_steps 1000000 \
|
| 148 |
+
--collect_num_envs 128
|
| 149 |
+
```
|
| 150 |
+
|
| 151 |
+
The file stores arrays shaped `[num_envs, num_iters, ...]`, preserving per-environment contiguity so episode boundaries are respected during window sampling.
|
| 152 |
+
|
| 153 |
+
### Stage 2b — Train offline from live PPO rollouts
|
| 154 |
+
|
| 155 |
+
Roll out the PPO agent live at each update step and train the diffusion model on the collected windows. Windows that cross episode boundaries are masked out; windows with higher cumulative reward receive proportionally larger gradient contributions (clipped to `[0.1, return_weight_cap]`).
|
| 156 |
+
|
| 157 |
+
```bash
|
| 158 |
+
python main.py --mode offline \
|
| 159 |
+
--ppo_checkpoint_path /path/to/ppo_checkpoint \
|
| 160 |
+
--offline_total_timesteps 100000000 \
|
| 161 |
+
--save_policy
|
| 162 |
+
```
|
| 163 |
+
|
| 164 |
+
### Online DAgger Training
|
| 165 |
+
|
| 166 |
+
The diffusion model is trained **from scratch** via DAgger (Dataset Aggregation). At each iteration a mixed policy blends the PPO expert and the diffusion learner (controlled by an exponentially decaying `beta`). The mixed policy rolls out trajectories; the expert labels every visited state with the action it would take. These `(obs, expert_plan)` pairs are appended to a growing circular replay buffer, and the diffusion model is trained on the full buffer with the standard MDLM ELBO loss (pure behavioural cloning — no reward weighting).
|
| 167 |
+
|
| 168 |
+
```bash
|
| 169 |
+
# From scratch (requires PPO expert checkpoint)
|
| 170 |
+
python main.py --mode online \
|
| 171 |
+
--ppo_checkpoint_path /path/to/ppo_checkpoint \
|
| 172 |
+
--online_num_updates 1000 \
|
| 173 |
+
--save_policy
|
| 174 |
+
|
| 175 |
+
# Optional: warm-start from a pre-trained offline checkpoint
|
| 176 |
+
# (not used in the paper — both methods are compared independently)
|
| 177 |
+
python main.py --mode online \
|
| 178 |
+
--ppo_checkpoint_path /path/to/ppo_checkpoint \
|
| 179 |
+
--offline_checkpoint_path /path/to/offline_checkpoint \
|
| 180 |
+
--online_num_updates 1000 \
|
| 181 |
+
--save_policy
|
| 182 |
+
```
|
| 183 |
+
|
| 184 |
+
When `save_policy=true`, online training uploads **two** W&B artifacts: `{env_name}-policy` (final weights) and `{env_name}-policy-best` (weights from the validation iteration with the highest return). Either artifact can be consumed downstream by `--checkpoint_path wandb:…`.
|
| 185 |
+
|
| 186 |
+
### Stage 4 — Evaluate
|
| 187 |
+
|
| 188 |
+
```bash
|
| 189 |
+
python main.py --mode inference \
|
| 190 |
+
--checkpoint_path /path/to/checkpoint \
|
| 191 |
+
--eval_steps 10000 \
|
| 192 |
+
--eval_num_envs 32
|
| 193 |
+
```
|
| 194 |
+
|
| 195 |
+
Prints mean episode return, completed episodes, steps per second, and per-achievement unlock counts. Uses historical inpainting: the first `hist_len` plan positions are locked to observed history.
|
| 196 |
+
|
| 197 |
+
### Loading checkpoints from W&B artifacts
|
| 198 |
+
|
| 199 |
+
Any checkpoint path argument (`--checkpoint_path`, `--offline_checkpoint_path`, `--ppo_checkpoint_path`) accepts a W&B artifact reference prefixed with `wandb:`. The artifact is downloaded automatically before training or evaluation begins.
|
| 200 |
+
|
| 201 |
+
```bash
|
| 202 |
+
# Fully qualified: entity/project/artifact_name:version_or_alias
|
| 203 |
+
python main.py --mode inference \
|
| 204 |
+
--checkpoint_path wandb:my-team/remdm-craftax/Craftax-Classic-Symbolic-v1-policy:latest
|
| 205 |
+
|
| 206 |
+
# Online fine-tuning from a W&B offline checkpoint
|
| 207 |
+
python main.py --mode online \
|
| 208 |
+
--offline_checkpoint_path wandb:my-team/remdm-craftax/Craftax-Classic-Symbolic-v1-policy:v3
|
| 209 |
+
|
| 210 |
+
# PPO checkpoint from W&B
|
| 211 |
+
python main.py --mode offline \
|
| 212 |
+
--ppo_checkpoint_path wandb:my-team/ppo-craftax/ppo-rnn-policy:best
|
| 213 |
+
```
|
| 214 |
+
|
| 215 |
+
Control the download location with `--wandb_download_dir` (defaults to `./artifacts/`).
|
| 216 |
+
|
| 217 |
+
### Resuming a Training Run
|
| 218 |
+
|
| 219 |
+
A completed training checkpoint can be used as the starting point for a new run that continues where the previous one left off. This is useful when extending the training budget or when a preempted job needs to be restarted.
|
| 220 |
+
|
| 221 |
+
**Offline resume:**
|
| 222 |
+
|
| 223 |
+
```bash
|
| 224 |
+
# Auto-detect step and wandb run ID from checkpoint metadata
|
| 225 |
+
python main.py --mode offline \
|
| 226 |
+
--ppo_checkpoint_path /path/to/ppo_checkpoint \
|
| 227 |
+
--resume_checkpoint_path /path/to/completed_offline_checkpoint \
|
| 228 |
+
--offline_total_timesteps 200000000 \
|
| 229 |
+
--save_policy
|
| 230 |
+
|
| 231 |
+
# Explicit step and wandb run ID override
|
| 232 |
+
python main.py --mode offline \
|
| 233 |
+
--ppo_checkpoint_path /path/to/ppo_checkpoint \
|
| 234 |
+
--resume_checkpoint_path /path/to/completed_offline_checkpoint \
|
| 235 |
+
--resume_step 1525 \
|
| 236 |
+
--resume_wandb_run_id abc123xyz \
|
| 237 |
+
--offline_total_timesteps 200000000 \
|
| 238 |
+
--save_policy
|
| 239 |
+
|
| 240 |
+
# Resume from a W&B artifact
|
| 241 |
+
python main.py --mode offline \
|
| 242 |
+
--ppo_checkpoint_path /path/to/ppo_checkpoint \
|
| 243 |
+
--resume_checkpoint_path wandb:my-team/remdm-craftax/policy:latest \
|
| 244 |
+
--offline_total_timesteps 200000000 \
|
| 245 |
+
--save_policy
|
| 246 |
+
```
|
| 247 |
+
|
| 248 |
+
**Online resume:**
|
| 249 |
+
|
| 250 |
+
```bash
|
| 251 |
+
python main.py --mode online \
|
| 252 |
+
--ppo_checkpoint_path /path/to/ppo_checkpoint \
|
| 253 |
+
--resume_checkpoint_path /path/to/completed_online_checkpoint \
|
| 254 |
+
--online_num_updates 2000 \
|
| 255 |
+
--save_policy
|
| 256 |
+
```
|
| 257 |
+
|
| 258 |
+
**Notes:**
|
| 259 |
+
- The DAgger replay buffer is **not** persisted across resumes. It starts empty and refills within the first few iterations.
|
| 260 |
+
- JIT compilation is fully preserved. Resume only affects initialisation outside `jax.jit` (loading checkpoint, setting the optimizer step counter, adjusting scan length).
|
| 261 |
+
- The cosine LR schedule is constructed for the full `num_updates` range. The optimizer step counter is set to the resume offset so the learning rate picks up exactly where the previous run stopped.
|
| 262 |
+
- When `resume_checkpoint_path` points to a checkpoint with a metadata sidecar, `resume_step` and `resume_wandb_run_id` are auto-detected. Explicit CLI flags override the metadata values.
|
| 263 |
+
- Checkpoints without a metadata sidecar (created before this feature) still load; provide `--resume_step` explicitly.
|
| 264 |
+
|
| 265 |
+
|
| 266 |
+
---
|
| 267 |
+
|
| 268 |
+
## Configuration
|
| 269 |
+
|
| 270 |
+
All hyperparameters are in `configs/defaults.yaml`. Override any value on the command line:
|
| 271 |
+
|
| 272 |
+
```bash
|
| 273 |
+
python main.py --mode offline --lr 1e-4 --plan_horizon 64 --num_minibatches 16
|
| 274 |
+
```
|
| 275 |
+
|
| 276 |
+
Point to a custom config file:
|
| 277 |
+
|
| 278 |
+
```bash
|
| 279 |
+
python main.py --mode online --config configs/my_experiment.yaml
|
| 280 |
+
```
|
| 281 |
+
|
| 282 |
+
Preset configs for larger runs are provided in `configs/`:
|
| 283 |
+
|
| 284 |
+
| File | Purpose |
|
| 285 |
+
|------|---------|
|
| 286 |
+
| `configs/defaults.yaml` | Base defaults for all modes |
|
| 287 |
+
| `configs/classic_exp_a_beta_fix.yaml` | Craftax Classic DAgger — beta decay fix only (isolates data quality) |
|
| 288 |
+
| `configs/classic_exp_b_beta_big_model.yaml` | Craftax Classic DAgger — beta fix + 3.5× larger transformer |
|
| 289 |
+
| `configs/classic_exp_c_full_recipe.yaml` | Craftax Classic DAgger — beta + big model + training dynamics |
|
| 290 |
+
| `configs/craftax_exp_a_beta_fix.yaml` | Full Craftax DAgger — beta decay fix only |
|
| 291 |
+
| `configs/craftax_exp_b_beta_big_model.yaml` | Full Craftax DAgger — beta fix + larger transformer |
|
| 292 |
+
| `configs/craftax_exp_c_full_recipe.yaml` | Full Craftax DAgger — full recipe |
|
| 293 |
+
| `configs/final_classic_ucl.yaml` | Final Craftax Classic DAgger — UCL 3090 Ti, seed 42 (produces the Classic checkpoint consumed by the ablation suite) |
|
| 294 |
+
| `configs/final_classic_qmul.yaml` | Env-frame-matched second seed of `final_classic_ucl.yaml` (QMUL H200, seed 43) |
|
| 295 |
+
| `configs/final_craftax_ucl.yaml` | Final Full Craftax DAgger — UCL 4090, seed 42 (produces the Full Craftax checkpoint consumed by the ablation suite) |
|
| 296 |
+
| `configs/final_craftax_qmul.yaml` | Env-frame-matched second seed of `final_craftax_ucl.yaml` (QMUL H200, seed 43) |
|
| 297 |
+
|
| 298 |
+
RL fine-tuning ablation hyperparameters live under `experiments/rl_finetuning/configs/` and are loaded by `run_ablations.py`, not by `main.py`. See `experiments/README.md`.
|
| 299 |
+
|
| 300 |
+
The `final_*_qmul.yaml` presets differ from their UCL counterparts only in `num_envs` (smaller partition) and `seed`. All fairness-critical hyperparameters are denominated in env frames or update cycles and automatically rescaled by `resolve_scaled_hyperparams()` at load time, so no manual derivation is needed when moving between hardware tiers.
|
| 301 |
+
|
| 302 |
+
### Key hyperparameters
|
| 303 |
+
|
| 304 |
+
**Environment**
|
| 305 |
+
|
| 306 |
+
| Parameter | Default | Description |
|
| 307 |
+
|-----------|---------|-------------|
|
| 308 |
+
| `env_name` | `Craftax-Classic-Symbolic-v1` | Craftax environment ID. Use `Craftax-Symbolic-v1` for Full Craftax. |
|
| 309 |
+
| `use_optimistic_resets` | `false` | Use `OptimisticResetVecEnvWrapper` instead of `AutoResetEnvWrapper` |
|
| 310 |
+
| `optimistic_reset_ratio` | 16 | Fraction of envs reset per step when optimistic resets are enabled |
|
| 311 |
+
|
| 312 |
+
**Diffusion model**
|
| 313 |
+
|
| 314 |
+
| Parameter | Default | Description |
|
| 315 |
+
|-----------|---------|-------------|
|
| 316 |
+
| `plan_horizon` | 32 | Action plan length H |
|
| 317 |
+
| `diffusion_steps` | 15 | Denoising steps T at inference |
|
| 318 |
+
| `diffusion_schedule` | `cosine` | Noise schedule: `cosine` or `linear` |
|
| 319 |
+
| `remask_strategy` | `rescale` | Remasking strategy: `rescale`, `cap`, or `conf` |
|
| 320 |
+
| `train_sigma` | 0.0 | Per-token remasking correction during training (0 = standard MDLM) |
|
| 321 |
+
| `label_smoothing` | 0.0 | Cross-entropy label smoothing epsilon (0 = exact ELBO) |
|
| 322 |
+
| `eta` | 0.5 | Remasking strength |
|
| 323 |
+
| `use_loop` | `true` | Three-phase loop remasking (Algorithm 3) |
|
| 324 |
+
| `t_on` / `t_off` | 0.7 / 0.3 | Time window boundaries for loop remasking |
|
| 325 |
+
| `temperature` | 0.5 | Softmax temperature for token sampling |
|
| 326 |
+
| `top_p` | 0.95 | Nucleus sampling threshold |
|
| 327 |
+
|
| 328 |
+
**Transformer architecture**
|
| 329 |
+
|
| 330 |
+
| Parameter | Default | Description |
|
| 331 |
+
|-----------|---------|-------------|
|
| 332 |
+
| `d_model` | 256 | Hidden dimension |
|
| 333 |
+
| `n_heads` | 4 | Attention heads |
|
| 334 |
+
| `n_layers` | 4 | Transformer blocks |
|
| 335 |
+
| `d_ff` | 512 | FFN inner dimension |
|
| 336 |
+
| `obs_encoder_layers` | 2 | MLP layers in the observation encoder |
|
| 337 |
+
| `obs_encoder_width` | 512 | Observation encoder hidden width |
|
| 338 |
+
| `dropout_rate` | 0.1 | Dropout rate (disabled at inference) |
|
| 339 |
+
|
| 340 |
+
**Offline training**
|
| 341 |
+
|
| 342 |
+
| Parameter | Default | Description |
|
| 343 |
+
|-----------|---------|-------------|
|
| 344 |
+
| `offline_total_timesteps` | 1e8 | **PRIMARY** env-frame budget for live-PPO data collection. Derives `num_updates` as `offline_total_timesteps // (num_envs * num_steps)`, making the run hardware-portable across `num_envs` changes. |
|
| 345 |
+
| `offline_num_updates` | `null` | **LEGACY** outer update count; used only when `offline_total_timesteps` is unset. |
|
| 346 |
+
| `num_envs` | 1024 | Parallel environments |
|
| 347 |
+
| `num_steps` | 64 | Environment steps collected per update |
|
| 348 |
+
| `num_minibatches` | 8 | Gradient minibatches per epoch |
|
| 349 |
+
| `update_epochs` | 4 | SGD epochs per update step |
|
| 350 |
+
| `num_repeats` | 1 | Independent training seeds (vmapped) |
|
| 351 |
+
| `lr` | 3e-4 | Adam learning rate (cosine-decayed to 10% over all gradient steps) |
|
| 352 |
+
| `lr_warmup_frames` | `null` | **PRIMARY** env-frame warm-up budget. Derives `lr_warmup_steps` as `lr_warmup_frames // (num_envs * num_steps)`. |
|
| 353 |
+
| `lr_warmup_steps` | 0 | **LEGACY** linear warm-up steps before cosine decay (used when `lr_warmup_frames` is unset; 0 = disabled). |
|
| 354 |
+
| `max_grad_norm` | 1.0 | Global gradient clipping norm |
|
| 355 |
+
| `return_weight_cap` | 5.0 | Clip ceiling for per-window return weights (lower clip is fixed at 0.1) |
|
| 356 |
+
| `collect_temperature` | 1.0 | Softmax temperature on PPO logits during live data collection |
|
| 357 |
+
| `val_interval_frames` | `null` | **PRIMARY** env-frames between validation rollouts. Overrides `val_interval` via `val_interval = val_interval_frames // (num_envs * num_steps)`. |
|
| 358 |
+
| `val_interval` | 50 | **LEGACY** validation frequency in update steps (used when `val_interval_frames` is unset). |
|
| 359 |
+
| `val_diffusion_steps` | 50 | Denoising steps used during validation rollouts |
|
| 360 |
+
| `val_replan_every` | 4 | Environment steps executed per diffusion plan during validation |
|
| 361 |
+
| `val_steps` | 128 | Total environment steps per validation rollout |
|
| 362 |
+
|
| 363 |
+
**Online DAgger training**
|
| 364 |
+
|
| 365 |
+
| Parameter | Default | Description |
|
| 366 |
+
|-----------|---------|-------------|
|
| 367 |
+
| `online_total_timesteps` | `null` | **PRIMARY** env-frame budget for online DAgger (hardware-portable). Derives `num_updates` as `online_total_timesteps // (num_envs * num_steps)`. |
|
| 368 |
+
| `online_num_updates` | 1000 | **LEGACY** outer DAgger iterations (used when `online_total_timesteps` is unset). |
|
| 369 |
+
| `dagger_beta_init` | 1.0 | Initial expert mixing probability `beta_1` (1.0 = pure expert on the first iteration). |
|
| 370 |
+
| `dagger_beta_final` | `null` | **PRIMARY** target mixing ratio at the end of training. Overrides `dagger_beta_decay` via `decay = (beta_final / beta_init) ** (1 / num_updates)`. |
|
| 371 |
+
| `dagger_beta_decay` | 0.95 | **LEGACY** per-update decay: `beta_i = beta_init * decay^i` (used when `dagger_beta_final` is unset). |
|
| 372 |
+
| `dagger_buffer_cycles` | `null` | **PRIMARY** buffer capacity denominated in update cycles of history (1 cycle = `num_envs * num_steps` frames). Overrides `dagger_buffer_max` via `buffer_max = cycles * (num_envs * num_steps)`. |
|
| 373 |
+
| `dagger_buffer_max` | 100000 | **LEGACY** max samples in the DAgger replay buffer (circular eviction when full). |
|
| 374 |
+
| `dagger_train_passes` | `null` | Passes per update over the aggregated buffer. `null` = 1 pass (matches offline BC per-update gradient work exactly for fair compute comparison). Raise to >1 to trade BC fairness for wider per-update buffer coverage. |
|
| 375 |
+
| `dagger_expert_deterministic` | `true` | If `true`, the PPO expert takes the argmax action (fixed `s → a*` map); if `false`, it samples categorically. Deterministic removes label noise from the aggregated dataset. |
|
| 376 |
+
|
| 377 |
+
**Data collection**
|
| 378 |
+
|
| 379 |
+
| Parameter | Default | Description |
|
| 380 |
+
|-----------|---------|-------------|
|
| 381 |
+
| `collect_num_steps` | 10000000 | Total environment steps to collect |
|
| 382 |
+
| `collect_num_envs` | 128 | Parallel environments during collection |
|
| 383 |
+
| `ppo_model_type` | `ppo_rnn` | PPO architecture: `ppo`, `ppo_rnn`, or `ppo_rnd` |
|
| 384 |
+
| `layer_size` | 512 | PPO actor-critic hidden layer width |
|
| 385 |
+
|
| 386 |
+
**Inference**
|
| 387 |
+
|
| 388 |
+
| Parameter | Default | Description |
|
| 389 |
+
|-----------|---------|-------------|
|
| 390 |
+
| `eval_steps` | 10000 | Environment steps for evaluation |
|
| 391 |
+
| `eval_num_envs` | 32 | Parallel agents during evaluation (independent of `num_envs`) |
|
| 392 |
+
| `diffusion_steps_eval` | 10 | Denoising steps T used at evaluation time |
|
| 393 |
+
|
| 394 |
+
**Checkpointing**
|
| 395 |
+
|
| 396 |
+
| Parameter | Default | Description |
|
| 397 |
+
|-----------|---------|-------------|
|
| 398 |
+
| `save_policy` | `true` | Save final checkpoint at end of training and upload it as a W&B artifact |
|
| 399 |
+
|
| 400 |
+
**Resume**
|
| 401 |
+
|
| 402 |
+
| Parameter | Default | Description |
|
| 403 |
+
|-----------|---------|-------------|
|
| 404 |
+
| `resume_checkpoint_path` | `null` | Path to a completed checkpoint to resume from (accepts `wandb:` refs) |
|
| 405 |
+
| `resume_wandb_run_id` | `null` | W&B run ID to resume logging into (auto-read from checkpoint metadata) |
|
| 406 |
+
| `resume_step` | `null` | Update step the checkpoint was saved at (auto-read from checkpoint metadata) |
|
| 407 |
+
|
| 408 |
+
**Logging**
|
| 409 |
+
|
| 410 |
+
| Parameter | Default | Description |
|
| 411 |
+
|-----------|---------|-------------|
|
| 412 |
+
| `use_wandb` | `true` | Enable Weights & Biases logging |
|
| 413 |
+
| `wandb_project` | `remdm-craftax` | W&B project name |
|
| 414 |
+
| `wandb_entity` | `"mathis-weil-university-college-london-ucl-"` | W&B entity (team or username) |
|
| 415 |
+
| `wandb_download_dir` | `null` | Download directory for W&B artifacts; null = `./artifacts/` |
|
| 416 |
+
| `seed` | `null` | RNG seed (random if null) |
|
| 417 |
+
|
| 418 |
+
---
|
| 419 |
+
|
| 420 |
+
## Remasking Strategies
|
| 421 |
+
|
| 422 |
+
Controlled by `--remask_strategy`. All strategies operate on top of the three-phase loop controlled by `--use_loop`, `--t_on`, and `--t_off`.
|
| 423 |
+
|
| 424 |
+
| Strategy | Formula | Description |
|
| 425 |
+
|----------|---------|-------------|
|
| 426 |
+
| `rescale` | `sigma = eta * sigma_max` | Scales maximum remasking probability proportionally |
|
| 427 |
+
| `cap` | `sigma = min(eta, sigma_max)` | Caps remasking at a fixed rate |
|
| 428 |
+
| `conf` | `sigma = eta * sigma_max * (1 - confidence)` | High-confidence tokens are remasked less |
|
| 429 |
+
|
| 430 |
+
---
|
| 431 |
+
|
| 432 |
+
## Environment Wrappers
|
| 433 |
+
|
| 434 |
+
**From `Craftax_Baselines/wrappers.py`** (submodule):
|
| 435 |
+
|
| 436 |
+
| Wrapper | Purpose |
|
| 437 |
+
|---------|---------|
|
| 438 |
+
| `LogWrapper` | Tracks episode returns and lengths; adds stats to the info dict |
|
| 439 |
+
| `AutoResetEnvWrapper` | Automatically resets episodes on `done` |
|
| 440 |
+
| `BatchEnvWrapper` | Vmaps `reset` and `step` over `num_envs` environments |
|
| 441 |
+
| `OptimisticResetVecEnvWrapper` | Batched resets with reduced overhead; enable via `--use_optimistic_resets` |
|
| 442 |
+
|
| 443 |
+
**From `src/envs/wrappers.py`**:
|
| 444 |
+
|
| 445 |
+
| Wrapper | Purpose |
|
| 446 |
+
|---------|---------|
|
| 447 |
+
| `SequenceHistoryWrapper` | Maintains a sliding window of past observations and actions in the env state |
|
| 448 |
+
| `DiscreteTokenizationWrapper` | Quantizes continuous observations into discrete token indices |
|
| 449 |
+
| `PlannerWrapper` | Manages the plan/replan cycle for the diffusion planner |
|
| 450 |
+
| `OfflineTrajectoryWrapper` | Accumulates transitions into a fixed-size circular replay buffer |
|
| 451 |
+
|
| 452 |
+
**Wrapper stacks:**
|
| 453 |
+
|
| 454 |
+
```
|
| 455 |
+
Training: env -> LogWrapper -> AutoResetEnvWrapper -> BatchEnvWrapper
|
| 456 |
+
Inference: env -> LogWrapper -> AutoResetEnvWrapper -> BatchEnvWrapper
|
| 457 |
+
```
|
| 458 |
+
|
| 459 |
+
---
|
| 460 |
+
|
| 461 |
+
## Project Structure
|
| 462 |
+
|
| 463 |
+
```
|
| 464 |
+
craftax-ReMDM-planner/
|
| 465 |
+
├── Craftax_Baselines/ # Git submodule — PPO agents and standard wrappers
|
| 466 |
+
│ ├── wrappers.py # LogWrapper, BatchEnvWrapper, AutoResetEnvWrapper, etc.
|
| 467 |
+
│ ├── ppo_rnn.py # PPO-RNN training script
|
| 468 |
+
│ ├── ppo_rnd.py # PPO-RND training script
|
| 469 |
+
│ ├── ppo.py # PPO model definitions
|
| 470 |
+
│ └── models/
|
| 471 |
+
│ ├── actor_critic.py # ActorCritic variants
|
| 472 |
+
│ ├── rnd.py # RND network
|
| 473 |
+
│ └── icm.py # ICM encoder, forward, and inverse networks
|
| 474 |
+
├── configs/
|
| 475 |
+
│ ├── defaults.yaml # Base hyperparameters (CLI-overridable)
|
| 476 |
+
│ ├── classic_exp_a_beta_fix.yaml # Classic DAgger — beta decay fix only
|
| 477 |
+
│ ├── classic_exp_b_beta_big_model.yaml # Classic DAgger — beta fix + big model
|
| 478 |
+
│ ├── classic_exp_c_full_recipe.yaml # Classic DAgger — full recipe
|
| 479 |
+
│ ├── craftax_exp_a_beta_fix.yaml # Full Craftax DAgger — beta fix
|
| 480 |
+
│ ├── craftax_exp_b_beta_big_model.yaml # Full Craftax DAgger — beta + big model
|
| 481 |
+
│ ├── craftax_exp_c_full_recipe.yaml # Full Craftax DAgger — full recipe
|
| 482 |
+
│ ├── final_classic_ucl.yaml # Classic DAgger — UCL 3090 Ti, seed 42
|
| 483 |
+
│ ├── final_classic_qmul.yaml # Classic DAgger — QMUL H200, seed 43
|
| 484 |
+
│ ├── final_craftax_ucl.yaml # Full Craftax DAgger — UCL 4090, seed 42
|
| 485 |
+
│ └── final_craftax_qmul.yaml # Full Craftax DAgger — QMUL H200, seed 43
|
| 486 |
+
├── src/
|
| 487 |
+
│ ├── diffusion/
|
| 488 |
+
│ │ ├── forward.py # Forward masking process q(z_t | x_0)
|
| 489 |
+
│ │ ├── loss.py # Continuous-time MDLM ELBO loss
|
| 490 |
+
│ │ ├── sampling.py # Reverse diffusion with ReMDM remasking
|
| 491 |
+
│ │ └── schedules.py # Linear and cosine noise schedules
|
| 492 |
+
│ ├── models/
|
| 493 |
+
│ │ └── denoiser.py # DenoisingTransformer (obs encoder + transformer)
|
| 494 |
+
│ ├── envs/
|
| 495 |
+
│ │ └── wrappers.py # Sequence, tokenization, planner, and trajectory wrappers
|
| 496 |
+
│ └── planners/
|
| 497 |
+
│ ├── collect.py # --mode collect: PPO rollouts -> .npz
|
| 498 |
+
│ ├── common.py # Shared utilities
|
| 499 |
+
│ ├── env.py # Environment construction
|
| 500 |
+
│ ├── inference.py # --mode inference: MPC evaluation with inpainting
|
| 501 |
+
│ ├── logging.py # Centralised W&B logging utilities
|
| 502 |
+
│ ├── model.py # Diffusion model lifecycle
|
| 503 |
+
│ ├── offline.py # --mode offline: make_train (live PPO rollouts)
|
| 504 |
+
│ ├── online.py # --mode online: DAgger fine-tuning
|
| 505 |
+
│ └── ppo.py # PPO agent adapter and checkpoint loading utilities
|
| 506 |
+
├── experiments/
|
| 507 |
+
│ └── rl_finetuning/ # RL fine-tuning ablation suite (see experiments/README.md)
|
| 508 |
+
│ ├── run_ablations.py # CLI entry point
|
| 509 |
+
│ ├── ablations/ # Loss, optimizer, registry, and training modules
|
| 510 |
+
│ ├── diagnostics/ # Gradient, representation, and timestep diagnostics
|
| 511 |
+
│ ├── analysis/ # Plots, tables, and report generation
|
| 512 |
+
│ └── configs/ # ablations_default.yaml, ablations_fast.yaml,
|
| 513 |
+
│ # ablations_final_{classic,craftax}_{ucl,qmul}.yaml
|
| 514 |
+
├── main.py # CLI entry point
|
| 515 |
+
├── pyproject.toml # uv project — direct deps + tool config
|
| 516 |
+
└── uv.lock # Reproducible lockfile (commit this)
|
| 517 |
+
```
|
| 518 |
+
|
| 519 |
+
---
|
| 520 |
+
|
| 521 |
+
## Implementation Notes
|
| 522 |
+
|
| 523 |
+
**JAX functional purity**: training closures (`make_train`, `make_train_dagger`) are fully JIT-compatible. Environment construction and checkpoint I/O happen outside `jax.jit`.
|
| 524 |
+
|
| 525 |
+
**Offline training**: `--mode offline` rolls out the PPO agent live at each update step via `make_train`. Use `--mode collect` to save a trajectory `.npz` for inspection or analysis; re-feeding it to `--mode offline` is not supported — pass `--ppo_checkpoint_path` instead.
|
| 526 |
+
|
| 527 |
+
**Episode-boundary masking**: the offline sampler pre-computes a validity mask over all `(env, time)` positions. A window at `(e, t)` is valid only if `dones[e, t+1:t+H-1]` are all `False`.
|
| 528 |
+
|
| 529 |
+
**Return weighting**: valid windows are weighted by their cumulative reward, normalised by the batch mean and clipped to `[0.1, RETURN_WEIGHT_CAP]`. Weights are passed as per-sample multipliers into the MDLM loss before reduction, so they correctly scale each sample's gradient contribution.
|
| 530 |
+
|
| 531 |
+
**LR schedule**: cosine decay from `lr` to `lr * 0.1` over all gradient steps. Set `lr_warmup_frames > 0` (env-frame-invariant, PRIMARY) or `lr_warmup_steps > 0` (LEGACY) to prepend a linear warm-up phase.
|
| 532 |
+
|
| 533 |
+
**Env-frame-invariant hyperparameters**: the PRIMARY keys `offline_total_timesteps`, `online_total_timesteps`, `lr_warmup_frames`, `val_interval_frames`, `dagger_beta_final`, and `dagger_buffer_cycles` are denominated in env frames (or update cycles). At config load time, `resolve_scaled_hyperparams()` in `src/planners/common.py` converts them to the equivalent update-step-denominated quantities (`num_updates`, `lr_warmup_steps`, `val_interval`, `dagger_beta_decay`, `dagger_buffer_max`) using the current `num_envs * num_steps` frames-per-update. This lets the same config run on different hardware tiers without re-tuning.
|
| 534 |
+
|
| 535 |
+
**Loss weight clipping**: the MDLM SUBS weight `-alpha'(t) / (1 - alpha_t)` is clipped to 1000 to prevent numerical instability when `alpha_t ≈ 1`.
|
| 536 |
+
|
| 537 |
+
**Validation rollouts**: during offline training, a held-out rollout runs every `val_interval` steps. It uses the same sampling parameters as inference (`remask_strategy`, `eta`, `use_loop`, `t_on`, `t_off`, `temperature`, `top_p`) with `val_diffusion_steps` denoising steps and `val_replan_every` env steps per plan, for a total of `val_steps` environment steps.
|
| 538 |
+
|
| 539 |
+
**W&B logging**: all metric aggregation is centralised in `src/planners/logging.py`. Metric namespaces: `diffusion/` (loss, accuracy), `train/` (data quality, throughput), `env/` (episode returns, achievements), `val/` (validation rollouts, emitted every `val_interval` steps), `dagger/` (online DAgger training: beta, buffer fill, reward mean, valid fraction). `train/sps` (environment frames/sec) is only logged in modes that perform live environment interaction.
|
| 540 |
+
|
| 541 |
+
**DAgger dataset aggregation**: online training (`--mode online`) implements DAgger (Ross et al., 2011). A circular replay buffer accumulates `(obs, expert_plan)` pairs across all iterations. Each update samples uniformly from the full buffer, not just the latest batch. Training samples that cross episode boundaries (any `done` within the plan-horizon window) are marked invalid. The expert (PPO agent) receives correct `done` flags so its RNN hidden state resets on episode boundaries. Windows are extracted with a sliding stride (one per env-time position) rather than stepping the buffer in plan-horizon chunks, so every visited state contributes a label.
|
| 542 |
+
|
| 543 |
+
**Best-checkpoint tracking**: during online training, the parameters from the validation iteration with the highest validation return are preserved alongside the current live parameters. The final checkpoint and the best-validation checkpoint are both uploaded as separate W&B artifacts (`{env_name}-policy` and `{env_name}-policy-best`).
|
| 544 |
+
|
| 545 |
+
**Denoising step indexing**: the reverse scan runs from `step_idx = 0` to `T-1`, mapping to diffusion time `t = (T - step_idx) / T` (high noise to low noise).
|
| 546 |
+
|
| 547 |
+
**Submodule PPO agents**: PPO training lives entirely in `Craftax_Baselines/`. Planner scripts only consume pre-trained checkpoints via `--ppo_checkpoint_path`.
|
checkpoints/offline/Craftax-Classic-Symbolic-v1-OfflineDiffusion-BC-100M/100000000/_CHECKPOINT_METADATA
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
{"item_handlers": {"default": "orbax.checkpoint._src.handlers.standard_checkpoint_handler.StandardCheckpointHandler"}, "metrics": {}, "performance_metrics": {}, "init_timestamp_nsecs": 1775663434533263974, "commit_timestamp_nsecs": 1775663435644779625, "custom_metadata": {}}
|
checkpoints/offline/Craftax-Classic-Symbolic-v1-OfflineDiffusion-BC-100M/100000000/default/_METADATA
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
checkpoints/offline/Craftax-Classic-Symbolic-v1-OfflineDiffusion-BC-100M/100000000/default/_sharding
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
{"b3B0X3N0YXRlLjEuMC5jb3VudA==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuRGVuc2VfMC5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuRGVuc2VfMC5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuRGVuc2VfMS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuRGVuc2VfMS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuRGVuc2VfMi5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuRGVuc2VfMi5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuRGVuc2VfMy5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuRGVuc2VfMy5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuRGVuc2VfNC5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuRGVuc2VfNC5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuRGVuc2VfNS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuRGVuc2VfNS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuRW1iZWRfMC5lbWJlZGRpbmc=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuTGF5ZXJOb3JtXzAuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuTGF5ZXJOb3JtXzAuc2NhbGU=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuTGF5ZXJOb3JtXzEuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuTGF5ZXJOb3JtXzEuc2NhbGU=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja180Lk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC52YWx1ZS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja180Lk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC52YWx1ZS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja180Lk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5rZXkuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja180Lk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5rZXkua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja180Lk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5vdXQuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja180Lk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5vdXQua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja180Lk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5xdWVyeS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja180Lk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5xdWVyeS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja180LkRlbnNlXzAuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja180LkRlbnNlXzAua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja180LkRlbnNlXzEuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja180LkRlbnNlXzEua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja180LkxheWVyTm9ybV8wLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja180LkxheWVyTm9ybV8wLnNjYWxl":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja180LkxheWVyTm9ybV8xLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja180LkxheWVyTm9ybV8xLnNjYWxl":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja181Lk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC52YWx1ZS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja181Lk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC52YWx1ZS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja181Lk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5rZXkuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja181Lk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5rZXkua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja181Lk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5vdXQuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja181Lk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5vdXQua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja181Lk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5xdWVyeS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja181Lk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5xdWVyeS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja181LkRlbnNlXzAuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja181LkRlbnNlXzAua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja181LkRlbnNlXzEuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja181LkRlbnNlXzEua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja181LkxheWVyTm9ybV8wLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja181LkxheWVyTm9ybV8wLnNjYWxl":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja181LkxheWVyTm9ybV8xLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja181LkxheWVyTm9ybV8xLnNjYWxl":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18wLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC52YWx1ZS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18wLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC52YWx1ZS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18wLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5rZXkuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18wLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5rZXkua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18wLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5vdXQuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18wLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5vdXQua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18wLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5xdWVyeS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18wLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5xdWVyeS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18wLkRlbnNlXzAuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18wLkRlbnNlXzAua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18wLkRlbnNlXzEuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18wLkRlbnNlXzEua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18wLkxheWVyTm9ybV8wLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18wLkxheWVyTm9ybV8wLnNjYWxl":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18wLkxheWVyTm9ybV8xLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18wLkxheWVyTm9ybV8xLnNjYWxl":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18xLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC52YWx1ZS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18xLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC52YWx1ZS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18xLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5rZXkuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18xLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5rZXkua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18xLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5vdXQuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18xLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5vdXQua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18xLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5xdWVyeS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18xLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5xdWVyeS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18xLkRlbnNlXzAuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18xLkRlbnNlXzAua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18xLkRlbnNlXzEuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18xLkRlbnNlXzEua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18xLkxheWVyTm9ybV8wLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18xLkxheWVyTm9ybV8wLnNjYWxl":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18xLkxheWVyTm9ybV8xLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18xLkxheWVyTm9ybV8xLnNjYWxl":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18yLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC52YWx1ZS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18yLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC52YWx1ZS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18yLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5rZXkuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18yLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5rZXkua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18yLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5vdXQuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18yLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5vdXQua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18yLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5xdWVyeS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18yLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5xdWVyeS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18yLkRlbnNlXzAuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18yLkRlbnNlXzAua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18yLkRlbnNlXzEuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18yLkRlbnNlXzEua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18yLkxheWVyTm9ybV8wLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18yLkxheWVyTm9ybV8wLnNjYWxl":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18yLkxheWVyTm9ybV8xLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18yLkxheWVyTm9ybV8xLnNjYWxl":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18zLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC52YWx1ZS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18zLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC52YWx1ZS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18zLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5rZXkuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18zLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5rZXkua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18zLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5vdXQuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18zLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5vdXQua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18zLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5xdWVyeS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18zLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5xdWVyeS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18zLkRlbnNlXzAuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18zLkRlbnNlXzAua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18zLkRlbnNlXzEuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18zLkRlbnNlXzEua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18zLkxheWVyTm9ybV8wLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18zLkxheWVyTm9ybV8wLnNjYWxl":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18zLkxheWVyTm9ybV8xLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18zLkxheWVyTm9ybV8xLnNjYWxl":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuRGVuc2VfMC5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuRGVuc2VfMC5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuRGVuc2VfMS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuRGVuc2VfMS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuRGVuc2VfMi5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuRGVuc2VfMi5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuRGVuc2VfMy5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuRGVuc2VfMy5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuRGVuc2VfNC5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuRGVuc2VfNC5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuRGVuc2VfNS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuRGVuc2VfNS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuRW1iZWRfMC5lbWJlZGRpbmc=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuTGF5ZXJOb3JtXzAuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuTGF5ZXJOb3JtXzAuc2NhbGU=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuTGF5ZXJOb3JtXzEuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuTGF5ZXJOb3JtXzEuc2NhbGU=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja180Lk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC52YWx1ZS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja180Lk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC52YWx1ZS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja180Lk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5rZXkuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja180Lk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5rZXkua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja180Lk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5vdXQuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja180Lk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5vdXQua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja180Lk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5xdWVyeS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja180Lk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5xdWVyeS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja180LkRlbnNlXzAuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja180LkRlbnNlXzAua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja180LkRlbnNlXzEuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja180LkRlbnNlXzEua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja180LkxheWVyTm9ybV8wLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja180LkxheWVyTm9ybV8wLnNjYWxl":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja180LkxheWVyTm9ybV8xLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja180LkxheWVyTm9ybV8xLnNjYWxl":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja181Lk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC52YWx1ZS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja181Lk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC52YWx1ZS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja181Lk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5rZXkuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja181Lk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5rZXkua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja181Lk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5vdXQuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja181Lk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5vdXQua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja181Lk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5xdWVyeS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja181Lk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5xdWVyeS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja181LkRlbnNlXzAuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja181LkRlbnNlXzAua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja181LkRlbnNlXzEuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja181LkRlbnNlXzEua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja181LkxheWVyTm9ybV8wLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja181LkxheWVyTm9ybV8wLnNjYWxl":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja181LkxheWVyTm9ybV8xLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja181LkxheWVyTm9ybV8xLnNjYWxl":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18wLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC52YWx1ZS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18wLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC52YWx1ZS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18wLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5rZXkuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18wLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5rZXkua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18wLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5vdXQuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18wLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5vdXQua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18wLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5xdWVyeS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18wLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5xdWVyeS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18wLkRlbnNlXzAuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18wLkRlbnNlXzAua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18wLkRlbnNlXzEuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18wLkRlbnNlXzEua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18wLkxheWVyTm9ybV8wLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18wLkxheWVyTm9ybV8wLnNjYWxl":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18wLkxheWVyTm9ybV8xLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18wLkxheWVyTm9ybV8xLnNjYWxl":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18xLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC52YWx1ZS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18xLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC52YWx1ZS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18xLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5rZXkuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18xLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5rZXkua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18xLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5vdXQuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18xLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5vdXQua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18xLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5xdWVyeS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18xLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5xdWVyeS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18xLkRlbnNlXzAuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18xLkRlbnNlXzAua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18xLkRlbnNlXzEuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18xLkRlbnNlXzEua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18xLkxheWVyTm9ybV8wLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18xLkxheWVyTm9ybV8wLnNjYWxl":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18xLkxheWVyTm9ybV8xLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18xLkxheWVyTm9ybV8xLnNjYWxl":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18yLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC52YWx1ZS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18yLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC52YWx1ZS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18yLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5rZXkuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18yLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5rZXkua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18yLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5vdXQuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18yLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5vdXQua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18yLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5xdWVyeS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18yLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5xdWVyeS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18yLkRlbnNlXzAuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18yLkRlbnNlXzAua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18yLkRlbnNlXzEuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18yLkRlbnNlXzEua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18yLkxheWVyTm9ybV8wLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18yLkxheWVyTm9ybV8wLnNjYWxl":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18yLkxheWVyTm9ybV8xLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18yLkxheWVyTm9ybV8xLnNjYWxl":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18zLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC52YWx1ZS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18zLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC52YWx1ZS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18zLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5rZXkuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18zLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5rZXkua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18zLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5vdXQuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18zLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5vdXQua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18zLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5xdWVyeS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18zLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5xdWVyeS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18zLkRlbnNlXzAuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18zLkRlbnNlXzAua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18zLkRlbnNlXzEuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18zLkRlbnNlXzEua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18zLkxheWVyTm9ybV8wLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18zLkxheWVyTm9ybV8wLnNjYWxl":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18zLkxheWVyTm9ybV8xLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18zLkxheWVyTm9ybV8xLnNjYWxl":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMS5jb3VudA==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","c3RlcA==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5EZW5zZV80LmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5EZW5zZV80Lmtlcm5lbA==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5EZW5zZV81LmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5EZW5zZV81Lmtlcm5lbA==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5EZW5zZV8wLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5EZW5zZV8wLmtlcm5lbA==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5EZW5zZV8xLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5EZW5zZV8xLmtlcm5lbA==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5EZW5zZV8yLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5EZW5zZV8yLmtlcm5lbA==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5EZW5zZV8zLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5EZW5zZV8zLmtlcm5lbA==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5FbWJlZF8wLmVtYmVkZGluZw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5MYXllck5vcm1fMC5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5MYXllck5vcm1fMC5zY2FsZQ==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5MYXllck5vcm1fMS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5MYXllck5vcm1fMS5zY2FsZQ==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzAuRGVuc2VfMC5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzAuRGVuc2VfMC5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzAuRGVuc2VfMS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzAuRGVuc2VfMS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzAuTGF5ZXJOb3JtXzAuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzAuTGF5ZXJOb3JtXzAuc2NhbGU=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzAuTGF5ZXJOb3JtXzEuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzAuTGF5ZXJOb3JtXzEuc2NhbGU=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzAuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLm91dC5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzAuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLm91dC5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzAuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLmtleS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzAuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLmtleS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzAuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLnF1ZXJ5LmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzAuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLnF1ZXJ5Lmtlcm5lbA==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzAuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLnZhbHVlLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzAuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLnZhbHVlLmtlcm5lbA==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzEuRGVuc2VfMC5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzEuRGVuc2VfMC5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzEuRGVuc2VfMS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzEuRGVuc2VfMS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzEuTGF5ZXJOb3JtXzAuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzEuTGF5ZXJOb3JtXzAuc2NhbGU=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzEuTGF5ZXJOb3JtXzEuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzEuTGF5ZXJOb3JtXzEuc2NhbGU=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzEuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLm91dC5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzEuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLm91dC5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzEuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLmtleS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzEuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLmtleS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzEuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLnF1ZXJ5LmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzEuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLnF1ZXJ5Lmtlcm5lbA==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzEuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLnZhbHVlLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzEuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLnZhbHVlLmtlcm5lbA==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzIuRGVuc2VfMC5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzIuRGVuc2VfMC5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzIuRGVuc2VfMS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzIuRGVuc2VfMS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzIuTGF5ZXJOb3JtXzAuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzIuTGF5ZXJOb3JtXzAuc2NhbGU=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzIuTGF5ZXJOb3JtXzEuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzIuTGF5ZXJOb3JtXzEuc2NhbGU=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzIuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLm91dC5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzIuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLm91dC5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzIuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLmtleS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzIuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLmtleS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzIuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLnF1ZXJ5LmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzIuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLnF1ZXJ5Lmtlcm5lbA==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzIuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLnZhbHVlLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzIuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLnZhbHVlLmtlcm5lbA==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzMuRGVuc2VfMC5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzMuRGVuc2VfMC5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzMuRGVuc2VfMS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzMuRGVuc2VfMS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzMuTGF5ZXJOb3JtXzAuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzMuTGF5ZXJOb3JtXzAuc2NhbGU=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzMuTGF5ZXJOb3JtXzEuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzMuTGF5ZXJOb3JtXzEuc2NhbGU=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzMuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLm91dC5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzMuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLm91dC5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzMuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLmtleS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzMuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLmtleS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzMuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLnF1ZXJ5LmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzMuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLnF1ZXJ5Lmtlcm5lbA==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzMuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLnZhbHVlLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzMuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLnZhbHVlLmtlcm5lbA==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzQuRGVuc2VfMC5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzQuRGVuc2VfMC5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzQuRGVuc2VfMS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzQuRGVuc2VfMS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzQuTGF5ZXJOb3JtXzAuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzQuTGF5ZXJOb3JtXzAuc2NhbGU=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzQuTGF5ZXJOb3JtXzEuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzQuTGF5ZXJOb3JtXzEuc2NhbGU=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzQuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLm91dC5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzQuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLm91dC5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzQuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLmtleS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzQuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLmtleS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzQuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLnF1ZXJ5LmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzQuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLnF1ZXJ5Lmtlcm5lbA==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzQuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLnZhbHVlLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzQuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLnZhbHVlLmtlcm5lbA==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzUuRGVuc2VfMC5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzUuRGVuc2VfMC5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzUuRGVuc2VfMS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzUuRGVuc2VfMS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzUuTGF5ZXJOb3JtXzAuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzUuTGF5ZXJOb3JtXzAuc2NhbGU=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzUuTGF5ZXJOb3JtXzEuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzUuTGF5ZXJOb3JtXzEuc2NhbGU=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzUuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLm91dC5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzUuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLm91dC5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzUuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLmtleS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzUuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLmtleS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzUuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLnF1ZXJ5LmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzUuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLnF1ZXJ5Lmtlcm5lbA==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzUuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLnZhbHVlLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzUuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLnZhbHVlLmtlcm5lbA==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}"}
|
checkpoints/offline/Craftax-Classic-Symbolic-v1-OfflineDiffusion-BC-100M/100000000/default/array_metadatas/process_0
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
{"array_metadatas": [{"array_metadata": {"param_name": "step", "write_shape": [], "chunk_shape": [], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.Dense_0.bias", "write_shape": [768], "chunk_shape": [768], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.Dense_0.kernel", "write_shape": [1345, 768], "chunk_shape": [1345, 768], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.Dense_1.bias", "write_shape": [768], "chunk_shape": [768], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.Dense_1.kernel", "write_shape": [768, 768], "chunk_shape": [768, 768], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.Dense_2.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.Dense_2.kernel", "write_shape": [768, 384], "chunk_shape": [768, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.Dense_3.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.Dense_3.kernel", "write_shape": [384, 384], "chunk_shape": [384, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.Dense_4.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.Dense_4.kernel", "write_shape": [384, 384], "chunk_shape": [384, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.Dense_5.bias", "write_shape": [17], "chunk_shape": [17], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.Dense_5.kernel", "write_shape": [384, 17], "chunk_shape": [384, 17], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.Embed_0.embedding", "write_shape": [18, 384], "chunk_shape": [18, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.LayerNorm_0.bias", "write_shape": [768], "chunk_shape": [768], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.LayerNorm_0.scale", "write_shape": [768], "chunk_shape": [768], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.LayerNorm_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.LayerNorm_1.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_0.Dense_0.bias", "write_shape": [768], "chunk_shape": [768], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_0.Dense_0.kernel", "write_shape": [384, 768], "chunk_shape": [384, 768], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_0.Dense_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_0.Dense_1.kernel", "write_shape": [768, 384], "chunk_shape": [768, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_0.LayerNorm_0.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_0.LayerNorm_0.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_0.LayerNorm_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_0.LayerNorm_1.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_0.MultiHeadDotProductAttention_0.key.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_0.MultiHeadDotProductAttention_0.key.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_0.MultiHeadDotProductAttention_0.out.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_0.MultiHeadDotProductAttention_0.out.kernel", "write_shape": [8, 48, 384], "chunk_shape": [8, 48, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_0.MultiHeadDotProductAttention_0.query.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_0.MultiHeadDotProductAttention_0.query.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_0.MultiHeadDotProductAttention_0.value.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_0.MultiHeadDotProductAttention_0.value.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_1.Dense_0.bias", "write_shape": [768], "chunk_shape": [768], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_1.Dense_0.kernel", "write_shape": [384, 768], "chunk_shape": [384, 768], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_1.Dense_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_1.Dense_1.kernel", "write_shape": [768, 384], "chunk_shape": [768, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_1.LayerNorm_0.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_1.LayerNorm_0.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_1.LayerNorm_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_1.LayerNorm_1.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_1.MultiHeadDotProductAttention_0.key.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_1.MultiHeadDotProductAttention_0.key.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_1.MultiHeadDotProductAttention_0.out.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_1.MultiHeadDotProductAttention_0.out.kernel", "write_shape": [8, 48, 384], "chunk_shape": [8, 48, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_1.MultiHeadDotProductAttention_0.query.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_1.MultiHeadDotProductAttention_0.query.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_1.MultiHeadDotProductAttention_0.value.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_1.MultiHeadDotProductAttention_0.value.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_2.Dense_0.bias", "write_shape": [768], "chunk_shape": [768], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_2.Dense_0.kernel", "write_shape": [384, 768], "chunk_shape": [384, 768], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_2.Dense_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_2.Dense_1.kernel", "write_shape": [768, 384], "chunk_shape": [768, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_2.LayerNorm_0.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_2.LayerNorm_0.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_2.LayerNorm_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_2.LayerNorm_1.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_2.MultiHeadDotProductAttention_0.key.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_2.MultiHeadDotProductAttention_0.key.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_2.MultiHeadDotProductAttention_0.out.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_2.MultiHeadDotProductAttention_0.out.kernel", "write_shape": [8, 48, 384], "chunk_shape": [8, 48, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_2.MultiHeadDotProductAttention_0.query.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_2.MultiHeadDotProductAttention_0.query.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_2.MultiHeadDotProductAttention_0.value.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_2.MultiHeadDotProductAttention_0.value.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_3.Dense_0.bias", "write_shape": [768], "chunk_shape": [768], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_3.Dense_0.kernel", "write_shape": [384, 768], "chunk_shape": [384, 768], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_3.Dense_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_3.Dense_1.kernel", "write_shape": [768, 384], "chunk_shape": [768, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_3.LayerNorm_0.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_3.LayerNorm_0.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_3.LayerNorm_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_3.LayerNorm_1.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_3.MultiHeadDotProductAttention_0.key.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_3.MultiHeadDotProductAttention_0.key.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_3.MultiHeadDotProductAttention_0.out.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_3.MultiHeadDotProductAttention_0.out.kernel", "write_shape": [8, 48, 384], "chunk_shape": [8, 48, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_3.MultiHeadDotProductAttention_0.query.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_3.MultiHeadDotProductAttention_0.query.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_3.MultiHeadDotProductAttention_0.value.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_3.MultiHeadDotProductAttention_0.value.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_4.Dense_0.bias", "write_shape": [768], "chunk_shape": [768], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_4.Dense_0.kernel", "write_shape": [384, 768], "chunk_shape": [384, 768], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_4.Dense_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_4.Dense_1.kernel", "write_shape": [768, 384], "chunk_shape": [768, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_4.LayerNorm_0.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_4.LayerNorm_0.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_4.LayerNorm_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_4.LayerNorm_1.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_4.MultiHeadDotProductAttention_0.key.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_4.MultiHeadDotProductAttention_0.key.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_4.MultiHeadDotProductAttention_0.out.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_4.MultiHeadDotProductAttention_0.out.kernel", "write_shape": [8, 48, 384], "chunk_shape": [8, 48, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_4.MultiHeadDotProductAttention_0.query.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_4.MultiHeadDotProductAttention_0.query.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_4.MultiHeadDotProductAttention_0.value.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_4.MultiHeadDotProductAttention_0.value.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_5.Dense_0.bias", "write_shape": [768], "chunk_shape": [768], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_5.Dense_0.kernel", "write_shape": [384, 768], "chunk_shape": [384, 768], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_5.Dense_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_5.Dense_1.kernel", "write_shape": [768, 384], "chunk_shape": [768, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_5.LayerNorm_0.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_5.LayerNorm_0.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_5.LayerNorm_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_5.LayerNorm_1.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_5.MultiHeadDotProductAttention_0.key.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_5.MultiHeadDotProductAttention_0.key.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_5.MultiHeadDotProductAttention_0.out.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_5.MultiHeadDotProductAttention_0.out.kernel", "write_shape": [8, 48, 384], "chunk_shape": [8, 48, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_5.MultiHeadDotProductAttention_0.query.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_5.MultiHeadDotProductAttention_0.query.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_5.MultiHeadDotProductAttention_0.value.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_5.MultiHeadDotProductAttention_0.value.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.count", "write_shape": [], "chunk_shape": [], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.Dense_0.bias", "write_shape": [768], "chunk_shape": [768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.Dense_0.kernel", "write_shape": [1345, 768], "chunk_shape": [1345, 768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.Dense_1.bias", "write_shape": [768], "chunk_shape": [768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.Dense_1.kernel", "write_shape": [768, 768], "chunk_shape": [768, 768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.Dense_2.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.Dense_2.kernel", "write_shape": [768, 384], "chunk_shape": [768, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.Dense_3.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.Dense_3.kernel", "write_shape": [384, 384], "chunk_shape": [384, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.Dense_4.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.Dense_4.kernel", "write_shape": [384, 384], "chunk_shape": [384, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.Dense_5.bias", "write_shape": [17], "chunk_shape": [17], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.Dense_5.kernel", "write_shape": [384, 17], "chunk_shape": [384, 17], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.Embed_0.embedding", "write_shape": [18, 384], "chunk_shape": [18, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.LayerNorm_0.bias", "write_shape": [768], "chunk_shape": [768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.LayerNorm_0.scale", "write_shape": [768], "chunk_shape": [768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.LayerNorm_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.LayerNorm_1.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_0.Dense_0.bias", "write_shape": [768], "chunk_shape": [768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_0.Dense_0.kernel", "write_shape": [384, 768], "chunk_shape": [384, 768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_0.Dense_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_0.Dense_1.kernel", "write_shape": [768, 384], "chunk_shape": [768, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_0.LayerNorm_0.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_0.LayerNorm_0.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_0.LayerNorm_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_0.LayerNorm_1.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_0.MultiHeadDotProductAttention_0.key.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_0.MultiHeadDotProductAttention_0.key.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_0.MultiHeadDotProductAttention_0.out.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_0.MultiHeadDotProductAttention_0.out.kernel", "write_shape": [8, 48, 384], "chunk_shape": [8, 48, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_0.MultiHeadDotProductAttention_0.query.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_0.MultiHeadDotProductAttention_0.query.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_0.MultiHeadDotProductAttention_0.value.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_0.MultiHeadDotProductAttention_0.value.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_1.Dense_0.bias", "write_shape": [768], "chunk_shape": [768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_1.Dense_0.kernel", "write_shape": [384, 768], "chunk_shape": [384, 768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_1.Dense_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_1.Dense_1.kernel", "write_shape": [768, 384], "chunk_shape": [768, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_1.LayerNorm_0.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_1.LayerNorm_0.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_1.LayerNorm_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_1.LayerNorm_1.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_1.MultiHeadDotProductAttention_0.key.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_1.MultiHeadDotProductAttention_0.key.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_1.MultiHeadDotProductAttention_0.out.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_1.MultiHeadDotProductAttention_0.out.kernel", "write_shape": [8, 48, 384], "chunk_shape": [8, 48, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_1.MultiHeadDotProductAttention_0.query.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_1.MultiHeadDotProductAttention_0.query.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_1.MultiHeadDotProductAttention_0.value.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_1.MultiHeadDotProductAttention_0.value.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_2.Dense_0.bias", "write_shape": [768], "chunk_shape": [768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_2.Dense_0.kernel", "write_shape": [384, 768], "chunk_shape": [384, 768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_2.Dense_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_2.Dense_1.kernel", "write_shape": [768, 384], "chunk_shape": [768, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_2.LayerNorm_0.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_2.LayerNorm_0.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_2.LayerNorm_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_2.LayerNorm_1.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_2.MultiHeadDotProductAttention_0.key.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_2.MultiHeadDotProductAttention_0.key.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_2.MultiHeadDotProductAttention_0.out.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_2.MultiHeadDotProductAttention_0.out.kernel", "write_shape": [8, 48, 384], "chunk_shape": [8, 48, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_2.MultiHeadDotProductAttention_0.query.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_2.MultiHeadDotProductAttention_0.query.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_2.MultiHeadDotProductAttention_0.value.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_2.MultiHeadDotProductAttention_0.value.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_3.Dense_0.bias", "write_shape": [768], "chunk_shape": [768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_3.Dense_0.kernel", "write_shape": [384, 768], "chunk_shape": [384, 768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_3.Dense_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_3.Dense_1.kernel", "write_shape": [768, 384], "chunk_shape": [768, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_3.LayerNorm_0.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_3.LayerNorm_0.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_3.LayerNorm_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_3.LayerNorm_1.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_3.MultiHeadDotProductAttention_0.key.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_3.MultiHeadDotProductAttention_0.key.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_3.MultiHeadDotProductAttention_0.out.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_3.MultiHeadDotProductAttention_0.out.kernel", "write_shape": [8, 48, 384], "chunk_shape": [8, 48, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_3.MultiHeadDotProductAttention_0.query.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_3.MultiHeadDotProductAttention_0.query.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_3.MultiHeadDotProductAttention_0.value.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_3.MultiHeadDotProductAttention_0.value.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_4.Dense_0.bias", "write_shape": [768], "chunk_shape": [768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_4.Dense_0.kernel", "write_shape": [384, 768], "chunk_shape": [384, 768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_4.Dense_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_4.Dense_1.kernel", "write_shape": [768, 384], "chunk_shape": [768, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_4.LayerNorm_0.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_4.LayerNorm_0.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_4.LayerNorm_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_4.LayerNorm_1.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_4.MultiHeadDotProductAttention_0.key.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_4.MultiHeadDotProductAttention_0.key.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_4.MultiHeadDotProductAttention_0.out.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_4.MultiHeadDotProductAttention_0.out.kernel", "write_shape": [8, 48, 384], "chunk_shape": [8, 48, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_4.MultiHeadDotProductAttention_0.query.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_4.MultiHeadDotProductAttention_0.query.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_4.MultiHeadDotProductAttention_0.value.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_4.MultiHeadDotProductAttention_0.value.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_5.Dense_0.bias", "write_shape": [768], "chunk_shape": [768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_5.Dense_0.kernel", "write_shape": [384, 768], "chunk_shape": [384, 768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_5.Dense_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_5.Dense_1.kernel", "write_shape": [768, 384], "chunk_shape": [768, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_5.LayerNorm_0.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_5.LayerNorm_0.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_5.LayerNorm_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_5.LayerNorm_1.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_5.MultiHeadDotProductAttention_0.key.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_5.MultiHeadDotProductAttention_0.key.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_5.MultiHeadDotProductAttention_0.out.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_5.MultiHeadDotProductAttention_0.out.kernel", "write_shape": [8, 48, 384], "chunk_shape": [8, 48, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_5.MultiHeadDotProductAttention_0.query.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_5.MultiHeadDotProductAttention_0.query.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_5.MultiHeadDotProductAttention_0.value.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_5.MultiHeadDotProductAttention_0.value.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.Dense_0.bias", "write_shape": [768], "chunk_shape": [768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.Dense_0.kernel", "write_shape": [1345, 768], "chunk_shape": [1345, 768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.Dense_1.bias", "write_shape": [768], "chunk_shape": [768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.Dense_1.kernel", "write_shape": [768, 768], "chunk_shape": [768, 768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.Dense_2.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.Dense_2.kernel", "write_shape": [768, 384], "chunk_shape": [768, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.Dense_3.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.Dense_3.kernel", "write_shape": [384, 384], "chunk_shape": [384, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.Dense_4.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.Dense_4.kernel", "write_shape": [384, 384], "chunk_shape": [384, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.Dense_5.bias", "write_shape": [17], "chunk_shape": [17], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.Dense_5.kernel", "write_shape": [384, 17], "chunk_shape": [384, 17], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.Embed_0.embedding", "write_shape": [18, 384], "chunk_shape": [18, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.LayerNorm_0.bias", "write_shape": [768], "chunk_shape": [768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.LayerNorm_0.scale", "write_shape": [768], "chunk_shape": [768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.LayerNorm_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.LayerNorm_1.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_0.Dense_0.bias", "write_shape": [768], "chunk_shape": [768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_0.Dense_0.kernel", "write_shape": [384, 768], "chunk_shape": [384, 768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_0.Dense_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_0.Dense_1.kernel", "write_shape": [768, 384], "chunk_shape": [768, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_0.LayerNorm_0.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_0.LayerNorm_0.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_0.LayerNorm_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_0.LayerNorm_1.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_0.MultiHeadDotProductAttention_0.key.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_0.MultiHeadDotProductAttention_0.key.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_0.MultiHeadDotProductAttention_0.out.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_0.MultiHeadDotProductAttention_0.out.kernel", "write_shape": [8, 48, 384], "chunk_shape": [8, 48, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_0.MultiHeadDotProductAttention_0.query.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_0.MultiHeadDotProductAttention_0.query.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_0.MultiHeadDotProductAttention_0.value.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_0.MultiHeadDotProductAttention_0.value.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_1.Dense_0.bias", "write_shape": [768], "chunk_shape": [768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_1.Dense_0.kernel", "write_shape": [384, 768], "chunk_shape": [384, 768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_1.Dense_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_1.Dense_1.kernel", "write_shape": [768, 384], "chunk_shape": [768, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_1.LayerNorm_0.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_1.LayerNorm_0.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_1.LayerNorm_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_1.LayerNorm_1.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_1.MultiHeadDotProductAttention_0.key.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_1.MultiHeadDotProductAttention_0.key.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_1.MultiHeadDotProductAttention_0.out.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_1.MultiHeadDotProductAttention_0.out.kernel", "write_shape": [8, 48, 384], "chunk_shape": [8, 48, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_1.MultiHeadDotProductAttention_0.query.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_1.MultiHeadDotProductAttention_0.query.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_1.MultiHeadDotProductAttention_0.value.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_1.MultiHeadDotProductAttention_0.value.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_2.Dense_0.bias", "write_shape": [768], "chunk_shape": [768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_2.Dense_0.kernel", "write_shape": [384, 768], "chunk_shape": [384, 768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_2.Dense_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_2.Dense_1.kernel", "write_shape": [768, 384], "chunk_shape": [768, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_2.LayerNorm_0.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_2.LayerNorm_0.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_2.LayerNorm_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_2.LayerNorm_1.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_2.MultiHeadDotProductAttention_0.key.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_2.MultiHeadDotProductAttention_0.key.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_2.MultiHeadDotProductAttention_0.out.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_2.MultiHeadDotProductAttention_0.out.kernel", "write_shape": [8, 48, 384], "chunk_shape": [8, 48, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_2.MultiHeadDotProductAttention_0.query.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_2.MultiHeadDotProductAttention_0.query.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_2.MultiHeadDotProductAttention_0.value.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_2.MultiHeadDotProductAttention_0.value.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_3.Dense_0.bias", "write_shape": [768], "chunk_shape": [768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_3.Dense_0.kernel", "write_shape": [384, 768], "chunk_shape": [384, 768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_3.Dense_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_3.Dense_1.kernel", "write_shape": [768, 384], "chunk_shape": [768, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_3.LayerNorm_0.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_3.LayerNorm_0.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_3.LayerNorm_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_3.LayerNorm_1.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_3.MultiHeadDotProductAttention_0.key.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_3.MultiHeadDotProductAttention_0.key.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_3.MultiHeadDotProductAttention_0.out.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_3.MultiHeadDotProductAttention_0.out.kernel", "write_shape": [8, 48, 384], "chunk_shape": [8, 48, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_3.MultiHeadDotProductAttention_0.query.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_3.MultiHeadDotProductAttention_0.query.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_3.MultiHeadDotProductAttention_0.value.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_3.MultiHeadDotProductAttention_0.value.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_4.Dense_0.bias", "write_shape": [768], "chunk_shape": [768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_4.Dense_0.kernel", "write_shape": [384, 768], "chunk_shape": [384, 768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_4.Dense_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_4.Dense_1.kernel", "write_shape": [768, 384], "chunk_shape": [768, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_4.LayerNorm_0.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_4.LayerNorm_0.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_4.LayerNorm_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_4.LayerNorm_1.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_4.MultiHeadDotProductAttention_0.key.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_4.MultiHeadDotProductAttention_0.key.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_4.MultiHeadDotProductAttention_0.out.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_4.MultiHeadDotProductAttention_0.out.kernel", "write_shape": [8, 48, 384], "chunk_shape": [8, 48, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_4.MultiHeadDotProductAttention_0.query.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_4.MultiHeadDotProductAttention_0.query.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_4.MultiHeadDotProductAttention_0.value.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_4.MultiHeadDotProductAttention_0.value.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_5.Dense_0.bias", "write_shape": [768], "chunk_shape": [768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_5.Dense_0.kernel", "write_shape": [384, 768], "chunk_shape": [384, 768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_5.Dense_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_5.Dense_1.kernel", "write_shape": [768, 384], "chunk_shape": [768, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_5.LayerNorm_0.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_5.LayerNorm_0.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_5.LayerNorm_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_5.LayerNorm_1.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_5.MultiHeadDotProductAttention_0.key.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_5.MultiHeadDotProductAttention_0.key.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_5.MultiHeadDotProductAttention_0.out.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_5.MultiHeadDotProductAttention_0.out.kernel", "write_shape": [8, 48, 384], "chunk_shape": [8, 48, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_5.MultiHeadDotProductAttention_0.query.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_5.MultiHeadDotProductAttention_0.query.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_5.MultiHeadDotProductAttention_0.value.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_5.MultiHeadDotProductAttention_0.value.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.1.count", "write_shape": [], "chunk_shape": [], "ext_metadata": null}}]}
|
checkpoints/offline/Craftax-Classic-Symbolic-v1-OfflineDiffusion-BC-100M/100000000/default/d/021af9ba431a3072f4819480f91b83af
ADDED
|
Binary file (3.72 kB). View file
|
|
|
checkpoints/offline/Craftax-Classic-Symbolic-v1-OfflineDiffusion-BC-100M/100000000/default/manifest.ocdbt
ADDED
|
Binary file (117 Bytes). View file
|
|
|
checkpoints/offline/Craftax-Classic-Symbolic-v1-OfflineDiffusion-BC-100M/100000000/default/ocdbt.process_0/d/0cf9a08a9722f9b8a0b7f007da7c1e92
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:bbff61a18e9475d72fae302d4748615daf5fc6b87cc0e0a338c96b8a781d6c0f
|
| 3 |
+
size 101199872
|
checkpoints/offline/Craftax-Classic-Symbolic-v1-OfflineDiffusion-BC-100M/100000000/default/ocdbt.process_0/d/1968eb861d84503c0e805cffdd77528a
ADDED
|
Binary file (832 Bytes). View file
|
|
|
checkpoints/offline/Craftax-Classic-Symbolic-v1-OfflineDiffusion-BC-100M/100000000/default/ocdbt.process_0/d/aec20934c03229d1bd9651c955e59d84
ADDED
|
Binary file (171 Bytes). View file
|
|
|
checkpoints/offline/Craftax-Classic-Symbolic-v1-OfflineDiffusion-BC-100M/100000000/default/ocdbt.process_0/d/d32b0ea672fe7a9b86b8e62e7c20dbaf
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:66e7df58a5ad39030e5631943ffa5d45164b91f283a2b7b34d4265c6bbf08be4
|
| 3 |
+
size 448037
|
checkpoints/offline/Craftax-Classic-Symbolic-v1-OfflineDiffusion-BC-100M/100000000/default/ocdbt.process_0/manifest.ocdbt
ADDED
|
Binary file (259 Bytes). View file
|
|
|
checkpoints/offline/Craftax-Classic-Symbolic-v1-OfflineDiffusion-BC-100M/resume_metadata.json
ADDED
|
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"mode": "offline",
|
| 3 |
+
"update_step": 1525,
|
| 4 |
+
"total_gradient_steps_completed": 97600,
|
| 5 |
+
"wandb_run_id": "6opvce2t",
|
| 6 |
+
"config_snapshot": {
|
| 7 |
+
"ENV_NAME": "Craftax-Classic-Symbolic-v1",
|
| 8 |
+
"USE_OPTIMISTIC_RESETS": false,
|
| 9 |
+
"OPTIMISTIC_RESET_RATIO": 16,
|
| 10 |
+
"D_MODEL": 384,
|
| 11 |
+
"N_HEADS": 8,
|
| 12 |
+
"N_LAYERS": 6,
|
| 13 |
+
"D_FF": 768,
|
| 14 |
+
"OBS_ENCODER_LAYERS": 2,
|
| 15 |
+
"OBS_ENCODER_WIDTH": 768,
|
| 16 |
+
"DROPOUT_RATE": 0.1,
|
| 17 |
+
"PLAN_HORIZON": 32,
|
| 18 |
+
"DIFFUSION_SCHEDULE": "cosine",
|
| 19 |
+
"TRAIN_SIGMA": 0.0,
|
| 20 |
+
"LABEL_SMOOTHING": 0.0,
|
| 21 |
+
"DIFFUSION_STEPS": 15,
|
| 22 |
+
"DIFFUSION_STEPS_EVAL": 10,
|
| 23 |
+
"REMASK_STRATEGY": "rescale",
|
| 24 |
+
"ETA": 0.5,
|
| 25 |
+
"USE_LOOP": true,
|
| 26 |
+
"T_ON": 0.7,
|
| 27 |
+
"T_OFF": 0.3,
|
| 28 |
+
"TEMPERATURE": 0.5,
|
| 29 |
+
"TOP_P": 0.95,
|
| 30 |
+
"LR": 0.0003,
|
| 31 |
+
"MAX_GRAD_NORM": 1.0,
|
| 32 |
+
"LR_WARMUP_FRAMES": "1.048576e8",
|
| 33 |
+
"NUM_ENVS": 512,
|
| 34 |
+
"NUM_STEPS": 128,
|
| 35 |
+
"NUM_MINIBATCHES": 8,
|
| 36 |
+
"UPDATE_EPOCHS": 8,
|
| 37 |
+
"NUM_REPEATS": 1,
|
| 38 |
+
"OFFLINE_TOTAL_TIMESTEPS": 99942400,
|
| 39 |
+
"COLLECT_TEMPERATURE": 1.0,
|
| 40 |
+
"RETURN_WEIGHT_CAP": 5.0,
|
| 41 |
+
"ONLINE_TOTAL_TIMESTEPS": 100000000.0,
|
| 42 |
+
"DAGGER_BETA_INIT": 1.0,
|
| 43 |
+
"DAGGER_BETA_FINAL": 0.344,
|
| 44 |
+
"DAGGER_BUFFER_CYCLES": 1.90735,
|
| 45 |
+
"VAL_INTERVAL_FRAMES": 1000000.0,
|
| 46 |
+
"VAL_DIFFUSION_STEPS": 50,
|
| 47 |
+
"VAL_REPLAN_EVERY": 4,
|
| 48 |
+
"VAL_STEPS": 256,
|
| 49 |
+
"COLLECT_NUM_STEPS": 10000000,
|
| 50 |
+
"COLLECT_NUM_ENVS": 128,
|
| 51 |
+
"PPO_MODEL_TYPE": "ppo_rnn",
|
| 52 |
+
"LAYER_SIZE": 512,
|
| 53 |
+
"EVAL_STEPS": 10000,
|
| 54 |
+
"EVAL_NUM_ENVS": 32,
|
| 55 |
+
"SAVE_POLICY": true,
|
| 56 |
+
"SEED": 42,
|
| 57 |
+
"USE_WANDB": true,
|
| 58 |
+
"WANDB_PROJECT": "remdm-craftax",
|
| 59 |
+
"WANDB_ENTITY": "mathis-weil-university-college-london-ucl-",
|
| 60 |
+
"MODE": "offline",
|
| 61 |
+
"JIT": true,
|
| 62 |
+
"PPO_CHECKPOINT_PATH": "checkpoints/ppo_agents/policies/Craftax-Classic-Symbolic-v1-PPO_RNN-1000M",
|
| 63 |
+
"NUM_UPDATES": 1525,
|
| 64 |
+
"LR_WARMUP_STEPS": 1600,
|
| 65 |
+
"VAL_INTERVAL": 15,
|
| 66 |
+
"MINIBATCH_SIZE": 6208
|
| 67 |
+
}
|
| 68 |
+
}
|
checkpoints/online/Craftax-Classic-Symbolic-v1-OnlineDiffusion-DAgger-50M/50000000/_CHECKPOINT_METADATA
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
{"item_handlers": {"default": "orbax.checkpoint._src.handlers.standard_checkpoint_handler.StandardCheckpointHandler"}, "metrics": {}, "performance_metrics": {}, "init_timestamp_nsecs": 1775623858059636986, "commit_timestamp_nsecs": 1775623858516125466, "custom_metadata": {}}
|
checkpoints/online/Craftax-Classic-Symbolic-v1-OnlineDiffusion-DAgger-50M/50000000/default/_METADATA
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
checkpoints/online/Craftax-Classic-Symbolic-v1-OnlineDiffusion-DAgger-50M/50000000/default/_sharding
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
{"b3B0X3N0YXRlLjEuMC5jb3VudA==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuRGVuc2VfMC5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuRGVuc2VfMC5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuRGVuc2VfMS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuRGVuc2VfMS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuRGVuc2VfMi5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuRGVuc2VfMi5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuRGVuc2VfMy5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuRGVuc2VfMy5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuRGVuc2VfNC5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuRGVuc2VfNC5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuRGVuc2VfNS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuRGVuc2VfNS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuRW1iZWRfMC5lbWJlZGRpbmc=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuTGF5ZXJOb3JtXzAuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuTGF5ZXJOb3JtXzAuc2NhbGU=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuTGF5ZXJOb3JtXzEuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuTGF5ZXJOb3JtXzEuc2NhbGU=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja180Lk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC52YWx1ZS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja180Lk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC52YWx1ZS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja180Lk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5rZXkuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja180Lk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5rZXkua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja180Lk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5vdXQuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja180Lk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5vdXQua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja180Lk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5xdWVyeS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja180Lk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5xdWVyeS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja180LkRlbnNlXzAuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja180LkRlbnNlXzAua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja180LkRlbnNlXzEuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja180LkRlbnNlXzEua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja180LkxheWVyTm9ybV8wLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja180LkxheWVyTm9ybV8wLnNjYWxl":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja180LkxheWVyTm9ybV8xLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja180LkxheWVyTm9ybV8xLnNjYWxl":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja181Lk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC52YWx1ZS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja181Lk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC52YWx1ZS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja181Lk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5rZXkuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja181Lk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5rZXkua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja181Lk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5vdXQuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja181Lk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5vdXQua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja181Lk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5xdWVyeS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja181Lk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5xdWVyeS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja181LkRlbnNlXzAuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja181LkRlbnNlXzAua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja181LkRlbnNlXzEuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja181LkRlbnNlXzEua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja181LkxheWVyTm9ybV8wLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja181LkxheWVyTm9ybV8wLnNjYWxl":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja181LkxheWVyTm9ybV8xLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja181LkxheWVyTm9ybV8xLnNjYWxl":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18wLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC52YWx1ZS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18wLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC52YWx1ZS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18wLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5rZXkuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18wLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5rZXkua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18wLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5vdXQuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18wLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5vdXQua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18wLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5xdWVyeS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18wLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5xdWVyeS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18wLkRlbnNlXzAuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18wLkRlbnNlXzAua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18wLkRlbnNlXzEuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18wLkRlbnNlXzEua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18wLkxheWVyTm9ybV8wLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18wLkxheWVyTm9ybV8wLnNjYWxl":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18wLkxheWVyTm9ybV8xLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18wLkxheWVyTm9ybV8xLnNjYWxl":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18xLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC52YWx1ZS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18xLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC52YWx1ZS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18xLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5rZXkuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18xLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5rZXkua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18xLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5vdXQuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18xLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5vdXQua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18xLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5xdWVyeS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18xLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5xdWVyeS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18xLkRlbnNlXzAuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18xLkRlbnNlXzAua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18xLkRlbnNlXzEuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18xLkRlbnNlXzEua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18xLkxheWVyTm9ybV8wLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18xLkxheWVyTm9ybV8wLnNjYWxl":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18xLkxheWVyTm9ybV8xLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18xLkxheWVyTm9ybV8xLnNjYWxl":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18yLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC52YWx1ZS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18yLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC52YWx1ZS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18yLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5rZXkuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18yLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5rZXkua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18yLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5vdXQuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18yLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5vdXQua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18yLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5xdWVyeS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18yLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5xdWVyeS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18yLkRlbnNlXzAuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18yLkRlbnNlXzAua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18yLkRlbnNlXzEuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18yLkRlbnNlXzEua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18yLkxheWVyTm9ybV8wLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18yLkxheWVyTm9ybV8wLnNjYWxl":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18yLkxheWVyTm9ybV8xLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18yLkxheWVyTm9ybV8xLnNjYWxl":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18zLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC52YWx1ZS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18zLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC52YWx1ZS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18zLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5rZXkuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18zLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5rZXkua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18zLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5vdXQuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18zLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5vdXQua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18zLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5xdWVyeS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18zLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5xdWVyeS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18zLkRlbnNlXzAuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18zLkRlbnNlXzAua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18zLkRlbnNlXzEuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18zLkRlbnNlXzEua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18zLkxheWVyTm9ybV8wLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18zLkxheWVyTm9ybV8wLnNjYWxl":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18zLkxheWVyTm9ybV8xLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18zLkxheWVyTm9ybV8xLnNjYWxl":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuRGVuc2VfMC5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuRGVuc2VfMC5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuRGVuc2VfMS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuRGVuc2VfMS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuRGVuc2VfMi5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuRGVuc2VfMi5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuRGVuc2VfMy5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuRGVuc2VfMy5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuRGVuc2VfNC5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuRGVuc2VfNC5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuRGVuc2VfNS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuRGVuc2VfNS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuRW1iZWRfMC5lbWJlZGRpbmc=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuTGF5ZXJOb3JtXzAuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuTGF5ZXJOb3JtXzAuc2NhbGU=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuTGF5ZXJOb3JtXzEuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuTGF5ZXJOb3JtXzEuc2NhbGU=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja180Lk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC52YWx1ZS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja180Lk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC52YWx1ZS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja180Lk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5rZXkuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja180Lk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5rZXkua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja180Lk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5vdXQuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja180Lk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5vdXQua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja180Lk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5xdWVyeS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja180Lk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5xdWVyeS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja180LkRlbnNlXzAuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja180LkRlbnNlXzAua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja180LkRlbnNlXzEuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja180LkRlbnNlXzEua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja180LkxheWVyTm9ybV8wLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja180LkxheWVyTm9ybV8wLnNjYWxl":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja180LkxheWVyTm9ybV8xLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja180LkxheWVyTm9ybV8xLnNjYWxl":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja181Lk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC52YWx1ZS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja181Lk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC52YWx1ZS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja181Lk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5rZXkuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja181Lk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5rZXkua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja181Lk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5vdXQuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja181Lk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5vdXQua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja181Lk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5xdWVyeS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja181Lk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5xdWVyeS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja181LkRlbnNlXzAuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja181LkRlbnNlXzAua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja181LkRlbnNlXzEuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja181LkRlbnNlXzEua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja181LkxheWVyTm9ybV8wLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja181LkxheWVyTm9ybV8wLnNjYWxl":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja181LkxheWVyTm9ybV8xLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja181LkxheWVyTm9ybV8xLnNjYWxl":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18wLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC52YWx1ZS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18wLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC52YWx1ZS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18wLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5rZXkuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18wLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5rZXkua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18wLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5vdXQuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18wLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5vdXQua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18wLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5xdWVyeS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18wLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5xdWVyeS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18wLkRlbnNlXzAuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18wLkRlbnNlXzAua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18wLkRlbnNlXzEuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18wLkRlbnNlXzEua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18wLkxheWVyTm9ybV8wLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18wLkxheWVyTm9ybV8wLnNjYWxl":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18wLkxheWVyTm9ybV8xLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18wLkxheWVyTm9ybV8xLnNjYWxl":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18xLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC52YWx1ZS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18xLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC52YWx1ZS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18xLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5rZXkuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18xLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5rZXkua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18xLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5vdXQuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18xLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5vdXQua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18xLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5xdWVyeS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18xLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5xdWVyeS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18xLkRlbnNlXzAuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18xLkRlbnNlXzAua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18xLkRlbnNlXzEuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18xLkRlbnNlXzEua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18xLkxheWVyTm9ybV8wLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18xLkxheWVyTm9ybV8wLnNjYWxl":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18xLkxheWVyTm9ybV8xLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18xLkxheWVyTm9ybV8xLnNjYWxl":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18yLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC52YWx1ZS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18yLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC52YWx1ZS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18yLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5rZXkuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18yLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5rZXkua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18yLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5vdXQuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18yLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5vdXQua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18yLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5xdWVyeS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18yLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5xdWVyeS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18yLkRlbnNlXzAuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18yLkRlbnNlXzAua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18yLkRlbnNlXzEuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18yLkRlbnNlXzEua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18yLkxheWVyTm9ybV8wLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18yLkxheWVyTm9ybV8wLnNjYWxl":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18yLkxheWVyTm9ybV8xLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18yLkxheWVyTm9ybV8xLnNjYWxl":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18zLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC52YWx1ZS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18zLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC52YWx1ZS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18zLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5rZXkuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18zLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5rZXkua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18zLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5vdXQuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18zLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5vdXQua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18zLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5xdWVyeS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18zLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5xdWVyeS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18zLkRlbnNlXzAuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18zLkRlbnNlXzAua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18zLkRlbnNlXzEuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18zLkRlbnNlXzEua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18zLkxheWVyTm9ybV8wLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18zLkxheWVyTm9ybV8wLnNjYWxl":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18zLkxheWVyTm9ybV8xLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18zLkxheWVyTm9ybV8xLnNjYWxl":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5EZW5zZV80LmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5EZW5zZV80Lmtlcm5lbA==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5EZW5zZV81LmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5EZW5zZV81Lmtlcm5lbA==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5EZW5zZV8wLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5EZW5zZV8wLmtlcm5lbA==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5EZW5zZV8xLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5EZW5zZV8xLmtlcm5lbA==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5EZW5zZV8yLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5EZW5zZV8yLmtlcm5lbA==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5EZW5zZV8zLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5EZW5zZV8zLmtlcm5lbA==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5FbWJlZF8wLmVtYmVkZGluZw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5MYXllck5vcm1fMC5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5MYXllck5vcm1fMC5zY2FsZQ==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5MYXllck5vcm1fMS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5MYXllck5vcm1fMS5zY2FsZQ==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzAuRGVuc2VfMC5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzAuRGVuc2VfMC5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzAuRGVuc2VfMS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzAuRGVuc2VfMS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzAuTGF5ZXJOb3JtXzAuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzAuTGF5ZXJOb3JtXzAuc2NhbGU=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzAuTGF5ZXJOb3JtXzEuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzAuTGF5ZXJOb3JtXzEuc2NhbGU=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzAuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLm91dC5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzAuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLm91dC5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzAuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLmtleS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzAuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLmtleS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzAuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLnF1ZXJ5LmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzAuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLnF1ZXJ5Lmtlcm5lbA==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzAuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLnZhbHVlLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzAuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLnZhbHVlLmtlcm5lbA==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzEuRGVuc2VfMC5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzEuRGVuc2VfMC5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzEuRGVuc2VfMS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzEuRGVuc2VfMS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzEuTGF5ZXJOb3JtXzAuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzEuTGF5ZXJOb3JtXzAuc2NhbGU=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzEuTGF5ZXJOb3JtXzEuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzEuTGF5ZXJOb3JtXzEuc2NhbGU=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzEuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLm91dC5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzEuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLm91dC5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzEuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLmtleS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzEuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLmtleS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzEuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLnF1ZXJ5LmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzEuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLnF1ZXJ5Lmtlcm5lbA==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzEuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLnZhbHVlLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzEuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLnZhbHVlLmtlcm5lbA==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzIuRGVuc2VfMC5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzIuRGVuc2VfMC5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzIuRGVuc2VfMS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzIuRGVuc2VfMS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzIuTGF5ZXJOb3JtXzAuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzIuTGF5ZXJOb3JtXzAuc2NhbGU=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzIuTGF5ZXJOb3JtXzEuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzIuTGF5ZXJOb3JtXzEuc2NhbGU=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzIuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLm91dC5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzIuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLm91dC5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzIuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLmtleS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzIuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLmtleS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzIuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLnF1ZXJ5LmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzIuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLnF1ZXJ5Lmtlcm5lbA==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzIuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLnZhbHVlLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzIuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLnZhbHVlLmtlcm5lbA==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzMuRGVuc2VfMC5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzMuRGVuc2VfMC5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzMuRGVuc2VfMS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzMuRGVuc2VfMS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzMuTGF5ZXJOb3JtXzAuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzMuTGF5ZXJOb3JtXzAuc2NhbGU=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzMuTGF5ZXJOb3JtXzEuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzMuTGF5ZXJOb3JtXzEuc2NhbGU=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzMuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLm91dC5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzMuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLm91dC5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzMuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLmtleS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzMuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLmtleS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzMuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLnF1ZXJ5LmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzMuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLnF1ZXJ5Lmtlcm5lbA==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzMuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLnZhbHVlLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzMuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLnZhbHVlLmtlcm5lbA==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzQuRGVuc2VfMC5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzQuRGVuc2VfMC5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzQuRGVuc2VfMS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzQuRGVuc2VfMS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzQuTGF5ZXJOb3JtXzAuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzQuTGF5ZXJOb3JtXzAuc2NhbGU=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzQuTGF5ZXJOb3JtXzEuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzQuTGF5ZXJOb3JtXzEuc2NhbGU=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzQuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLm91dC5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzQuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLm91dC5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzQuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLmtleS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzQuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLmtleS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzQuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLnF1ZXJ5LmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzQuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLnF1ZXJ5Lmtlcm5lbA==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzQuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLnZhbHVlLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzQuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLnZhbHVlLmtlcm5lbA==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzUuRGVuc2VfMC5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzUuRGVuc2VfMC5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzUuRGVuc2VfMS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzUuRGVuc2VfMS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzUuTGF5ZXJOb3JtXzAuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzUuTGF5ZXJOb3JtXzAuc2NhbGU=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzUuTGF5ZXJOb3JtXzEuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzUuTGF5ZXJOb3JtXzEuc2NhbGU=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzUuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLm91dC5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzUuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLm91dC5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzUuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLmtleS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzUuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLmtleS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzUuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLnF1ZXJ5LmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzUuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLnF1ZXJ5Lmtlcm5lbA==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzUuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLnZhbHVlLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzUuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLnZhbHVlLmtlcm5lbA==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}"}
|
checkpoints/online/Craftax-Classic-Symbolic-v1-OnlineDiffusion-DAgger-50M/50000000/default/array_metadatas/process_0
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
{"array_metadatas": [{"array_metadata": {"param_name": "params.params.Dense_0.bias", "write_shape": [768], "chunk_shape": [768], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.Dense_0.kernel", "write_shape": [1345, 768], "chunk_shape": [1345, 768], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.Dense_1.bias", "write_shape": [768], "chunk_shape": [768], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.Dense_1.kernel", "write_shape": [768, 768], "chunk_shape": [768, 768], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.Dense_2.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.Dense_2.kernel", "write_shape": [768, 384], "chunk_shape": [768, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.Dense_3.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.Dense_3.kernel", "write_shape": [384, 384], "chunk_shape": [384, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.Dense_4.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.Dense_4.kernel", "write_shape": [384, 384], "chunk_shape": [384, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.Dense_5.bias", "write_shape": [17], "chunk_shape": [17], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.Dense_5.kernel", "write_shape": [384, 17], "chunk_shape": [384, 17], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.Embed_0.embedding", "write_shape": [18, 384], "chunk_shape": [18, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.LayerNorm_0.bias", "write_shape": [768], "chunk_shape": [768], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.LayerNorm_0.scale", "write_shape": [768], "chunk_shape": [768], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.LayerNorm_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.LayerNorm_1.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_0.Dense_0.bias", "write_shape": [768], "chunk_shape": [768], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_0.Dense_0.kernel", "write_shape": [384, 768], "chunk_shape": [384, 768], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_0.Dense_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_0.Dense_1.kernel", "write_shape": [768, 384], "chunk_shape": [768, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_0.LayerNorm_0.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_0.LayerNorm_0.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_0.LayerNorm_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_0.LayerNorm_1.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_0.MultiHeadDotProductAttention_0.key.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_0.MultiHeadDotProductAttention_0.key.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_0.MultiHeadDotProductAttention_0.out.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_0.MultiHeadDotProductAttention_0.out.kernel", "write_shape": [8, 48, 384], "chunk_shape": [8, 48, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_0.MultiHeadDotProductAttention_0.query.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_0.MultiHeadDotProductAttention_0.query.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_0.MultiHeadDotProductAttention_0.value.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_0.MultiHeadDotProductAttention_0.value.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_1.Dense_0.bias", "write_shape": [768], "chunk_shape": [768], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_1.Dense_0.kernel", "write_shape": [384, 768], "chunk_shape": [384, 768], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_1.Dense_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_1.Dense_1.kernel", "write_shape": [768, 384], "chunk_shape": [768, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_1.LayerNorm_0.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_1.LayerNorm_0.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_1.LayerNorm_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_1.LayerNorm_1.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_1.MultiHeadDotProductAttention_0.key.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_1.MultiHeadDotProductAttention_0.key.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_1.MultiHeadDotProductAttention_0.out.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_1.MultiHeadDotProductAttention_0.out.kernel", "write_shape": [8, 48, 384], "chunk_shape": [8, 48, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_1.MultiHeadDotProductAttention_0.query.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_1.MultiHeadDotProductAttention_0.query.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_1.MultiHeadDotProductAttention_0.value.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_1.MultiHeadDotProductAttention_0.value.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_2.Dense_0.bias", "write_shape": [768], "chunk_shape": [768], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_2.Dense_0.kernel", "write_shape": [384, 768], "chunk_shape": [384, 768], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_2.Dense_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_2.Dense_1.kernel", "write_shape": [768, 384], "chunk_shape": [768, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_2.LayerNorm_0.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_2.LayerNorm_0.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_2.LayerNorm_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_2.LayerNorm_1.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_2.MultiHeadDotProductAttention_0.key.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_2.MultiHeadDotProductAttention_0.key.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_2.MultiHeadDotProductAttention_0.out.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_2.MultiHeadDotProductAttention_0.out.kernel", "write_shape": [8, 48, 384], "chunk_shape": [8, 48, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_2.MultiHeadDotProductAttention_0.query.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_2.MultiHeadDotProductAttention_0.query.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_2.MultiHeadDotProductAttention_0.value.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_2.MultiHeadDotProductAttention_0.value.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_3.Dense_0.bias", "write_shape": [768], "chunk_shape": [768], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_3.Dense_0.kernel", "write_shape": [384, 768], "chunk_shape": [384, 768], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_3.Dense_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_3.Dense_1.kernel", "write_shape": [768, 384], "chunk_shape": [768, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_3.LayerNorm_0.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_3.LayerNorm_0.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_3.LayerNorm_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_3.LayerNorm_1.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_3.MultiHeadDotProductAttention_0.key.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_3.MultiHeadDotProductAttention_0.key.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_3.MultiHeadDotProductAttention_0.out.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_3.MultiHeadDotProductAttention_0.out.kernel", "write_shape": [8, 48, 384], "chunk_shape": [8, 48, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_3.MultiHeadDotProductAttention_0.query.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_3.MultiHeadDotProductAttention_0.query.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_3.MultiHeadDotProductAttention_0.value.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_3.MultiHeadDotProductAttention_0.value.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_4.Dense_0.bias", "write_shape": [768], "chunk_shape": [768], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_4.Dense_0.kernel", "write_shape": [384, 768], "chunk_shape": [384, 768], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_4.Dense_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_4.Dense_1.kernel", "write_shape": [768, 384], "chunk_shape": [768, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_4.LayerNorm_0.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_4.LayerNorm_0.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_4.LayerNorm_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_4.LayerNorm_1.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_4.MultiHeadDotProductAttention_0.key.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_4.MultiHeadDotProductAttention_0.key.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_4.MultiHeadDotProductAttention_0.out.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_4.MultiHeadDotProductAttention_0.out.kernel", "write_shape": [8, 48, 384], "chunk_shape": [8, 48, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_4.MultiHeadDotProductAttention_0.query.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_4.MultiHeadDotProductAttention_0.query.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_4.MultiHeadDotProductAttention_0.value.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_4.MultiHeadDotProductAttention_0.value.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_5.Dense_0.bias", "write_shape": [768], "chunk_shape": [768], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_5.Dense_0.kernel", "write_shape": [384, 768], "chunk_shape": [384, 768], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_5.Dense_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_5.Dense_1.kernel", "write_shape": [768, 384], "chunk_shape": [768, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_5.LayerNorm_0.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_5.LayerNorm_0.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_5.LayerNorm_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_5.LayerNorm_1.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_5.MultiHeadDotProductAttention_0.key.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_5.MultiHeadDotProductAttention_0.key.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_5.MultiHeadDotProductAttention_0.out.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_5.MultiHeadDotProductAttention_0.out.kernel", "write_shape": [8, 48, 384], "chunk_shape": [8, 48, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_5.MultiHeadDotProductAttention_0.query.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_5.MultiHeadDotProductAttention_0.query.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_5.MultiHeadDotProductAttention_0.value.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_5.MultiHeadDotProductAttention_0.value.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.count", "write_shape": [], "chunk_shape": [], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.Dense_0.bias", "write_shape": [768], "chunk_shape": [768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.Dense_0.kernel", "write_shape": [1345, 768], "chunk_shape": [1345, 768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.Dense_1.bias", "write_shape": [768], "chunk_shape": [768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.Dense_1.kernel", "write_shape": [768, 768], "chunk_shape": [768, 768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.Dense_2.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.Dense_2.kernel", "write_shape": [768, 384], "chunk_shape": [768, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.Dense_3.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.Dense_3.kernel", "write_shape": [384, 384], "chunk_shape": [384, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.Dense_4.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.Dense_4.kernel", "write_shape": [384, 384], "chunk_shape": [384, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.Dense_5.bias", "write_shape": [17], "chunk_shape": [17], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.Dense_5.kernel", "write_shape": [384, 17], "chunk_shape": [384, 17], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.Embed_0.embedding", "write_shape": [18, 384], "chunk_shape": [18, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.LayerNorm_0.bias", "write_shape": [768], "chunk_shape": [768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.LayerNorm_0.scale", "write_shape": [768], "chunk_shape": [768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.LayerNorm_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.LayerNorm_1.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_0.Dense_0.bias", "write_shape": [768], "chunk_shape": [768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_0.Dense_0.kernel", "write_shape": [384, 768], "chunk_shape": [384, 768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_0.Dense_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_0.Dense_1.kernel", "write_shape": [768, 384], "chunk_shape": [768, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_0.LayerNorm_0.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_0.LayerNorm_0.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_0.LayerNorm_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_0.LayerNorm_1.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_0.MultiHeadDotProductAttention_0.key.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_0.MultiHeadDotProductAttention_0.key.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_0.MultiHeadDotProductAttention_0.out.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_0.MultiHeadDotProductAttention_0.out.kernel", "write_shape": [8, 48, 384], "chunk_shape": [8, 48, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_0.MultiHeadDotProductAttention_0.query.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_0.MultiHeadDotProductAttention_0.query.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_0.MultiHeadDotProductAttention_0.value.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_0.MultiHeadDotProductAttention_0.value.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_1.Dense_0.bias", "write_shape": [768], "chunk_shape": [768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_1.Dense_0.kernel", "write_shape": [384, 768], "chunk_shape": [384, 768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_1.Dense_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_1.Dense_1.kernel", "write_shape": [768, 384], "chunk_shape": [768, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_1.LayerNorm_0.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_1.LayerNorm_0.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_1.LayerNorm_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_1.LayerNorm_1.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_1.MultiHeadDotProductAttention_0.key.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_1.MultiHeadDotProductAttention_0.key.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_1.MultiHeadDotProductAttention_0.out.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_1.MultiHeadDotProductAttention_0.out.kernel", "write_shape": [8, 48, 384], "chunk_shape": [8, 48, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_1.MultiHeadDotProductAttention_0.query.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_1.MultiHeadDotProductAttention_0.query.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_1.MultiHeadDotProductAttention_0.value.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_1.MultiHeadDotProductAttention_0.value.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_2.Dense_0.bias", "write_shape": [768], "chunk_shape": [768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_2.Dense_0.kernel", "write_shape": [384, 768], "chunk_shape": [384, 768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_2.Dense_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_2.Dense_1.kernel", "write_shape": [768, 384], "chunk_shape": [768, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_2.LayerNorm_0.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_2.LayerNorm_0.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_2.LayerNorm_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_2.LayerNorm_1.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_2.MultiHeadDotProductAttention_0.key.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_2.MultiHeadDotProductAttention_0.key.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_2.MultiHeadDotProductAttention_0.out.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_2.MultiHeadDotProductAttention_0.out.kernel", "write_shape": [8, 48, 384], "chunk_shape": [8, 48, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_2.MultiHeadDotProductAttention_0.query.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_2.MultiHeadDotProductAttention_0.query.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_2.MultiHeadDotProductAttention_0.value.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_2.MultiHeadDotProductAttention_0.value.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_3.Dense_0.bias", "write_shape": [768], "chunk_shape": [768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_3.Dense_0.kernel", "write_shape": [384, 768], "chunk_shape": [384, 768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_3.Dense_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_3.Dense_1.kernel", "write_shape": [768, 384], "chunk_shape": [768, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_3.LayerNorm_0.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_3.LayerNorm_0.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_3.LayerNorm_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_3.LayerNorm_1.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_3.MultiHeadDotProductAttention_0.key.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_3.MultiHeadDotProductAttention_0.key.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_3.MultiHeadDotProductAttention_0.out.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_3.MultiHeadDotProductAttention_0.out.kernel", "write_shape": [8, 48, 384], "chunk_shape": [8, 48, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_3.MultiHeadDotProductAttention_0.query.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_3.MultiHeadDotProductAttention_0.query.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_3.MultiHeadDotProductAttention_0.value.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_3.MultiHeadDotProductAttention_0.value.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_4.Dense_0.bias", "write_shape": [768], "chunk_shape": [768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_4.Dense_0.kernel", "write_shape": [384, 768], "chunk_shape": [384, 768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_4.Dense_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_4.Dense_1.kernel", "write_shape": [768, 384], "chunk_shape": [768, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_4.LayerNorm_0.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_4.LayerNorm_0.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_4.LayerNorm_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_4.LayerNorm_1.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_4.MultiHeadDotProductAttention_0.key.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_4.MultiHeadDotProductAttention_0.key.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_4.MultiHeadDotProductAttention_0.out.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_4.MultiHeadDotProductAttention_0.out.kernel", "write_shape": [8, 48, 384], "chunk_shape": [8, 48, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_4.MultiHeadDotProductAttention_0.query.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_4.MultiHeadDotProductAttention_0.query.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_4.MultiHeadDotProductAttention_0.value.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_4.MultiHeadDotProductAttention_0.value.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_5.Dense_0.bias", "write_shape": [768], "chunk_shape": [768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_5.Dense_0.kernel", "write_shape": [384, 768], "chunk_shape": [384, 768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_5.Dense_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_5.Dense_1.kernel", "write_shape": [768, 384], "chunk_shape": [768, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_5.LayerNorm_0.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_5.LayerNorm_0.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_5.LayerNorm_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_5.LayerNorm_1.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_5.MultiHeadDotProductAttention_0.key.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_5.MultiHeadDotProductAttention_0.key.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_5.MultiHeadDotProductAttention_0.out.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_5.MultiHeadDotProductAttention_0.out.kernel", "write_shape": [8, 48, 384], "chunk_shape": [8, 48, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_5.MultiHeadDotProductAttention_0.query.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_5.MultiHeadDotProductAttention_0.query.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_5.MultiHeadDotProductAttention_0.value.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_5.MultiHeadDotProductAttention_0.value.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.Dense_0.bias", "write_shape": [768], "chunk_shape": [768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.Dense_0.kernel", "write_shape": [1345, 768], "chunk_shape": [1345, 768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.Dense_1.bias", "write_shape": [768], "chunk_shape": [768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.Dense_1.kernel", "write_shape": [768, 768], "chunk_shape": [768, 768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.Dense_2.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.Dense_2.kernel", "write_shape": [768, 384], "chunk_shape": [768, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.Dense_3.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.Dense_3.kernel", "write_shape": [384, 384], "chunk_shape": [384, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.Dense_4.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.Dense_4.kernel", "write_shape": [384, 384], "chunk_shape": [384, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.Dense_5.bias", "write_shape": [17], "chunk_shape": [17], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.Dense_5.kernel", "write_shape": [384, 17], "chunk_shape": [384, 17], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.Embed_0.embedding", "write_shape": [18, 384], "chunk_shape": [18, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.LayerNorm_0.bias", "write_shape": [768], "chunk_shape": [768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.LayerNorm_0.scale", "write_shape": [768], "chunk_shape": [768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.LayerNorm_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.LayerNorm_1.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_0.Dense_0.bias", "write_shape": [768], "chunk_shape": [768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_0.Dense_0.kernel", "write_shape": [384, 768], "chunk_shape": [384, 768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_0.Dense_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_0.Dense_1.kernel", "write_shape": [768, 384], "chunk_shape": [768, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_0.LayerNorm_0.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_0.LayerNorm_0.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_0.LayerNorm_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_0.LayerNorm_1.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_0.MultiHeadDotProductAttention_0.key.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_0.MultiHeadDotProductAttention_0.key.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_0.MultiHeadDotProductAttention_0.out.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_0.MultiHeadDotProductAttention_0.out.kernel", "write_shape": [8, 48, 384], "chunk_shape": [8, 48, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_0.MultiHeadDotProductAttention_0.query.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_0.MultiHeadDotProductAttention_0.query.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_0.MultiHeadDotProductAttention_0.value.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_0.MultiHeadDotProductAttention_0.value.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_1.Dense_0.bias", "write_shape": [768], "chunk_shape": [768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_1.Dense_0.kernel", "write_shape": [384, 768], "chunk_shape": [384, 768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_1.Dense_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_1.Dense_1.kernel", "write_shape": [768, 384], "chunk_shape": [768, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_1.LayerNorm_0.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_1.LayerNorm_0.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_1.LayerNorm_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_1.LayerNorm_1.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_1.MultiHeadDotProductAttention_0.key.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_1.MultiHeadDotProductAttention_0.key.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_1.MultiHeadDotProductAttention_0.out.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_1.MultiHeadDotProductAttention_0.out.kernel", "write_shape": [8, 48, 384], "chunk_shape": [8, 48, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_1.MultiHeadDotProductAttention_0.query.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_1.MultiHeadDotProductAttention_0.query.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_1.MultiHeadDotProductAttention_0.value.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_1.MultiHeadDotProductAttention_0.value.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_2.Dense_0.bias", "write_shape": [768], "chunk_shape": [768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_2.Dense_0.kernel", "write_shape": [384, 768], "chunk_shape": [384, 768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_2.Dense_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_2.Dense_1.kernel", "write_shape": [768, 384], "chunk_shape": [768, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_2.LayerNorm_0.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_2.LayerNorm_0.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_2.LayerNorm_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_2.LayerNorm_1.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_2.MultiHeadDotProductAttention_0.key.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_2.MultiHeadDotProductAttention_0.key.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_2.MultiHeadDotProductAttention_0.out.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_2.MultiHeadDotProductAttention_0.out.kernel", "write_shape": [8, 48, 384], "chunk_shape": [8, 48, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_2.MultiHeadDotProductAttention_0.query.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_2.MultiHeadDotProductAttention_0.query.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_2.MultiHeadDotProductAttention_0.value.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_2.MultiHeadDotProductAttention_0.value.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_3.Dense_0.bias", "write_shape": [768], "chunk_shape": [768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_3.Dense_0.kernel", "write_shape": [384, 768], "chunk_shape": [384, 768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_3.Dense_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_3.Dense_1.kernel", "write_shape": [768, 384], "chunk_shape": [768, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_3.LayerNorm_0.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_3.LayerNorm_0.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_3.LayerNorm_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_3.LayerNorm_1.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_3.MultiHeadDotProductAttention_0.key.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_3.MultiHeadDotProductAttention_0.key.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_3.MultiHeadDotProductAttention_0.out.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_3.MultiHeadDotProductAttention_0.out.kernel", "write_shape": [8, 48, 384], "chunk_shape": [8, 48, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_3.MultiHeadDotProductAttention_0.query.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_3.MultiHeadDotProductAttention_0.query.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_3.MultiHeadDotProductAttention_0.value.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_3.MultiHeadDotProductAttention_0.value.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_4.Dense_0.bias", "write_shape": [768], "chunk_shape": [768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_4.Dense_0.kernel", "write_shape": [384, 768], "chunk_shape": [384, 768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_4.Dense_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_4.Dense_1.kernel", "write_shape": [768, 384], "chunk_shape": [768, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_4.LayerNorm_0.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_4.LayerNorm_0.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_4.LayerNorm_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_4.LayerNorm_1.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_4.MultiHeadDotProductAttention_0.key.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_4.MultiHeadDotProductAttention_0.key.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_4.MultiHeadDotProductAttention_0.out.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_4.MultiHeadDotProductAttention_0.out.kernel", "write_shape": [8, 48, 384], "chunk_shape": [8, 48, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_4.MultiHeadDotProductAttention_0.query.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_4.MultiHeadDotProductAttention_0.query.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_4.MultiHeadDotProductAttention_0.value.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_4.MultiHeadDotProductAttention_0.value.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_5.Dense_0.bias", "write_shape": [768], "chunk_shape": [768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_5.Dense_0.kernel", "write_shape": [384, 768], "chunk_shape": [384, 768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_5.Dense_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_5.Dense_1.kernel", "write_shape": [768, 384], "chunk_shape": [768, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_5.LayerNorm_0.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_5.LayerNorm_0.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_5.LayerNorm_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_5.LayerNorm_1.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_5.MultiHeadDotProductAttention_0.key.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_5.MultiHeadDotProductAttention_0.key.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_5.MultiHeadDotProductAttention_0.out.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_5.MultiHeadDotProductAttention_0.out.kernel", "write_shape": [8, 48, 384], "chunk_shape": [8, 48, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_5.MultiHeadDotProductAttention_0.query.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_5.MultiHeadDotProductAttention_0.query.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_5.MultiHeadDotProductAttention_0.value.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_5.MultiHeadDotProductAttention_0.value.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}]}
|
checkpoints/online/Craftax-Classic-Symbolic-v1-OnlineDiffusion-DAgger-50M/50000000/default/d/63ff4b6b75238977cfc360704c224d86
ADDED
|
Binary file (2.22 kB). View file
|
|
|
checkpoints/online/Craftax-Classic-Symbolic-v1-OnlineDiffusion-DAgger-50M/50000000/default/manifest.ocdbt
ADDED
|
Binary file (117 Bytes). View file
|
|
|
checkpoints/online/Craftax-Classic-Symbolic-v1-OnlineDiffusion-DAgger-50M/50000000/default/ocdbt.process_0/d/63a3ac9c870d5f7eb2b74967758ee043
ADDED
|
Binary file (171 Bytes). View file
|
|
|
checkpoints/online/Craftax-Classic-Symbolic-v1-OnlineDiffusion-DAgger-50M/50000000/default/ocdbt.process_0/d/c3b086748e0ea04233c3638a3994fa30
ADDED
|
Binary file (3.77 kB). View file
|
|
|
checkpoints/online/Craftax-Classic-Symbolic-v1-OnlineDiffusion-DAgger-50M/50000000/default/ocdbt.process_0/d/db1925df3ee2d3c92bea0a9878efa5fc
ADDED
|
Binary file (832 Bytes). View file
|
|
|
checkpoints/online/Craftax-Classic-Symbolic-v1-OnlineDiffusion-DAgger-50M/50000000/default/ocdbt.process_0/d/e25dea5d414404e637e55db20175c620
ADDED
|
Binary file (214 Bytes). View file
|
|
|
checkpoints/online/Craftax-Classic-Symbolic-v1-OnlineDiffusion-DAgger-50M/50000000/default/ocdbt.process_0/d/fd4cdc0c7be44d4f518c6ccca7ad654a
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:1c27dc63cbdd625c2b62fac311fd37e14406b411ac848847ca4bd4e99f333419
|
| 3 |
+
size 34631680
|
checkpoints/online/Craftax-Classic-Symbolic-v1-OnlineDiffusion-DAgger-50M/50000000/default/ocdbt.process_0/manifest.ocdbt
ADDED
|
Binary file (302 Bytes). View file
|
|
|
checkpoints/ppo_agents/Craftax-Classic-Symbolic-v1-PPO_RNN-1000M/1000000000/_CHECKPOINT_METADATA
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
{"item_handlers": {"default": "orbax.checkpoint._src.handlers.standard_checkpoint_handler.StandardCheckpointHandler"}, "metrics": {}, "performance_metrics": {}, "init_timestamp_nsecs": 1773173340517772966, "commit_timestamp_nsecs": 1773173340998852009, "custom_metadata": {}}
|
checkpoints/ppo_agents/Craftax-Classic-Symbolic-v1-PPO_RNN-1000M/1000000000/default/_METADATA
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
{"tree_metadata": {"('step',)": {"key_metadata": [{"key": "step", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": []}}, "('params', 'params', 'Dense_0', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('params', 'params', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1345, 512]}}, "('params', 'params', 'Dense_1', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('params', 'params', 'Dense_1', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512, 512]}}, "('params', 'params', 'Dense_2', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('params', 'params', 'Dense_2', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_2", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512, 512]}}, "('params', 'params', 'Dense_3', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_3", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [17]}}, "('params', 'params', 'Dense_3', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_3", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512, 17]}}, "('params', 'params', 'Dense_4', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_4", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('params', 'params', 'Dense_4', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_4", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512, 512]}}, "('params', 'params', 'Dense_5', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_5", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('params', 'params', 'Dense_5', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_5", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512, 512]}}, "('params', 'params', 'Dense_6', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_6", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1]}}, "('params', 'params', 'Dense_6', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_6", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512, 1]}}, "('params', 'params', 'ScannedRNN_0', 'GRUCell_1', 'hn', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "ScannedRNN_0", "key_type": 2}, {"key": "GRUCell_1", "key_type": 2}, {"key": "hn", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('params', 'params', 'ScannedRNN_0', 'GRUCell_1', 'hn', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "ScannedRNN_0", "key_type": 2}, {"key": "GRUCell_1", "key_type": 2}, {"key": "hn", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512, 512]}}, "('params', 'params', 'ScannedRNN_0', 'GRUCell_1', 'hr', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "ScannedRNN_0", "key_type": 2}, {"key": "GRUCell_1", "key_type": 2}, {"key": "hr", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512, 512]}}, "('params', 'params', 'ScannedRNN_0', 'GRUCell_1', 'hz', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "ScannedRNN_0", "key_type": 2}, {"key": "GRUCell_1", "key_type": 2}, {"key": "hz", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512, 512]}}, "('params', 'params', 'ScannedRNN_0', 'GRUCell_1', 'in', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "ScannedRNN_0", "key_type": 2}, {"key": "GRUCell_1", "key_type": 2}, {"key": "in", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('params', 'params', 'ScannedRNN_0', 'GRUCell_1', 'in', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "ScannedRNN_0", "key_type": 2}, {"key": "GRUCell_1", "key_type": 2}, {"key": "in", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512, 512]}}, "('params', 'params', 'ScannedRNN_0', 'GRUCell_1', 'ir', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "ScannedRNN_0", "key_type": 2}, {"key": "GRUCell_1", "key_type": 2}, {"key": "ir", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('params', 'params', 'ScannedRNN_0', 'GRUCell_1', 'ir', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "ScannedRNN_0", "key_type": 2}, {"key": "GRUCell_1", "key_type": 2}, {"key": "ir", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512, 512]}}, "('params', 'params', 'ScannedRNN_0', 'GRUCell_1', 'iz', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "ScannedRNN_0", "key_type": 2}, {"key": "GRUCell_1", "key_type": 2}, {"key": "iz", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('params', 'params', 'ScannedRNN_0', 'GRUCell_1', 'iz', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "ScannedRNN_0", "key_type": 2}, {"key": "GRUCell_1", "key_type": 2}, {"key": "iz", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512, 512]}}, "('opt_state', '0')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "0", "key_type": 1}], "value_metadata": {"value_type": "None", "skip_deserialize": true}}, "('opt_state', '1', '0', 'count')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "count", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": []}}, "('opt_state', '1', '0', 'mu', 'params', 'Dense_0', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('opt_state', '1', '0', 'mu', 'params', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1345, 512]}}, "('opt_state', '1', '0', 'mu', 'params', 'Dense_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('opt_state', '1', '0', 'mu', 'params', 'Dense_1', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512, 512]}}, "('opt_state', '1', '0', 'mu', 'params', 'Dense_2', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('opt_state', '1', '0', 'mu', 'params', 'Dense_2', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_2", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512, 512]}}, "('opt_state', '1', '0', 'mu', 'params', 'Dense_3', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_3", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [17]}}, "('opt_state', '1', '0', 'mu', 'params', 'Dense_3', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_3", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512, 17]}}, "('opt_state', '1', '0', 'mu', 'params', 'Dense_4', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_4", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('opt_state', '1', '0', 'mu', 'params', 'Dense_4', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_4", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512, 512]}}, "('opt_state', '1', '0', 'mu', 'params', 'Dense_5', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_5", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('opt_state', '1', '0', 'mu', 'params', 'Dense_5', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_5", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512, 512]}}, "('opt_state', '1', '0', 'mu', 'params', 'Dense_6', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_6", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1]}}, "('opt_state', '1', '0', 'mu', 'params', 'Dense_6', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_6", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512, 1]}}, "('opt_state', '1', '0', 'mu', 'params', 'ScannedRNN_0', 'GRUCell_1', 'hn', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "ScannedRNN_0", "key_type": 2}, {"key": "GRUCell_1", "key_type": 2}, {"key": "hn", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('opt_state', '1', '0', 'mu', 'params', 'ScannedRNN_0', 'GRUCell_1', 'hn', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "ScannedRNN_0", "key_type": 2}, {"key": "GRUCell_1", "key_type": 2}, {"key": "hn", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512, 512]}}, "('opt_state', '1', '0', 'mu', 'params', 'ScannedRNN_0', 'GRUCell_1', 'hr', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "ScannedRNN_0", "key_type": 2}, {"key": "GRUCell_1", "key_type": 2}, {"key": "hr", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512, 512]}}, "('opt_state', '1', '0', 'mu', 'params', 'ScannedRNN_0', 'GRUCell_1', 'hz', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "ScannedRNN_0", "key_type": 2}, {"key": "GRUCell_1", "key_type": 2}, {"key": "hz", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512, 512]}}, "('opt_state', '1', '0', 'mu', 'params', 'ScannedRNN_0', 'GRUCell_1', 'in', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "ScannedRNN_0", "key_type": 2}, {"key": "GRUCell_1", "key_type": 2}, {"key": "in", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('opt_state', '1', '0', 'mu', 'params', 'ScannedRNN_0', 'GRUCell_1', 'in', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "ScannedRNN_0", "key_type": 2}, {"key": "GRUCell_1", "key_type": 2}, {"key": "in", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512, 512]}}, "('opt_state', '1', '0', 'mu', 'params', 'ScannedRNN_0', 'GRUCell_1', 'ir', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "ScannedRNN_0", "key_type": 2}, {"key": "GRUCell_1", "key_type": 2}, {"key": "ir", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('opt_state', '1', '0', 'mu', 'params', 'ScannedRNN_0', 'GRUCell_1', 'ir', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "ScannedRNN_0", "key_type": 2}, {"key": "GRUCell_1", "key_type": 2}, {"key": "ir", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512, 512]}}, "('opt_state', '1', '0', 'mu', 'params', 'ScannedRNN_0', 'GRUCell_1', 'iz', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "ScannedRNN_0", "key_type": 2}, {"key": "GRUCell_1", "key_type": 2}, {"key": "iz", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('opt_state', '1', '0', 'mu', 'params', 'ScannedRNN_0', 'GRUCell_1', 'iz', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "ScannedRNN_0", "key_type": 2}, {"key": "GRUCell_1", "key_type": 2}, {"key": "iz", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512, 512]}}, "('opt_state', '1', '0', 'nu', 'params', 'Dense_0', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('opt_state', '1', '0', 'nu', 'params', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1345, 512]}}, "('opt_state', '1', '0', 'nu', 'params', 'Dense_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('opt_state', '1', '0', 'nu', 'params', 'Dense_1', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512, 512]}}, "('opt_state', '1', '0', 'nu', 'params', 'Dense_2', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('opt_state', '1', '0', 'nu', 'params', 'Dense_2', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_2", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512, 512]}}, "('opt_state', '1', '0', 'nu', 'params', 'Dense_3', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_3", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [17]}}, "('opt_state', '1', '0', 'nu', 'params', 'Dense_3', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_3", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512, 17]}}, "('opt_state', '1', '0', 'nu', 'params', 'Dense_4', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_4", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('opt_state', '1', '0', 'nu', 'params', 'Dense_4', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_4", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512, 512]}}, "('opt_state', '1', '0', 'nu', 'params', 'Dense_5', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_5", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('opt_state', '1', '0', 'nu', 'params', 'Dense_5', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_5", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512, 512]}}, "('opt_state', '1', '0', 'nu', 'params', 'Dense_6', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_6", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1]}}, "('opt_state', '1', '0', 'nu', 'params', 'Dense_6', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_6", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512, 1]}}, "('opt_state', '1', '0', 'nu', 'params', 'ScannedRNN_0', 'GRUCell_1', 'hn', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "ScannedRNN_0", "key_type": 2}, {"key": "GRUCell_1", "key_type": 2}, {"key": "hn", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('opt_state', '1', '0', 'nu', 'params', 'ScannedRNN_0', 'GRUCell_1', 'hn', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "ScannedRNN_0", "key_type": 2}, {"key": "GRUCell_1", "key_type": 2}, {"key": "hn", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512, 512]}}, "('opt_state', '1', '0', 'nu', 'params', 'ScannedRNN_0', 'GRUCell_1', 'hr', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "ScannedRNN_0", "key_type": 2}, {"key": "GRUCell_1", "key_type": 2}, {"key": "hr", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512, 512]}}, "('opt_state', '1', '0', 'nu', 'params', 'ScannedRNN_0', 'GRUCell_1', 'hz', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "ScannedRNN_0", "key_type": 2}, {"key": "GRUCell_1", "key_type": 2}, {"key": "hz", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512, 512]}}, "('opt_state', '1', '0', 'nu', 'params', 'ScannedRNN_0', 'GRUCell_1', 'in', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "ScannedRNN_0", "key_type": 2}, {"key": "GRUCell_1", "key_type": 2}, {"key": "in", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('opt_state', '1', '0', 'nu', 'params', 'ScannedRNN_0', 'GRUCell_1', 'in', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "ScannedRNN_0", "key_type": 2}, {"key": "GRUCell_1", "key_type": 2}, {"key": "in", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512, 512]}}, "('opt_state', '1', '0', 'nu', 'params', 'ScannedRNN_0', 'GRUCell_1', 'ir', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "ScannedRNN_0", "key_type": 2}, {"key": "GRUCell_1", "key_type": 2}, {"key": "ir", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('opt_state', '1', '0', 'nu', 'params', 'ScannedRNN_0', 'GRUCell_1', 'ir', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "ScannedRNN_0", "key_type": 2}, {"key": "GRUCell_1", "key_type": 2}, {"key": "ir", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512, 512]}}, "('opt_state', '1', '0', 'nu', 'params', 'ScannedRNN_0', 'GRUCell_1', 'iz', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "ScannedRNN_0", "key_type": 2}, {"key": "GRUCell_1", "key_type": 2}, {"key": "iz", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('opt_state', '1', '0', 'nu', 'params', 'ScannedRNN_0', 'GRUCell_1', 'iz', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "ScannedRNN_0", "key_type": 2}, {"key": "GRUCell_1", "key_type": 2}, {"key": "iz", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512, 512]}}, "('opt_state', '1', '1', 'count')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "1", "key_type": 1}, {"key": "count", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": []}}}, "use_ocdbt": true, "use_zarr3": false, "store_array_data_equal_to_fill_value": true, "custom_metadata": null}
|
checkpoints/ppo_agents/Craftax-Classic-Symbolic-v1-PPO_RNN-1000M/1000000000/default/_sharding
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
{"b3B0X3N0YXRlLjEuMC5jb3VudA==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuRGVuc2VfMC5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuRGVuc2VfMC5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuRGVuc2VfMS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuRGVuc2VfMS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuRGVuc2VfMi5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuRGVuc2VfMi5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuRGVuc2VfMy5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuRGVuc2VfMy5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuRGVuc2VfNC5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuRGVuc2VfNC5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuRGVuc2VfNS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuRGVuc2VfNS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuRGVuc2VfNi5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuRGVuc2VfNi5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuU2Nhbm5lZFJOTl8wLkdSVUNlbGxfMS5obi5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuU2Nhbm5lZFJOTl8wLkdSVUNlbGxfMS5obi5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuU2Nhbm5lZFJOTl8wLkdSVUNlbGxfMS5oci5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuU2Nhbm5lZFJOTl8wLkdSVUNlbGxfMS5oei5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuU2Nhbm5lZFJOTl8wLkdSVUNlbGxfMS5pbi5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuU2Nhbm5lZFJOTl8wLkdSVUNlbGxfMS5pbi5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuU2Nhbm5lZFJOTl8wLkdSVUNlbGxfMS5pci5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuU2Nhbm5lZFJOTl8wLkdSVUNlbGxfMS5pci5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuU2Nhbm5lZFJOTl8wLkdSVUNlbGxfMS5pei5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuU2Nhbm5lZFJOTl8wLkdSVUNlbGxfMS5pei5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuRGVuc2VfMC5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuRGVuc2VfMC5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuRGVuc2VfMS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuRGVuc2VfMS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuRGVuc2VfMi5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuRGVuc2VfMi5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuRGVuc2VfMy5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuRGVuc2VfMy5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuRGVuc2VfNC5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuRGVuc2VfNC5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuRGVuc2VfNS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuRGVuc2VfNS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuRGVuc2VfNi5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuRGVuc2VfNi5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuU2Nhbm5lZFJOTl8wLkdSVUNlbGxfMS5obi5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuU2Nhbm5lZFJOTl8wLkdSVUNlbGxfMS5obi5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuU2Nhbm5lZFJOTl8wLkdSVUNlbGxfMS5oci5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuU2Nhbm5lZFJOTl8wLkdSVUNlbGxfMS5oei5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuU2Nhbm5lZFJOTl8wLkdSVUNlbGxfMS5pbi5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuU2Nhbm5lZFJOTl8wLkdSVUNlbGxfMS5pbi5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuU2Nhbm5lZFJOTl8wLkdSVUNlbGxfMS5pci5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuU2Nhbm5lZFJOTl8wLkdSVUNlbGxfMS5pci5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuU2Nhbm5lZFJOTl8wLkdSVUNlbGxfMS5pei5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuU2Nhbm5lZFJOTl8wLkdSVUNlbGxfMS5pei5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMS5jb3VudA==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","c3RlcA==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5EZW5zZV80LmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5EZW5zZV80Lmtlcm5lbA==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5EZW5zZV81LmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5EZW5zZV81Lmtlcm5lbA==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5EZW5zZV82LmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5EZW5zZV82Lmtlcm5lbA==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5EZW5zZV8wLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5EZW5zZV8wLmtlcm5lbA==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5EZW5zZV8xLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5EZW5zZV8xLmtlcm5lbA==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5EZW5zZV8yLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5EZW5zZV8yLmtlcm5lbA==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5EZW5zZV8zLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5EZW5zZV8zLmtlcm5lbA==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5TY2FubmVkUk5OXzAuR1JVQ2VsbF8xLmh6Lmtlcm5lbA==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5TY2FubmVkUk5OXzAuR1JVQ2VsbF8xLmhuLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5TY2FubmVkUk5OXzAuR1JVQ2VsbF8xLmhuLmtlcm5lbA==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5TY2FubmVkUk5OXzAuR1JVQ2VsbF8xLmhyLmtlcm5lbA==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5TY2FubmVkUk5OXzAuR1JVQ2VsbF8xLml6LmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5TY2FubmVkUk5OXzAuR1JVQ2VsbF8xLml6Lmtlcm5lbA==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5TY2FubmVkUk5OXzAuR1JVQ2VsbF8xLmluLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5TY2FubmVkUk5OXzAuR1JVQ2VsbF8xLmluLmtlcm5lbA==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5TY2FubmVkUk5OXzAuR1JVQ2VsbF8xLmlyLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5TY2FubmVkUk5OXzAuR1JVQ2VsbF8xLmlyLmtlcm5lbA==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}"}
|