MathisW78 commited on
Commit
6140064
·
verified ·
1 Parent(s): 16ca0bc

Upload COMP0258 demo bundle (code + diffusion/PPO checkpoints + ablation assets)

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +118 -0
  2. Craftax_Baselines/.gitignore +169 -0
  3. Craftax_Baselines/.pre-commit-config.yaml +6 -0
  4. Craftax_Baselines/Dockerfile +41 -0
  5. Craftax_Baselines/LICENSE +19 -0
  6. Craftax_Baselines/README.md +46 -0
  7. Craftax_Baselines/analysis/__init__.py +0 -0
  8. Craftax_Baselines/analysis/view_ppo_agent.py +151 -0
  9. Craftax_Baselines/build.sh +10 -0
  10. Craftax_Baselines/images/logo.png +0 -0
  11. Craftax_Baselines/logz/__init__.py +0 -0
  12. Craftax_Baselines/logz/batch_logging.py +115 -0
  13. Craftax_Baselines/models/__init__.py +0 -0
  14. Craftax_Baselines/models/actor_critic.py +256 -0
  15. Craftax_Baselines/models/icm.py +72 -0
  16. Craftax_Baselines/models/rnd.py +120 -0
  17. Craftax_Baselines/ppo.py +733 -0
  18. Craftax_Baselines/ppo_rnd.py +680 -0
  19. Craftax_Baselines/ppo_rnn.py +542 -0
  20. Craftax_Baselines/requirements.txt +16 -0
  21. Craftax_Baselines/run_docker.sh +24 -0
  22. Craftax_Baselines/wrappers.py +200 -0
  23. README.md +547 -0
  24. checkpoints/offline/Craftax-Classic-Symbolic-v1-OfflineDiffusion-BC-100M/100000000/_CHECKPOINT_METADATA +1 -0
  25. checkpoints/offline/Craftax-Classic-Symbolic-v1-OfflineDiffusion-BC-100M/100000000/default/_METADATA +0 -0
  26. checkpoints/offline/Craftax-Classic-Symbolic-v1-OfflineDiffusion-BC-100M/100000000/default/_sharding +1 -0
  27. checkpoints/offline/Craftax-Classic-Symbolic-v1-OfflineDiffusion-BC-100M/100000000/default/array_metadatas/process_0 +1 -0
  28. checkpoints/offline/Craftax-Classic-Symbolic-v1-OfflineDiffusion-BC-100M/100000000/default/d/021af9ba431a3072f4819480f91b83af +0 -0
  29. checkpoints/offline/Craftax-Classic-Symbolic-v1-OfflineDiffusion-BC-100M/100000000/default/manifest.ocdbt +0 -0
  30. checkpoints/offline/Craftax-Classic-Symbolic-v1-OfflineDiffusion-BC-100M/100000000/default/ocdbt.process_0/d/0cf9a08a9722f9b8a0b7f007da7c1e92 +3 -0
  31. checkpoints/offline/Craftax-Classic-Symbolic-v1-OfflineDiffusion-BC-100M/100000000/default/ocdbt.process_0/d/1968eb861d84503c0e805cffdd77528a +0 -0
  32. checkpoints/offline/Craftax-Classic-Symbolic-v1-OfflineDiffusion-BC-100M/100000000/default/ocdbt.process_0/d/aec20934c03229d1bd9651c955e59d84 +0 -0
  33. checkpoints/offline/Craftax-Classic-Symbolic-v1-OfflineDiffusion-BC-100M/100000000/default/ocdbt.process_0/d/d32b0ea672fe7a9b86b8e62e7c20dbaf +3 -0
  34. checkpoints/offline/Craftax-Classic-Symbolic-v1-OfflineDiffusion-BC-100M/100000000/default/ocdbt.process_0/manifest.ocdbt +0 -0
  35. checkpoints/offline/Craftax-Classic-Symbolic-v1-OfflineDiffusion-BC-100M/resume_metadata.json +68 -0
  36. checkpoints/online/Craftax-Classic-Symbolic-v1-OnlineDiffusion-DAgger-50M/50000000/_CHECKPOINT_METADATA +1 -0
  37. checkpoints/online/Craftax-Classic-Symbolic-v1-OnlineDiffusion-DAgger-50M/50000000/default/_METADATA +0 -0
  38. checkpoints/online/Craftax-Classic-Symbolic-v1-OnlineDiffusion-DAgger-50M/50000000/default/_sharding +1 -0
  39. checkpoints/online/Craftax-Classic-Symbolic-v1-OnlineDiffusion-DAgger-50M/50000000/default/array_metadatas/process_0 +1 -0
  40. checkpoints/online/Craftax-Classic-Symbolic-v1-OnlineDiffusion-DAgger-50M/50000000/default/d/63ff4b6b75238977cfc360704c224d86 +0 -0
  41. checkpoints/online/Craftax-Classic-Symbolic-v1-OnlineDiffusion-DAgger-50M/50000000/default/manifest.ocdbt +0 -0
  42. checkpoints/online/Craftax-Classic-Symbolic-v1-OnlineDiffusion-DAgger-50M/50000000/default/ocdbt.process_0/d/63a3ac9c870d5f7eb2b74967758ee043 +0 -0
  43. checkpoints/online/Craftax-Classic-Symbolic-v1-OnlineDiffusion-DAgger-50M/50000000/default/ocdbt.process_0/d/c3b086748e0ea04233c3638a3994fa30 +0 -0
  44. checkpoints/online/Craftax-Classic-Symbolic-v1-OnlineDiffusion-DAgger-50M/50000000/default/ocdbt.process_0/d/db1925df3ee2d3c92bea0a9878efa5fc +0 -0
  45. checkpoints/online/Craftax-Classic-Symbolic-v1-OnlineDiffusion-DAgger-50M/50000000/default/ocdbt.process_0/d/e25dea5d414404e637e55db20175c620 +0 -0
  46. checkpoints/online/Craftax-Classic-Symbolic-v1-OnlineDiffusion-DAgger-50M/50000000/default/ocdbt.process_0/d/fd4cdc0c7be44d4f518c6ccca7ad654a +3 -0
  47. checkpoints/online/Craftax-Classic-Symbolic-v1-OnlineDiffusion-DAgger-50M/50000000/default/ocdbt.process_0/manifest.ocdbt +0 -0
  48. checkpoints/ppo_agents/Craftax-Classic-Symbolic-v1-PPO_RNN-1000M/1000000000/_CHECKPOINT_METADATA +1 -0
  49. checkpoints/ppo_agents/Craftax-Classic-Symbolic-v1-PPO_RNN-1000M/1000000000/default/_METADATA +1 -0
  50. checkpoints/ppo_agents/Craftax-Classic-Symbolic-v1-PPO_RNN-1000M/1000000000/default/_sharding +1 -0
.gitattributes CHANGED
@@ -33,3 +33,121 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ checkpoints/offline/Craftax-Classic-Symbolic-v1-OfflineDiffusion-BC-100M/100000000/default/ocdbt.process_0/d/0cf9a08a9722f9b8a0b7f007da7c1e92 filter=lfs diff=lfs merge=lfs -text
37
+ checkpoints/offline/Craftax-Classic-Symbolic-v1-OfflineDiffusion-BC-100M/100000000/default/ocdbt.process_0/d/d32b0ea672fe7a9b86b8e62e7c20dbaf filter=lfs diff=lfs merge=lfs -text
38
+ checkpoints/online/Craftax-Classic-Symbolic-v1-OnlineDiffusion-DAgger-50M/50000000/default/ocdbt.process_0/d/fd4cdc0c7be44d4f518c6ccca7ad654a filter=lfs diff=lfs merge=lfs -text
39
+ checkpoints/ppo_agents/Craftax-Classic-Symbolic-v1-PPO_RNN-1000M/1000000000/default/ocdbt.process_0/d/369457b7c6608f1adf28eb88024d6b91 filter=lfs diff=lfs merge=lfs -text
40
+ checkpoints/ppo_agents/Craftax-Classic-Symbolic-v1-PPO_RNN-1000M/1000000000/default/ocdbt.process_0/d/6bb110a840637eede93e25f5356236f9 filter=lfs diff=lfs merge=lfs -text
41
+ checkpoints/ppo_agents/Craftax-Symbolic-v1-PPO_RNN-1000M/1000000000/default/ocdbt.process_0/d/e5a4020f50167115120fe5dac41c20fb filter=lfs diff=lfs merge=lfs -text
42
+ experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/achievement_breakdown.png filter=lfs diff=lfs merge=lfs -text
43
+ experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/achievement_collapse_action_diversity.png filter=lfs diff=lfs merge=lfs -text
44
+ experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/achievement_collapse_advantage_clip.png filter=lfs diff=lfs merge=lfs -text
45
+ experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/achievement_collapse_attention_only.png filter=lfs diff=lfs merge=lfs -text
46
+ experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/achievement_collapse_baseline_rl.png filter=lfs diff=lfs merge=lfs -text
47
+ experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/achievement_collapse_bc_wins.png filter=lfs diff=lfs merge=lfs -text
48
+ experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/achievement_collapse_entropy_bonus.png filter=lfs diff=lfs merge=lfs -text
49
+ experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/achievement_collapse_ewc.png filter=lfs diff=lfs merge=lfs -text
50
+ experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/achievement_collapse_ffn_only.png filter=lfs diff=lfs merge=lfs -text
51
+ experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/achievement_collapse_frozen_backbone.png filter=lfs diff=lfs merge=lfs -text
52
+ experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/achievement_collapse_gradient_surgery.png filter=lfs diff=lfs merge=lfs -text
53
+ experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/achievement_collapse_head_only.png filter=lfs diff=lfs merge=lfs -text
54
+ experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/achievement_collapse_kl_penalty.png filter=lfs diff=lfs merge=lfs -text
55
+ experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/achievement_collapse_layer_ablation_top1.png filter=lfs diff=lfs merge=lfs -text
56
+ experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/achievement_collapse_layer_ablation_top2.png filter=lfs diff=lfs merge=lfs -text
57
+ experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/achievement_collapse_layer_ablation_top3.png filter=lfs diff=lfs merge=lfs -text
58
+ experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/achievement_collapse_llrd.png filter=lfs diff=lfs merge=lfs -text
59
+ experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/achievement_collapse_lora.png filter=lfs diff=lfs merge=lfs -text
60
+ experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/achievement_collapse_low_t.png filter=lfs diff=lfs merge=lfs -text
61
+ experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/achievement_collapse_mixed_replay.png filter=lfs diff=lfs merge=lfs -text
62
+ experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/achievement_collapse_normalized_adv.png filter=lfs diff=lfs merge=lfs -text
63
+ experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/achievement_collapse_reward_filtering.png filter=lfs diff=lfs merge=lfs -text
64
+ experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/achievement_collapse_reward_model.png filter=lfs diff=lfs merge=lfs -text
65
+ experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/achievement_collapse_running_stats.png filter=lfs diff=lfs merge=lfs -text
66
+ experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/achievement_collapse_t_curriculum.png filter=lfs diff=lfs merge=lfs -text
67
+ experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/achievement_collapse_trust_region_kl.png filter=lfs diff=lfs merge=lfs -text
68
+ experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/action_dist/js_divergence_comparison.png filter=lfs diff=lfs merge=lfs -text
69
+ experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/cka_similarity.png filter=lfs diff=lfs merge=lfs -text
70
+ experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/curves_action_diversity.png filter=lfs diff=lfs merge=lfs -text
71
+ experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/curves_advantage_clip.png filter=lfs diff=lfs merge=lfs -text
72
+ experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/curves_attention_only.png filter=lfs diff=lfs merge=lfs -text
73
+ experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/curves_baseline_rl.png filter=lfs diff=lfs merge=lfs -text
74
+ experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/curves_bc_wins.png filter=lfs diff=lfs merge=lfs -text
75
+ experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/curves_entropy_bonus.png filter=lfs diff=lfs merge=lfs -text
76
+ experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/curves_ewc.png filter=lfs diff=lfs merge=lfs -text
77
+ experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/curves_ffn_only.png filter=lfs diff=lfs merge=lfs -text
78
+ experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/curves_frozen_backbone.png filter=lfs diff=lfs merge=lfs -text
79
+ experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/curves_gradient_surgery.png filter=lfs diff=lfs merge=lfs -text
80
+ experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/curves_head_only.png filter=lfs diff=lfs merge=lfs -text
81
+ experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/curves_kl_penalty.png filter=lfs diff=lfs merge=lfs -text
82
+ experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/curves_layer_ablation_top1.png filter=lfs diff=lfs merge=lfs -text
83
+ experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/curves_layer_ablation_top2.png filter=lfs diff=lfs merge=lfs -text
84
+ experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/curves_layer_ablation_top3.png filter=lfs diff=lfs merge=lfs -text
85
+ experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/curves_llrd.png filter=lfs diff=lfs merge=lfs -text
86
+ experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/curves_lora.png filter=lfs diff=lfs merge=lfs -text
87
+ experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/curves_low_t.png filter=lfs diff=lfs merge=lfs -text
88
+ experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/curves_mixed_replay.png filter=lfs diff=lfs merge=lfs -text
89
+ experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/curves_normalized_adv.png filter=lfs diff=lfs merge=lfs -text
90
+ experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/curves_reward_filtering.png filter=lfs diff=lfs merge=lfs -text
91
+ experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/curves_reward_model.png filter=lfs diff=lfs merge=lfs -text
92
+ experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/curves_running_stats.png filter=lfs diff=lfs merge=lfs -text
93
+ experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/curves_t_curriculum.png filter=lfs diff=lfs merge=lfs -text
94
+ experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/curves_trust_region_kl.png filter=lfs diff=lfs merge=lfs -text
95
+ experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/diagnosis_decision_tree.png filter=lfs diff=lfs merge=lfs -text
96
+ experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/eval_scores_over_training.png filter=lfs diff=lfs merge=lfs -text
97
+ experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/final_score_comparison.png filter=lfs diff=lfs merge=lfs -text
98
+ experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/gradient_alignment.png filter=lfs diff=lfs merge=lfs -text
99
+ experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/gradient_conflict_map.png filter=lfs diff=lfs merge=lfs -text
100
+ experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/per_layer_grad_heatmap_action_diversity.png filter=lfs diff=lfs merge=lfs -text
101
+ experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/per_layer_grad_heatmap_advantage_clip.png filter=lfs diff=lfs merge=lfs -text
102
+ experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/per_layer_grad_heatmap_attention_only.png filter=lfs diff=lfs merge=lfs -text
103
+ experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/per_layer_grad_heatmap_baseline_rl.png filter=lfs diff=lfs merge=lfs -text
104
+ experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/per_layer_grad_heatmap_bc_wins.png filter=lfs diff=lfs merge=lfs -text
105
+ experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/per_layer_grad_heatmap_entropy_bonus.png filter=lfs diff=lfs merge=lfs -text
106
+ experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/per_layer_grad_heatmap_ewc.png filter=lfs diff=lfs merge=lfs -text
107
+ experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/per_layer_grad_heatmap_ffn_only.png filter=lfs diff=lfs merge=lfs -text
108
+ experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/per_layer_grad_heatmap_frozen_backbone.png filter=lfs diff=lfs merge=lfs -text
109
+ experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/per_layer_grad_heatmap_gradient_surgery.png filter=lfs diff=lfs merge=lfs -text
110
+ experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/per_layer_grad_heatmap_head_only.png filter=lfs diff=lfs merge=lfs -text
111
+ experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/per_layer_grad_heatmap_kl_penalty.png filter=lfs diff=lfs merge=lfs -text
112
+ experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/per_layer_grad_heatmap_layer_ablation_top1.png filter=lfs diff=lfs merge=lfs -text
113
+ experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/per_layer_grad_heatmap_layer_ablation_top2.png filter=lfs diff=lfs merge=lfs -text
114
+ experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/per_layer_grad_heatmap_layer_ablation_top3.png filter=lfs diff=lfs merge=lfs -text
115
+ experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/per_layer_grad_heatmap_llrd.png filter=lfs diff=lfs merge=lfs -text
116
+ experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/per_layer_grad_heatmap_lora.png filter=lfs diff=lfs merge=lfs -text
117
+ experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/per_layer_grad_heatmap_low_t.png filter=lfs diff=lfs merge=lfs -text
118
+ experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/per_layer_grad_heatmap_mixed_replay.png filter=lfs diff=lfs merge=lfs -text
119
+ experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/per_layer_grad_heatmap_normalized_adv.png filter=lfs diff=lfs merge=lfs -text
120
+ experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/per_layer_grad_heatmap_reward_filtering.png filter=lfs diff=lfs merge=lfs -text
121
+ experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/per_layer_grad_heatmap_reward_model.png filter=lfs diff=lfs merge=lfs -text
122
+ experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/per_layer_grad_heatmap_running_stats.png filter=lfs diff=lfs merge=lfs -text
123
+ experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/per_layer_grad_heatmap_t_curriculum.png filter=lfs diff=lfs merge=lfs -text
124
+ experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/per_layer_grad_heatmap_trust_region_kl.png filter=lfs diff=lfs merge=lfs -text
125
+ experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/representation_drift.png filter=lfs diff=lfs merge=lfs -text
126
+ experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/score_delta_over_baseline_rl.png filter=lfs diff=lfs merge=lfs -text
127
+ experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/t_bin_grad_norms_action_diversity.png filter=lfs diff=lfs merge=lfs -text
128
+ experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/t_bin_grad_norms_advantage_clip.png filter=lfs diff=lfs merge=lfs -text
129
+ experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/t_bin_grad_norms_attention_only.png filter=lfs diff=lfs merge=lfs -text
130
+ experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/t_bin_grad_norms_baseline_rl.png filter=lfs diff=lfs merge=lfs -text
131
+ experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/t_bin_grad_norms_bc_wins.png filter=lfs diff=lfs merge=lfs -text
132
+ experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/t_bin_grad_norms_entropy_bonus.png filter=lfs diff=lfs merge=lfs -text
133
+ experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/t_bin_grad_norms_ewc.png filter=lfs diff=lfs merge=lfs -text
134
+ experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/t_bin_grad_norms_ffn_only.png filter=lfs diff=lfs merge=lfs -text
135
+ experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/t_bin_grad_norms_frozen_backbone.png filter=lfs diff=lfs merge=lfs -text
136
+ experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/t_bin_grad_norms_gradient_surgery.png filter=lfs diff=lfs merge=lfs -text
137
+ experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/t_bin_grad_norms_head_only.png filter=lfs diff=lfs merge=lfs -text
138
+ experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/t_bin_grad_norms_kl_penalty.png filter=lfs diff=lfs merge=lfs -text
139
+ experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/t_bin_grad_norms_layer_ablation_top1.png filter=lfs diff=lfs merge=lfs -text
140
+ experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/t_bin_grad_norms_layer_ablation_top2.png filter=lfs diff=lfs merge=lfs -text
141
+ experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/t_bin_grad_norms_layer_ablation_top3.png filter=lfs diff=lfs merge=lfs -text
142
+ experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/t_bin_grad_norms_llrd.png filter=lfs diff=lfs merge=lfs -text
143
+ experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/t_bin_grad_norms_low_t.png filter=lfs diff=lfs merge=lfs -text
144
+ experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/t_bin_grad_norms_mixed_replay.png filter=lfs diff=lfs merge=lfs -text
145
+ experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/t_bin_grad_norms_normalized_adv.png filter=lfs diff=lfs merge=lfs -text
146
+ experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/t_bin_grad_norms_reward_filtering.png filter=lfs diff=lfs merge=lfs -text
147
+ experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/t_bin_grad_norms_reward_model.png filter=lfs diff=lfs merge=lfs -text
148
+ experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/t_bin_grad_norms_running_stats.png filter=lfs diff=lfs merge=lfs -text
149
+ experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/t_bin_grad_norms_t_curriculum.png filter=lfs diff=lfs merge=lfs -text
150
+ experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/t_bin_grad_norms_trust_region_kl.png filter=lfs diff=lfs merge=lfs -text
151
+ experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/t_bin_norms_heatmap.png filter=lfs diff=lfs merge=lfs -text
152
+ experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/t_distribution_analysis.png filter=lfs diff=lfs merge=lfs -text
153
+ experiments/rl_finetuning/outputs/craftax_classic_final_results/analysis/figures/win_rate_and_effective_batch_size.png filter=lfs diff=lfs merge=lfs -text
Craftax_Baselines/.gitignore ADDED
@@ -0,0 +1,169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ tmp/
2
+ wandb/
3
+ res/
4
+ runs/
5
+
6
+ play_data
7
+
8
+ # Byte-compiled / optimized / DLL files
9
+ __pycache__/
10
+ *.py[cod]
11
+ *$py.class
12
+
13
+ # C extensions
14
+ *.so
15
+
16
+ # Distribution / packaging
17
+ .Python
18
+ build/
19
+ develop-eggs/
20
+ dist/
21
+ downloads/
22
+ eggs/
23
+ .eggs/
24
+ lib/
25
+ lib64/
26
+ parts/
27
+ sdist/
28
+ var/
29
+ wheels/
30
+ share/python-wheels/
31
+ *.egg-info/
32
+ .installed.cfg
33
+ *.egg
34
+ MANIFEST
35
+
36
+ # PyInstaller
37
+ # Usually these files are written by a python script from a template
38
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
39
+ *.manifest
40
+ *.spec
41
+
42
+ # Installer logs
43
+ pip-log.txt
44
+ pip-delete-this-directory.txt
45
+
46
+ # Unit test / coverage reports
47
+ htmlcov/
48
+ .tox/
49
+ .nox/
50
+ .coverage
51
+ .coverage.*
52
+ .cache
53
+ nosetests.xml
54
+ coverage.xml
55
+ *.cover
56
+ *.py,cover
57
+ .hypothesis/
58
+ .pytest_cache/
59
+ cover/
60
+
61
+ # Translations
62
+ *.mo
63
+ *.pot
64
+
65
+ # Django stuff:
66
+ *.log
67
+ local_settings.py
68
+ db.sqlite3
69
+ db.sqlite3-journal
70
+
71
+ # Flask stuff:
72
+ instance/
73
+ .webassets-cache
74
+
75
+ # Scrapy stuff:
76
+ .scrapy
77
+
78
+ # Sphinx documentation
79
+ docs/_build/
80
+
81
+ # PyBuilder
82
+ .pybuilder/
83
+ target/
84
+
85
+ # Jupyter Notebook
86
+ .ipynb_checkpoints
87
+
88
+ # IPython
89
+ profile_default/
90
+ ipython_config.py
91
+
92
+ # pyenv
93
+ # For a library or package, you might want to ignore these files since the code is
94
+ # intended to run in multiple environments; otherwise, check them in:
95
+ # .python-version
96
+
97
+ # pipenv
98
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
99
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
100
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
101
+ # install all needed dependencies.
102
+ #Pipfile.lock
103
+
104
+ # poetry
105
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
106
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
107
+ # commonly ignored for libraries.
108
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
109
+ #poetry.lock
110
+
111
+ # pdm
112
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
113
+ #pdm.lock
114
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
115
+ # in version control.
116
+ # https://pdm.fming.dev/#use-with-ide
117
+ .pdm.toml
118
+
119
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
120
+ __pypackages__/
121
+
122
+ # Celery stuff
123
+ celerybeat-schedule
124
+ celerybeat.pid
125
+
126
+ # SageMath parsed files
127
+ *.sage.py
128
+
129
+ # Environments
130
+ .env
131
+ .venv
132
+ env/
133
+ venv/
134
+ ENV/
135
+ env.bak/
136
+ venv.bak/
137
+
138
+ # Spyder project settings
139
+ .spyderproject
140
+ .spyproject
141
+
142
+ # Rope project settings
143
+ .ropeproject
144
+
145
+ # mkdocs documentation
146
+ /site
147
+
148
+ # mypy
149
+ .mypy_cache/
150
+ .dmypy.json
151
+ dmypy.json
152
+
153
+ # Pyre type checker
154
+ .pyre/
155
+
156
+ # pytype static type analyzer
157
+ .pytype/
158
+
159
+ # Cython debug symbols
160
+ cython_debug/
161
+
162
+ # PyCharm
163
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
164
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
165
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
166
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
167
+ .idea/
168
+ texture_cache.pbz2
169
+ texture_cache*.pbz2
Craftax_Baselines/.pre-commit-config.yaml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ repos:
2
+ - repo: https://github.com/psf/black
3
+ rev: 22.3.0
4
+ hooks:
5
+ - id: black
6
+ language_version: python3
Craftax_Baselines/Dockerfile ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM nvidia/cuda:12.1.0-cudnn8-devel-ubuntu22.04
2
+
3
+ ENV CUDA_PATH /usr/local/cuda
4
+ ENV CUDA_INCLUDE_PATH /usr/local/cuda/include
5
+ ENV CUDA_LIBRARY_PATH /usr/local/cuda/lib64
6
+
7
+ # Set timezone
8
+ ENV TZ=Europe/London DEBIAN_FRONTEND=noninteractive
9
+
10
+ # Add Python 3.8 to Ubuntu 22.04 and install dependencies
11
+ RUN apt update
12
+ RUN apt install -y software-properties-common && add-apt-repository ppa:deadsnakes/ppa
13
+ RUN apt install -y \
14
+ git \
15
+ python3.8 \
16
+ python3-pip \
17
+ python3.8-venv \
18
+ python3-setuptools \
19
+ python3-wheel
20
+
21
+ # Create local user
22
+ # https://jtreminio.com/blog/running-docker-containers-as-current-host-user/
23
+ ARG UID
24
+ ARG GID
25
+ RUN if [ ${UID:-0} -ne 0 ] && [ ${GID:-0} -ne 0 ]; then \
26
+ groupadd -g ${GID} duser &&\
27
+ useradd -l -u ${UID} -g duser duser &&\
28
+ install -d -m 0755 -o duser -g duser /home/duser &&\
29
+ chown --changes --silent --no-dereference --recursive ${UID}:${GID} /home/duser \
30
+ ;fi
31
+
32
+ USER duser
33
+ WORKDIR /home/duser
34
+
35
+ # Install Python packages
36
+ ENV PATH="/home/duser/.local/bin:$PATH"
37
+ RUN python3 -m pip install --upgrade pip
38
+ ARG REQS
39
+ RUN pip install $REQS -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
40
+
41
+ WORKDIR /home/duser/Craftax
Craftax_Baselines/LICENSE ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Copyright (c) 2024 Michael Matthews
2
+
3
+ Permission is hereby granted, free of charge, to any person obtaining a copy
4
+ of this software and associated documentation files (the "Software"), to deal
5
+ in the Software without restriction, including without limitation the rights
6
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
7
+ copies of the Software, and to permit persons to whom the Software is
8
+ furnished to do so, subject to the following conditions:
9
+
10
+ The above copyright notice and this permission notice shall be included in all
11
+ copies or substantial portions of the Software.
12
+
13
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
18
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
19
+ SOFTWARE.
Craftax_Baselines/README.md ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <p align="center">
2
+ <img width="80%" src="https://raw.githubusercontent.com/MichaelTMatthews/Craftax_Baselines/main/images/logo.png" />
3
+ </p>
4
+
5
+ # Craftax Baselines
6
+
7
+ This repository contains the code for running the baselines from the [Craftax paper](https://arxiv.org/abs/2402.16801).
8
+ For packaging reasons, this is separate to the [main repository](https://github.com/MichaelTMatthews/Craftax/).
9
+
10
+ # Installation
11
+ ```commandline
12
+ git clone https://github.com/MichaelTMatthews/Craftax_Baselines.git
13
+ cd Craftax_Baselines
14
+ pip install -r requirements.txt -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
15
+ pre-commit install
16
+ ```
17
+
18
+ # Run Experiments
19
+
20
+ ### PPO
21
+ ```commandline
22
+ python ppo.py
23
+ ```
24
+
25
+ ### PPO-RNN
26
+ ```commandline
27
+ python ppo_rnn.py
28
+ ```
29
+
30
+ ### ICM
31
+ ```commandline
32
+ python ppo.py --train_icm
33
+ ```
34
+
35
+ ### E3B
36
+ ```commandline
37
+ python ppo.py --train_icm --use_e3b --icm_reward_coeff 0
38
+ ```
39
+
40
+ ### RND
41
+ ```commandline
42
+ python ppo_rnd.py
43
+ ```
44
+
45
+ # Visualisation
46
+ You can save trained policies with the `--save_policy` flag. These can then be viewed with the `view_ppo_agent` script (pass in the path up to the `files` directory).
Craftax_Baselines/analysis/__init__.py ADDED
File without changes
Craftax_Baselines/analysis/view_ppo_agent.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import sys
4
+
5
+ import jax
6
+ import jax.numpy as jnp
7
+ import numpy as np
8
+ import optax
9
+ import yaml
10
+ from craftax.environment_base.wrappers import AutoResetEnvWrapper
11
+ from flax.training.train_state import TrainState
12
+ import orbax.checkpoint as ocp
13
+
14
+ from ..models.actor_critic import ActorCriticConv, ActorCritic
15
+
16
+
17
+ def main(args):
18
+
19
+ with open(os.path.join(args.path, "config.yaml")) as f:
20
+ raw_config = yaml.load(f, Loader=yaml.Loader)
21
+
22
+ config = {}
23
+ for key, value in raw_config.items():
24
+ if isinstance(value, dict) and "value" in value:
25
+ config[key] = value["value"]
26
+
27
+ config["NUM_ENVS"] = 1
28
+
29
+ options = ocp.CheckpointManagerOptions(max_to_keep=1)
30
+ checkpoint_manager = ocp.CheckpointManager(
31
+ os.path.join(args.path, "policies"),
32
+ options=options
33
+ )
34
+
35
+ is_classic = False
36
+
37
+ if config["ENV_NAME"] == "Craftax-Symbolic-v1":
38
+ from craftax.craftax.envs.craftax_symbolic_env import CraftaxSymbolicEnv
39
+ from craftax.craftax.constants import Action
40
+
41
+ env = CraftaxSymbolicEnv(CraftaxSymbolicEnv.default_static_params())
42
+ network = ActorCritic(len(Action), config["LAYER_SIZE"])
43
+ elif config["ENV_NAME"] == "Craftax-Pixels-v1":
44
+ from craftax.craftax.envs.craftax_pixels_env import CraftaxPixelsEnv
45
+ from craftax.craftax.constants import Action
46
+
47
+ env = CraftaxPixelsEnv(CraftaxPixelsEnv.default_static_params())
48
+ network = ActorCriticConv(len(Action), config["LAYER_SIZE"])
49
+ elif config["ENV_NAME"] == "Craftax-Classic-Symbolic-v1":
50
+ from craftax.craftax_classic.envs.craftax_symbolic_env import (
51
+ CraftaxClassicSymbolicEnv,
52
+ )
53
+ from craftax.craftax_classic.constants import Action
54
+
55
+ env = CraftaxClassicSymbolicEnv(
56
+ CraftaxClassicSymbolicEnv.default_static_params()
57
+ )
58
+ network = ActorCritic(len(Action), config["LAYER_SIZE"])
59
+ is_classic = True
60
+ elif config["ENV_NAME"] == "Craftax-Classic-Pixels-v1":
61
+ from craftax.craftax_classic.envs.craftax_pixels_env import (
62
+ CraftaxClassicPixelsEnv,
63
+ )
64
+ from craftax.craftax_classic.constants import Action
65
+
66
+ env = CraftaxClassicPixelsEnv(CraftaxClassicPixelsEnv.default_static_params())
67
+ network = ActorCriticConv(len(Action), config["LAYER_SIZE"])
68
+ is_classic = True
69
+ else:
70
+ raise ValueError(f"Unknown env: {config['ENV_NAME']}")
71
+
72
+ env = AutoResetEnvWrapper(env)
73
+ env_params = env.default_params
74
+
75
+ init_x = jnp.zeros((config["NUM_ENVS"], *env.observation_space(env_params).shape))
76
+
77
+ rng = jax.random.PRNGKey(np.random.randint(2**31))
78
+ rng, _rng, __rng = jax.random.split(rng, 3)
79
+ network_params = network.init(_rng, init_x)
80
+
81
+ tx = optax.chain(
82
+ optax.clip_by_global_norm(config["MAX_GRAD_NORM"]),
83
+ optax.adam(config["LR"], eps=1e-5),
84
+ )
85
+ train_state = TrainState.create(
86
+ apply_fn=network.apply,
87
+ params=network_params,
88
+ tx=tx,
89
+ )
90
+
91
+ abstract_train_state = jax.eval_shape(lambda: train_state)
92
+
93
+ train_state = checkpoint_manager.restore(
94
+ config["TOTAL_TIMESTEPS"],
95
+ args=ocp.args.StandardRestore(abstract_train_state)
96
+ )
97
+
98
+ obs, env_state = env.reset(key=__rng)
99
+ done = 0
100
+
101
+ if is_classic:
102
+ from craftax.craftax_classic.play_craftax_classic import CraftaxRenderer
103
+ from craftax.craftax_classic.constants import Achievement
104
+ else:
105
+ from craftax.craftax.play_craftax import CraftaxRenderer
106
+ from craftax.craftax.constants import Achievement
107
+
108
+ renderer = CraftaxRenderer(env, env_params, pixel_render_size=1)
109
+
110
+ while not renderer.is_quit_requested():
111
+ done = np.array([done], dtype=bool)
112
+ obs = jnp.expand_dims(obs, axis=0)
113
+
114
+ pi, value = network.apply(train_state.params, obs)
115
+ rng, _rng = jax.random.split(rng)
116
+ action = pi.sample(seed=_rng)[0]
117
+ # action = jnp.argmax(pi.probs[0, 0])
118
+
119
+ if action is not None:
120
+ rng, _rng = jax.random.split(rng)
121
+ old_achievements = env_state.achievements
122
+ obs, env_state, reward, done, info = env.step(
123
+ _rng, env_state, action, env_params
124
+ )
125
+ new_achievements = env_state.achievements
126
+ print_new_achievements(Achievement, old_achievements, new_achievements)
127
+ if done:
128
+ print("\n")
129
+ renderer.render(env_state)
130
+
131
+
132
+ def print_new_achievements(achievements_cls, old_achievements, new_achievements):
133
+ for i in range(len(old_achievements)):
134
+ if old_achievements[i] == 0 and new_achievements[i] == 1:
135
+ print(f"{achievements_cls(i).name} ({new_achievements.sum()}/{22})")
136
+
137
+
138
+ if __name__ == "__main__":
139
+ parser = argparse.ArgumentParser()
140
+ parser.add_argument("--path", type=str)
141
+ parser.add_argument("--debug", action="store_true")
142
+
143
+ args, rest_args = parser.parse_known_args(sys.argv[1:])
144
+ if rest_args:
145
+ raise ValueError(f"Unknown args {rest_args}")
146
+
147
+ if args.debug:
148
+ with jax.disable_jit():
149
+ main(args)
150
+ else:
151
+ main(args)
Craftax_Baselines/build.sh ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ echo 'Building Dockerfile with image name craftax'
4
+ docker build \
5
+ --build-arg UID=$(id -u ${USER}) \
6
+ --build-arg GID=1234 \
7
+ --build-arg REQS="$(cat requirements.txt)" \
8
+ -t craftax_baselines \
9
+ --no-cache \
10
+ .
Craftax_Baselines/images/logo.png ADDED
Craftax_Baselines/logz/__init__.py ADDED
File without changes
Craftax_Baselines/logz/batch_logging.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+
3
+ import jax.numpy as jnp
4
+ import numpy as np
5
+ import wandb
6
+
7
+ batch_logs = {}
8
+ log_times = []
9
+
10
+
11
+ def create_log_dict(info, config):
12
+ to_log = {
13
+ "episode_return": info["returned_episode_returns"],
14
+ "episode_length": info["returned_episode_lengths"],
15
+ }
16
+
17
+ diffusion_keys = [
18
+ "loss", "unweighted_loss", "accuracy", "mean_t",
19
+ "acc_t_low", "acc_t_mid", "acc_t_high", "grad_norm",
20
+ "action_entropy", "action_unique_frac"
21
+ ]
22
+ for k in diffusion_keys:
23
+ if k in info:
24
+ to_log[f"diffusion/{k}"] = info[k]
25
+
26
+ sum_achievements = 0.0
27
+ sum_val_achievements = 0.0
28
+ has_val = False
29
+
30
+ for k, v in info.items():
31
+ if k.startswith("val/"):
32
+ has_val = True
33
+ to_log[k] = v
34
+ if "achievements" in k.lower() and k != "val/achievements":
35
+ sum_val_achievements += v / 100.0
36
+ elif "achievements" in k.lower():
37
+ to_log[k] = v
38
+ if k != "achievements":
39
+ sum_achievements += v / 100.0
40
+
41
+ to_log["achievements"] = sum_achievements
42
+ if has_val:
43
+ to_log["val/achievements"] = sum_val_achievements
44
+
45
+ if config.get("TRAIN_ICM") or config.get("USE_RND"):
46
+ to_log["intrinsic_reward"] = info.get("reward_i", 0.0)
47
+ to_log["extrinsic_reward"] = info.get("reward_e", 0.0)
48
+
49
+ if config.get("TRAIN_ICM"):
50
+ to_log["icm_inverse_loss"] = info.get("icm_inverse_loss", 0.0)
51
+ to_log["icm_forward_loss"] = info.get("icm_forward_loss", 0.0)
52
+ elif config.get("USE_RND"):
53
+ to_log["rnd_loss"] = info.get("rnd_loss", 0.0)
54
+
55
+ return to_log
56
+
57
+
58
+ def batch_log(update_step, log, config):
59
+ update_step = int(update_step)
60
+ if update_step not in batch_logs:
61
+ batch_logs[update_step] = []
62
+
63
+ batch_logs[update_step].append(log)
64
+
65
+ if len(batch_logs[update_step]) == config.get("NUM_REPEATS", 1):
66
+ agg_logs = {}
67
+ for key in batch_logs[update_step][0]:
68
+ agg = []
69
+ if key in ["goal_heatmap"]:
70
+ agg = [batch_logs[update_step][0][key]]
71
+ else:
72
+ for i in range(config.get("NUM_REPEATS", 1)):
73
+ # Use .get() to prevent KeyErrors if repeats are out of sync
74
+ val = batch_logs[update_step][i].get(key, float("nan"))
75
+ if not jnp.isnan(val):
76
+ agg.append(val)
77
+
78
+ if len(agg) > 0:
79
+ if key in [
80
+ "episode_length",
81
+ "episode_return",
82
+ "exploration_bonus",
83
+ "e_mean",
84
+ "e_std",
85
+ "rnd_loss",
86
+ "diffusion/loss",
87
+ "diffusion/unweighted_loss",
88
+ "diffusion/accuracy",
89
+ "diffusion/acc_t_low",
90
+ "diffusion/acc_t_mid",
91
+ "diffusion/acc_t_high",
92
+ "diffusion/action_entropy",
93
+ "diffusion/grad_norm"
94
+ ] or key.startswith("val/") or "achievement" in key.lower():
95
+ agg_logs[key] = np.mean(agg)
96
+ else:
97
+ agg_logs[key] = np.array(agg)
98
+
99
+ log_times.append(time.time())
100
+
101
+ if config.get("DEBUG"):
102
+ if len(log_times) == 1:
103
+ print("Started logging")
104
+ elif len(log_times) > 1:
105
+ dt = log_times[-1] - log_times[-2]
106
+ steps_between_updates = (
107
+ config["NUM_STEPS"] * config["NUM_ENVS"] * config.get("NUM_REPEATS", 1)
108
+ )
109
+ sps = steps_between_updates / dt
110
+ agg_logs["sps"] = sps
111
+
112
+ wandb.log(agg_logs)
113
+
114
+ # Clear buffer to prevent memory leaks
115
+ del batch_logs[update_step]
Craftax_Baselines/models/__init__.py ADDED
File without changes
Craftax_Baselines/models/actor_critic.py ADDED
@@ -0,0 +1,256 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import jax.numpy as jnp
2
+ import flax.linen as nn
3
+ import numpy as np
4
+ from flax.linen.initializers import constant, orthogonal
5
+ from typing import Sequence
6
+
7
+ import distrax
8
+
9
+
10
+ class ActorCriticConvSymbolicCraftax(nn.Module):
11
+ action_dim: int
12
+ map_obs_shape: Sequence[int]
13
+ layer_width: int
14
+
15
+ @nn.compact
16
+ def __call__(self, obs):
17
+ # Split into map and flat obs
18
+ flat_map_obs_shape = (
19
+ self.map_obs_shape[0] * self.map_obs_shape[1] * self.map_obs_shape[2]
20
+ )
21
+ image_obs = obs[:, :flat_map_obs_shape]
22
+ image_dim = self.map_obs_shape
23
+ image_obs = image_obs.reshape((image_obs.shape[0], *image_dim))
24
+
25
+ flat_obs = obs[:, flat_map_obs_shape:]
26
+
27
+ # Convolutions on map
28
+ image_embedding = nn.Conv(features=32, kernel_size=(2, 2))(image_obs)
29
+ image_embedding = nn.relu(image_embedding)
30
+ image_embedding = nn.max_pool(
31
+ image_embedding, window_shape=(2, 2), strides=(1, 1)
32
+ )
33
+ image_embedding = nn.Conv(features=32, kernel_size=(2, 2))(image_embedding)
34
+ image_embedding = nn.relu(image_embedding)
35
+ image_embedding = nn.max_pool(
36
+ image_embedding, window_shape=(2, 2), strides=(1, 1)
37
+ )
38
+ image_embedding = image_embedding.reshape(image_embedding.shape[0], -1)
39
+ # image_embedding = jnp.concatenate([image_embedding, obs[:, : CraftaxEnv.get_flat_map_obs_shape()]], axis=-1)
40
+
41
+ # Combine embeddings
42
+ embedding = jnp.concatenate([image_embedding, flat_obs], axis=-1)
43
+ embedding = nn.Dense(
44
+ self.layer_width, kernel_init=orthogonal(2), bias_init=constant(0.0)
45
+ )(embedding)
46
+ embedding = nn.relu(embedding)
47
+
48
+ actor_mean = nn.Dense(
49
+ self.layer_width, kernel_init=orthogonal(2), bias_init=constant(0.0)
50
+ )(embedding)
51
+ actor_mean = nn.relu(actor_mean)
52
+
53
+ actor_mean = nn.Dense(
54
+ self.action_dim, kernel_init=orthogonal(0.01), bias_init=constant(0.0)
55
+ )(actor_mean)
56
+ actor_mean = nn.relu(actor_mean)
57
+
58
+ actor_mean = nn.Dense(
59
+ self.action_dim, kernel_init=orthogonal(0.01), bias_init=constant(0.0)
60
+ )(actor_mean)
61
+
62
+ pi = distrax.Categorical(logits=actor_mean)
63
+
64
+ critic = nn.Dense(
65
+ self.layer_width, kernel_init=orthogonal(2), bias_init=constant(0.0)
66
+ )(embedding)
67
+ critic = nn.relu(critic)
68
+ critic = nn.Dense(
69
+ self.layer_width, kernel_init=orthogonal(2), bias_init=constant(0.0)
70
+ )(critic)
71
+ critic = nn.relu(critic)
72
+ critic = nn.Dense(
73
+ self.layer_width, kernel_init=orthogonal(2), bias_init=constant(0.0)
74
+ )(critic)
75
+ critic = nn.relu(critic)
76
+ critic = nn.Dense(1, kernel_init=orthogonal(1.0), bias_init=constant(0.0))(
77
+ critic
78
+ )
79
+
80
+ return pi, jnp.squeeze(critic, axis=-1)
81
+
82
+
83
+ class ActorCriticConv(nn.Module):
84
+ action_dim: int
85
+ layer_width: int
86
+ activation: str = "tanh"
87
+
88
+ @nn.compact
89
+ def __call__(self, obs):
90
+ x = nn.Conv(features=32, kernel_size=(5, 5))(obs)
91
+ x = nn.relu(x)
92
+ x = nn.max_pool(x, window_shape=(3, 3), strides=(3, 3))
93
+ x = nn.Conv(features=32, kernel_size=(5, 5))(x)
94
+ x = nn.relu(x)
95
+ x = nn.max_pool(x, window_shape=(3, 3), strides=(3, 3))
96
+ x = nn.Conv(features=32, kernel_size=(5, 5))(x)
97
+ x = nn.relu(x)
98
+ x = nn.max_pool(x, window_shape=(3, 3), strides=(3, 3))
99
+
100
+ embedding = x.reshape(x.shape[0], -1)
101
+
102
+ actor_mean = nn.Dense(
103
+ self.layer_width, kernel_init=orthogonal(2), bias_init=constant(0.0)
104
+ )(embedding)
105
+ actor_mean = nn.relu(actor_mean)
106
+
107
+ actor_mean = nn.Dense(
108
+ self.action_dim, kernel_init=orthogonal(0.01), bias_init=constant(0.0)
109
+ )(actor_mean)
110
+ actor_mean = nn.relu(actor_mean)
111
+
112
+ actor_mean = nn.Dense(
113
+ self.action_dim, kernel_init=orthogonal(0.01), bias_init=constant(0.0)
114
+ )(actor_mean)
115
+
116
+ pi = distrax.Categorical(logits=actor_mean)
117
+
118
+ critic = nn.Dense(
119
+ self.layer_width, kernel_init=orthogonal(2), bias_init=constant(0.0)
120
+ )(embedding)
121
+ critic = nn.relu(critic)
122
+ critic = nn.Dense(1, kernel_init=orthogonal(1.0), bias_init=constant(0.0))(
123
+ critic
124
+ )
125
+
126
+ return pi, jnp.squeeze(critic, axis=-1)
127
+
128
+
129
+ class ActorCritic(nn.Module):
130
+ action_dim: int
131
+ layer_width: int
132
+ activation: str = "tanh"
133
+
134
+ @nn.compact
135
+ def __call__(self, x):
136
+ if self.activation == "relu":
137
+ activation = nn.relu
138
+ else:
139
+ activation = nn.tanh
140
+
141
+ actor_mean = nn.Dense(
142
+ self.layer_width,
143
+ kernel_init=orthogonal(np.sqrt(2)),
144
+ bias_init=constant(0.0),
145
+ )(x)
146
+ actor_mean = activation(actor_mean)
147
+
148
+ actor_mean = nn.Dense(
149
+ self.layer_width,
150
+ kernel_init=orthogonal(np.sqrt(2)),
151
+ bias_init=constant(0.0),
152
+ )(actor_mean)
153
+ actor_mean = activation(actor_mean)
154
+
155
+ actor_mean = nn.Dense(
156
+ self.layer_width,
157
+ kernel_init=orthogonal(np.sqrt(2)),
158
+ bias_init=constant(0.0),
159
+ )(actor_mean)
160
+ actor_mean = activation(actor_mean)
161
+
162
+ actor_mean = nn.Dense(
163
+ self.action_dim, kernel_init=orthogonal(0.01), bias_init=constant(0.0)
164
+ )(actor_mean)
165
+ pi = distrax.Categorical(logits=actor_mean)
166
+
167
+ critic = nn.Dense(
168
+ self.layer_width,
169
+ kernel_init=orthogonal(np.sqrt(2)),
170
+ bias_init=constant(0.0),
171
+ )(x)
172
+ critic = activation(critic)
173
+
174
+ critic = nn.Dense(
175
+ self.layer_width,
176
+ kernel_init=orthogonal(np.sqrt(2)),
177
+ bias_init=constant(0.0),
178
+ )(critic)
179
+ critic = activation(critic)
180
+
181
+ critic = nn.Dense(
182
+ self.layer_width,
183
+ kernel_init=orthogonal(np.sqrt(2)),
184
+ bias_init=constant(0.0),
185
+ )(critic)
186
+ critic = activation(critic)
187
+
188
+ critic = nn.Dense(1, kernel_init=orthogonal(1.0), bias_init=constant(0.0))(
189
+ critic
190
+ )
191
+
192
+ return pi, jnp.squeeze(critic, axis=-1)
193
+
194
+
195
+ class ActorCriticWithEmbedding(nn.Module):
196
+ action_dim: int
197
+ layer_width: int
198
+ activation: str = "tanh"
199
+
200
+ @nn.compact
201
+ def __call__(self, x):
202
+ if self.activation == "relu":
203
+ activation = nn.relu
204
+ else:
205
+ activation = nn.tanh
206
+
207
+ actor_emb = nn.Dense(
208
+ self.layer_width,
209
+ kernel_init=orthogonal(np.sqrt(2)),
210
+ bias_init=constant(0.0),
211
+ )(x)
212
+ actor_emb = activation(actor_emb)
213
+
214
+ actor_emb = nn.Dense(
215
+ self.layer_width,
216
+ kernel_init=orthogonal(np.sqrt(2)),
217
+ bias_init=constant(0.0),
218
+ )(actor_emb)
219
+ actor_emb = activation(actor_emb)
220
+
221
+ actor_emb = nn.Dense(
222
+ 128, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0)
223
+ )(actor_emb)
224
+ actor_emb = activation(actor_emb)
225
+
226
+ actor_mean = nn.Dense(
227
+ self.action_dim, kernel_init=orthogonal(0.01), bias_init=constant(0.0)
228
+ )(actor_emb)
229
+ pi = distrax.Categorical(logits=actor_mean)
230
+
231
+ critic = nn.Dense(
232
+ self.layer_width,
233
+ kernel_init=orthogonal(np.sqrt(2)),
234
+ bias_init=constant(0.0),
235
+ )(x)
236
+ critic = activation(critic)
237
+
238
+ critic = nn.Dense(
239
+ self.layer_width,
240
+ kernel_init=orthogonal(np.sqrt(2)),
241
+ bias_init=constant(0.0),
242
+ )(critic)
243
+ critic = activation(critic)
244
+
245
+ critic = nn.Dense(
246
+ self.layer_width,
247
+ kernel_init=orthogonal(np.sqrt(2)),
248
+ bias_init=constant(0.0),
249
+ )(critic)
250
+ critic = activation(critic)
251
+
252
+ critic = nn.Dense(1, kernel_init=orthogonal(1.0), bias_init=constant(0.0))(
253
+ critic
254
+ )
255
+
256
+ return pi, jnp.squeeze(critic, axis=-1), actor_emb
Craftax_Baselines/models/icm.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import jax
2
+ import jax.numpy as jnp
3
+ import flax.linen as nn
4
+
5
+
6
+ class ICMEncoder(nn.Module):
7
+ layer_size: int
8
+ output_dim: int
9
+ num_layers: int
10
+
11
+ @nn.compact
12
+ def __call__(self, obs):
13
+ activation = nn.relu
14
+
15
+ # TODO Look at weight inits
16
+
17
+ emb = obs
18
+ for _ in range(self.num_layers):
19
+ emb = nn.Dense(
20
+ self.layer_size,
21
+ )(emb)
22
+ emb = activation(emb)
23
+
24
+ emb = nn.Dense(self.output_dim)(emb)
25
+
26
+ return emb
27
+
28
+
29
+ class ICMForward(nn.Module):
30
+ layer_size: int
31
+ output_dim: int
32
+ num_layers: int
33
+ num_actions: int
34
+
35
+ @nn.compact
36
+ def __call__(self, latent, action):
37
+ activation = nn.relu
38
+
39
+ action1h = jax.nn.one_hot(action, num_classes=self.num_actions)
40
+ emb = jnp.concatenate((latent, action1h), axis=-1)
41
+ for _ in range(self.num_layers):
42
+ emb = nn.Dense(
43
+ self.layer_size,
44
+ )(emb)
45
+ emb = activation(emb)
46
+
47
+ emb = nn.Dense(self.output_dim)(emb)
48
+
49
+ return emb
50
+
51
+
52
+ class ICMInverse(nn.Module):
53
+ layer_size: int
54
+ output_dim: int
55
+ num_layers: int
56
+
57
+ @nn.compact
58
+ def __call__(self, latent, next_latent):
59
+ activation = nn.relu
60
+
61
+ emb = jnp.concatenate((latent, next_latent), axis=-1)
62
+ for _ in range(self.num_layers):
63
+ emb = nn.Dense(
64
+ self.layer_size,
65
+ )(emb)
66
+ emb = activation(emb)
67
+
68
+ action_raw = nn.Dense(self.output_dim)(emb)
69
+
70
+ action_logits = jax.nn.log_softmax(action_raw)
71
+
72
+ return action_logits
Craftax_Baselines/models/rnd.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import jax.numpy as jnp
2
+ import flax.linen as nn
3
+ import numpy as np
4
+ from flax.linen.initializers import constant, orthogonal
5
+
6
+ import distrax
7
+
8
+
9
+ class RNDNetwork(nn.Module):
10
+ layer_size: int
11
+ output_dim: int
12
+ num_layers: int
13
+
14
+ @nn.compact
15
+ def __call__(self, x):
16
+ activation = nn.relu
17
+
18
+ emb = x
19
+ for _ in range(self.num_layers):
20
+ emb = nn.Dense(
21
+ self.layer_size,
22
+ )(emb)
23
+ emb = activation(emb)
24
+
25
+ emb = nn.Dense(self.output_dim)(emb)
26
+
27
+ return emb
28
+
29
+
30
+ class ActorCriticRND(nn.Module):
31
+ action_dim: int
32
+ layer_width: int
33
+ activation: str = "tanh"
34
+
35
+ @nn.compact
36
+ def __call__(self, x):
37
+ if self.activation == "relu":
38
+ activation = nn.relu
39
+ else:
40
+ activation = nn.tanh
41
+
42
+ actor_mean = nn.Dense(
43
+ self.layer_width,
44
+ kernel_init=orthogonal(np.sqrt(2)),
45
+ bias_init=constant(0.0),
46
+ )(x)
47
+ actor_mean = activation(actor_mean)
48
+
49
+ actor_mean = nn.Dense(
50
+ self.layer_width,
51
+ kernel_init=orthogonal(np.sqrt(2)),
52
+ bias_init=constant(0.0),
53
+ )(actor_mean)
54
+ actor_mean = activation(actor_mean)
55
+
56
+ actor_mean = nn.Dense(
57
+ self.layer_width,
58
+ kernel_init=orthogonal(np.sqrt(2)),
59
+ bias_init=constant(0.0),
60
+ )(actor_mean)
61
+ actor_mean = activation(actor_mean)
62
+
63
+ actor_mean = nn.Dense(
64
+ self.action_dim, kernel_init=orthogonal(0.01), bias_init=constant(0.0)
65
+ )(actor_mean)
66
+ pi = distrax.Categorical(logits=actor_mean)
67
+
68
+ # Extrinsic reward
69
+ critic_e = nn.Dense(
70
+ self.layer_width,
71
+ kernel_init=orthogonal(np.sqrt(2)),
72
+ bias_init=constant(0.0),
73
+ )(x)
74
+ critic_e = activation(critic_e)
75
+
76
+ critic_e = nn.Dense(
77
+ self.layer_width,
78
+ kernel_init=orthogonal(np.sqrt(2)),
79
+ bias_init=constant(0.0),
80
+ )(critic_e)
81
+ critic_e = activation(critic_e)
82
+
83
+ critic_e = nn.Dense(
84
+ self.layer_width,
85
+ kernel_init=orthogonal(np.sqrt(2)),
86
+ bias_init=constant(0.0),
87
+ )(critic_e)
88
+ critic_e = activation(critic_e)
89
+
90
+ critic_e = nn.Dense(1, kernel_init=orthogonal(1.0), bias_init=constant(0.0))(
91
+ critic_e
92
+ )
93
+
94
+ # Intrinsic reward
95
+ critic_i = nn.Dense(
96
+ self.layer_width,
97
+ kernel_init=orthogonal(np.sqrt(2)),
98
+ bias_init=constant(0.0),
99
+ )(x)
100
+ critic_i = activation(critic_i)
101
+
102
+ critic_i = nn.Dense(
103
+ self.layer_width,
104
+ kernel_init=orthogonal(np.sqrt(2)),
105
+ bias_init=constant(0.0),
106
+ )(critic_i)
107
+ critic_i = activation(critic_i)
108
+
109
+ critic_i = nn.Dense(
110
+ self.layer_width,
111
+ kernel_init=orthogonal(np.sqrt(2)),
112
+ bias_init=constant(0.0),
113
+ )(critic_i)
114
+ critic_i = activation(critic_i)
115
+
116
+ critic_i = nn.Dense(1, kernel_init=orthogonal(1.0), bias_init=constant(0.0))(
117
+ critic_i
118
+ )
119
+
120
+ return pi, jnp.squeeze(critic_e, axis=-1), jnp.squeeze(critic_i, axis=-1)
Craftax_Baselines/ppo.py ADDED
@@ -0,0 +1,733 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import sys
4
+ import time
5
+
6
+ import jax
7
+ import jax.numpy as jnp
8
+ import numpy as np
9
+ import optax
10
+ from craftax.craftax_env import make_craftax_env_from_name
11
+
12
+ import wandb
13
+ from typing import NamedTuple
14
+
15
+ from flax.training.train_state import TrainState
16
+ import orbax.checkpoint as ocp
17
+
18
+ from logz.batch_logging import batch_log, create_log_dict
19
+ from models.actor_critic import (
20
+ ActorCritic,
21
+ ActorCriticConv,
22
+ )
23
+ from models.icm import ICMEncoder, ICMForward, ICMInverse
24
+ from wrappers import (
25
+ LogWrapper,
26
+ OptimisticResetVecEnvWrapper,
27
+ BatchEnvWrapper,
28
+ AutoResetEnvWrapper,
29
+ )
30
+
31
+ # Code adapted from the original implementation made by Chris Lu
32
+ # Original code located at https://github.com/luchris429/purejaxrl
33
+
34
+
35
+ class Transition(NamedTuple):
36
+ done: jnp.ndarray
37
+ action: jnp.ndarray
38
+ value: jnp.ndarray
39
+ reward_e: jnp.ndarray
40
+ reward_i: jnp.ndarray
41
+ reward: jnp.ndarray
42
+ log_prob: jnp.ndarray
43
+ obs: jnp.ndarray
44
+ next_obs: jnp.ndarray
45
+ info: jnp.ndarray
46
+
47
+
48
+ def make_train(config):
49
+ config["NUM_UPDATES"] = (
50
+ config["TOTAL_TIMESTEPS"] // config["NUM_STEPS"] // config["NUM_ENVS"]
51
+ )
52
+ config["MINIBATCH_SIZE"] = (
53
+ config["NUM_ENVS"] * config["NUM_STEPS"] // config["NUM_MINIBATCHES"]
54
+ )
55
+
56
+ env = make_craftax_env_from_name(
57
+ config["ENV_NAME"], not config["USE_OPTIMISTIC_RESETS"]
58
+ )
59
+ env_params = env.default_params
60
+
61
+ env = LogWrapper(env)
62
+ if config["USE_OPTIMISTIC_RESETS"]:
63
+ env = OptimisticResetVecEnvWrapper(
64
+ env,
65
+ num_envs=config["NUM_ENVS"],
66
+ reset_ratio=min(config["OPTIMISTIC_RESET_RATIO"], config["NUM_ENVS"]),
67
+ )
68
+ else:
69
+ env = AutoResetEnvWrapper(env)
70
+ env = BatchEnvWrapper(env, num_envs=config["NUM_ENVS"])
71
+
72
+ def linear_schedule(count):
73
+ frac = (
74
+ 1.0
75
+ - (count // (config["NUM_MINIBATCHES"] * config["UPDATE_EPOCHS"]))
76
+ / config["NUM_UPDATES"]
77
+ )
78
+ return config["LR"] * frac
79
+
80
+ def train(rng):
81
+ # INIT NETWORK
82
+ if "Symbolic" in config["ENV_NAME"]:
83
+ network = ActorCritic(env.action_space(env_params).n, config["LAYER_SIZE"])
84
+ else:
85
+ network = ActorCriticConv(
86
+ env.action_space(env_params).n, config["LAYER_SIZE"]
87
+ )
88
+
89
+ rng, _rng = jax.random.split(rng)
90
+ init_x = jnp.zeros((1, *env.observation_space(env_params).shape))
91
+ network_params = network.init(_rng, init_x)
92
+ if config["ANNEAL_LR"]:
93
+ tx = optax.chain(
94
+ optax.clip_by_global_norm(config["MAX_GRAD_NORM"]),
95
+ optax.adam(learning_rate=linear_schedule, eps=1e-5),
96
+ )
97
+ else:
98
+ tx = optax.chain(
99
+ optax.clip_by_global_norm(config["MAX_GRAD_NORM"]),
100
+ optax.adam(config["LR"], eps=1e-5),
101
+ )
102
+ train_state = TrainState.create(
103
+ apply_fn=network.apply,
104
+ params=network_params,
105
+ tx=tx,
106
+ )
107
+
108
+ # Exploration state
109
+ ex_state = {
110
+ "icm_encoder": None,
111
+ "icm_forward": None,
112
+ "icm_inverse": None,
113
+ "e3b_matrix": None,
114
+ }
115
+
116
+ if config["TRAIN_ICM"]:
117
+ obs_shape = env.observation_space(env_params).shape
118
+ assert len(obs_shape) == 1, "Only configured for 1D observations"
119
+ obs_shape = obs_shape[0]
120
+
121
+ # Encoder
122
+ icm_encoder_network = ICMEncoder(
123
+ num_layers=3,
124
+ output_dim=config["ICM_LATENT_SIZE"],
125
+ layer_size=config["ICM_LAYER_SIZE"],
126
+ )
127
+ rng, _rng = jax.random.split(rng)
128
+ icm_encoder_network_params = icm_encoder_network.init(
129
+ _rng, jnp.zeros((1, obs_shape))
130
+ )
131
+ tx = optax.chain(
132
+ optax.clip_by_global_norm(config["MAX_GRAD_NORM"]),
133
+ optax.adam(config["ICM_LR"], eps=1e-5),
134
+ )
135
+ ex_state["icm_encoder"] = TrainState.create(
136
+ apply_fn=icm_encoder_network.apply,
137
+ params=icm_encoder_network_params,
138
+ tx=tx,
139
+ )
140
+
141
+ # Forward
142
+ icm_forward_network = ICMForward(
143
+ num_layers=3,
144
+ output_dim=config["ICM_LATENT_SIZE"],
145
+ layer_size=config["ICM_LAYER_SIZE"],
146
+ num_actions=env.num_actions,
147
+ )
148
+ rng, _rng = jax.random.split(rng)
149
+ icm_forward_network_params = icm_forward_network.init(
150
+ _rng, jnp.zeros((1, config["ICM_LATENT_SIZE"])), jnp.zeros((1,))
151
+ )
152
+ tx = optax.chain(
153
+ optax.clip_by_global_norm(config["MAX_GRAD_NORM"]),
154
+ optax.adam(config["ICM_LR"], eps=1e-5),
155
+ )
156
+ ex_state["icm_forward"] = TrainState.create(
157
+ apply_fn=icm_forward_network.apply,
158
+ params=icm_forward_network_params,
159
+ tx=tx,
160
+ )
161
+
162
+ # Inverse
163
+ icm_inverse_network = ICMInverse(
164
+ num_layers=3,
165
+ output_dim=env.num_actions,
166
+ layer_size=config["ICM_LAYER_SIZE"],
167
+ )
168
+ rng, _rng = jax.random.split(rng)
169
+ icm_inverse_network_params = icm_inverse_network.init(
170
+ _rng,
171
+ jnp.zeros((1, config["ICM_LATENT_SIZE"])),
172
+ jnp.zeros((1, config["ICM_LATENT_SIZE"])),
173
+ )
174
+ tx = optax.chain(
175
+ optax.clip_by_global_norm(config["MAX_GRAD_NORM"]),
176
+ optax.adam(config["ICM_LR"], eps=1e-5),
177
+ )
178
+ ex_state["icm_inverse"] = TrainState.create(
179
+ apply_fn=icm_inverse_network.apply,
180
+ params=icm_inverse_network_params,
181
+ tx=tx,
182
+ )
183
+
184
+ if config["USE_E3B"]:
185
+ ex_state["e3b_matrix"] = (
186
+ jnp.repeat(
187
+ jnp.expand_dims(
188
+ jnp.identity(config["ICM_LATENT_SIZE"]), axis=0
189
+ ),
190
+ config["NUM_ENVS"],
191
+ axis=0,
192
+ )
193
+ / config["E3B_LAMBDA"]
194
+ )
195
+
196
+ # INIT ENV
197
+ rng, _rng = jax.random.split(rng)
198
+ obsv, env_state = env.reset(_rng, env_params)
199
+
200
+ # TRAIN LOOP
201
+ def _update_step(runner_state, unused):
202
+ # COLLECT TRAJECTORIES
203
+ def _env_step(runner_state, unused):
204
+ (
205
+ train_state,
206
+ env_state,
207
+ last_obs,
208
+ ex_state,
209
+ rng,
210
+ update_step,
211
+ ) = runner_state
212
+
213
+ # SELECT ACTION
214
+ rng, _rng = jax.random.split(rng)
215
+ pi, value = network.apply(train_state.params, last_obs)
216
+ action = pi.sample(seed=_rng)
217
+ log_prob = pi.log_prob(action)
218
+
219
+ # STEP ENV
220
+ rng, _rng = jax.random.split(rng)
221
+ obsv, env_state, reward_e, done, info = env.step(
222
+ _rng, env_state, action, env_params
223
+ )
224
+
225
+ reward_i = jnp.zeros(config["NUM_ENVS"])
226
+
227
+ if config["TRAIN_ICM"]:
228
+ latent_obs = ex_state["icm_encoder"].apply_fn(
229
+ ex_state["icm_encoder"].params, last_obs
230
+ )
231
+ latent_next_obs = ex_state["icm_encoder"].apply_fn(
232
+ ex_state["icm_encoder"].params, obsv
233
+ )
234
+
235
+ latent_next_obs_pred = ex_state["icm_forward"].apply_fn(
236
+ ex_state["icm_forward"].params, latent_obs, action
237
+ )
238
+ error = (latent_next_obs - latent_next_obs_pred) * (
239
+ 1 - done[:, None]
240
+ )
241
+ mse = jnp.square(error).mean(axis=-1)
242
+
243
+ reward_i = mse * config["ICM_REWARD_COEFF"]
244
+
245
+ if config["USE_E3B"]:
246
+ # Embedding is (NUM_ENVS, 128)
247
+ # e3b_matrix is (NUM_ENVS, 128, 128)
248
+ us = jax.vmap(jnp.matmul)(ex_state["e3b_matrix"], latent_obs)
249
+ bs = jax.vmap(jnp.dot)(latent_obs, us)
250
+
251
+ def update_c(c, b, u):
252
+ return c - (1.0 / (1 + b)) * jnp.outer(u, u)
253
+
254
+ updated_cs = jax.vmap(update_c)(ex_state["e3b_matrix"], bs, us)
255
+ new_cs = (
256
+ jnp.repeat(
257
+ jnp.expand_dims(
258
+ jnp.identity(config["ICM_LATENT_SIZE"]), axis=0
259
+ ),
260
+ config["NUM_ENVS"],
261
+ axis=0,
262
+ )
263
+ / config["E3B_LAMBDA"]
264
+ )
265
+ ex_state["e3b_matrix"] = jnp.where(
266
+ done[:, None, None], new_cs, updated_cs
267
+ )
268
+
269
+ e3b_bonus = jnp.where(
270
+ done, jnp.zeros((config["NUM_ENVS"],)), bs
271
+ )
272
+
273
+ reward_i = e3b_bonus * config["E3B_REWARD_COEFF"]
274
+
275
+ reward = reward_e + reward_i
276
+
277
+ transition = Transition(
278
+ done=done,
279
+ action=action,
280
+ value=value,
281
+ reward=reward,
282
+ reward_i=reward_i,
283
+ reward_e=reward_e,
284
+ log_prob=log_prob,
285
+ obs=last_obs,
286
+ next_obs=obsv,
287
+ info=info,
288
+ )
289
+ runner_state = (
290
+ train_state,
291
+ env_state,
292
+ obsv,
293
+ ex_state,
294
+ rng,
295
+ update_step,
296
+ )
297
+ return runner_state, transition
298
+
299
+ runner_state, traj_batch = jax.lax.scan(
300
+ _env_step, runner_state, None, config["NUM_STEPS"]
301
+ )
302
+
303
+ # CALCULATE ADVANTAGE
304
+ (
305
+ train_state,
306
+ env_state,
307
+ last_obs,
308
+ ex_state,
309
+ rng,
310
+ update_step,
311
+ ) = runner_state
312
+ _, last_val = network.apply(train_state.params, last_obs)
313
+
314
+ def _calculate_gae(traj_batch, last_val):
315
+ def _get_advantages(gae_and_next_value, transition):
316
+ gae, next_value = gae_and_next_value
317
+ done, value, reward = (
318
+ transition.done,
319
+ transition.value,
320
+ transition.reward,
321
+ )
322
+ delta = reward + config["GAMMA"] * next_value * (1 - done) - value
323
+ gae = (
324
+ delta
325
+ + config["GAMMA"] * config["GAE_LAMBDA"] * (1 - done) * gae
326
+ )
327
+ return (gae, value), gae
328
+
329
+ _, advantages = jax.lax.scan(
330
+ _get_advantages,
331
+ (jnp.zeros_like(last_val), last_val),
332
+ traj_batch,
333
+ reverse=True,
334
+ unroll=16,
335
+ )
336
+ return advantages, advantages + traj_batch.value
337
+
338
+ advantages, targets = _calculate_gae(traj_batch, last_val)
339
+
340
+ # UPDATE NETWORK
341
+ def _update_epoch(update_state, unused):
342
+ def _update_minbatch(train_state, batch_info):
343
+ traj_batch, advantages, targets = batch_info
344
+
345
+ # Policy/value network
346
+ def _loss_fn(params, traj_batch, gae, targets):
347
+ # RERUN NETWORK
348
+ pi, value = network.apply(params, traj_batch.obs)
349
+ log_prob = pi.log_prob(traj_batch.action)
350
+
351
+ # CALCULATE VALUE LOSS
352
+ value_pred_clipped = traj_batch.value + (
353
+ value - traj_batch.value
354
+ ).clip(-config["CLIP_EPS"], config["CLIP_EPS"])
355
+ value_losses = jnp.square(value - targets)
356
+ value_losses_clipped = jnp.square(value_pred_clipped - targets)
357
+ value_loss = (
358
+ 0.5 * jnp.maximum(value_losses, value_losses_clipped).mean()
359
+ )
360
+
361
+ # CALCULATE ACTOR LOSS
362
+ ratio = jnp.exp(log_prob - traj_batch.log_prob)
363
+ gae = (gae - gae.mean()) / (gae.std() + 1e-8)
364
+ loss_actor1 = ratio * gae
365
+ loss_actor2 = (
366
+ jnp.clip(
367
+ ratio,
368
+ 1.0 - config["CLIP_EPS"],
369
+ 1.0 + config["CLIP_EPS"],
370
+ )
371
+ * gae
372
+ )
373
+ loss_actor = -jnp.minimum(loss_actor1, loss_actor2)
374
+ loss_actor = loss_actor.mean()
375
+ entropy = pi.entropy().mean()
376
+
377
+ total_loss = (
378
+ loss_actor
379
+ + config["VF_COEF"] * value_loss
380
+ - config["ENT_COEF"] * entropy
381
+ )
382
+ return total_loss, (value_loss, loss_actor, entropy)
383
+
384
+ grad_fn = jax.value_and_grad(_loss_fn, has_aux=True)
385
+ total_loss, grads = grad_fn(
386
+ train_state.params, traj_batch, advantages, targets
387
+ )
388
+ train_state = train_state.apply_gradients(grads=grads)
389
+
390
+ losses = (total_loss, 0)
391
+ return train_state, losses
392
+
393
+ (
394
+ train_state,
395
+ traj_batch,
396
+ advantages,
397
+ targets,
398
+ rng,
399
+ ) = update_state
400
+ rng, _rng = jax.random.split(rng)
401
+ batch_size = config["MINIBATCH_SIZE"] * config["NUM_MINIBATCHES"]
402
+ assert (
403
+ batch_size == config["NUM_STEPS"] * config["NUM_ENVS"]
404
+ ), "batch size must be equal to number of steps * number of envs"
405
+ permutation = jax.random.permutation(_rng, batch_size)
406
+ batch = (traj_batch, advantages, targets)
407
+ batch = jax.tree.map(
408
+ lambda x: x.reshape((batch_size,) + x.shape[2:]), batch
409
+ )
410
+ shuffled_batch = jax.tree.map(
411
+ lambda x: jnp.take(x, permutation, axis=0), batch
412
+ )
413
+ minibatches = jax.tree.map(
414
+ lambda x: jnp.reshape(
415
+ x, [config["NUM_MINIBATCHES"], -1] + list(x.shape[1:])
416
+ ),
417
+ shuffled_batch,
418
+ )
419
+ train_state, losses = jax.lax.scan(
420
+ _update_minbatch, train_state, minibatches
421
+ )
422
+ update_state = (
423
+ train_state,
424
+ traj_batch,
425
+ advantages,
426
+ targets,
427
+ rng,
428
+ )
429
+ return update_state, losses
430
+
431
+ update_state = (
432
+ train_state,
433
+ traj_batch,
434
+ advantages,
435
+ targets,
436
+ rng,
437
+ )
438
+ update_state, loss_info = jax.lax.scan(
439
+ _update_epoch, update_state, None, config["UPDATE_EPOCHS"]
440
+ )
441
+
442
+ train_state = update_state[0]
443
+ metric = jax.tree.map(
444
+ lambda x: (x * traj_batch.info["returned_episode"]).sum()
445
+ / traj_batch.info["returned_episode"].sum(),
446
+ traj_batch.info,
447
+ )
448
+
449
+ rng = update_state[-1]
450
+
451
+ # UPDATE EXPLORATION STATE
452
+ def _update_ex_epoch(update_state, unused):
453
+ def _update_ex_minbatch(ex_state, traj_batch):
454
+ def _inverse_loss_fn(
455
+ icm_encoder_params, icm_inverse_params, traj_batch
456
+ ):
457
+ latent_obs = ex_state["icm_encoder"].apply_fn(
458
+ icm_encoder_params, traj_batch.obs
459
+ )
460
+ latent_next_obs = ex_state["icm_encoder"].apply_fn(
461
+ icm_encoder_params, traj_batch.next_obs
462
+ )
463
+
464
+ action_pred_logits = ex_state["icm_inverse"].apply_fn(
465
+ icm_inverse_params, latent_obs, latent_next_obs
466
+ )
467
+ true_action = jax.nn.one_hot(
468
+ traj_batch.action, num_classes=action_pred_logits.shape[-1]
469
+ )
470
+
471
+ bce = -jnp.mean(
472
+ jnp.sum(
473
+ action_pred_logits
474
+ * true_action
475
+ * (1 - traj_batch.done[:, None]),
476
+ axis=1,
477
+ )
478
+ )
479
+
480
+ return bce * config["ICM_INVERSE_LOSS_COEF"]
481
+
482
+ inverse_grad_fn = jax.value_and_grad(
483
+ _inverse_loss_fn,
484
+ has_aux=False,
485
+ argnums=(
486
+ 0,
487
+ 1,
488
+ ),
489
+ )
490
+ inverse_loss, grads = inverse_grad_fn(
491
+ ex_state["icm_encoder"].params,
492
+ ex_state["icm_inverse"].params,
493
+ traj_batch,
494
+ )
495
+ icm_encoder_grad, icm_inverse_grad = grads
496
+ ex_state["icm_encoder"] = ex_state["icm_encoder"].apply_gradients(
497
+ grads=icm_encoder_grad
498
+ )
499
+ ex_state["icm_inverse"] = ex_state["icm_inverse"].apply_gradients(
500
+ grads=icm_inverse_grad
501
+ )
502
+
503
+ def _forward_loss_fn(icm_forward_params, traj_batch):
504
+ latent_obs = ex_state["icm_encoder"].apply_fn(
505
+ ex_state["icm_encoder"].params, traj_batch.obs
506
+ )
507
+ latent_next_obs = ex_state["icm_encoder"].apply_fn(
508
+ ex_state["icm_encoder"].params, traj_batch.next_obs
509
+ )
510
+
511
+ latent_next_obs_pred = ex_state["icm_forward"].apply_fn(
512
+ icm_forward_params, latent_obs, traj_batch.action
513
+ )
514
+
515
+ error = (latent_next_obs - latent_next_obs_pred) * (
516
+ 1 - traj_batch.done[:, None]
517
+ )
518
+ return (
519
+ jnp.square(error).mean() * config["ICM_FORWARD_LOSS_COEF"]
520
+ )
521
+
522
+ forward_grad_fn = jax.value_and_grad(
523
+ _forward_loss_fn, has_aux=False
524
+ )
525
+ forward_loss, icm_forward_grad = forward_grad_fn(
526
+ ex_state["icm_forward"].params, traj_batch
527
+ )
528
+ ex_state["icm_forward"] = ex_state["icm_forward"].apply_gradients(
529
+ grads=icm_forward_grad
530
+ )
531
+
532
+ losses = (inverse_loss, forward_loss)
533
+ return ex_state, losses
534
+
535
+ (ex_state, traj_batch, rng) = update_state
536
+ rng, _rng = jax.random.split(rng)
537
+ batch_size = config["MINIBATCH_SIZE"] * config["NUM_MINIBATCHES"]
538
+ assert (
539
+ batch_size == config["NUM_STEPS"] * config["NUM_ENVS"]
540
+ ), "batch size must be equal to number of steps * number of envs"
541
+ permutation = jax.random.permutation(_rng, batch_size)
542
+ batch = jax.tree.map(
543
+ lambda x: x.reshape((batch_size,) + x.shape[2:]), traj_batch
544
+ )
545
+ shuffled_batch = jax.tree.map(
546
+ lambda x: jnp.take(x, permutation, axis=0), batch
547
+ )
548
+ minibatches = jax.tree.map(
549
+ lambda x: jnp.reshape(
550
+ x, [config["NUM_MINIBATCHES"], -1] + list(x.shape[1:])
551
+ ),
552
+ shuffled_batch,
553
+ )
554
+ ex_state, losses = jax.lax.scan(
555
+ _update_ex_minbatch, ex_state, minibatches
556
+ )
557
+ update_state = (ex_state, traj_batch, rng)
558
+ return update_state, losses
559
+
560
+ if config["TRAIN_ICM"]:
561
+ ex_update_state = (ex_state, traj_batch, rng)
562
+ ex_update_state, ex_loss = jax.lax.scan(
563
+ _update_ex_epoch,
564
+ ex_update_state,
565
+ None,
566
+ config["EXPLORATION_UPDATE_EPOCHS"],
567
+ )
568
+ metric["icm_inverse_loss"] = ex_loss[0].mean()
569
+ metric["icm_forward_loss"] = ex_loss[1].mean()
570
+ metric["reward_i"] = traj_batch.reward_i.mean()
571
+ metric["reward_e"] = traj_batch.reward_e.mean()
572
+
573
+ ex_state = ex_update_state[0]
574
+ rng = ex_update_state[-1]
575
+
576
+ # wandb logging
577
+ if config["DEBUG"] and config["USE_WANDB"]:
578
+
579
+ def callback(metric, update_step):
580
+ to_log = create_log_dict(metric, config)
581
+ batch_log(update_step, to_log, config)
582
+
583
+ jax.debug.callback(
584
+ callback,
585
+ metric,
586
+ update_step,
587
+ )
588
+
589
+ runner_state = (
590
+ train_state,
591
+ env_state,
592
+ last_obs,
593
+ ex_state,
594
+ rng,
595
+ update_step + 1,
596
+ )
597
+ return runner_state, metric
598
+
599
+ rng, _rng = jax.random.split(rng)
600
+ runner_state = (
601
+ train_state,
602
+ env_state,
603
+ obsv,
604
+ ex_state,
605
+ _rng,
606
+ 0,
607
+ )
608
+ runner_state, metric = jax.lax.scan(
609
+ _update_step, runner_state, None, config["NUM_UPDATES"]
610
+ )
611
+ return {"runner_state": runner_state} # , "info": metric}
612
+
613
+ return train
614
+
615
+
616
+ def run_ppo(config):
617
+ config = {k.upper(): v for k, v in config.__dict__.items()}
618
+
619
+ if config["USE_WANDB"]:
620
+ wandb.init(
621
+ project=config["WANDB_PROJECT"],
622
+ entity=config["WANDB_ENTITY"],
623
+ config=config,
624
+ name=config["ENV_NAME"]
625
+ + "-"
626
+ + str(int(config["TOTAL_TIMESTEPS"] // 1e6))
627
+ + "M",
628
+ )
629
+
630
+ rng = jax.random.PRNGKey(config["SEED"])
631
+ rngs = jax.random.split(rng, config["NUM_REPEATS"])
632
+
633
+ train_jit = jax.jit(make_train(config))
634
+ train_vmap = jax.vmap(train_jit)
635
+
636
+ t0 = time.time()
637
+ out = train_vmap(rngs)
638
+ t1 = time.time()
639
+ print("Time to run experiment", t1 - t0)
640
+ print("SPS: ", config["TOTAL_TIMESTEPS"] / (t1 - t0))
641
+
642
+ if config["USE_WANDB"]:
643
+
644
+ def _save_network(rs_index, dir_name):
645
+ train_states = out["runner_state"][rs_index]
646
+ train_state = jax.tree.map(lambda x: x[0], train_states)
647
+
648
+ path = os.path.join(wandb.run.dir, dir_name)
649
+ options = ocp.CheckpointManagerOptions(max_to_keep=1)
650
+
651
+ with ocp.CheckpointManager(path, options=options) as checkpoint_manager:
652
+ checkpoint_manager.save(
653
+ int(config["TOTAL_TIMESTEPS"]),
654
+ args=ocp.args.StandardSave(train_state)
655
+ )
656
+
657
+ print(f"saved runner state to {path}")
658
+
659
+ if config["SAVE_POLICY"]:
660
+ _save_network(0, "policies")
661
+
662
+
663
+ if __name__ == "__main__":
664
+ parser = argparse.ArgumentParser()
665
+ parser.add_argument("--env_name", type=str, default="Craftax-Symbolic-v1")
666
+ parser.add_argument(
667
+ "--num_envs",
668
+ type=int,
669
+ default=1024,
670
+ )
671
+ parser.add_argument(
672
+ "--total_timesteps", type=lambda x: int(float(x)), default=1e9
673
+ ) # Allow scientific notation
674
+ parser.add_argument("--lr", type=float, default=2e-4)
675
+ parser.add_argument("--num_steps", type=int, default=64)
676
+ parser.add_argument("--update_epochs", type=int, default=4)
677
+ parser.add_argument("--num_minibatches", type=int, default=8)
678
+ parser.add_argument("--gamma", type=float, default=0.99)
679
+ parser.add_argument("--gae_lambda", type=float, default=0.8)
680
+ parser.add_argument("--clip_eps", type=float, default=0.2)
681
+ parser.add_argument("--ent_coef", type=float, default=0.01)
682
+ parser.add_argument("--vf_coef", type=float, default=0.5)
683
+ parser.add_argument("--max_grad_norm", type=float, default=1.0)
684
+ parser.add_argument("--activation", type=str, default="tanh")
685
+ parser.add_argument(
686
+ "--anneal_lr", action=argparse.BooleanOptionalAction, default=True
687
+ )
688
+ parser.add_argument("--debug", action=argparse.BooleanOptionalAction, default=True)
689
+ parser.add_argument("--jit", action=argparse.BooleanOptionalAction, default=True)
690
+ parser.add_argument("--seed", type=int)
691
+ parser.add_argument(
692
+ "--use_wandb", action=argparse.BooleanOptionalAction, default=True
693
+ )
694
+ parser.add_argument("--save_policy", action="store_true")
695
+ parser.add_argument("--num_repeats", type=int, default=1)
696
+ parser.add_argument("--layer_size", type=int, default=512)
697
+ parser.add_argument("--wandb_project", type=str)
698
+ parser.add_argument("--wandb_entity", type=str)
699
+ parser.add_argument(
700
+ "--use_optimistic_resets", action=argparse.BooleanOptionalAction, default=True
701
+ )
702
+ parser.add_argument("--optimistic_reset_ratio", type=int, default=16)
703
+
704
+ # EXPLORATION
705
+ parser.add_argument("--exploration_update_epochs", type=int, default=4)
706
+ # ICM
707
+ parser.add_argument("--icm_reward_coeff", type=float, default=1.0)
708
+ parser.add_argument("--train_icm", action="store_true")
709
+ parser.add_argument("--icm_lr", type=float, default=3e-4)
710
+ parser.add_argument("--icm_forward_loss_coef", type=float, default=1.0)
711
+ parser.add_argument("--icm_inverse_loss_coef", type=float, default=1.0)
712
+ parser.add_argument("--icm_layer_size", type=int, default=256)
713
+ parser.add_argument("--icm_latent_size", type=int, default=32)
714
+ # E3B
715
+ parser.add_argument("--e3b_reward_coeff", type=float, default=1.0)
716
+ parser.add_argument("--use_e3b", action="store_true")
717
+ parser.add_argument("--e3b_lambda", type=float, default=0.1)
718
+
719
+ args, rest_args = parser.parse_known_args(sys.argv[1:])
720
+ if rest_args:
721
+ raise ValueError(f"Unknown args {rest_args}")
722
+
723
+ if args.use_e3b:
724
+ assert args.train_icm
725
+ assert args.icm_reward_coeff == 0
726
+ if args.seed is None:
727
+ args.seed = np.random.randint(2**31)
728
+
729
+ if args.jit:
730
+ run_ppo(args)
731
+ else:
732
+ with jax.disable_jit():
733
+ run_ppo(args)
Craftax_Baselines/ppo_rnd.py ADDED
@@ -0,0 +1,680 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import sys
4
+ import time
5
+
6
+ import jax
7
+ import jax.numpy as jnp
8
+ import numpy as np
9
+ import optax
10
+ from craftax.craftax_env import make_craftax_env_from_name
11
+
12
+ import wandb
13
+ from typing import NamedTuple
14
+
15
+ from flax.training.train_state import TrainState
16
+ import orbax.checkpoint as ocp
17
+
18
+ from logz.batch_logging import batch_log, create_log_dict
19
+ from wrappers import (
20
+ LogWrapper,
21
+ OptimisticResetVecEnvWrapper,
22
+ AutoResetEnvWrapper,
23
+ BatchEnvWrapper,
24
+ )
25
+ from models.rnd import RNDNetwork, ActorCriticRND
26
+
27
+ # Code adapted from the original implementation made by Chris Lu
28
+ # Original code located at https://github.com/luchris429/purejaxrl
29
+
30
+
31
+ class Transition(NamedTuple):
32
+ done: jnp.ndarray
33
+ action: jnp.ndarray
34
+ value_e: jnp.ndarray
35
+ value_i: jnp.ndarray
36
+ reward_e: jnp.ndarray
37
+ reward_i: jnp.ndarray
38
+ reward: jnp.ndarray
39
+ log_prob: jnp.ndarray
40
+ obs: jnp.ndarray
41
+ next_obs: jnp.ndarray
42
+ info: jnp.ndarray
43
+
44
+
45
+ def make_train(config):
46
+ config["NUM_UPDATES"] = (
47
+ config["TOTAL_TIMESTEPS"] // config["NUM_STEPS"] // config["NUM_ENVS"]
48
+ )
49
+ config["MINIBATCH_SIZE"] = (
50
+ config["NUM_ENVS"] * config["NUM_STEPS"] // config["NUM_MINIBATCHES"]
51
+ )
52
+
53
+ env = make_craftax_env_from_name(
54
+ config["ENV_NAME"], not config["USE_OPTIMISTIC_RESETS"]
55
+ )
56
+ env_params = env.default_params
57
+
58
+ env = LogWrapper(env)
59
+ if config["USE_OPTIMISTIC_RESETS"]:
60
+ env = OptimisticResetVecEnvWrapper(
61
+ env,
62
+ num_envs=config["NUM_ENVS"],
63
+ reset_ratio=min(config["OPTIMISTIC_RESET_RATIO"], config["NUM_ENVS"]),
64
+ )
65
+ else:
66
+ env = AutoResetEnvWrapper(env)
67
+ env = BatchEnvWrapper(env, num_envs=config["NUM_ENVS"])
68
+
69
+ def linear_schedule(count):
70
+ frac = (
71
+ 1.0
72
+ - (count // (config["NUM_MINIBATCHES"] * config["UPDATE_EPOCHS"]))
73
+ / config["NUM_UPDATES"]
74
+ )
75
+ return config["LR"] * frac
76
+
77
+ def train(rng):
78
+ # INIT NETWORK
79
+ if "Symbolic" in config["ENV_NAME"]:
80
+ network = ActorCriticRND(
81
+ env.action_space(env_params).n, config["LAYER_SIZE"]
82
+ )
83
+ else:
84
+ raise ValueError
85
+ # network = ActorCriticConv(
86
+ # env.action_space(env_params).n, config["LAYER_SIZE"]
87
+ # )
88
+
89
+ rng, _rng = jax.random.split(rng)
90
+ init_x = jnp.zeros((1, *env.observation_space(env_params).shape))
91
+ network_params = network.init(_rng, init_x)
92
+ if config["ANNEAL_LR"]:
93
+ tx = optax.chain(
94
+ optax.clip_by_global_norm(config["MAX_GRAD_NORM"]),
95
+ optax.adam(learning_rate=linear_schedule, eps=1e-5),
96
+ )
97
+ else:
98
+ tx = optax.chain(
99
+ optax.clip_by_global_norm(config["MAX_GRAD_NORM"]),
100
+ optax.adam(config["LR"], eps=1e-5),
101
+ )
102
+ train_state = TrainState.create(
103
+ apply_fn=network.apply,
104
+ params=network_params,
105
+ tx=tx,
106
+ )
107
+
108
+ # Exploration state
109
+ ex_state = {
110
+ "rnd_model": None,
111
+ }
112
+
113
+ if config["USE_RND"]:
114
+ obs_shape = env.observation_space(env_params).shape
115
+ assert len(obs_shape) == 1, "Only configured for 1D observations"
116
+ obs_shape = obs_shape[0]
117
+
118
+ # Random network
119
+ rnd_random_network = RNDNetwork(
120
+ num_layers=3,
121
+ output_dim=config["RND_OUTPUT_SIZE"],
122
+ layer_size=config["RND_LAYER_SIZE"],
123
+ )
124
+ rng, _rng = jax.random.split(rng)
125
+ rnd_random_network_params = rnd_random_network.init(
126
+ _rng, jnp.zeros((1, obs_shape))
127
+ )
128
+
129
+ # Distillation Network
130
+ rnd_distillation_network = RNDNetwork(
131
+ num_layers=3,
132
+ output_dim=config["RND_OUTPUT_SIZE"],
133
+ layer_size=config["RND_LAYER_SIZE"],
134
+ )
135
+ rng, _rng = jax.random.split(rng)
136
+ rnd_distillation_network_params = rnd_distillation_network.init(
137
+ _rng, jnp.zeros((1, obs_shape))
138
+ )
139
+ tx = optax.chain(
140
+ optax.clip_by_global_norm(config["MAX_GRAD_NORM"]),
141
+ optax.adam(config["RND_LR"], eps=1e-5),
142
+ )
143
+ ex_state["rnd_distillation_network"] = TrainState.create(
144
+ apply_fn=rnd_distillation_network.apply,
145
+ params=rnd_distillation_network_params,
146
+ tx=tx,
147
+ )
148
+
149
+ # INIT ENV
150
+ rng, _rng = jax.random.split(rng)
151
+ obsv, env_state = env.reset(_rng, env_params)
152
+
153
+ # TRAIN LOOP
154
+ def _update_step(runner_state, unused):
155
+ # COLLECT TRAJECTORIES
156
+ def _env_step(runner_state, unused):
157
+ (
158
+ train_state,
159
+ env_state,
160
+ last_obs,
161
+ ex_state,
162
+ rng,
163
+ update_step,
164
+ ) = runner_state
165
+
166
+ # SELECT ACTION
167
+ rng, _rng = jax.random.split(rng)
168
+ pi, value_e, value_i = network.apply(train_state.params, last_obs)
169
+ action = pi.sample(seed=_rng)
170
+ log_prob = pi.log_prob(action)
171
+
172
+ # STEP ENV
173
+ rng, _rng = jax.random.split(rng)
174
+ obsv, env_state, reward_e, done, info = env.step(
175
+ _rng, env_state, action, env_params
176
+ )
177
+
178
+ reward_i = jnp.zeros(config["NUM_ENVS"])
179
+
180
+ if config["USE_RND"]:
181
+ random_pred = rnd_random_network.apply(
182
+ rnd_random_network_params, obsv
183
+ )
184
+
185
+ distill_pred = ex_state["rnd_distillation_network"].apply_fn(
186
+ ex_state["rnd_distillation_network"].params, obsv
187
+ )
188
+ error = (random_pred - distill_pred) * (1 - done[:, None])
189
+ mse = jnp.square(error).mean(axis=-1)
190
+
191
+ reward_i = mse * config["RND_REWARD_COEFF"]
192
+
193
+ reward = reward_e + reward_i
194
+
195
+ transition = Transition(
196
+ done=done,
197
+ action=action,
198
+ value_e=value_e,
199
+ value_i=value_i,
200
+ reward=reward,
201
+ reward_i=reward_i,
202
+ reward_e=reward_e,
203
+ log_prob=log_prob,
204
+ obs=last_obs,
205
+ next_obs=obsv,
206
+ info=info,
207
+ )
208
+ runner_state = (
209
+ train_state,
210
+ env_state,
211
+ obsv,
212
+ ex_state,
213
+ rng,
214
+ update_step,
215
+ )
216
+ return runner_state, transition
217
+
218
+ runner_state, traj_batch = jax.lax.scan(
219
+ _env_step, runner_state, None, config["NUM_STEPS"]
220
+ )
221
+
222
+ # CALCULATE ADVANTAGE
223
+ (
224
+ train_state,
225
+ env_state,
226
+ last_obs,
227
+ ex_state,
228
+ rng,
229
+ update_step,
230
+ ) = runner_state
231
+ _, last_val_e, last_val_i = network.apply(train_state.params, last_obs)
232
+
233
+ def _calculate_gae(traj_batch, last_val, is_extrinsic):
234
+ def _get_advantages(gae_and_next_value, transition):
235
+ gae, next_value, is_extrinsic = gae_and_next_value
236
+ done, value, reward = (
237
+ transition.done,
238
+ jax.lax.select(
239
+ is_extrinsic, transition.value_e, transition.value_i
240
+ ),
241
+ jax.lax.select(
242
+ is_extrinsic, transition.reward_e, transition.reward_i
243
+ ),
244
+ )
245
+ done = jnp.logical_and(
246
+ done, jnp.logical_or(config["RND_IS_EPISODIC"], is_extrinsic)
247
+ )
248
+
249
+ delta = reward + config["GAMMA"] * next_value * (1 - done) - value
250
+ gae = (
251
+ delta
252
+ + config["GAMMA"] * config["GAE_LAMBDA"] * (1 - done) * gae
253
+ )
254
+ return (gae, value, is_extrinsic), gae
255
+
256
+ _, advantages = jax.lax.scan(
257
+ _get_advantages,
258
+ (jnp.zeros_like(last_val), last_val, is_extrinsic),
259
+ traj_batch,
260
+ reverse=True,
261
+ unroll=16,
262
+ )
263
+ return advantages, advantages + jax.lax.select(
264
+ is_extrinsic, traj_batch.value_e, traj_batch.value_i
265
+ )
266
+
267
+ advantages_e, targets_e = _calculate_gae(traj_batch, last_val_e, True)
268
+ advantages_i, targets_i = _calculate_gae(traj_batch, last_val_i, False)
269
+
270
+ # UPDATE NETWORK
271
+ def _update_epoch(update_state, unused):
272
+ def _update_minbatch(train_state, batch_info):
273
+ (
274
+ traj_batch,
275
+ advantages_e,
276
+ targets_e,
277
+ advantages_i,
278
+ targets_i,
279
+ ) = batch_info
280
+
281
+ # Policy/value network
282
+ def _loss_fn(
283
+ params, traj_batch, gae_e, targets_e, gae_i, targets_i
284
+ ):
285
+ # RERUN NETWORK
286
+ pi, value_e, value_i = network.apply(params, traj_batch.obs)
287
+ log_prob = pi.log_prob(traj_batch.action)
288
+
289
+ # CALCULATE EXTRINSIC VALUE LOSS
290
+ value_pred_clipped_e = traj_batch.value_e + (
291
+ value_e - traj_batch.value_e
292
+ ).clip(-config["CLIP_EPS"], config["CLIP_EPS"])
293
+ value_losses_e = jnp.square(value_e - targets_e)
294
+ value_losses_clipped_e = jnp.square(
295
+ value_pred_clipped_e - targets_e
296
+ )
297
+ value_loss_e = (
298
+ 0.5
299
+ * jnp.maximum(value_losses_e, value_losses_clipped_e).mean()
300
+ )
301
+
302
+ # CALCULATE INTRINSIC VALUE LOSS
303
+ value_pred_clipped_i = traj_batch.value_i + (
304
+ value_i - traj_batch.value_i
305
+ ).clip(-config["CLIP_EPS"], config["CLIP_EPS"])
306
+ value_losses_i = jnp.square(value_i - targets_i)
307
+ value_losses_clipped_i = jnp.square(
308
+ value_pred_clipped_i - targets_i
309
+ )
310
+ value_loss_i = (
311
+ 0.5
312
+ * jnp.maximum(value_losses_i, value_losses_clipped_i).mean()
313
+ )
314
+
315
+ # CALCULATE ACTOR LOSS
316
+ gae = gae_e
317
+ if config["USE_RND"]:
318
+ gae += gae_i * config["RND_GAE_COEFF"]
319
+ ratio = jnp.exp(log_prob - traj_batch.log_prob)
320
+ gae = (gae - gae.mean()) / (gae.std() + 1e-8)
321
+ loss_actor1 = ratio * gae
322
+ loss_actor2 = (
323
+ jnp.clip(
324
+ ratio,
325
+ 1.0 - config["CLIP_EPS"],
326
+ 1.0 + config["CLIP_EPS"],
327
+ )
328
+ * gae
329
+ )
330
+ loss_actor = -jnp.minimum(loss_actor1, loss_actor2)
331
+ loss_actor = loss_actor.mean()
332
+ entropy = pi.entropy().mean()
333
+
334
+ value_loss = value_loss_e
335
+ if config["USE_RND"]:
336
+ value_loss += value_loss_i
337
+
338
+ total_loss = (
339
+ loss_actor
340
+ + config["VF_COEF"] * value_loss
341
+ - config["ENT_COEF"] * entropy
342
+ )
343
+ return total_loss, (
344
+ value_loss_e,
345
+ value_loss_i,
346
+ loss_actor,
347
+ entropy,
348
+ )
349
+
350
+ grad_fn = jax.value_and_grad(_loss_fn, has_aux=True)
351
+ total_loss, grads = grad_fn(
352
+ train_state.params,
353
+ traj_batch,
354
+ advantages_e,
355
+ targets_e,
356
+ advantages_i,
357
+ targets_i,
358
+ )
359
+ train_state = train_state.apply_gradients(grads=grads)
360
+
361
+ losses = (total_loss, 0)
362
+ return train_state, losses
363
+
364
+ (
365
+ train_state,
366
+ traj_batch,
367
+ advantages_e,
368
+ targets_e,
369
+ advantages_i,
370
+ targets_i,
371
+ rng,
372
+ ) = update_state
373
+ rng, _rng = jax.random.split(rng)
374
+ batch_size = config["MINIBATCH_SIZE"] * config["NUM_MINIBATCHES"]
375
+ assert (
376
+ batch_size == config["NUM_STEPS"] * config["NUM_ENVS"]
377
+ ), "batch size must be equal to number of steps * number of envs"
378
+ permutation = jax.random.permutation(_rng, batch_size)
379
+ batch = (
380
+ traj_batch,
381
+ advantages_e,
382
+ targets_e,
383
+ advantages_i,
384
+ targets_i,
385
+ )
386
+ batch = jax.tree.map(
387
+ lambda x: x.reshape((batch_size,) + x.shape[2:]), batch
388
+ )
389
+ shuffled_batch = jax.tree.map(
390
+ lambda x: jnp.take(x, permutation, axis=0), batch
391
+ )
392
+ minibatches = jax.tree.map(
393
+ lambda x: jnp.reshape(
394
+ x, [config["NUM_MINIBATCHES"], -1] + list(x.shape[1:])
395
+ ),
396
+ shuffled_batch,
397
+ )
398
+ train_state, losses = jax.lax.scan(
399
+ _update_minbatch, train_state, minibatches
400
+ )
401
+ update_state = (
402
+ train_state,
403
+ traj_batch,
404
+ advantages_e,
405
+ targets_e,
406
+ advantages_i,
407
+ targets_i,
408
+ rng,
409
+ )
410
+ return update_state, losses
411
+
412
+ update_state = (
413
+ train_state,
414
+ traj_batch,
415
+ advantages_e,
416
+ targets_e,
417
+ advantages_i,
418
+ targets_i,
419
+ rng,
420
+ )
421
+ update_state, loss_info = jax.lax.scan(
422
+ _update_epoch, update_state, None, config["UPDATE_EPOCHS"]
423
+ )
424
+
425
+ train_state = update_state[0]
426
+ metric = jax.tree.map(
427
+ lambda x: (x * traj_batch.info["returned_episode"]).sum()
428
+ / traj_batch.info["returned_episode"].sum(),
429
+ traj_batch.info,
430
+ )
431
+
432
+ rng = update_state[-1]
433
+
434
+ # UPDATE EXPLORATION STATE
435
+ def _update_ex_epoch(update_state, unused):
436
+ def _update_ex_minbatch(ex_state, traj_batch):
437
+ rnd_loss = 0
438
+
439
+ if config["USE_RND"]:
440
+
441
+ def _rnd_loss_fn(rnd_distillation_params, traj_batch):
442
+ random_network_out = rnd_random_network.apply(
443
+ rnd_random_network_params, traj_batch.next_obs
444
+ )
445
+
446
+ distillation_network_out = ex_state[
447
+ "rnd_distillation_network"
448
+ ].apply_fn(rnd_distillation_params, traj_batch.next_obs)
449
+
450
+ error = (random_network_out - distillation_network_out) * (
451
+ 1 - traj_batch.done[:, None]
452
+ )
453
+ return jnp.square(error).mean() * config["RND_LOSS_COEFF"]
454
+
455
+ rnd_grad_fn = jax.value_and_grad(_rnd_loss_fn, has_aux=False)
456
+ rnd_loss, rnd_grad = rnd_grad_fn(
457
+ ex_state["rnd_distillation_network"].params, traj_batch
458
+ )
459
+ ex_state["rnd_distillation_network"] = ex_state[
460
+ "rnd_distillation_network"
461
+ ].apply_gradients(grads=rnd_grad)
462
+
463
+ losses = (rnd_loss,)
464
+ return ex_state, losses
465
+
466
+ (ex_state, traj_batch, rng) = update_state
467
+ rng, _rng = jax.random.split(rng)
468
+ batch_size = config["MINIBATCH_SIZE"] * config["NUM_MINIBATCHES"]
469
+ assert (
470
+ batch_size == config["NUM_STEPS"] * config["NUM_ENVS"]
471
+ ), "batch size must be equal to number of steps * number of envs"
472
+ permutation = jax.random.permutation(_rng, batch_size)
473
+ batch = jax.tree.map(
474
+ lambda x: x.reshape((batch_size,) + x.shape[2:]), traj_batch
475
+ )
476
+ shuffled_batch = jax.tree.map(
477
+ lambda x: jnp.take(x, permutation, axis=0), batch
478
+ )
479
+ minibatches = jax.tree.map(
480
+ lambda x: jnp.reshape(
481
+ x, [config["NUM_MINIBATCHES"], -1] + list(x.shape[1:])
482
+ ),
483
+ shuffled_batch,
484
+ )
485
+ ex_state, losses = jax.lax.scan(
486
+ _update_ex_minbatch, ex_state, minibatches
487
+ )
488
+ update_state = (ex_state, traj_batch, rng)
489
+ return update_state, losses
490
+
491
+ if config["USE_RND"]:
492
+ ex_update_state = (ex_state, traj_batch, rng)
493
+ ex_update_state, ex_loss = jax.lax.scan(
494
+ _update_ex_epoch,
495
+ ex_update_state,
496
+ None,
497
+ config["EXPLORATION_UPDATE_EPOCHS"],
498
+ )
499
+ metric["rnd_loss"] = ex_loss[0].mean()
500
+ metric["reward_i"] = traj_batch.reward_i.mean()
501
+ metric["reward_e"] = traj_batch.reward_e.mean()
502
+
503
+ ex_state = ex_update_state[0]
504
+ rng = ex_update_state[-1]
505
+
506
+ # wandb logging
507
+ if config["DEBUG"] and config["USE_WANDB"]:
508
+
509
+ def callback(
510
+ metric, update_step
511
+ ): # , loss_info, traj_batch, ex_state, advantages_i, targets_i):
512
+ to_log = create_log_dict(metric, config)
513
+ batch_log(update_step, to_log, config)
514
+
515
+ jax.debug.callback(
516
+ callback,
517
+ metric,
518
+ update_step,
519
+ # loss_info, traj_batch, ex_state, advantages_i, targets_i
520
+ )
521
+
522
+ runner_state = (
523
+ train_state,
524
+ env_state,
525
+ last_obs,
526
+ ex_state,
527
+ rng,
528
+ update_step + 1,
529
+ )
530
+ return runner_state, metric
531
+
532
+ rng, _rng = jax.random.split(rng)
533
+ runner_state = (
534
+ train_state,
535
+ env_state,
536
+ obsv,
537
+ ex_state,
538
+ _rng,
539
+ 0,
540
+ )
541
+ runner_state, metric = jax.lax.scan(
542
+ _update_step, runner_state, None, config["NUM_UPDATES"]
543
+ )
544
+ return {"runner_state": runner_state} # , "info": metric}
545
+
546
+ return train
547
+
548
+
549
+ def run_ppo(config):
550
+ config = {k.upper(): v for k, v in config.__dict__.items()}
551
+
552
+ if config["USE_WANDB"]:
553
+ wandb.init(
554
+ project=config["WANDB_PROJECT"],
555
+ entity=config["WANDB_ENTITY"],
556
+ config=config,
557
+ name=config["ENV_NAME"]
558
+ + "-PPO_RND-"
559
+ + str(int(config["TOTAL_TIMESTEPS"] // 1e6))
560
+ + "M",
561
+ )
562
+
563
+ rng = jax.random.PRNGKey(config["SEED"])
564
+ rngs = jax.random.split(rng, config["NUM_REPEATS"])
565
+
566
+ train_jit = jax.jit(make_train(config))
567
+ train_vmap = jax.vmap(train_jit)
568
+
569
+ t0 = time.time()
570
+ out = train_vmap(rngs)
571
+ t1 = time.time()
572
+ print("Time to run experiment", t1 - t0)
573
+ print("SPS: ", config["TOTAL_TIMESTEPS"] / (t1 - t0))
574
+ # t1 = time.time()
575
+ # out = train_vmap(rngs)
576
+ # t2 = time.time()
577
+ # print("t2", t2 - t1)
578
+ # print("SPS2: ", config["TOTAL_TIMESTEPS"] / (t2 - t1))
579
+
580
+ if config["USE_WANDB"]:
581
+ # if config["DEBUG"] == "end":
582
+ # info = out["info"]
583
+ # for update in range(info["timestep"].shape[1]):
584
+ # if update % 10 == 0:
585
+ # for repeat in range(info["timestep"].shape[0]):
586
+ # update_info = jax.tree.map(lambda x: x[repeat, update], info)
587
+ # to_log = create_log_dict(update_info)
588
+ # batch_log(update, to_log, config)
589
+ #
590
+ # t2 = time.time()
591
+ # print("Time to log to wandb", t2 - t1)
592
+
593
+ def _save_network(rs_index, dir_name):
594
+ train_states = out["runner_state"][rs_index]
595
+ train_state = jax.tree.map(lambda x: x[0], train_states)
596
+
597
+ path = os.path.join(wandb.run.dir, dir_name)
598
+ options = ocp.CheckpointManagerOptions(max_to_keep=1)
599
+
600
+ with ocp.CheckpointManager(path, options=options) as checkpoint_manager:
601
+ checkpoint_manager.save(
602
+ int(config["TOTAL_TIMESTEPS"]),
603
+ args=ocp.args.StandardSave(train_state)
604
+ )
605
+
606
+ print(f"saved runner state to {path}")
607
+
608
+ if config["SAVE_POLICY"]:
609
+ _save_network(0, "policies")
610
+
611
+
612
+ if __name__ == "__main__":
613
+ parser = argparse.ArgumentParser()
614
+ parser.add_argument("--env_name", type=str, default="Craftax-Symbolic-v1")
615
+ parser.add_argument(
616
+ "--num_envs",
617
+ type=int,
618
+ default=1024,
619
+ )
620
+ parser.add_argument(
621
+ "--total_timesteps", type=lambda x: int(float(x)), default=1e9
622
+ ) # Allow scientific notation
623
+ parser.add_argument("--lr", type=float, default=2e-4)
624
+ parser.add_argument("--num_steps", type=int, default=64)
625
+ parser.add_argument("--update_epochs", type=int, default=4)
626
+ parser.add_argument("--num_minibatches", type=int, default=8)
627
+ parser.add_argument("--gamma", type=float, default=0.99)
628
+ parser.add_argument("--gae_lambda", type=float, default=0.8)
629
+ parser.add_argument("--clip_eps", type=float, default=0.2)
630
+ parser.add_argument("--ent_coef", type=float, default=0.01)
631
+ parser.add_argument("--vf_coef", type=float, default=0.5)
632
+ parser.add_argument("--max_grad_norm", type=float, default=1.0)
633
+ parser.add_argument("--activation", type=str, default="tanh")
634
+ parser.add_argument(
635
+ "--anneal_lr", action=argparse.BooleanOptionalAction, default=True
636
+ )
637
+ parser.add_argument("--debug", action=argparse.BooleanOptionalAction, default=True)
638
+ parser.add_argument("--jit", action=argparse.BooleanOptionalAction, default=True)
639
+ parser.add_argument("--seed", type=int)
640
+ parser.add_argument(
641
+ "--use_wandb", action=argparse.BooleanOptionalAction, default=True
642
+ )
643
+ parser.add_argument("--save_policy", action="store_true")
644
+ parser.add_argument("--num_repeats", type=int, default=1)
645
+ parser.add_argument("--layer_size", type=int, default=512)
646
+ parser.add_argument("--wandb_project", type=str)
647
+ parser.add_argument("--wandb_entity", type=str)
648
+ parser.add_argument(
649
+ "--use_optimistic_resets", action=argparse.BooleanOptionalAction, default=True
650
+ )
651
+ parser.add_argument("--optimistic_reset_ratio", type=int, default=16)
652
+
653
+ # EXPLORATION
654
+ parser.add_argument("--exploration_update_epochs", type=int, default=1)
655
+ # RND
656
+ parser.add_argument(
657
+ "--use_rnd", action=argparse.BooleanOptionalAction, default=True
658
+ )
659
+ parser.add_argument("--rnd_layer_size", type=int, default=256)
660
+ parser.add_argument("--rnd_output_size", type=int, default=512)
661
+ parser.add_argument("--rnd_lr", type=float, default=3e-4)
662
+ parser.add_argument("--rnd_reward_coeff", type=float, default=1.0)
663
+ parser.add_argument("--rnd_loss_coeff", type=float, default=0.01)
664
+ parser.add_argument("--rnd_gae_coeff", type=float, default=0.01)
665
+ parser.add_argument(
666
+ "--rnd_is_episodic", action=argparse.BooleanOptionalAction, default=False
667
+ )
668
+
669
+ args, rest_args = parser.parse_known_args(sys.argv[1:])
670
+ if rest_args:
671
+ raise ValueError(f"Unknown args {rest_args}")
672
+
673
+ if args.seed is None:
674
+ args.seed = np.random.randint(2**31)
675
+
676
+ if args.jit:
677
+ run_ppo(args)
678
+ else:
679
+ with jax.disable_jit():
680
+ run_ppo(args)
Craftax_Baselines/ppo_rnn.py ADDED
@@ -0,0 +1,542 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import sys
4
+
5
+ import jax
6
+ import jax.numpy as jnp
7
+ import flax.linen as nn
8
+ import numpy as np
9
+ import optax
10
+ import time
11
+
12
+ import orbax.checkpoint as ocp
13
+
14
+ import wandb
15
+ from flax.linen.initializers import constant, orthogonal
16
+ from typing import NamedTuple, Dict
17
+ from flax.training.train_state import TrainState
18
+ import distrax
19
+ import functools
20
+
21
+ from wrappers import (
22
+ LogWrapper,
23
+ OptimisticResetVecEnvWrapper,
24
+ BatchEnvWrapper,
25
+ AutoResetEnvWrapper,
26
+ )
27
+ from logz.batch_logging import create_log_dict, batch_log
28
+
29
+ from craftax.craftax_env import make_craftax_env_from_name
30
+
31
+ # Code adapted from the original implementation made by Chris Lu
32
+ # Original code located at https://github.com/luchris429/purejaxrl
33
+
34
+
35
+ class ScannedRNN(nn.Module):
36
+ @functools.partial(
37
+ nn.scan,
38
+ variable_broadcast="params",
39
+ in_axes=0,
40
+ out_axes=0,
41
+ split_rngs={"params": False},
42
+ )
43
+ @nn.compact
44
+ def __call__(self, carry, x):
45
+ """Applies the module."""
46
+ rnn_state = carry
47
+ ins, resets = x
48
+ rnn_state = jnp.where(
49
+ resets[:, np.newaxis],
50
+ self.initialize_carry(ins.shape[0], ins.shape[1]),
51
+ rnn_state,
52
+ )
53
+ new_rnn_state, y = nn.GRUCell(features=ins.shape[1])(rnn_state, ins)
54
+ return new_rnn_state, y
55
+
56
+ @staticmethod
57
+ def initialize_carry(batch_size, hidden_size):
58
+ # Use a dummy key since the default state init fn is just zeros.
59
+ cell = nn.GRUCell(features=hidden_size)
60
+ return cell.initialize_carry(jax.random.PRNGKey(0), (batch_size, hidden_size))
61
+
62
+
63
+ class ActorCriticRNN(nn.Module):
64
+ action_dim: int
65
+ config: Dict
66
+
67
+ @nn.compact
68
+ def __call__(self, hidden, x):
69
+ obs, dones = x
70
+ embedding = nn.Dense(
71
+ self.config["LAYER_SIZE"],
72
+ kernel_init=orthogonal(np.sqrt(2)),
73
+ bias_init=constant(0.0),
74
+ )(obs)
75
+ embedding = nn.relu(embedding)
76
+
77
+ rnn_in = (embedding, dones)
78
+ hidden, embedding = ScannedRNN()(hidden, rnn_in)
79
+
80
+ actor_mean = nn.Dense(
81
+ self.config["LAYER_SIZE"],
82
+ kernel_init=orthogonal(2),
83
+ bias_init=constant(0.0),
84
+ )(embedding)
85
+ actor_mean = nn.relu(actor_mean)
86
+ actor_mean = nn.Dense(
87
+ self.config["LAYER_SIZE"],
88
+ kernel_init=orthogonal(2),
89
+ bias_init=constant(0.0),
90
+ )(actor_mean)
91
+ actor_mean = nn.relu(actor_mean)
92
+ actor_mean = nn.Dense(
93
+ self.action_dim, kernel_init=orthogonal(0.01), bias_init=constant(0.0)
94
+ )(actor_mean)
95
+
96
+ pi = distrax.Categorical(logits=actor_mean)
97
+
98
+ critic = nn.Dense(
99
+ self.config["LAYER_SIZE"],
100
+ kernel_init=orthogonal(2),
101
+ bias_init=constant(0.0),
102
+ )(embedding)
103
+ critic = nn.relu(critic)
104
+ critic = nn.Dense(
105
+ self.config["LAYER_SIZE"],
106
+ kernel_init=orthogonal(2),
107
+ bias_init=constant(0.0),
108
+ )(critic)
109
+ critic = nn.relu(critic)
110
+ critic = nn.Dense(1, kernel_init=orthogonal(1.0), bias_init=constant(0.0))(
111
+ critic
112
+ )
113
+
114
+ return hidden, pi, jnp.squeeze(critic, axis=-1)
115
+
116
+
117
+ class Transition(NamedTuple):
118
+ done: jnp.ndarray
119
+ action: jnp.ndarray
120
+ value: jnp.ndarray
121
+ reward: jnp.ndarray
122
+ log_prob: jnp.ndarray
123
+ obs: jnp.ndarray
124
+ info: jnp.ndarray
125
+
126
+
127
+ def make_train(config):
128
+ config["NUM_UPDATES"] = (
129
+ config["TOTAL_TIMESTEPS"] // config["NUM_STEPS"] // config["NUM_ENVS"]
130
+ )
131
+ config["MINIBATCH_SIZE"] = (
132
+ config["NUM_ENVS"] * config["NUM_STEPS"] // config["NUM_MINIBATCHES"]
133
+ )
134
+
135
+ # Create environment
136
+ env = make_craftax_env_from_name(
137
+ config["ENV_NAME"], not config["USE_OPTIMISTIC_RESETS"]
138
+ )
139
+ env_params = env.default_params
140
+
141
+ # Wrap with some extra logging
142
+ env = LogWrapper(env)
143
+
144
+ # Wrap with a batcher, maybe using optimistic resets
145
+ if config["USE_OPTIMISTIC_RESETS"]:
146
+ env = OptimisticResetVecEnvWrapper(
147
+ env,
148
+ num_envs=config["NUM_ENVS"],
149
+ reset_ratio=min(config["OPTIMISTIC_RESET_RATIO"], config["NUM_ENVS"]),
150
+ )
151
+ else:
152
+ env = AutoResetEnvWrapper(env)
153
+ env = BatchEnvWrapper(env, num_envs=config["NUM_ENVS"])
154
+
155
+ def linear_schedule(count):
156
+ frac = (
157
+ 1.0
158
+ - (count // (config["NUM_MINIBATCHES"] * config["UPDATE_EPOCHS"]))
159
+ / config["NUM_UPDATES"]
160
+ )
161
+ return config["LR"] * frac
162
+
163
+ def train(rng):
164
+ # INIT NETWORK
165
+ network = ActorCriticRNN(env.action_space(env_params).n, config=config)
166
+ rng, _rng = jax.random.split(rng)
167
+ init_x = (
168
+ jnp.zeros(
169
+ (1, config["NUM_ENVS"], *env.observation_space(env_params).shape)
170
+ ),
171
+ jnp.zeros((1, config["NUM_ENVS"])),
172
+ )
173
+ init_hstate = ScannedRNN.initialize_carry(
174
+ config["NUM_ENVS"], config["LAYER_SIZE"]
175
+ )
176
+ network_params = network.init(_rng, init_hstate, init_x)
177
+ if config["ANNEAL_LR"]:
178
+ tx = optax.chain(
179
+ optax.clip_by_global_norm(config["MAX_GRAD_NORM"]),
180
+ optax.adam(learning_rate=linear_schedule, eps=1e-5),
181
+ )
182
+ else:
183
+ tx = optax.chain(
184
+ optax.clip_by_global_norm(config["MAX_GRAD_NORM"]),
185
+ optax.adam(config["LR"], eps=1e-5),
186
+ )
187
+ train_state = TrainState.create(
188
+ apply_fn=network.apply,
189
+ params=network_params,
190
+ tx=tx,
191
+ )
192
+
193
+ # INIT ENV
194
+ rng, _rng = jax.random.split(rng)
195
+ obsv, env_state = env.reset(_rng, env_params)
196
+ init_hstate = ScannedRNN.initialize_carry(
197
+ config["NUM_ENVS"], config["LAYER_SIZE"]
198
+ )
199
+
200
+ # TRAIN LOOP
201
+ def _update_step(runner_state, unused):
202
+ # COLLECT TRAJECTORIES
203
+ def _env_step(runner_state, unused):
204
+ (
205
+ train_state,
206
+ env_state,
207
+ last_obs,
208
+ last_done,
209
+ hstate,
210
+ rng,
211
+ update_step,
212
+ ) = runner_state
213
+ rng, _rng = jax.random.split(rng)
214
+
215
+ # SELECT ACTION
216
+ ac_in = (last_obs[np.newaxis, :], last_done[np.newaxis, :])
217
+ hstate, pi, value = network.apply(train_state.params, hstate, ac_in)
218
+ action = pi.sample(seed=_rng)
219
+ log_prob = pi.log_prob(action)
220
+ value, action, log_prob = (
221
+ value.squeeze(0),
222
+ action.squeeze(0),
223
+ log_prob.squeeze(0),
224
+ )
225
+
226
+ # STEP ENV
227
+ rng, _rng = jax.random.split(rng)
228
+ obsv, env_state, reward, done, info = env.step(
229
+ _rng, env_state, action, env_params
230
+ )
231
+ transition = Transition(
232
+ last_done, action, value, reward, log_prob, last_obs, info
233
+ )
234
+ runner_state = (
235
+ train_state,
236
+ env_state,
237
+ obsv,
238
+ done,
239
+ hstate,
240
+ rng,
241
+ update_step,
242
+ )
243
+ return runner_state, transition
244
+
245
+ initial_hstate = runner_state[-3]
246
+ runner_state, traj_batch = jax.lax.scan(
247
+ _env_step, runner_state, None, config["NUM_STEPS"]
248
+ )
249
+
250
+ # CALCULATE ADVANTAGE
251
+ (
252
+ train_state,
253
+ env_state,
254
+ last_obs,
255
+ last_done,
256
+ hstate,
257
+ rng,
258
+ update_step,
259
+ ) = runner_state
260
+ ac_in = (last_obs[np.newaxis, :], last_done[np.newaxis, :])
261
+ _, _, last_val = network.apply(train_state.params, hstate, ac_in)
262
+ last_val = last_val.squeeze(0)
263
+
264
+ def _calculate_gae(traj_batch, last_val, last_done):
265
+ def _get_advantages(carry, transition):
266
+ gae, next_value, next_done = carry
267
+ done, value, reward = (
268
+ transition.done,
269
+ transition.value,
270
+ transition.reward,
271
+ )
272
+ delta = (
273
+ reward + config["GAMMA"] * next_value * (1 - next_done) - value
274
+ )
275
+ gae = (
276
+ delta
277
+ + config["GAMMA"] * config["GAE_LAMBDA"] * (1 - next_done) * gae
278
+ )
279
+ return (gae, value, done), gae
280
+
281
+ _, advantages = jax.lax.scan(
282
+ _get_advantages,
283
+ (jnp.zeros_like(last_val), last_val, last_done),
284
+ traj_batch,
285
+ reverse=True,
286
+ unroll=16,
287
+ )
288
+ return advantages, advantages + traj_batch.value
289
+
290
+ advantages, targets = _calculate_gae(traj_batch, last_val, last_done)
291
+
292
+ # UPDATE NETWORK
293
+ def _update_epoch(update_state, unused):
294
+ def _update_minbatch(train_state, batch_info):
295
+ init_hstate, traj_batch, advantages, targets = batch_info
296
+
297
+ def _loss_fn(params, init_hstate, traj_batch, gae, targets):
298
+ # RERUN NETWORK
299
+ _, pi, value = network.apply(
300
+ params, init_hstate[0], (traj_batch.obs, traj_batch.done)
301
+ )
302
+ log_prob = pi.log_prob(traj_batch.action)
303
+
304
+ # CALCULATE VALUE LOSS
305
+ value_pred_clipped = traj_batch.value + (
306
+ value - traj_batch.value
307
+ ).clip(-config["CLIP_EPS"], config["CLIP_EPS"])
308
+ value_losses = jnp.square(value - targets)
309
+ value_losses_clipped = jnp.square(value_pred_clipped - targets)
310
+ value_loss = (
311
+ 0.5 * jnp.maximum(value_losses, value_losses_clipped).mean()
312
+ )
313
+
314
+ # CALCULATE ACTOR LOSS
315
+ ratio = jnp.exp(log_prob - traj_batch.log_prob)
316
+ gae = (gae - gae.mean()) / (gae.std() + 1e-8)
317
+ loss_actor1 = ratio * gae
318
+ loss_actor2 = (
319
+ jnp.clip(
320
+ ratio,
321
+ 1.0 - config["CLIP_EPS"],
322
+ 1.0 + config["CLIP_EPS"],
323
+ )
324
+ * gae
325
+ )
326
+ loss_actor = -jnp.minimum(loss_actor1, loss_actor2)
327
+ loss_actor = loss_actor.mean()
328
+ entropy = pi.entropy().mean()
329
+
330
+ total_loss = (
331
+ loss_actor
332
+ + config["VF_COEF"] * value_loss
333
+ - config["ENT_COEF"] * entropy
334
+ )
335
+ return total_loss, (value_loss, loss_actor, entropy)
336
+
337
+ grad_fn = jax.value_and_grad(_loss_fn, has_aux=True)
338
+ total_loss, grads = grad_fn(
339
+ train_state.params, init_hstate, traj_batch, advantages, targets
340
+ )
341
+ train_state = train_state.apply_gradients(grads=grads)
342
+ return train_state, total_loss
343
+
344
+ (
345
+ train_state,
346
+ init_hstate,
347
+ traj_batch,
348
+ advantages,
349
+ targets,
350
+ rng,
351
+ ) = update_state
352
+
353
+ rng, _rng = jax.random.split(rng)
354
+ permutation = jax.random.permutation(_rng, config["NUM_ENVS"])
355
+ batch = (init_hstate, traj_batch, advantages, targets)
356
+
357
+ shuffled_batch = jax.tree.map(
358
+ lambda x: jnp.take(x, permutation, axis=1), batch
359
+ )
360
+
361
+ minibatches = jax.tree.map(
362
+ lambda x: jnp.swapaxes(
363
+ jnp.reshape(
364
+ x,
365
+ [x.shape[0], config["NUM_MINIBATCHES"], -1]
366
+ + list(x.shape[2:]),
367
+ ),
368
+ 1,
369
+ 0,
370
+ ),
371
+ shuffled_batch,
372
+ )
373
+
374
+ train_state, total_loss = jax.lax.scan(
375
+ _update_minbatch, train_state, minibatches
376
+ )
377
+ update_state = (
378
+ train_state,
379
+ init_hstate,
380
+ traj_batch,
381
+ advantages,
382
+ targets,
383
+ rng,
384
+ )
385
+ return update_state, total_loss
386
+
387
+ init_hstate = initial_hstate[None, :] # TBH
388
+ update_state = (
389
+ train_state,
390
+ init_hstate,
391
+ traj_batch,
392
+ advantages,
393
+ targets,
394
+ rng,
395
+ )
396
+ update_state, loss_info = jax.lax.scan(
397
+ _update_epoch, update_state, None, config["UPDATE_EPOCHS"]
398
+ )
399
+ train_state = update_state[0]
400
+ metric = jax.tree.map(
401
+ lambda x: (x * traj_batch.info["returned_episode"]).sum()
402
+ / traj_batch.info["returned_episode"].sum(),
403
+ traj_batch.info,
404
+ )
405
+ rng = update_state[-1]
406
+ if config["DEBUG"] and config["USE_WANDB"]:
407
+
408
+ def callback(metric, update_step):
409
+ to_log = create_log_dict(metric, config)
410
+ batch_log(update_step, to_log, config)
411
+
412
+ jax.debug.callback(callback, metric, update_step)
413
+
414
+ runner_state = (
415
+ train_state,
416
+ env_state,
417
+ last_obs,
418
+ last_done,
419
+ hstate,
420
+ rng,
421
+ update_step + 1,
422
+ )
423
+ return runner_state, metric
424
+
425
+ rng, _rng = jax.random.split(rng)
426
+ runner_state = (
427
+ train_state,
428
+ env_state,
429
+ obsv,
430
+ jnp.zeros((config["NUM_ENVS"]), dtype=bool),
431
+ init_hstate,
432
+ _rng,
433
+ 0,
434
+ )
435
+ runner_state, metric = jax.lax.scan(
436
+ _update_step, runner_state, None, config["NUM_UPDATES"]
437
+ )
438
+ return {"runner_state": runner_state, "metric": metric}
439
+
440
+ return train
441
+
442
+
443
+ def run_ppo(config):
444
+ config = {k.upper(): v for k, v in config.__dict__.items()}
445
+
446
+ if config["USE_WANDB"]:
447
+ wandb.init(
448
+ project=config["WANDB_PROJECT"],
449
+ entity=config["WANDB_ENTITY"],
450
+ config=config,
451
+ name=config["ENV_NAME"]
452
+ + "-PPO_RNN-"
453
+ + str(int(config["TOTAL_TIMESTEPS"] // 1e6))
454
+ + "M",
455
+ )
456
+
457
+ rng = jax.random.PRNGKey(config["SEED"])
458
+ rngs = jax.random.split(rng, config["NUM_REPEATS"])
459
+
460
+ train_jit = jax.jit(make_train(config))
461
+ train_vmap = jax.vmap(train_jit)
462
+
463
+ t0 = time.time()
464
+ out = train_vmap(rngs)
465
+ t1 = time.time()
466
+ print("Time to run experiment", t1 - t0)
467
+ print("SPS: ", config["TOTAL_TIMESTEPS"] / (t1 - t0))
468
+
469
+ if config["USE_WANDB"]:
470
+
471
+ def _save_network(rs_index, dir_name):
472
+ train_states = out["runner_state"][rs_index]
473
+ train_state = jax.tree.map(lambda x: x[0], train_states)
474
+
475
+ path = os.path.join(wandb.run.dir, dir_name)
476
+ options = ocp.CheckpointManagerOptions(max_to_keep=1)
477
+
478
+ with ocp.CheckpointManager(path, options=options) as checkpoint_manager:
479
+ checkpoint_manager.save(
480
+ int(config["TOTAL_TIMESTEPS"]),
481
+ args=ocp.args.StandardSave(train_state)
482
+ )
483
+
484
+ print(f"saved runner state to {path}")
485
+
486
+ if config["SAVE_POLICY"]:
487
+ _save_network(0, "policies")
488
+
489
+
490
+ if __name__ == "__main__":
491
+ parser = argparse.ArgumentParser()
492
+ parser.add_argument("--env_name", type=str, default="Craftax-Symbolic-v1")
493
+ parser.add_argument(
494
+ "--num_envs",
495
+ type=int,
496
+ default=1024,
497
+ )
498
+ parser.add_argument("--total_timesteps", type=lambda x: int(float(x)), default=1e9)
499
+ parser.add_argument("--lr", type=float, default=2e-4)
500
+ parser.add_argument("--num_steps", type=int, default=64)
501
+ parser.add_argument("--update_epochs", type=int, default=4)
502
+ parser.add_argument("--num_minibatches", type=int, default=8)
503
+ parser.add_argument("--gamma", type=float, default=0.99)
504
+ parser.add_argument("--gae_lambda", type=float, default=0.8)
505
+ parser.add_argument("--clip_eps", type=float, default=0.2)
506
+ parser.add_argument("--ent_coef", type=float, default=0.01)
507
+ parser.add_argument("--vf_coef", type=float, default=0.5)
508
+ parser.add_argument("--max_grad_norm", type=float, default=1.0)
509
+ parser.add_argument("--activation", type=str, default="tanh")
510
+ parser.add_argument(
511
+ "--anneal_lr", action=argparse.BooleanOptionalAction, default=True
512
+ )
513
+ parser.add_argument("--debug", action=argparse.BooleanOptionalAction, default=True)
514
+ parser.add_argument("--jit", action=argparse.BooleanOptionalAction, default=True)
515
+ parser.add_argument("--seed", type=int, default=np.random.randint(2**31))
516
+ parser.add_argument(
517
+ "--use_wandb", action=argparse.BooleanOptionalAction, default=True
518
+ )
519
+ parser.add_argument(
520
+ "--save_policy", action=argparse.BooleanOptionalAction, default=False
521
+ )
522
+ parser.add_argument("--num_repeats", type=int, default=1)
523
+ parser.add_argument("--layer_size", type=int, default=512)
524
+ parser.add_argument("--wandb_project", type=str)
525
+ parser.add_argument("--wandb_entity", type=str)
526
+ parser.add_argument(
527
+ "--use_optimistic_resets", action=argparse.BooleanOptionalAction, default=True
528
+ )
529
+ parser.add_argument("--optimistic_reset_ratio", type=int, default=16)
530
+
531
+ args, rest_args = parser.parse_known_args(sys.argv[1:])
532
+ if rest_args:
533
+ raise ValueError(f"Unknown args {rest_args}")
534
+
535
+ if args.seed is None:
536
+ args.seed = np.random.randint(2**31)
537
+
538
+ if args.jit:
539
+ run_ppo(args)
540
+ else:
541
+ with jax.disable_jit():
542
+ run_ppo(args)
Craftax_Baselines/requirements.txt ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ jax[cuda12_pip]
2
+ distrax
3
+ optax
4
+ flax
5
+ numpy
6
+ black
7
+ pre-commit
8
+ argparse
9
+ wandb
10
+ orbax-checkpoint==0.5.0
11
+ pygame
12
+ gymnax
13
+ chex
14
+ matplotlib
15
+ imageio
16
+ craftax
Craftax_Baselines/run_docker.sh ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ WANDB_API_KEY=$(cat ./wandb_key)
3
+ # git pull
4
+
5
+ script_and_args="${@:2}"
6
+ if [ $1 == "all" ]; then
7
+ gpus="0 1 2 3 4 5 6 7"
8
+ else
9
+ gpus=$1
10
+ fi
11
+
12
+ for gpu in $gpus; do
13
+ echo "Launching container craftax_$gpu on GPU $gpu"
14
+ docker run \
15
+ --gpus device=$gpu \
16
+ -e WANDB_API_KEY=$WANDB_API_KEY \
17
+ -v $(pwd):/home/duser/Craftax \
18
+ --name craftax_$gpu \
19
+ --user $(id -u) \
20
+ --rm \
21
+ -d \
22
+ -t craftax_baselines \
23
+ /bin/bash -c "$script_and_args"
24
+ done
Craftax_Baselines/wrappers.py ADDED
@@ -0,0 +1,200 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import jax
2
+ import jax.numpy as jnp
3
+ import chex
4
+ import numpy as np
5
+ from flax import struct
6
+ from functools import partial
7
+ from typing import Optional, Tuple, Union, Any
8
+
9
+
10
+ class GymnaxWrapper(object):
11
+ """Base class for Gymnax wrappers."""
12
+
13
+ def __init__(self, env):
14
+ self._env = env
15
+
16
+ # provide proxy access to regular attributes of wrapped object
17
+ def __getattr__(self, name):
18
+ return getattr(self._env, name)
19
+
20
+
21
+ class BatchEnvWrapper(GymnaxWrapper):
22
+ """Batches reset and step functions"""
23
+
24
+ def __init__(self, env, num_envs: int):
25
+ super().__init__(env)
26
+
27
+ self.num_envs = num_envs
28
+
29
+ self.reset_fn = jax.vmap(self._env.reset, in_axes=(0, None))
30
+ self.step_fn = jax.vmap(self._env.step, in_axes=(0, 0, 0, None))
31
+
32
+ @partial(jax.jit, static_argnums=(0, 2))
33
+ def reset(self, rng, params=None):
34
+ rng, _rng = jax.random.split(rng)
35
+ rngs = jax.random.split(_rng, self.num_envs)
36
+ obs, env_state = self.reset_fn(rngs, params)
37
+ return obs, env_state
38
+
39
+ @partial(jax.jit, static_argnums=(0, 4))
40
+ def step(self, rng, state, action, params=None):
41
+ rng, _rng = jax.random.split(rng)
42
+ rngs = jax.random.split(_rng, self.num_envs)
43
+ obs, state, reward, done, info = self.step_fn(rngs, state, action, params)
44
+
45
+ return obs, state, reward, done, info
46
+
47
+
48
+ class AutoResetEnvWrapper(GymnaxWrapper):
49
+ """Provides standard auto-reset functionality, providing the same behaviour as Gymnax-default."""
50
+
51
+ def __init__(self, env):
52
+ super().__init__(env)
53
+
54
+ @partial(jax.jit, static_argnums=(0, 2))
55
+ def reset(self, key, params=None):
56
+ return self._env.reset(key, params)
57
+
58
+ @partial(jax.jit, static_argnums=(0, 4))
59
+ def step(self, rng, state, action, params=None):
60
+
61
+ rng, _rng = jax.random.split(rng)
62
+ obs_st, state_st, reward, done, info = self._env.step(
63
+ _rng, state, action, params
64
+ )
65
+
66
+ rng, _rng = jax.random.split(rng)
67
+ obs_re, state_re = self._env.reset(_rng, params)
68
+
69
+ # Auto-reset environment based on termination
70
+ def auto_reset(done, state_re, state_st, obs_re, obs_st):
71
+ state = jax.tree.map(
72
+ lambda x, y: jax.lax.select(done, x, y), state_re, state_st
73
+ )
74
+ obs = jax.lax.select(done, obs_re, obs_st)
75
+
76
+ return obs, state
77
+
78
+ obs, state = auto_reset(done, state_re, state_st, obs_re, obs_st)
79
+
80
+ return obs, state, reward, done, info
81
+
82
+
83
+ class OptimisticResetVecEnvWrapper(GymnaxWrapper):
84
+ """
85
+ Provides efficient 'optimistic' resets.
86
+ The wrapper also necessarily handles the batching of environment steps and resetting.
87
+ reset_ratio: the number of environment workers per environment reset. Higher means more efficient but a higher
88
+ chance of duplicate resets.
89
+ """
90
+
91
+ def __init__(self, env, num_envs: int, reset_ratio: int):
92
+ super().__init__(env)
93
+
94
+ self.num_envs = num_envs
95
+ self.reset_ratio = reset_ratio
96
+ assert (
97
+ num_envs % reset_ratio == 0
98
+ ), "Reset ratio must perfectly divide num envs."
99
+ self.num_resets = self.num_envs // reset_ratio
100
+
101
+ self.reset_fn = jax.vmap(self._env.reset, in_axes=(0, None))
102
+ self.step_fn = jax.vmap(self._env.step, in_axes=(0, 0, 0, None))
103
+
104
+ @partial(jax.jit, static_argnums=(0, 2))
105
+ def reset(self, rng, params=None):
106
+ rng, _rng = jax.random.split(rng)
107
+ rngs = jax.random.split(_rng, self.num_envs)
108
+ obs, env_state = self.reset_fn(rngs, params)
109
+ return obs, env_state
110
+
111
+ @partial(jax.jit, static_argnums=(0, 4))
112
+ def step(self, rng, state, action, params=None):
113
+
114
+ rng, _rng = jax.random.split(rng)
115
+ rngs = jax.random.split(_rng, self.num_envs)
116
+ obs_st, state_st, reward, done, info = self.step_fn(rngs, state, action, params)
117
+
118
+ rng, _rng = jax.random.split(rng)
119
+ rngs = jax.random.split(_rng, self.num_resets)
120
+ obs_re, state_re = self.reset_fn(rngs, params)
121
+
122
+ rng, _rng = jax.random.split(rng)
123
+ reset_indexes = jnp.arange(self.num_resets).repeat(self.reset_ratio)
124
+
125
+ being_reset = jax.random.choice(
126
+ _rng,
127
+ jnp.arange(self.num_envs),
128
+ shape=(self.num_resets,),
129
+ p=done,
130
+ replace=False,
131
+ )
132
+ reset_indexes = reset_indexes.at[being_reset].set(jnp.arange(self.num_resets))
133
+
134
+ obs_re = obs_re[reset_indexes]
135
+ state_re = jax.tree.map(lambda x: x[reset_indexes], state_re)
136
+
137
+ # Auto-reset environment based on termination
138
+ def auto_reset(done, state_re, state_st, obs_re, obs_st):
139
+ state = jax.tree.map(
140
+ lambda x, y: jax.lax.select(done, x, y), state_re, state_st
141
+ )
142
+ obs = jax.lax.select(done, obs_re, obs_st)
143
+
144
+ return state, obs
145
+
146
+ state, obs = jax.vmap(auto_reset)(done, state_re, state_st, obs_re, obs_st)
147
+
148
+ return obs, state, reward, done, info
149
+
150
+
151
+ @struct.dataclass
152
+ class LogEnvState:
153
+ env_state: Any
154
+ episode_returns: float
155
+ episode_lengths: int
156
+ returned_episode_returns: float
157
+ returned_episode_lengths: int
158
+ timestep: int
159
+
160
+
161
+ class LogWrapper(GymnaxWrapper):
162
+ """Log the episode returns and lengths."""
163
+
164
+ def __init__(self, env):
165
+ super().__init__(env)
166
+
167
+ @partial(jax.jit, static_argnums=(0, 2))
168
+ def reset(self, key: chex.PRNGKey, params=None):
169
+ obs, env_state = self._env.reset(key, params)
170
+ state = LogEnvState(env_state, 0.0, 0, 0.0, 0, 0)
171
+ return obs, state
172
+
173
+ @partial(jax.jit, static_argnums=(0, 4))
174
+ def step(
175
+ self,
176
+ key: chex.PRNGKey,
177
+ state,
178
+ action: Union[int, float],
179
+ params=None,
180
+ ):
181
+ obs, env_state, reward, done, info = self._env.step(
182
+ key, state.env_state, action, params
183
+ )
184
+ new_episode_return = state.episode_returns + reward
185
+ new_episode_length = state.episode_lengths + 1
186
+ state = LogEnvState(
187
+ env_state=env_state,
188
+ episode_returns=new_episode_return * (1 - done),
189
+ episode_lengths=new_episode_length * (1 - done),
190
+ returned_episode_returns=state.returned_episode_returns * (1 - done)
191
+ + new_episode_return * done,
192
+ returned_episode_lengths=state.returned_episode_lengths * (1 - done)
193
+ + new_episode_length * done,
194
+ timestep=state.timestep + 1,
195
+ )
196
+ info["returned_episode_returns"] = state.returned_episode_returns
197
+ info["returned_episode_lengths"] = state.returned_episode_lengths
198
+ info["timestep"] = state.timestep
199
+ info["returned_episode"] = done
200
+ return obs, state, reward, done, info
README.md ADDED
@@ -0,0 +1,547 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ReMDM Planner — Discrete Diffusion Planning on Craftax
2
+
3
+ A JAX implementation of **ReMDM** (Remasking Discrete Diffusion Model) for action-sequence planning in the [Craftax](https://github.com/MichaelTMatthews/Craftax) environment. A bidirectional transformer learns to generate action plans by iteratively denoising masked token sequences, conditioned on the current environment observation.
4
+
5
+ ---
6
+
7
+ ## Description
8
+
9
+ The planner starts from a fully-masked action sequence and iteratively unmasks tokens over `T` denoising steps, producing a `plan_horizon`-length plan. The ReMDM framework extends standard Masked Discrete Language Modelling (MDLM) with remasking strategies that allow committed tokens to be re-predicted, improving plan coherence.
10
+
11
+ Two independent training pipelines are available — **Offline BC** and **Online DAgger** — both supervised by a pre-trained PPO expert but otherwise separate. Neither depends on the other; the paper compares them head-to-head.
12
+
13
+ ```
14
+ [Shared] Train PPO agent Craftax_Baselines/ppo_rnn.py | ppo_rnd.py
15
+ |
16
+ v checkpoint
17
+ ┌───────┴────────┐
18
+ │ │
19
+ [Offline BC] [Online DAgger]
20
+ main.py main.py
21
+ --mode offline --mode online
22
+ (train on live (train from scratch;
23
+ PPO rollouts) mixed policy + expert
24
+ │ labels into replay buffer)
25
+ v v
26
+ checkpoint checkpoint
27
+ │ │
28
+ └───────┬────────┘
29
+ v
30
+ [Evaluate] main.py --mode inference --checkpoint_path ...
31
+
32
+ Optional: an offline BC checkpoint can warm-start DAgger
33
+ via --offline_checkpoint_path (not used in the paper).
34
+
35
+ [Offline BC] ──checkpoint──> [Online DAgger]
36
+ ```
37
+
38
+ **Optional utility modes:**
39
+ ```
40
+ [Collect] Save PPO rollouts to disk main.py --mode collect
41
+ [Smoke test] Quick end-to-end check main.py --mode smoke
42
+ ```
43
+
44
+ ---
45
+
46
+ ## Installation
47
+
48
+ ### Prerequisites (system-level)
49
+
50
+ `uv` manages Python packages only. The following must be installed at the OS level before
51
+ running on a GPU node — they are **not** in `pyproject.toml`:
52
+
53
+ - **CUDA 13** driver and toolkit (`libcuda.so`, `libcudnn`)
54
+
55
+ On HPC clusters these are typically loaded via `module load cuda/13.x`.
56
+
57
+ ### 1. Create the virtual environment
58
+
59
+ ```bash
60
+ # CPU-only (local development / macOS)
61
+ uv sync
62
+
63
+ # NVIDIA CUDA 13 (GPU node — Linux only)
64
+ uv sync --extra cuda
65
+
66
+ # Activate
67
+ source .venv/bin/activate
68
+ ```
69
+
70
+ `uv sync` reads `pyproject.toml`, resolves a fully-reproducible lockfile (`uv.lock`),
71
+ and installs into `.venv/`. Commit `uv.lock` to pin the exact dependency graph.
72
+
73
+ ### 2. Initialise the submodule
74
+
75
+ ```bash
76
+ git submodule update --init --recursive
77
+ ```
78
+
79
+ ---
80
+
81
+ ## Dependencies
82
+
83
+ | Package | Version | Role |
84
+ |---------|---------|------|
85
+ | `jax` | >=0.9.2 | JIT compilation and functional arrays |
86
+ | `flax` | >=0.12.6 | Neural network definitions |
87
+ | `optax` | >=0.2.8 | Adam optimiser and gradient clipping |
88
+ | `craftax` | >=1.5.0 | Procedurally-generated Minecraft-like environment |
89
+ | `chex` | >=0.1.91 | JAX testing and assertion utilities |
90
+ | `distrax` | >=0.1.7 | Probability distributions |
91
+ | `orbax` | >=0.1.9 | Model checkpointing |
92
+ | `wandb` | >=0.25.1 | Experiment logging |
93
+ | `numpy` | >=2.4.4 | Array operations |
94
+ | `matplotlib` | >=3.10.8 | Plotting |
95
+ | `polars` | >=1.39.3 | DataFrame analysis |
96
+ | `orjson` | >=3.11.8 | Fast JSON serialisation |
97
+ | `pyyaml` | >=6.0.3 | Config file parsing |
98
+
99
+ Full specification in `pyproject.toml`. Exact transitive pins are in `uv.lock`.
100
+
101
+ ---
102
+
103
+ ## Usage
104
+
105
+ All modes share the same entry point. Defaults are loaded from `configs/defaults.yaml`; any value can be overridden on the command line.
106
+
107
+ ```bash
108
+ python main.py --mode <MODE> [--config PATH] [OVERRIDES...]
109
+ ```
110
+
111
+ Pass `--no-jit` to disable JIT compilation (useful for debugging):
112
+
113
+ ```bash
114
+ python main.py --mode offline --no-jit --num_envs 4
115
+ ```
116
+
117
+ ### Stage 1 — Train a PPO agent
118
+
119
+ PPO training is handled by the `Craftax_Baselines` submodule and produces the checkpoint consumed by all downstream stages.
120
+
121
+ ```bash
122
+ cd Craftax_Baselines
123
+
124
+ # PPO with GRU hidden state (recommended)
125
+ python ppo_rnn.py \
126
+ --env_name Craftax-Classic-Symbolic-v1 \
127
+ --total_timesteps 500000000 \
128
+ --save_policy --use_wandb
129
+
130
+ # PPO with Random Network Distillation
131
+ python ppo_rnd.py \
132
+ --env_name Craftax-Classic-Symbolic-v1 \
133
+ --total_timesteps 500000000 \
134
+ --save_policy --use_wandb
135
+
136
+ cd ..
137
+ ```
138
+
139
+ ### Stage 2a — Collect trajectories to disk
140
+
141
+ Roll out the PPO checkpoint and save `(obs, actions, rewards, dones)` as a `.npz` file for reuse across multiple diffusion training runs.
142
+
143
+ ```bash
144
+ python main.py --mode collect \
145
+ --ppo_checkpoint_path /path/to/ppo_checkpoint \
146
+ --offline_data_path data/trajectories.npz \
147
+ --collect_num_steps 1000000 \
148
+ --collect_num_envs 128
149
+ ```
150
+
151
+ The file stores arrays shaped `[num_envs, num_iters, ...]`, preserving per-environment contiguity so episode boundaries are respected during window sampling.
152
+
153
+ ### Stage 2b — Train offline from live PPO rollouts
154
+
155
+ Roll out the PPO agent live at each update step and train the diffusion model on the collected windows. Windows that cross episode boundaries are masked out; windows with higher cumulative reward receive proportionally larger gradient contributions (clipped to `[0.1, return_weight_cap]`).
156
+
157
+ ```bash
158
+ python main.py --mode offline \
159
+ --ppo_checkpoint_path /path/to/ppo_checkpoint \
160
+ --offline_total_timesteps 100000000 \
161
+ --save_policy
162
+ ```
163
+
164
+ ### Online DAgger Training
165
+
166
+ The diffusion model is trained **from scratch** via DAgger (Dataset Aggregation). At each iteration a mixed policy blends the PPO expert and the diffusion learner (controlled by an exponentially decaying `beta`). The mixed policy rolls out trajectories; the expert labels every visited state with the action it would take. These `(obs, expert_plan)` pairs are appended to a growing circular replay buffer, and the diffusion model is trained on the full buffer with the standard MDLM ELBO loss (pure behavioural cloning — no reward weighting).
167
+
168
+ ```bash
169
+ # From scratch (requires PPO expert checkpoint)
170
+ python main.py --mode online \
171
+ --ppo_checkpoint_path /path/to/ppo_checkpoint \
172
+ --online_num_updates 1000 \
173
+ --save_policy
174
+
175
+ # Optional: warm-start from a pre-trained offline checkpoint
176
+ # (not used in the paper — both methods are compared independently)
177
+ python main.py --mode online \
178
+ --ppo_checkpoint_path /path/to/ppo_checkpoint \
179
+ --offline_checkpoint_path /path/to/offline_checkpoint \
180
+ --online_num_updates 1000 \
181
+ --save_policy
182
+ ```
183
+
184
+ When `save_policy=true`, online training uploads **two** W&B artifacts: `{env_name}-policy` (final weights) and `{env_name}-policy-best` (weights from the validation iteration with the highest return). Either artifact can be consumed downstream by `--checkpoint_path wandb:…`.
185
+
186
+ ### Stage 4 — Evaluate
187
+
188
+ ```bash
189
+ python main.py --mode inference \
190
+ --checkpoint_path /path/to/checkpoint \
191
+ --eval_steps 10000 \
192
+ --eval_num_envs 32
193
+ ```
194
+
195
+ Prints mean episode return, completed episodes, steps per second, and per-achievement unlock counts. Uses historical inpainting: the first `hist_len` plan positions are locked to observed history.
196
+
197
+ ### Loading checkpoints from W&B artifacts
198
+
199
+ Any checkpoint path argument (`--checkpoint_path`, `--offline_checkpoint_path`, `--ppo_checkpoint_path`) accepts a W&B artifact reference prefixed with `wandb:`. The artifact is downloaded automatically before training or evaluation begins.
200
+
201
+ ```bash
202
+ # Fully qualified: entity/project/artifact_name:version_or_alias
203
+ python main.py --mode inference \
204
+ --checkpoint_path wandb:my-team/remdm-craftax/Craftax-Classic-Symbolic-v1-policy:latest
205
+
206
+ # Online fine-tuning from a W&B offline checkpoint
207
+ python main.py --mode online \
208
+ --offline_checkpoint_path wandb:my-team/remdm-craftax/Craftax-Classic-Symbolic-v1-policy:v3
209
+
210
+ # PPO checkpoint from W&B
211
+ python main.py --mode offline \
212
+ --ppo_checkpoint_path wandb:my-team/ppo-craftax/ppo-rnn-policy:best
213
+ ```
214
+
215
+ Control the download location with `--wandb_download_dir` (defaults to `./artifacts/`).
216
+
217
+ ### Resuming a Training Run
218
+
219
+ A completed training checkpoint can be used as the starting point for a new run that continues where the previous one left off. This is useful when extending the training budget or when a preempted job needs to be restarted.
220
+
221
+ **Offline resume:**
222
+
223
+ ```bash
224
+ # Auto-detect step and wandb run ID from checkpoint metadata
225
+ python main.py --mode offline \
226
+ --ppo_checkpoint_path /path/to/ppo_checkpoint \
227
+ --resume_checkpoint_path /path/to/completed_offline_checkpoint \
228
+ --offline_total_timesteps 200000000 \
229
+ --save_policy
230
+
231
+ # Explicit step and wandb run ID override
232
+ python main.py --mode offline \
233
+ --ppo_checkpoint_path /path/to/ppo_checkpoint \
234
+ --resume_checkpoint_path /path/to/completed_offline_checkpoint \
235
+ --resume_step 1525 \
236
+ --resume_wandb_run_id abc123xyz \
237
+ --offline_total_timesteps 200000000 \
238
+ --save_policy
239
+
240
+ # Resume from a W&B artifact
241
+ python main.py --mode offline \
242
+ --ppo_checkpoint_path /path/to/ppo_checkpoint \
243
+ --resume_checkpoint_path wandb:my-team/remdm-craftax/policy:latest \
244
+ --offline_total_timesteps 200000000 \
245
+ --save_policy
246
+ ```
247
+
248
+ **Online resume:**
249
+
250
+ ```bash
251
+ python main.py --mode online \
252
+ --ppo_checkpoint_path /path/to/ppo_checkpoint \
253
+ --resume_checkpoint_path /path/to/completed_online_checkpoint \
254
+ --online_num_updates 2000 \
255
+ --save_policy
256
+ ```
257
+
258
+ **Notes:**
259
+ - The DAgger replay buffer is **not** persisted across resumes. It starts empty and refills within the first few iterations.
260
+ - JIT compilation is fully preserved. Resume only affects initialisation outside `jax.jit` (loading checkpoint, setting the optimizer step counter, adjusting scan length).
261
+ - The cosine LR schedule is constructed for the full `num_updates` range. The optimizer step counter is set to the resume offset so the learning rate picks up exactly where the previous run stopped.
262
+ - When `resume_checkpoint_path` points to a checkpoint with a metadata sidecar, `resume_step` and `resume_wandb_run_id` are auto-detected. Explicit CLI flags override the metadata values.
263
+ - Checkpoints without a metadata sidecar (created before this feature) still load; provide `--resume_step` explicitly.
264
+
265
+
266
+ ---
267
+
268
+ ## Configuration
269
+
270
+ All hyperparameters are in `configs/defaults.yaml`. Override any value on the command line:
271
+
272
+ ```bash
273
+ python main.py --mode offline --lr 1e-4 --plan_horizon 64 --num_minibatches 16
274
+ ```
275
+
276
+ Point to a custom config file:
277
+
278
+ ```bash
279
+ python main.py --mode online --config configs/my_experiment.yaml
280
+ ```
281
+
282
+ Preset configs for larger runs are provided in `configs/`:
283
+
284
+ | File | Purpose |
285
+ |------|---------|
286
+ | `configs/defaults.yaml` | Base defaults for all modes |
287
+ | `configs/classic_exp_a_beta_fix.yaml` | Craftax Classic DAgger — beta decay fix only (isolates data quality) |
288
+ | `configs/classic_exp_b_beta_big_model.yaml` | Craftax Classic DAgger — beta fix + 3.5× larger transformer |
289
+ | `configs/classic_exp_c_full_recipe.yaml` | Craftax Classic DAgger — beta + big model + training dynamics |
290
+ | `configs/craftax_exp_a_beta_fix.yaml` | Full Craftax DAgger — beta decay fix only |
291
+ | `configs/craftax_exp_b_beta_big_model.yaml` | Full Craftax DAgger — beta fix + larger transformer |
292
+ | `configs/craftax_exp_c_full_recipe.yaml` | Full Craftax DAgger — full recipe |
293
+ | `configs/final_classic_ucl.yaml` | Final Craftax Classic DAgger — UCL 3090 Ti, seed 42 (produces the Classic checkpoint consumed by the ablation suite) |
294
+ | `configs/final_classic_qmul.yaml` | Env-frame-matched second seed of `final_classic_ucl.yaml` (QMUL H200, seed 43) |
295
+ | `configs/final_craftax_ucl.yaml` | Final Full Craftax DAgger — UCL 4090, seed 42 (produces the Full Craftax checkpoint consumed by the ablation suite) |
296
+ | `configs/final_craftax_qmul.yaml` | Env-frame-matched second seed of `final_craftax_ucl.yaml` (QMUL H200, seed 43) |
297
+
298
+ RL fine-tuning ablation hyperparameters live under `experiments/rl_finetuning/configs/` and are loaded by `run_ablations.py`, not by `main.py`. See `experiments/README.md`.
299
+
300
+ The `final_*_qmul.yaml` presets differ from their UCL counterparts only in `num_envs` (smaller partition) and `seed`. All fairness-critical hyperparameters are denominated in env frames or update cycles and automatically rescaled by `resolve_scaled_hyperparams()` at load time, so no manual derivation is needed when moving between hardware tiers.
301
+
302
+ ### Key hyperparameters
303
+
304
+ **Environment**
305
+
306
+ | Parameter | Default | Description |
307
+ |-----------|---------|-------------|
308
+ | `env_name` | `Craftax-Classic-Symbolic-v1` | Craftax environment ID. Use `Craftax-Symbolic-v1` for Full Craftax. |
309
+ | `use_optimistic_resets` | `false` | Use `OptimisticResetVecEnvWrapper` instead of `AutoResetEnvWrapper` |
310
+ | `optimistic_reset_ratio` | 16 | Fraction of envs reset per step when optimistic resets are enabled |
311
+
312
+ **Diffusion model**
313
+
314
+ | Parameter | Default | Description |
315
+ |-----------|---------|-------------|
316
+ | `plan_horizon` | 32 | Action plan length H |
317
+ | `diffusion_steps` | 15 | Denoising steps T at inference |
318
+ | `diffusion_schedule` | `cosine` | Noise schedule: `cosine` or `linear` |
319
+ | `remask_strategy` | `rescale` | Remasking strategy: `rescale`, `cap`, or `conf` |
320
+ | `train_sigma` | 0.0 | Per-token remasking correction during training (0 = standard MDLM) |
321
+ | `label_smoothing` | 0.0 | Cross-entropy label smoothing epsilon (0 = exact ELBO) |
322
+ | `eta` | 0.5 | Remasking strength |
323
+ | `use_loop` | `true` | Three-phase loop remasking (Algorithm 3) |
324
+ | `t_on` / `t_off` | 0.7 / 0.3 | Time window boundaries for loop remasking |
325
+ | `temperature` | 0.5 | Softmax temperature for token sampling |
326
+ | `top_p` | 0.95 | Nucleus sampling threshold |
327
+
328
+ **Transformer architecture**
329
+
330
+ | Parameter | Default | Description |
331
+ |-----------|---------|-------------|
332
+ | `d_model` | 256 | Hidden dimension |
333
+ | `n_heads` | 4 | Attention heads |
334
+ | `n_layers` | 4 | Transformer blocks |
335
+ | `d_ff` | 512 | FFN inner dimension |
336
+ | `obs_encoder_layers` | 2 | MLP layers in the observation encoder |
337
+ | `obs_encoder_width` | 512 | Observation encoder hidden width |
338
+ | `dropout_rate` | 0.1 | Dropout rate (disabled at inference) |
339
+
340
+ **Offline training**
341
+
342
+ | Parameter | Default | Description |
343
+ |-----------|---------|-------------|
344
+ | `offline_total_timesteps` | 1e8 | **PRIMARY** env-frame budget for live-PPO data collection. Derives `num_updates` as `offline_total_timesteps // (num_envs * num_steps)`, making the run hardware-portable across `num_envs` changes. |
345
+ | `offline_num_updates` | `null` | **LEGACY** outer update count; used only when `offline_total_timesteps` is unset. |
346
+ | `num_envs` | 1024 | Parallel environments |
347
+ | `num_steps` | 64 | Environment steps collected per update |
348
+ | `num_minibatches` | 8 | Gradient minibatches per epoch |
349
+ | `update_epochs` | 4 | SGD epochs per update step |
350
+ | `num_repeats` | 1 | Independent training seeds (vmapped) |
351
+ | `lr` | 3e-4 | Adam learning rate (cosine-decayed to 10% over all gradient steps) |
352
+ | `lr_warmup_frames` | `null` | **PRIMARY** env-frame warm-up budget. Derives `lr_warmup_steps` as `lr_warmup_frames // (num_envs * num_steps)`. |
353
+ | `lr_warmup_steps` | 0 | **LEGACY** linear warm-up steps before cosine decay (used when `lr_warmup_frames` is unset; 0 = disabled). |
354
+ | `max_grad_norm` | 1.0 | Global gradient clipping norm |
355
+ | `return_weight_cap` | 5.0 | Clip ceiling for per-window return weights (lower clip is fixed at 0.1) |
356
+ | `collect_temperature` | 1.0 | Softmax temperature on PPO logits during live data collection |
357
+ | `val_interval_frames` | `null` | **PRIMARY** env-frames between validation rollouts. Overrides `val_interval` via `val_interval = val_interval_frames // (num_envs * num_steps)`. |
358
+ | `val_interval` | 50 | **LEGACY** validation frequency in update steps (used when `val_interval_frames` is unset). |
359
+ | `val_diffusion_steps` | 50 | Denoising steps used during validation rollouts |
360
+ | `val_replan_every` | 4 | Environment steps executed per diffusion plan during validation |
361
+ | `val_steps` | 128 | Total environment steps per validation rollout |
362
+
363
+ **Online DAgger training**
364
+
365
+ | Parameter | Default | Description |
366
+ |-----------|---------|-------------|
367
+ | `online_total_timesteps` | `null` | **PRIMARY** env-frame budget for online DAgger (hardware-portable). Derives `num_updates` as `online_total_timesteps // (num_envs * num_steps)`. |
368
+ | `online_num_updates` | 1000 | **LEGACY** outer DAgger iterations (used when `online_total_timesteps` is unset). |
369
+ | `dagger_beta_init` | 1.0 | Initial expert mixing probability `beta_1` (1.0 = pure expert on the first iteration). |
370
+ | `dagger_beta_final` | `null` | **PRIMARY** target mixing ratio at the end of training. Overrides `dagger_beta_decay` via `decay = (beta_final / beta_init) ** (1 / num_updates)`. |
371
+ | `dagger_beta_decay` | 0.95 | **LEGACY** per-update decay: `beta_i = beta_init * decay^i` (used when `dagger_beta_final` is unset). |
372
+ | `dagger_buffer_cycles` | `null` | **PRIMARY** buffer capacity denominated in update cycles of history (1 cycle = `num_envs * num_steps` frames). Overrides `dagger_buffer_max` via `buffer_max = cycles * (num_envs * num_steps)`. |
373
+ | `dagger_buffer_max` | 100000 | **LEGACY** max samples in the DAgger replay buffer (circular eviction when full). |
374
+ | `dagger_train_passes` | `null` | Passes per update over the aggregated buffer. `null` = 1 pass (matches offline BC per-update gradient work exactly for fair compute comparison). Raise to >1 to trade BC fairness for wider per-update buffer coverage. |
375
+ | `dagger_expert_deterministic` | `true` | If `true`, the PPO expert takes the argmax action (fixed `s → a*` map); if `false`, it samples categorically. Deterministic removes label noise from the aggregated dataset. |
376
+
377
+ **Data collection**
378
+
379
+ | Parameter | Default | Description |
380
+ |-----------|---------|-------------|
381
+ | `collect_num_steps` | 10000000 | Total environment steps to collect |
382
+ | `collect_num_envs` | 128 | Parallel environments during collection |
383
+ | `ppo_model_type` | `ppo_rnn` | PPO architecture: `ppo`, `ppo_rnn`, or `ppo_rnd` |
384
+ | `layer_size` | 512 | PPO actor-critic hidden layer width |
385
+
386
+ **Inference**
387
+
388
+ | Parameter | Default | Description |
389
+ |-----------|---------|-------------|
390
+ | `eval_steps` | 10000 | Environment steps for evaluation |
391
+ | `eval_num_envs` | 32 | Parallel agents during evaluation (independent of `num_envs`) |
392
+ | `diffusion_steps_eval` | 10 | Denoising steps T used at evaluation time |
393
+
394
+ **Checkpointing**
395
+
396
+ | Parameter | Default | Description |
397
+ |-----------|---------|-------------|
398
+ | `save_policy` | `true` | Save final checkpoint at end of training and upload it as a W&B artifact |
399
+
400
+ **Resume**
401
+
402
+ | Parameter | Default | Description |
403
+ |-----------|---------|-------------|
404
+ | `resume_checkpoint_path` | `null` | Path to a completed checkpoint to resume from (accepts `wandb:` refs) |
405
+ | `resume_wandb_run_id` | `null` | W&B run ID to resume logging into (auto-read from checkpoint metadata) |
406
+ | `resume_step` | `null` | Update step the checkpoint was saved at (auto-read from checkpoint metadata) |
407
+
408
+ **Logging**
409
+
410
+ | Parameter | Default | Description |
411
+ |-----------|---------|-------------|
412
+ | `use_wandb` | `true` | Enable Weights & Biases logging |
413
+ | `wandb_project` | `remdm-craftax` | W&B project name |
414
+ | `wandb_entity` | `"mathis-weil-university-college-london-ucl-"` | W&B entity (team or username) |
415
+ | `wandb_download_dir` | `null` | Download directory for W&B artifacts; null = `./artifacts/` |
416
+ | `seed` | `null` | RNG seed (random if null) |
417
+
418
+ ---
419
+
420
+ ## Remasking Strategies
421
+
422
+ Controlled by `--remask_strategy`. All strategies operate on top of the three-phase loop controlled by `--use_loop`, `--t_on`, and `--t_off`.
423
+
424
+ | Strategy | Formula | Description |
425
+ |----------|---------|-------------|
426
+ | `rescale` | `sigma = eta * sigma_max` | Scales maximum remasking probability proportionally |
427
+ | `cap` | `sigma = min(eta, sigma_max)` | Caps remasking at a fixed rate |
428
+ | `conf` | `sigma = eta * sigma_max * (1 - confidence)` | High-confidence tokens are remasked less |
429
+
430
+ ---
431
+
432
+ ## Environment Wrappers
433
+
434
+ **From `Craftax_Baselines/wrappers.py`** (submodule):
435
+
436
+ | Wrapper | Purpose |
437
+ |---------|---------|
438
+ | `LogWrapper` | Tracks episode returns and lengths; adds stats to the info dict |
439
+ | `AutoResetEnvWrapper` | Automatically resets episodes on `done` |
440
+ | `BatchEnvWrapper` | Vmaps `reset` and `step` over `num_envs` environments |
441
+ | `OptimisticResetVecEnvWrapper` | Batched resets with reduced overhead; enable via `--use_optimistic_resets` |
442
+
443
+ **From `src/envs/wrappers.py`**:
444
+
445
+ | Wrapper | Purpose |
446
+ |---------|---------|
447
+ | `SequenceHistoryWrapper` | Maintains a sliding window of past observations and actions in the env state |
448
+ | `DiscreteTokenizationWrapper` | Quantizes continuous observations into discrete token indices |
449
+ | `PlannerWrapper` | Manages the plan/replan cycle for the diffusion planner |
450
+ | `OfflineTrajectoryWrapper` | Accumulates transitions into a fixed-size circular replay buffer |
451
+
452
+ **Wrapper stacks:**
453
+
454
+ ```
455
+ Training: env -> LogWrapper -> AutoResetEnvWrapper -> BatchEnvWrapper
456
+ Inference: env -> LogWrapper -> AutoResetEnvWrapper -> BatchEnvWrapper
457
+ ```
458
+
459
+ ---
460
+
461
+ ## Project Structure
462
+
463
+ ```
464
+ craftax-ReMDM-planner/
465
+ ├── Craftax_Baselines/ # Git submodule — PPO agents and standard wrappers
466
+ │ ├── wrappers.py # LogWrapper, BatchEnvWrapper, AutoResetEnvWrapper, etc.
467
+ │ ├── ppo_rnn.py # PPO-RNN training script
468
+ │ ├── ppo_rnd.py # PPO-RND training script
469
+ │ ├── ppo.py # PPO model definitions
470
+ │ └── models/
471
+ │ ├── actor_critic.py # ActorCritic variants
472
+ │ ├── rnd.py # RND network
473
+ │ └── icm.py # ICM encoder, forward, and inverse networks
474
+ ├── configs/
475
+ │ ├── defaults.yaml # Base hyperparameters (CLI-overridable)
476
+ │ ├── classic_exp_a_beta_fix.yaml # Classic DAgger — beta decay fix only
477
+ │ ├── classic_exp_b_beta_big_model.yaml # Classic DAgger — beta fix + big model
478
+ │ ├── classic_exp_c_full_recipe.yaml # Classic DAgger — full recipe
479
+ │ ├── craftax_exp_a_beta_fix.yaml # Full Craftax DAgger — beta fix
480
+ │ ├── craftax_exp_b_beta_big_model.yaml # Full Craftax DAgger — beta + big model
481
+ │ ├── craftax_exp_c_full_recipe.yaml # Full Craftax DAgger — full recipe
482
+ │ ├── final_classic_ucl.yaml # Classic DAgger — UCL 3090 Ti, seed 42
483
+ │ ├── final_classic_qmul.yaml # Classic DAgger — QMUL H200, seed 43
484
+ │ ├── final_craftax_ucl.yaml # Full Craftax DAgger — UCL 4090, seed 42
485
+ │ └── final_craftax_qmul.yaml # Full Craftax DAgger — QMUL H200, seed 43
486
+ ├── src/
487
+ │ ├── diffusion/
488
+ │ │ ├── forward.py # Forward masking process q(z_t | x_0)
489
+ │ │ ├── loss.py # Continuous-time MDLM ELBO loss
490
+ │ │ ├── sampling.py # Reverse diffusion with ReMDM remasking
491
+ │ │ └── schedules.py # Linear and cosine noise schedules
492
+ │ ├── models/
493
+ │ │ └── denoiser.py # DenoisingTransformer (obs encoder + transformer)
494
+ │ ├── envs/
495
+ │ │ └── wrappers.py # Sequence, tokenization, planner, and trajectory wrappers
496
+ │ └── planners/
497
+ │ ├── collect.py # --mode collect: PPO rollouts -> .npz
498
+ │ ├── common.py # Shared utilities
499
+ │ ├── env.py # Environment construction
500
+ │ ├── inference.py # --mode inference: MPC evaluation with inpainting
501
+ │ ├── logging.py # Centralised W&B logging utilities
502
+ │ ├── model.py # Diffusion model lifecycle
503
+ │ ├── offline.py # --mode offline: make_train (live PPO rollouts)
504
+ │ ├── online.py # --mode online: DAgger fine-tuning
505
+ │ └── ppo.py # PPO agent adapter and checkpoint loading utilities
506
+ ├── experiments/
507
+ │ └── rl_finetuning/ # RL fine-tuning ablation suite (see experiments/README.md)
508
+ │ ├── run_ablations.py # CLI entry point
509
+ │ ├── ablations/ # Loss, optimizer, registry, and training modules
510
+ │ ├── diagnostics/ # Gradient, representation, and timestep diagnostics
511
+ │ ├── analysis/ # Plots, tables, and report generation
512
+ │ └── configs/ # ablations_default.yaml, ablations_fast.yaml,
513
+ │ # ablations_final_{classic,craftax}_{ucl,qmul}.yaml
514
+ ├── main.py # CLI entry point
515
+ ├── pyproject.toml # uv project — direct deps + tool config
516
+ └── uv.lock # Reproducible lockfile (commit this)
517
+ ```
518
+
519
+ ---
520
+
521
+ ## Implementation Notes
522
+
523
+ **JAX functional purity**: training closures (`make_train`, `make_train_dagger`) are fully JIT-compatible. Environment construction and checkpoint I/O happen outside `jax.jit`.
524
+
525
+ **Offline training**: `--mode offline` rolls out the PPO agent live at each update step via `make_train`. Use `--mode collect` to save a trajectory `.npz` for inspection or analysis; re-feeding it to `--mode offline` is not supported — pass `--ppo_checkpoint_path` instead.
526
+
527
+ **Episode-boundary masking**: the offline sampler pre-computes a validity mask over all `(env, time)` positions. A window at `(e, t)` is valid only if `dones[e, t+1:t+H-1]` are all `False`.
528
+
529
+ **Return weighting**: valid windows are weighted by their cumulative reward, normalised by the batch mean and clipped to `[0.1, RETURN_WEIGHT_CAP]`. Weights are passed as per-sample multipliers into the MDLM loss before reduction, so they correctly scale each sample's gradient contribution.
530
+
531
+ **LR schedule**: cosine decay from `lr` to `lr * 0.1` over all gradient steps. Set `lr_warmup_frames > 0` (env-frame-invariant, PRIMARY) or `lr_warmup_steps > 0` (LEGACY) to prepend a linear warm-up phase.
532
+
533
+ **Env-frame-invariant hyperparameters**: the PRIMARY keys `offline_total_timesteps`, `online_total_timesteps`, `lr_warmup_frames`, `val_interval_frames`, `dagger_beta_final`, and `dagger_buffer_cycles` are denominated in env frames (or update cycles). At config load time, `resolve_scaled_hyperparams()` in `src/planners/common.py` converts them to the equivalent update-step-denominated quantities (`num_updates`, `lr_warmup_steps`, `val_interval`, `dagger_beta_decay`, `dagger_buffer_max`) using the current `num_envs * num_steps` frames-per-update. This lets the same config run on different hardware tiers without re-tuning.
534
+
535
+ **Loss weight clipping**: the MDLM SUBS weight `-alpha'(t) / (1 - alpha_t)` is clipped to 1000 to prevent numerical instability when `alpha_t ≈ 1`.
536
+
537
+ **Validation rollouts**: during offline training, a held-out rollout runs every `val_interval` steps. It uses the same sampling parameters as inference (`remask_strategy`, `eta`, `use_loop`, `t_on`, `t_off`, `temperature`, `top_p`) with `val_diffusion_steps` denoising steps and `val_replan_every` env steps per plan, for a total of `val_steps` environment steps.
538
+
539
+ **W&B logging**: all metric aggregation is centralised in `src/planners/logging.py`. Metric namespaces: `diffusion/` (loss, accuracy), `train/` (data quality, throughput), `env/` (episode returns, achievements), `val/` (validation rollouts, emitted every `val_interval` steps), `dagger/` (online DAgger training: beta, buffer fill, reward mean, valid fraction). `train/sps` (environment frames/sec) is only logged in modes that perform live environment interaction.
540
+
541
+ **DAgger dataset aggregation**: online training (`--mode online`) implements DAgger (Ross et al., 2011). A circular replay buffer accumulates `(obs, expert_plan)` pairs across all iterations. Each update samples uniformly from the full buffer, not just the latest batch. Training samples that cross episode boundaries (any `done` within the plan-horizon window) are marked invalid. The expert (PPO agent) receives correct `done` flags so its RNN hidden state resets on episode boundaries. Windows are extracted with a sliding stride (one per env-time position) rather than stepping the buffer in plan-horizon chunks, so every visited state contributes a label.
542
+
543
+ **Best-checkpoint tracking**: during online training, the parameters from the validation iteration with the highest validation return are preserved alongside the current live parameters. The final checkpoint and the best-validation checkpoint are both uploaded as separate W&B artifacts (`{env_name}-policy` and `{env_name}-policy-best`).
544
+
545
+ **Denoising step indexing**: the reverse scan runs from `step_idx = 0` to `T-1`, mapping to diffusion time `t = (T - step_idx) / T` (high noise to low noise).
546
+
547
+ **Submodule PPO agents**: PPO training lives entirely in `Craftax_Baselines/`. Planner scripts only consume pre-trained checkpoints via `--ppo_checkpoint_path`.
checkpoints/offline/Craftax-Classic-Symbolic-v1-OfflineDiffusion-BC-100M/100000000/_CHECKPOINT_METADATA ADDED
@@ -0,0 +1 @@
 
 
1
+ {"item_handlers": {"default": "orbax.checkpoint._src.handlers.standard_checkpoint_handler.StandardCheckpointHandler"}, "metrics": {}, "performance_metrics": {}, "init_timestamp_nsecs": 1775663434533263974, "commit_timestamp_nsecs": 1775663435644779625, "custom_metadata": {}}
checkpoints/offline/Craftax-Classic-Symbolic-v1-OfflineDiffusion-BC-100M/100000000/default/_METADATA ADDED
The diff for this file is too large to render. See raw diff
 
checkpoints/offline/Craftax-Classic-Symbolic-v1-OfflineDiffusion-BC-100M/100000000/default/_sharding ADDED
@@ -0,0 +1 @@
 
 
1
+ {"b3B0X3N0YXRlLjEuMC5jb3VudA==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuRGVuc2VfMC5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuRGVuc2VfMC5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuRGVuc2VfMS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuRGVuc2VfMS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuRGVuc2VfMi5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuRGVuc2VfMi5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuRGVuc2VfMy5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuRGVuc2VfMy5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuRGVuc2VfNC5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuRGVuc2VfNC5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuRGVuc2VfNS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuRGVuc2VfNS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuRW1iZWRfMC5lbWJlZGRpbmc=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuTGF5ZXJOb3JtXzAuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuTGF5ZXJOb3JtXzAuc2NhbGU=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuTGF5ZXJOb3JtXzEuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuTGF5ZXJOb3JtXzEuc2NhbGU=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja180Lk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC52YWx1ZS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja180Lk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC52YWx1ZS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja180Lk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5rZXkuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja180Lk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5rZXkua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja180Lk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5vdXQuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja180Lk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5vdXQua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja180Lk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5xdWVyeS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja180Lk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5xdWVyeS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja180LkRlbnNlXzAuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja180LkRlbnNlXzAua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja180LkRlbnNlXzEuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja180LkRlbnNlXzEua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja180LkxheWVyTm9ybV8wLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja180LkxheWVyTm9ybV8wLnNjYWxl":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja180LkxheWVyTm9ybV8xLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja180LkxheWVyTm9ybV8xLnNjYWxl":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja181Lk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC52YWx1ZS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja181Lk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC52YWx1ZS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja181Lk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5rZXkuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja181Lk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5rZXkua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja181Lk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5vdXQuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja181Lk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5vdXQua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja181Lk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5xdWVyeS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja181Lk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5xdWVyeS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja181LkRlbnNlXzAuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja181LkRlbnNlXzAua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja181LkRlbnNlXzEuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja181LkRlbnNlXzEua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja181LkxheWVyTm9ybV8wLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja181LkxheWVyTm9ybV8wLnNjYWxl":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja181LkxheWVyTm9ybV8xLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja181LkxheWVyTm9ybV8xLnNjYWxl":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18wLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC52YWx1ZS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18wLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC52YWx1ZS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18wLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5rZXkuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18wLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5rZXkua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18wLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5vdXQuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18wLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5vdXQua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18wLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5xdWVyeS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18wLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5xdWVyeS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18wLkRlbnNlXzAuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18wLkRlbnNlXzAua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18wLkRlbnNlXzEuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18wLkRlbnNlXzEua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18wLkxheWVyTm9ybV8wLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18wLkxheWVyTm9ybV8wLnNjYWxl":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18wLkxheWVyTm9ybV8xLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18wLkxheWVyTm9ybV8xLnNjYWxl":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18xLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC52YWx1ZS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18xLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC52YWx1ZS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18xLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5rZXkuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18xLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5rZXkua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18xLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5vdXQuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18xLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5vdXQua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18xLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5xdWVyeS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18xLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5xdWVyeS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18xLkRlbnNlXzAuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18xLkRlbnNlXzAua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18xLkRlbnNlXzEuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18xLkRlbnNlXzEua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18xLkxheWVyTm9ybV8wLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18xLkxheWVyTm9ybV8wLnNjYWxl":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18xLkxheWVyTm9ybV8xLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18xLkxheWVyTm9ybV8xLnNjYWxl":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18yLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC52YWx1ZS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18yLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC52YWx1ZS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18yLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5rZXkuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18yLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5rZXkua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18yLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5vdXQuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18yLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5vdXQua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18yLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5xdWVyeS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18yLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5xdWVyeS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18yLkRlbnNlXzAuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18yLkRlbnNlXzAua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18yLkRlbnNlXzEuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18yLkRlbnNlXzEua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18yLkxheWVyTm9ybV8wLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18yLkxheWVyTm9ybV8wLnNjYWxl":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18yLkxheWVyTm9ybV8xLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18yLkxheWVyTm9ybV8xLnNjYWxl":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18zLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC52YWx1ZS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18zLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC52YWx1ZS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18zLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5rZXkuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18zLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5rZXkua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18zLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5vdXQuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18zLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5vdXQua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18zLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5xdWVyeS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18zLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5xdWVyeS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18zLkRlbnNlXzAuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18zLkRlbnNlXzAua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18zLkRlbnNlXzEuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18zLkRlbnNlXzEua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18zLkxheWVyTm9ybV8wLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18zLkxheWVyTm9ybV8wLnNjYWxl":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18zLkxheWVyTm9ybV8xLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18zLkxheWVyTm9ybV8xLnNjYWxl":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuRGVuc2VfMC5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuRGVuc2VfMC5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuRGVuc2VfMS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuRGVuc2VfMS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuRGVuc2VfMi5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuRGVuc2VfMi5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuRGVuc2VfMy5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuRGVuc2VfMy5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuRGVuc2VfNC5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuRGVuc2VfNC5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuRGVuc2VfNS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuRGVuc2VfNS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuRW1iZWRfMC5lbWJlZGRpbmc=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuTGF5ZXJOb3JtXzAuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuTGF5ZXJOb3JtXzAuc2NhbGU=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuTGF5ZXJOb3JtXzEuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuTGF5ZXJOb3JtXzEuc2NhbGU=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja180Lk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC52YWx1ZS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja180Lk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC52YWx1ZS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja180Lk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5rZXkuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja180Lk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5rZXkua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja180Lk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5vdXQuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja180Lk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5vdXQua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja180Lk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5xdWVyeS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja180Lk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5xdWVyeS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja180LkRlbnNlXzAuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja180LkRlbnNlXzAua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja180LkRlbnNlXzEuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja180LkRlbnNlXzEua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja180LkxheWVyTm9ybV8wLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja180LkxheWVyTm9ybV8wLnNjYWxl":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja180LkxheWVyTm9ybV8xLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja180LkxheWVyTm9ybV8xLnNjYWxl":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja181Lk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC52YWx1ZS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja181Lk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC52YWx1ZS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja181Lk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5rZXkuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja181Lk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5rZXkua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja181Lk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5vdXQuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja181Lk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5vdXQua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja181Lk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5xdWVyeS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja181Lk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5xdWVyeS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja181LkRlbnNlXzAuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja181LkRlbnNlXzAua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja181LkRlbnNlXzEuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja181LkRlbnNlXzEua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja181LkxheWVyTm9ybV8wLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja181LkxheWVyTm9ybV8wLnNjYWxl":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja181LkxheWVyTm9ybV8xLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja181LkxheWVyTm9ybV8xLnNjYWxl":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18wLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC52YWx1ZS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18wLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC52YWx1ZS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18wLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5rZXkuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18wLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5rZXkua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18wLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5vdXQuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18wLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5vdXQua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18wLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5xdWVyeS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18wLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5xdWVyeS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18wLkRlbnNlXzAuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18wLkRlbnNlXzAua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18wLkRlbnNlXzEuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18wLkRlbnNlXzEua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18wLkxheWVyTm9ybV8wLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18wLkxheWVyTm9ybV8wLnNjYWxl":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18wLkxheWVyTm9ybV8xLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18wLkxheWVyTm9ybV8xLnNjYWxl":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18xLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC52YWx1ZS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18xLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC52YWx1ZS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18xLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5rZXkuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18xLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5rZXkua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18xLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5vdXQuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18xLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5vdXQua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18xLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5xdWVyeS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18xLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5xdWVyeS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18xLkRlbnNlXzAuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18xLkRlbnNlXzAua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18xLkRlbnNlXzEuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18xLkRlbnNlXzEua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18xLkxheWVyTm9ybV8wLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18xLkxheWVyTm9ybV8wLnNjYWxl":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18xLkxheWVyTm9ybV8xLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18xLkxheWVyTm9ybV8xLnNjYWxl":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18yLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC52YWx1ZS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18yLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC52YWx1ZS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18yLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5rZXkuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18yLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5rZXkua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18yLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5vdXQuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18yLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5vdXQua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18yLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5xdWVyeS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18yLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5xdWVyeS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18yLkRlbnNlXzAuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18yLkRlbnNlXzAua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18yLkRlbnNlXzEuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18yLkRlbnNlXzEua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18yLkxheWVyTm9ybV8wLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18yLkxheWVyTm9ybV8wLnNjYWxl":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18yLkxheWVyTm9ybV8xLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18yLkxheWVyTm9ybV8xLnNjYWxl":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18zLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC52YWx1ZS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18zLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC52YWx1ZS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18zLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5rZXkuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18zLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5rZXkua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18zLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5vdXQuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18zLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5vdXQua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18zLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5xdWVyeS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18zLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5xdWVyeS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18zLkRlbnNlXzAuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18zLkRlbnNlXzAua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18zLkRlbnNlXzEuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18zLkRlbnNlXzEua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18zLkxheWVyTm9ybV8wLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18zLkxheWVyTm9ybV8wLnNjYWxl":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18zLkxheWVyTm9ybV8xLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18zLkxheWVyTm9ybV8xLnNjYWxl":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMS5jb3VudA==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","c3RlcA==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5EZW5zZV80LmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5EZW5zZV80Lmtlcm5lbA==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5EZW5zZV81LmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5EZW5zZV81Lmtlcm5lbA==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5EZW5zZV8wLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5EZW5zZV8wLmtlcm5lbA==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5EZW5zZV8xLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5EZW5zZV8xLmtlcm5lbA==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5EZW5zZV8yLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5EZW5zZV8yLmtlcm5lbA==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5EZW5zZV8zLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5EZW5zZV8zLmtlcm5lbA==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5FbWJlZF8wLmVtYmVkZGluZw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5MYXllck5vcm1fMC5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5MYXllck5vcm1fMC5zY2FsZQ==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5MYXllck5vcm1fMS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5MYXllck5vcm1fMS5zY2FsZQ==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzAuRGVuc2VfMC5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzAuRGVuc2VfMC5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzAuRGVuc2VfMS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzAuRGVuc2VfMS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzAuTGF5ZXJOb3JtXzAuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzAuTGF5ZXJOb3JtXzAuc2NhbGU=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzAuTGF5ZXJOb3JtXzEuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzAuTGF5ZXJOb3JtXzEuc2NhbGU=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzAuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLm91dC5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzAuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLm91dC5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzAuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLmtleS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzAuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLmtleS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzAuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLnF1ZXJ5LmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzAuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLnF1ZXJ5Lmtlcm5lbA==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzAuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLnZhbHVlLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzAuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLnZhbHVlLmtlcm5lbA==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzEuRGVuc2VfMC5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzEuRGVuc2VfMC5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzEuRGVuc2VfMS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzEuRGVuc2VfMS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzEuTGF5ZXJOb3JtXzAuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzEuTGF5ZXJOb3JtXzAuc2NhbGU=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzEuTGF5ZXJOb3JtXzEuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzEuTGF5ZXJOb3JtXzEuc2NhbGU=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzEuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLm91dC5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzEuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLm91dC5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzEuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLmtleS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzEuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLmtleS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzEuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLnF1ZXJ5LmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzEuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLnF1ZXJ5Lmtlcm5lbA==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzEuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLnZhbHVlLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzEuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLnZhbHVlLmtlcm5lbA==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzIuRGVuc2VfMC5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzIuRGVuc2VfMC5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzIuRGVuc2VfMS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzIuRGVuc2VfMS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzIuTGF5ZXJOb3JtXzAuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzIuTGF5ZXJOb3JtXzAuc2NhbGU=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzIuTGF5ZXJOb3JtXzEuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzIuTGF5ZXJOb3JtXzEuc2NhbGU=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzIuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLm91dC5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzIuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLm91dC5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzIuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLmtleS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzIuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLmtleS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzIuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLnF1ZXJ5LmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzIuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLnF1ZXJ5Lmtlcm5lbA==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzIuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLnZhbHVlLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzIuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLnZhbHVlLmtlcm5lbA==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzMuRGVuc2VfMC5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzMuRGVuc2VfMC5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzMuRGVuc2VfMS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzMuRGVuc2VfMS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzMuTGF5ZXJOb3JtXzAuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzMuTGF5ZXJOb3JtXzAuc2NhbGU=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzMuTGF5ZXJOb3JtXzEuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzMuTGF5ZXJOb3JtXzEuc2NhbGU=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzMuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLm91dC5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzMuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLm91dC5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzMuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLmtleS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzMuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLmtleS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzMuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLnF1ZXJ5LmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzMuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLnF1ZXJ5Lmtlcm5lbA==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzMuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLnZhbHVlLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzMuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLnZhbHVlLmtlcm5lbA==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzQuRGVuc2VfMC5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzQuRGVuc2VfMC5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzQuRGVuc2VfMS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzQuRGVuc2VfMS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzQuTGF5ZXJOb3JtXzAuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzQuTGF5ZXJOb3JtXzAuc2NhbGU=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzQuTGF5ZXJOb3JtXzEuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzQuTGF5ZXJOb3JtXzEuc2NhbGU=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzQuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLm91dC5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzQuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLm91dC5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzQuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLmtleS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzQuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLmtleS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzQuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLnF1ZXJ5LmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzQuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLnF1ZXJ5Lmtlcm5lbA==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzQuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLnZhbHVlLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzQuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLnZhbHVlLmtlcm5lbA==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzUuRGVuc2VfMC5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzUuRGVuc2VfMC5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzUuRGVuc2VfMS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzUuRGVuc2VfMS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzUuTGF5ZXJOb3JtXzAuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzUuTGF5ZXJOb3JtXzAuc2NhbGU=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzUuTGF5ZXJOb3JtXzEuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzUuTGF5ZXJOb3JtXzEuc2NhbGU=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzUuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLm91dC5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzUuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLm91dC5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzUuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLmtleS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzUuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLmtleS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzUuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLnF1ZXJ5LmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzUuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLnF1ZXJ5Lmtlcm5lbA==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzUuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLnZhbHVlLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzUuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLnZhbHVlLmtlcm5lbA==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}"}
checkpoints/offline/Craftax-Classic-Symbolic-v1-OfflineDiffusion-BC-100M/100000000/default/array_metadatas/process_0 ADDED
@@ -0,0 +1 @@
 
 
1
+ {"array_metadatas": [{"array_metadata": {"param_name": "step", "write_shape": [], "chunk_shape": [], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.Dense_0.bias", "write_shape": [768], "chunk_shape": [768], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.Dense_0.kernel", "write_shape": [1345, 768], "chunk_shape": [1345, 768], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.Dense_1.bias", "write_shape": [768], "chunk_shape": [768], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.Dense_1.kernel", "write_shape": [768, 768], "chunk_shape": [768, 768], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.Dense_2.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.Dense_2.kernel", "write_shape": [768, 384], "chunk_shape": [768, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.Dense_3.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.Dense_3.kernel", "write_shape": [384, 384], "chunk_shape": [384, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.Dense_4.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.Dense_4.kernel", "write_shape": [384, 384], "chunk_shape": [384, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.Dense_5.bias", "write_shape": [17], "chunk_shape": [17], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.Dense_5.kernel", "write_shape": [384, 17], "chunk_shape": [384, 17], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.Embed_0.embedding", "write_shape": [18, 384], "chunk_shape": [18, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.LayerNorm_0.bias", "write_shape": [768], "chunk_shape": [768], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.LayerNorm_0.scale", "write_shape": [768], "chunk_shape": [768], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.LayerNorm_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.LayerNorm_1.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_0.Dense_0.bias", "write_shape": [768], "chunk_shape": [768], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_0.Dense_0.kernel", "write_shape": [384, 768], "chunk_shape": [384, 768], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_0.Dense_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_0.Dense_1.kernel", "write_shape": [768, 384], "chunk_shape": [768, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_0.LayerNorm_0.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_0.LayerNorm_0.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_0.LayerNorm_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_0.LayerNorm_1.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_0.MultiHeadDotProductAttention_0.key.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_0.MultiHeadDotProductAttention_0.key.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_0.MultiHeadDotProductAttention_0.out.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_0.MultiHeadDotProductAttention_0.out.kernel", "write_shape": [8, 48, 384], "chunk_shape": [8, 48, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_0.MultiHeadDotProductAttention_0.query.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_0.MultiHeadDotProductAttention_0.query.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_0.MultiHeadDotProductAttention_0.value.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_0.MultiHeadDotProductAttention_0.value.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_1.Dense_0.bias", "write_shape": [768], "chunk_shape": [768], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_1.Dense_0.kernel", "write_shape": [384, 768], "chunk_shape": [384, 768], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_1.Dense_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_1.Dense_1.kernel", "write_shape": [768, 384], "chunk_shape": [768, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_1.LayerNorm_0.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_1.LayerNorm_0.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_1.LayerNorm_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_1.LayerNorm_1.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_1.MultiHeadDotProductAttention_0.key.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_1.MultiHeadDotProductAttention_0.key.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_1.MultiHeadDotProductAttention_0.out.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_1.MultiHeadDotProductAttention_0.out.kernel", "write_shape": [8, 48, 384], "chunk_shape": [8, 48, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_1.MultiHeadDotProductAttention_0.query.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_1.MultiHeadDotProductAttention_0.query.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_1.MultiHeadDotProductAttention_0.value.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_1.MultiHeadDotProductAttention_0.value.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_2.Dense_0.bias", "write_shape": [768], "chunk_shape": [768], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_2.Dense_0.kernel", "write_shape": [384, 768], "chunk_shape": [384, 768], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_2.Dense_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_2.Dense_1.kernel", "write_shape": [768, 384], "chunk_shape": [768, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_2.LayerNorm_0.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_2.LayerNorm_0.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_2.LayerNorm_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_2.LayerNorm_1.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_2.MultiHeadDotProductAttention_0.key.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_2.MultiHeadDotProductAttention_0.key.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_2.MultiHeadDotProductAttention_0.out.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_2.MultiHeadDotProductAttention_0.out.kernel", "write_shape": [8, 48, 384], "chunk_shape": [8, 48, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_2.MultiHeadDotProductAttention_0.query.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_2.MultiHeadDotProductAttention_0.query.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_2.MultiHeadDotProductAttention_0.value.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_2.MultiHeadDotProductAttention_0.value.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_3.Dense_0.bias", "write_shape": [768], "chunk_shape": [768], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_3.Dense_0.kernel", "write_shape": [384, 768], "chunk_shape": [384, 768], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_3.Dense_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_3.Dense_1.kernel", "write_shape": [768, 384], "chunk_shape": [768, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_3.LayerNorm_0.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_3.LayerNorm_0.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_3.LayerNorm_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_3.LayerNorm_1.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_3.MultiHeadDotProductAttention_0.key.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_3.MultiHeadDotProductAttention_0.key.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_3.MultiHeadDotProductAttention_0.out.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_3.MultiHeadDotProductAttention_0.out.kernel", "write_shape": [8, 48, 384], "chunk_shape": [8, 48, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_3.MultiHeadDotProductAttention_0.query.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_3.MultiHeadDotProductAttention_0.query.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_3.MultiHeadDotProductAttention_0.value.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_3.MultiHeadDotProductAttention_0.value.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_4.Dense_0.bias", "write_shape": [768], "chunk_shape": [768], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_4.Dense_0.kernel", "write_shape": [384, 768], "chunk_shape": [384, 768], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_4.Dense_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_4.Dense_1.kernel", "write_shape": [768, 384], "chunk_shape": [768, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_4.LayerNorm_0.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_4.LayerNorm_0.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_4.LayerNorm_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_4.LayerNorm_1.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_4.MultiHeadDotProductAttention_0.key.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_4.MultiHeadDotProductAttention_0.key.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_4.MultiHeadDotProductAttention_0.out.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_4.MultiHeadDotProductAttention_0.out.kernel", "write_shape": [8, 48, 384], "chunk_shape": [8, 48, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_4.MultiHeadDotProductAttention_0.query.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_4.MultiHeadDotProductAttention_0.query.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_4.MultiHeadDotProductAttention_0.value.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_4.MultiHeadDotProductAttention_0.value.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_5.Dense_0.bias", "write_shape": [768], "chunk_shape": [768], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_5.Dense_0.kernel", "write_shape": [384, 768], "chunk_shape": [384, 768], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_5.Dense_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_5.Dense_1.kernel", "write_shape": [768, 384], "chunk_shape": [768, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_5.LayerNorm_0.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_5.LayerNorm_0.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_5.LayerNorm_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_5.LayerNorm_1.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_5.MultiHeadDotProductAttention_0.key.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_5.MultiHeadDotProductAttention_0.key.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_5.MultiHeadDotProductAttention_0.out.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_5.MultiHeadDotProductAttention_0.out.kernel", "write_shape": [8, 48, 384], "chunk_shape": [8, 48, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_5.MultiHeadDotProductAttention_0.query.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_5.MultiHeadDotProductAttention_0.query.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_5.MultiHeadDotProductAttention_0.value.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_5.MultiHeadDotProductAttention_0.value.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.count", "write_shape": [], "chunk_shape": [], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.Dense_0.bias", "write_shape": [768], "chunk_shape": [768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.Dense_0.kernel", "write_shape": [1345, 768], "chunk_shape": [1345, 768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.Dense_1.bias", "write_shape": [768], "chunk_shape": [768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.Dense_1.kernel", "write_shape": [768, 768], "chunk_shape": [768, 768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.Dense_2.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.Dense_2.kernel", "write_shape": [768, 384], "chunk_shape": [768, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.Dense_3.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.Dense_3.kernel", "write_shape": [384, 384], "chunk_shape": [384, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.Dense_4.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.Dense_4.kernel", "write_shape": [384, 384], "chunk_shape": [384, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.Dense_5.bias", "write_shape": [17], "chunk_shape": [17], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.Dense_5.kernel", "write_shape": [384, 17], "chunk_shape": [384, 17], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.Embed_0.embedding", "write_shape": [18, 384], "chunk_shape": [18, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.LayerNorm_0.bias", "write_shape": [768], "chunk_shape": [768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.LayerNorm_0.scale", "write_shape": [768], "chunk_shape": [768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.LayerNorm_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.LayerNorm_1.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_0.Dense_0.bias", "write_shape": [768], "chunk_shape": [768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_0.Dense_0.kernel", "write_shape": [384, 768], "chunk_shape": [384, 768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_0.Dense_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_0.Dense_1.kernel", "write_shape": [768, 384], "chunk_shape": [768, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_0.LayerNorm_0.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_0.LayerNorm_0.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_0.LayerNorm_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_0.LayerNorm_1.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_0.MultiHeadDotProductAttention_0.key.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_0.MultiHeadDotProductAttention_0.key.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_0.MultiHeadDotProductAttention_0.out.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_0.MultiHeadDotProductAttention_0.out.kernel", "write_shape": [8, 48, 384], "chunk_shape": [8, 48, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_0.MultiHeadDotProductAttention_0.query.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_0.MultiHeadDotProductAttention_0.query.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_0.MultiHeadDotProductAttention_0.value.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_0.MultiHeadDotProductAttention_0.value.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_1.Dense_0.bias", "write_shape": [768], "chunk_shape": [768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_1.Dense_0.kernel", "write_shape": [384, 768], "chunk_shape": [384, 768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_1.Dense_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_1.Dense_1.kernel", "write_shape": [768, 384], "chunk_shape": [768, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_1.LayerNorm_0.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_1.LayerNorm_0.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_1.LayerNorm_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_1.LayerNorm_1.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_1.MultiHeadDotProductAttention_0.key.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_1.MultiHeadDotProductAttention_0.key.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_1.MultiHeadDotProductAttention_0.out.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_1.MultiHeadDotProductAttention_0.out.kernel", "write_shape": [8, 48, 384], "chunk_shape": [8, 48, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_1.MultiHeadDotProductAttention_0.query.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_1.MultiHeadDotProductAttention_0.query.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_1.MultiHeadDotProductAttention_0.value.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_1.MultiHeadDotProductAttention_0.value.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_2.Dense_0.bias", "write_shape": [768], "chunk_shape": [768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_2.Dense_0.kernel", "write_shape": [384, 768], "chunk_shape": [384, 768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_2.Dense_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_2.Dense_1.kernel", "write_shape": [768, 384], "chunk_shape": [768, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_2.LayerNorm_0.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_2.LayerNorm_0.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_2.LayerNorm_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_2.LayerNorm_1.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_2.MultiHeadDotProductAttention_0.key.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_2.MultiHeadDotProductAttention_0.key.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_2.MultiHeadDotProductAttention_0.out.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_2.MultiHeadDotProductAttention_0.out.kernel", "write_shape": [8, 48, 384], "chunk_shape": [8, 48, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_2.MultiHeadDotProductAttention_0.query.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_2.MultiHeadDotProductAttention_0.query.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_2.MultiHeadDotProductAttention_0.value.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_2.MultiHeadDotProductAttention_0.value.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_3.Dense_0.bias", "write_shape": [768], "chunk_shape": [768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_3.Dense_0.kernel", "write_shape": [384, 768], "chunk_shape": [384, 768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_3.Dense_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_3.Dense_1.kernel", "write_shape": [768, 384], "chunk_shape": [768, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_3.LayerNorm_0.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_3.LayerNorm_0.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_3.LayerNorm_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_3.LayerNorm_1.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_3.MultiHeadDotProductAttention_0.key.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_3.MultiHeadDotProductAttention_0.key.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_3.MultiHeadDotProductAttention_0.out.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_3.MultiHeadDotProductAttention_0.out.kernel", "write_shape": [8, 48, 384], "chunk_shape": [8, 48, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_3.MultiHeadDotProductAttention_0.query.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_3.MultiHeadDotProductAttention_0.query.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_3.MultiHeadDotProductAttention_0.value.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_3.MultiHeadDotProductAttention_0.value.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_4.Dense_0.bias", "write_shape": [768], "chunk_shape": [768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_4.Dense_0.kernel", "write_shape": [384, 768], "chunk_shape": [384, 768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_4.Dense_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_4.Dense_1.kernel", "write_shape": [768, 384], "chunk_shape": [768, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_4.LayerNorm_0.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_4.LayerNorm_0.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_4.LayerNorm_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_4.LayerNorm_1.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_4.MultiHeadDotProductAttention_0.key.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_4.MultiHeadDotProductAttention_0.key.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_4.MultiHeadDotProductAttention_0.out.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_4.MultiHeadDotProductAttention_0.out.kernel", "write_shape": [8, 48, 384], "chunk_shape": [8, 48, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_4.MultiHeadDotProductAttention_0.query.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_4.MultiHeadDotProductAttention_0.query.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_4.MultiHeadDotProductAttention_0.value.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_4.MultiHeadDotProductAttention_0.value.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_5.Dense_0.bias", "write_shape": [768], "chunk_shape": [768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_5.Dense_0.kernel", "write_shape": [384, 768], "chunk_shape": [384, 768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_5.Dense_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_5.Dense_1.kernel", "write_shape": [768, 384], "chunk_shape": [768, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_5.LayerNorm_0.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_5.LayerNorm_0.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_5.LayerNorm_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_5.LayerNorm_1.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_5.MultiHeadDotProductAttention_0.key.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_5.MultiHeadDotProductAttention_0.key.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_5.MultiHeadDotProductAttention_0.out.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_5.MultiHeadDotProductAttention_0.out.kernel", "write_shape": [8, 48, 384], "chunk_shape": [8, 48, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_5.MultiHeadDotProductAttention_0.query.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_5.MultiHeadDotProductAttention_0.query.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_5.MultiHeadDotProductAttention_0.value.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_5.MultiHeadDotProductAttention_0.value.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.Dense_0.bias", "write_shape": [768], "chunk_shape": [768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.Dense_0.kernel", "write_shape": [1345, 768], "chunk_shape": [1345, 768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.Dense_1.bias", "write_shape": [768], "chunk_shape": [768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.Dense_1.kernel", "write_shape": [768, 768], "chunk_shape": [768, 768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.Dense_2.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.Dense_2.kernel", "write_shape": [768, 384], "chunk_shape": [768, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.Dense_3.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.Dense_3.kernel", "write_shape": [384, 384], "chunk_shape": [384, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.Dense_4.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.Dense_4.kernel", "write_shape": [384, 384], "chunk_shape": [384, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.Dense_5.bias", "write_shape": [17], "chunk_shape": [17], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.Dense_5.kernel", "write_shape": [384, 17], "chunk_shape": [384, 17], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.Embed_0.embedding", "write_shape": [18, 384], "chunk_shape": [18, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.LayerNorm_0.bias", "write_shape": [768], "chunk_shape": [768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.LayerNorm_0.scale", "write_shape": [768], "chunk_shape": [768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.LayerNorm_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.LayerNorm_1.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_0.Dense_0.bias", "write_shape": [768], "chunk_shape": [768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_0.Dense_0.kernel", "write_shape": [384, 768], "chunk_shape": [384, 768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_0.Dense_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_0.Dense_1.kernel", "write_shape": [768, 384], "chunk_shape": [768, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_0.LayerNorm_0.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_0.LayerNorm_0.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_0.LayerNorm_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_0.LayerNorm_1.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_0.MultiHeadDotProductAttention_0.key.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_0.MultiHeadDotProductAttention_0.key.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_0.MultiHeadDotProductAttention_0.out.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_0.MultiHeadDotProductAttention_0.out.kernel", "write_shape": [8, 48, 384], "chunk_shape": [8, 48, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_0.MultiHeadDotProductAttention_0.query.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_0.MultiHeadDotProductAttention_0.query.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_0.MultiHeadDotProductAttention_0.value.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_0.MultiHeadDotProductAttention_0.value.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_1.Dense_0.bias", "write_shape": [768], "chunk_shape": [768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_1.Dense_0.kernel", "write_shape": [384, 768], "chunk_shape": [384, 768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_1.Dense_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_1.Dense_1.kernel", "write_shape": [768, 384], "chunk_shape": [768, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_1.LayerNorm_0.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_1.LayerNorm_0.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_1.LayerNorm_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_1.LayerNorm_1.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_1.MultiHeadDotProductAttention_0.key.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_1.MultiHeadDotProductAttention_0.key.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_1.MultiHeadDotProductAttention_0.out.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_1.MultiHeadDotProductAttention_0.out.kernel", "write_shape": [8, 48, 384], "chunk_shape": [8, 48, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_1.MultiHeadDotProductAttention_0.query.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_1.MultiHeadDotProductAttention_0.query.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_1.MultiHeadDotProductAttention_0.value.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_1.MultiHeadDotProductAttention_0.value.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_2.Dense_0.bias", "write_shape": [768], "chunk_shape": [768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_2.Dense_0.kernel", "write_shape": [384, 768], "chunk_shape": [384, 768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_2.Dense_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_2.Dense_1.kernel", "write_shape": [768, 384], "chunk_shape": [768, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_2.LayerNorm_0.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_2.LayerNorm_0.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_2.LayerNorm_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_2.LayerNorm_1.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_2.MultiHeadDotProductAttention_0.key.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_2.MultiHeadDotProductAttention_0.key.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_2.MultiHeadDotProductAttention_0.out.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_2.MultiHeadDotProductAttention_0.out.kernel", "write_shape": [8, 48, 384], "chunk_shape": [8, 48, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_2.MultiHeadDotProductAttention_0.query.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_2.MultiHeadDotProductAttention_0.query.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_2.MultiHeadDotProductAttention_0.value.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_2.MultiHeadDotProductAttention_0.value.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_3.Dense_0.bias", "write_shape": [768], "chunk_shape": [768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_3.Dense_0.kernel", "write_shape": [384, 768], "chunk_shape": [384, 768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_3.Dense_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_3.Dense_1.kernel", "write_shape": [768, 384], "chunk_shape": [768, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_3.LayerNorm_0.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_3.LayerNorm_0.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_3.LayerNorm_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_3.LayerNorm_1.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_3.MultiHeadDotProductAttention_0.key.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_3.MultiHeadDotProductAttention_0.key.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_3.MultiHeadDotProductAttention_0.out.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_3.MultiHeadDotProductAttention_0.out.kernel", "write_shape": [8, 48, 384], "chunk_shape": [8, 48, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_3.MultiHeadDotProductAttention_0.query.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_3.MultiHeadDotProductAttention_0.query.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_3.MultiHeadDotProductAttention_0.value.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_3.MultiHeadDotProductAttention_0.value.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_4.Dense_0.bias", "write_shape": [768], "chunk_shape": [768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_4.Dense_0.kernel", "write_shape": [384, 768], "chunk_shape": [384, 768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_4.Dense_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_4.Dense_1.kernel", "write_shape": [768, 384], "chunk_shape": [768, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_4.LayerNorm_0.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_4.LayerNorm_0.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_4.LayerNorm_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_4.LayerNorm_1.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_4.MultiHeadDotProductAttention_0.key.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_4.MultiHeadDotProductAttention_0.key.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_4.MultiHeadDotProductAttention_0.out.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_4.MultiHeadDotProductAttention_0.out.kernel", "write_shape": [8, 48, 384], "chunk_shape": [8, 48, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_4.MultiHeadDotProductAttention_0.query.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_4.MultiHeadDotProductAttention_0.query.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_4.MultiHeadDotProductAttention_0.value.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_4.MultiHeadDotProductAttention_0.value.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_5.Dense_0.bias", "write_shape": [768], "chunk_shape": [768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_5.Dense_0.kernel", "write_shape": [384, 768], "chunk_shape": [384, 768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_5.Dense_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_5.Dense_1.kernel", "write_shape": [768, 384], "chunk_shape": [768, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_5.LayerNorm_0.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_5.LayerNorm_0.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_5.LayerNorm_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_5.LayerNorm_1.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_5.MultiHeadDotProductAttention_0.key.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_5.MultiHeadDotProductAttention_0.key.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_5.MultiHeadDotProductAttention_0.out.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_5.MultiHeadDotProductAttention_0.out.kernel", "write_shape": [8, 48, 384], "chunk_shape": [8, 48, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_5.MultiHeadDotProductAttention_0.query.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_5.MultiHeadDotProductAttention_0.query.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_5.MultiHeadDotProductAttention_0.value.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_5.MultiHeadDotProductAttention_0.value.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.1.count", "write_shape": [], "chunk_shape": [], "ext_metadata": null}}]}
checkpoints/offline/Craftax-Classic-Symbolic-v1-OfflineDiffusion-BC-100M/100000000/default/d/021af9ba431a3072f4819480f91b83af ADDED
Binary file (3.72 kB). View file
 
checkpoints/offline/Craftax-Classic-Symbolic-v1-OfflineDiffusion-BC-100M/100000000/default/manifest.ocdbt ADDED
Binary file (117 Bytes). View file
 
checkpoints/offline/Craftax-Classic-Symbolic-v1-OfflineDiffusion-BC-100M/100000000/default/ocdbt.process_0/d/0cf9a08a9722f9b8a0b7f007da7c1e92 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bbff61a18e9475d72fae302d4748615daf5fc6b87cc0e0a338c96b8a781d6c0f
3
+ size 101199872
checkpoints/offline/Craftax-Classic-Symbolic-v1-OfflineDiffusion-BC-100M/100000000/default/ocdbt.process_0/d/1968eb861d84503c0e805cffdd77528a ADDED
Binary file (832 Bytes). View file
 
checkpoints/offline/Craftax-Classic-Symbolic-v1-OfflineDiffusion-BC-100M/100000000/default/ocdbt.process_0/d/aec20934c03229d1bd9651c955e59d84 ADDED
Binary file (171 Bytes). View file
 
checkpoints/offline/Craftax-Classic-Symbolic-v1-OfflineDiffusion-BC-100M/100000000/default/ocdbt.process_0/d/d32b0ea672fe7a9b86b8e62e7c20dbaf ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:66e7df58a5ad39030e5631943ffa5d45164b91f283a2b7b34d4265c6bbf08be4
3
+ size 448037
checkpoints/offline/Craftax-Classic-Symbolic-v1-OfflineDiffusion-BC-100M/100000000/default/ocdbt.process_0/manifest.ocdbt ADDED
Binary file (259 Bytes). View file
 
checkpoints/offline/Craftax-Classic-Symbolic-v1-OfflineDiffusion-BC-100M/resume_metadata.json ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "mode": "offline",
3
+ "update_step": 1525,
4
+ "total_gradient_steps_completed": 97600,
5
+ "wandb_run_id": "6opvce2t",
6
+ "config_snapshot": {
7
+ "ENV_NAME": "Craftax-Classic-Symbolic-v1",
8
+ "USE_OPTIMISTIC_RESETS": false,
9
+ "OPTIMISTIC_RESET_RATIO": 16,
10
+ "D_MODEL": 384,
11
+ "N_HEADS": 8,
12
+ "N_LAYERS": 6,
13
+ "D_FF": 768,
14
+ "OBS_ENCODER_LAYERS": 2,
15
+ "OBS_ENCODER_WIDTH": 768,
16
+ "DROPOUT_RATE": 0.1,
17
+ "PLAN_HORIZON": 32,
18
+ "DIFFUSION_SCHEDULE": "cosine",
19
+ "TRAIN_SIGMA": 0.0,
20
+ "LABEL_SMOOTHING": 0.0,
21
+ "DIFFUSION_STEPS": 15,
22
+ "DIFFUSION_STEPS_EVAL": 10,
23
+ "REMASK_STRATEGY": "rescale",
24
+ "ETA": 0.5,
25
+ "USE_LOOP": true,
26
+ "T_ON": 0.7,
27
+ "T_OFF": 0.3,
28
+ "TEMPERATURE": 0.5,
29
+ "TOP_P": 0.95,
30
+ "LR": 0.0003,
31
+ "MAX_GRAD_NORM": 1.0,
32
+ "LR_WARMUP_FRAMES": "1.048576e8",
33
+ "NUM_ENVS": 512,
34
+ "NUM_STEPS": 128,
35
+ "NUM_MINIBATCHES": 8,
36
+ "UPDATE_EPOCHS": 8,
37
+ "NUM_REPEATS": 1,
38
+ "OFFLINE_TOTAL_TIMESTEPS": 99942400,
39
+ "COLLECT_TEMPERATURE": 1.0,
40
+ "RETURN_WEIGHT_CAP": 5.0,
41
+ "ONLINE_TOTAL_TIMESTEPS": 100000000.0,
42
+ "DAGGER_BETA_INIT": 1.0,
43
+ "DAGGER_BETA_FINAL": 0.344,
44
+ "DAGGER_BUFFER_CYCLES": 1.90735,
45
+ "VAL_INTERVAL_FRAMES": 1000000.0,
46
+ "VAL_DIFFUSION_STEPS": 50,
47
+ "VAL_REPLAN_EVERY": 4,
48
+ "VAL_STEPS": 256,
49
+ "COLLECT_NUM_STEPS": 10000000,
50
+ "COLLECT_NUM_ENVS": 128,
51
+ "PPO_MODEL_TYPE": "ppo_rnn",
52
+ "LAYER_SIZE": 512,
53
+ "EVAL_STEPS": 10000,
54
+ "EVAL_NUM_ENVS": 32,
55
+ "SAVE_POLICY": true,
56
+ "SEED": 42,
57
+ "USE_WANDB": true,
58
+ "WANDB_PROJECT": "remdm-craftax",
59
+ "WANDB_ENTITY": "mathis-weil-university-college-london-ucl-",
60
+ "MODE": "offline",
61
+ "JIT": true,
62
+ "PPO_CHECKPOINT_PATH": "checkpoints/ppo_agents/policies/Craftax-Classic-Symbolic-v1-PPO_RNN-1000M",
63
+ "NUM_UPDATES": 1525,
64
+ "LR_WARMUP_STEPS": 1600,
65
+ "VAL_INTERVAL": 15,
66
+ "MINIBATCH_SIZE": 6208
67
+ }
68
+ }
checkpoints/online/Craftax-Classic-Symbolic-v1-OnlineDiffusion-DAgger-50M/50000000/_CHECKPOINT_METADATA ADDED
@@ -0,0 +1 @@
 
 
1
+ {"item_handlers": {"default": "orbax.checkpoint._src.handlers.standard_checkpoint_handler.StandardCheckpointHandler"}, "metrics": {}, "performance_metrics": {}, "init_timestamp_nsecs": 1775623858059636986, "commit_timestamp_nsecs": 1775623858516125466, "custom_metadata": {}}
checkpoints/online/Craftax-Classic-Symbolic-v1-OnlineDiffusion-DAgger-50M/50000000/default/_METADATA ADDED
The diff for this file is too large to render. See raw diff
 
checkpoints/online/Craftax-Classic-Symbolic-v1-OnlineDiffusion-DAgger-50M/50000000/default/_sharding ADDED
@@ -0,0 +1 @@
 
 
1
+ {"b3B0X3N0YXRlLjEuMC5jb3VudA==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuRGVuc2VfMC5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuRGVuc2VfMC5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuRGVuc2VfMS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuRGVuc2VfMS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuRGVuc2VfMi5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuRGVuc2VfMi5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuRGVuc2VfMy5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuRGVuc2VfMy5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuRGVuc2VfNC5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuRGVuc2VfNC5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuRGVuc2VfNS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuRGVuc2VfNS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuRW1iZWRfMC5lbWJlZGRpbmc=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuTGF5ZXJOb3JtXzAuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuTGF5ZXJOb3JtXzAuc2NhbGU=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuTGF5ZXJOb3JtXzEuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuTGF5ZXJOb3JtXzEuc2NhbGU=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja180Lk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC52YWx1ZS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja180Lk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC52YWx1ZS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja180Lk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5rZXkuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja180Lk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5rZXkua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja180Lk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5vdXQuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja180Lk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5vdXQua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja180Lk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5xdWVyeS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja180Lk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5xdWVyeS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja180LkRlbnNlXzAuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja180LkRlbnNlXzAua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja180LkRlbnNlXzEuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja180LkRlbnNlXzEua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja180LkxheWVyTm9ybV8wLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja180LkxheWVyTm9ybV8wLnNjYWxl":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja180LkxheWVyTm9ybV8xLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja180LkxheWVyTm9ybV8xLnNjYWxl":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja181Lk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC52YWx1ZS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja181Lk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC52YWx1ZS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja181Lk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5rZXkuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja181Lk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5rZXkua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja181Lk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5vdXQuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja181Lk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5vdXQua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja181Lk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5xdWVyeS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja181Lk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5xdWVyeS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja181LkRlbnNlXzAuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja181LkRlbnNlXzAua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja181LkRlbnNlXzEuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja181LkRlbnNlXzEua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja181LkxheWVyTm9ybV8wLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja181LkxheWVyTm9ybV8wLnNjYWxl":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja181LkxheWVyTm9ybV8xLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja181LkxheWVyTm9ybV8xLnNjYWxl":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18wLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC52YWx1ZS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18wLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC52YWx1ZS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18wLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5rZXkuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18wLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5rZXkua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18wLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5vdXQuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18wLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5vdXQua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18wLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5xdWVyeS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18wLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5xdWVyeS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18wLkRlbnNlXzAuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18wLkRlbnNlXzAua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18wLkRlbnNlXzEuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18wLkRlbnNlXzEua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18wLkxheWVyTm9ybV8wLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18wLkxheWVyTm9ybV8wLnNjYWxl":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18wLkxheWVyTm9ybV8xLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18wLkxheWVyTm9ybV8xLnNjYWxl":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18xLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC52YWx1ZS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18xLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC52YWx1ZS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18xLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5rZXkuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18xLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5rZXkua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18xLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5vdXQuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18xLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5vdXQua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18xLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5xdWVyeS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18xLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5xdWVyeS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18xLkRlbnNlXzAuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18xLkRlbnNlXzAua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18xLkRlbnNlXzEuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18xLkRlbnNlXzEua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18xLkxheWVyTm9ybV8wLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18xLkxheWVyTm9ybV8wLnNjYWxl":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18xLkxheWVyTm9ybV8xLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18xLkxheWVyTm9ybV8xLnNjYWxl":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18yLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC52YWx1ZS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18yLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC52YWx1ZS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18yLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5rZXkuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18yLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5rZXkua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18yLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5vdXQuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18yLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5vdXQua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18yLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5xdWVyeS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18yLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5xdWVyeS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18yLkRlbnNlXzAuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18yLkRlbnNlXzAua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18yLkRlbnNlXzEuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18yLkRlbnNlXzEua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18yLkxheWVyTm9ybV8wLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18yLkxheWVyTm9ybV8wLnNjYWxl":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18yLkxheWVyTm9ybV8xLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18yLkxheWVyTm9ybV8xLnNjYWxl":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18zLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC52YWx1ZS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18zLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC52YWx1ZS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18zLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5rZXkuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18zLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5rZXkua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18zLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5vdXQuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18zLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5vdXQua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18zLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5xdWVyeS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18zLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5xdWVyeS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18zLkRlbnNlXzAuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18zLkRlbnNlXzAua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18zLkRlbnNlXzEuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18zLkRlbnNlXzEua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18zLkxheWVyTm9ybV8wLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18zLkxheWVyTm9ybV8wLnNjYWxl":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18zLkxheWVyTm9ybV8xLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18zLkxheWVyTm9ybV8xLnNjYWxl":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuRGVuc2VfMC5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuRGVuc2VfMC5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuRGVuc2VfMS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuRGVuc2VfMS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuRGVuc2VfMi5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuRGVuc2VfMi5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuRGVuc2VfMy5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuRGVuc2VfMy5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuRGVuc2VfNC5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuRGVuc2VfNC5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuRGVuc2VfNS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuRGVuc2VfNS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuRW1iZWRfMC5lbWJlZGRpbmc=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuTGF5ZXJOb3JtXzAuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuTGF5ZXJOb3JtXzAuc2NhbGU=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuTGF5ZXJOb3JtXzEuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuTGF5ZXJOb3JtXzEuc2NhbGU=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja180Lk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC52YWx1ZS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja180Lk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC52YWx1ZS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja180Lk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5rZXkuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja180Lk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5rZXkua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja180Lk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5vdXQuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja180Lk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5vdXQua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja180Lk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5xdWVyeS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja180Lk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5xdWVyeS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja180LkRlbnNlXzAuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja180LkRlbnNlXzAua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja180LkRlbnNlXzEuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja180LkRlbnNlXzEua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja180LkxheWVyTm9ybV8wLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja180LkxheWVyTm9ybV8wLnNjYWxl":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja180LkxheWVyTm9ybV8xLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja180LkxheWVyTm9ybV8xLnNjYWxl":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja181Lk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC52YWx1ZS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja181Lk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC52YWx1ZS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja181Lk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5rZXkuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja181Lk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5rZXkua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja181Lk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5vdXQuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja181Lk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5vdXQua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja181Lk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5xdWVyeS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja181Lk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5xdWVyeS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja181LkRlbnNlXzAuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja181LkRlbnNlXzAua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja181LkRlbnNlXzEuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja181LkRlbnNlXzEua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja181LkxheWVyTm9ybV8wLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja181LkxheWVyTm9ybV8wLnNjYWxl":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja181LkxheWVyTm9ybV8xLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja181LkxheWVyTm9ybV8xLnNjYWxl":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18wLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC52YWx1ZS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18wLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC52YWx1ZS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18wLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5rZXkuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18wLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5rZXkua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18wLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5vdXQuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18wLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5vdXQua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18wLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5xdWVyeS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18wLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5xdWVyeS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18wLkRlbnNlXzAuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18wLkRlbnNlXzAua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18wLkRlbnNlXzEuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18wLkRlbnNlXzEua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18wLkxheWVyTm9ybV8wLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18wLkxheWVyTm9ybV8wLnNjYWxl":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18wLkxheWVyTm9ybV8xLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18wLkxheWVyTm9ybV8xLnNjYWxl":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18xLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC52YWx1ZS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18xLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC52YWx1ZS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18xLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5rZXkuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18xLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5rZXkua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18xLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5vdXQuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18xLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5vdXQua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18xLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5xdWVyeS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18xLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5xdWVyeS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18xLkRlbnNlXzAuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18xLkRlbnNlXzAua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18xLkRlbnNlXzEuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18xLkRlbnNlXzEua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18xLkxheWVyTm9ybV8wLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18xLkxheWVyTm9ybV8wLnNjYWxl":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18xLkxheWVyTm9ybV8xLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18xLkxheWVyTm9ybV8xLnNjYWxl":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18yLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC52YWx1ZS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18yLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC52YWx1ZS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18yLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5rZXkuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18yLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5rZXkua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18yLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5vdXQuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18yLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5vdXQua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18yLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5xdWVyeS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18yLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5xdWVyeS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18yLkRlbnNlXzAuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18yLkRlbnNlXzAua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18yLkRlbnNlXzEuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18yLkRlbnNlXzEua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18yLkxheWVyTm9ybV8wLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18yLkxheWVyTm9ybV8wLnNjYWxl":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18yLkxheWVyTm9ybV8xLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18yLkxheWVyTm9ybV8xLnNjYWxl":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18zLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC52YWx1ZS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18zLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC52YWx1ZS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18zLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5rZXkuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18zLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5rZXkua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18zLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5vdXQuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18zLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5vdXQua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18zLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5xdWVyeS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18zLk11bHRpSGVhZERvdFByb2R1Y3RBdHRlbnRpb25fMC5xdWVyeS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18zLkRlbnNlXzAuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18zLkRlbnNlXzAua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18zLkRlbnNlXzEuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18zLkRlbnNlXzEua2VybmVs":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18zLkxheWVyTm9ybV8wLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18zLkxheWVyTm9ybV8wLnNjYWxl":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18zLkxheWVyTm9ybV8xLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuVHJhbnNmb3JtZXJCbG9ja18zLkxheWVyTm9ybV8xLnNjYWxl":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5EZW5zZV80LmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5EZW5zZV80Lmtlcm5lbA==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5EZW5zZV81LmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5EZW5zZV81Lmtlcm5lbA==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5EZW5zZV8wLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5EZW5zZV8wLmtlcm5lbA==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5EZW5zZV8xLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5EZW5zZV8xLmtlcm5lbA==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5EZW5zZV8yLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5EZW5zZV8yLmtlcm5lbA==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5EZW5zZV8zLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5EZW5zZV8zLmtlcm5lbA==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5FbWJlZF8wLmVtYmVkZGluZw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5MYXllck5vcm1fMC5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5MYXllck5vcm1fMC5zY2FsZQ==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5MYXllck5vcm1fMS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5MYXllck5vcm1fMS5zY2FsZQ==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzAuRGVuc2VfMC5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzAuRGVuc2VfMC5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzAuRGVuc2VfMS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzAuRGVuc2VfMS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzAuTGF5ZXJOb3JtXzAuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzAuTGF5ZXJOb3JtXzAuc2NhbGU=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzAuTGF5ZXJOb3JtXzEuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzAuTGF5ZXJOb3JtXzEuc2NhbGU=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzAuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLm91dC5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzAuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLm91dC5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzAuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLmtleS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzAuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLmtleS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzAuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLnF1ZXJ5LmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzAuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLnF1ZXJ5Lmtlcm5lbA==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzAuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLnZhbHVlLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzAuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLnZhbHVlLmtlcm5lbA==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzEuRGVuc2VfMC5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzEuRGVuc2VfMC5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzEuRGVuc2VfMS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzEuRGVuc2VfMS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzEuTGF5ZXJOb3JtXzAuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzEuTGF5ZXJOb3JtXzAuc2NhbGU=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzEuTGF5ZXJOb3JtXzEuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzEuTGF5ZXJOb3JtXzEuc2NhbGU=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzEuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLm91dC5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzEuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLm91dC5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzEuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLmtleS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzEuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLmtleS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzEuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLnF1ZXJ5LmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzEuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLnF1ZXJ5Lmtlcm5lbA==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzEuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLnZhbHVlLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzEuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLnZhbHVlLmtlcm5lbA==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzIuRGVuc2VfMC5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzIuRGVuc2VfMC5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzIuRGVuc2VfMS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzIuRGVuc2VfMS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzIuTGF5ZXJOb3JtXzAuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzIuTGF5ZXJOb3JtXzAuc2NhbGU=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzIuTGF5ZXJOb3JtXzEuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzIuTGF5ZXJOb3JtXzEuc2NhbGU=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzIuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLm91dC5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzIuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLm91dC5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzIuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLmtleS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzIuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLmtleS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzIuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLnF1ZXJ5LmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzIuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLnF1ZXJ5Lmtlcm5lbA==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzIuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLnZhbHVlLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzIuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLnZhbHVlLmtlcm5lbA==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzMuRGVuc2VfMC5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzMuRGVuc2VfMC5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzMuRGVuc2VfMS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzMuRGVuc2VfMS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzMuTGF5ZXJOb3JtXzAuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzMuTGF5ZXJOb3JtXzAuc2NhbGU=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzMuTGF5ZXJOb3JtXzEuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzMuTGF5ZXJOb3JtXzEuc2NhbGU=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzMuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLm91dC5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzMuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLm91dC5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzMuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLmtleS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzMuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLmtleS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzMuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLnF1ZXJ5LmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzMuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLnF1ZXJ5Lmtlcm5lbA==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzMuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLnZhbHVlLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzMuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLnZhbHVlLmtlcm5lbA==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzQuRGVuc2VfMC5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzQuRGVuc2VfMC5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzQuRGVuc2VfMS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzQuRGVuc2VfMS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzQuTGF5ZXJOb3JtXzAuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzQuTGF5ZXJOb3JtXzAuc2NhbGU=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzQuTGF5ZXJOb3JtXzEuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzQuTGF5ZXJOb3JtXzEuc2NhbGU=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzQuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLm91dC5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzQuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLm91dC5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzQuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLmtleS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzQuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLmtleS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzQuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLnF1ZXJ5LmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzQuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLnF1ZXJ5Lmtlcm5lbA==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzQuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLnZhbHVlLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzQuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLnZhbHVlLmtlcm5lbA==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzUuRGVuc2VfMC5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzUuRGVuc2VfMC5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzUuRGVuc2VfMS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzUuRGVuc2VfMS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzUuTGF5ZXJOb3JtXzAuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzUuTGF5ZXJOb3JtXzAuc2NhbGU=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzUuTGF5ZXJOb3JtXzEuYmlhcw==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzUuTGF5ZXJOb3JtXzEuc2NhbGU=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzUuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLm91dC5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzUuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLm91dC5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzUuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLmtleS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzUuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLmtleS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzUuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLnF1ZXJ5LmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzUuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLnF1ZXJ5Lmtlcm5lbA==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzUuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLnZhbHVlLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5UcmFuc2Zvcm1lckJsb2NrXzUuTXVsdGlIZWFkRG90UHJvZHVjdEF0dGVudGlvbl8wLnZhbHVlLmtlcm5lbA==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}"}
checkpoints/online/Craftax-Classic-Symbolic-v1-OnlineDiffusion-DAgger-50M/50000000/default/array_metadatas/process_0 ADDED
@@ -0,0 +1 @@
 
 
1
+ {"array_metadatas": [{"array_metadata": {"param_name": "params.params.Dense_0.bias", "write_shape": [768], "chunk_shape": [768], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.Dense_0.kernel", "write_shape": [1345, 768], "chunk_shape": [1345, 768], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.Dense_1.bias", "write_shape": [768], "chunk_shape": [768], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.Dense_1.kernel", "write_shape": [768, 768], "chunk_shape": [768, 768], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.Dense_2.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.Dense_2.kernel", "write_shape": [768, 384], "chunk_shape": [768, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.Dense_3.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.Dense_3.kernel", "write_shape": [384, 384], "chunk_shape": [384, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.Dense_4.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.Dense_4.kernel", "write_shape": [384, 384], "chunk_shape": [384, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.Dense_5.bias", "write_shape": [17], "chunk_shape": [17], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.Dense_5.kernel", "write_shape": [384, 17], "chunk_shape": [384, 17], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.Embed_0.embedding", "write_shape": [18, 384], "chunk_shape": [18, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.LayerNorm_0.bias", "write_shape": [768], "chunk_shape": [768], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.LayerNorm_0.scale", "write_shape": [768], "chunk_shape": [768], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.LayerNorm_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.LayerNorm_1.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_0.Dense_0.bias", "write_shape": [768], "chunk_shape": [768], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_0.Dense_0.kernel", "write_shape": [384, 768], "chunk_shape": [384, 768], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_0.Dense_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_0.Dense_1.kernel", "write_shape": [768, 384], "chunk_shape": [768, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_0.LayerNorm_0.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_0.LayerNorm_0.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_0.LayerNorm_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_0.LayerNorm_1.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_0.MultiHeadDotProductAttention_0.key.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_0.MultiHeadDotProductAttention_0.key.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_0.MultiHeadDotProductAttention_0.out.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_0.MultiHeadDotProductAttention_0.out.kernel", "write_shape": [8, 48, 384], "chunk_shape": [8, 48, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_0.MultiHeadDotProductAttention_0.query.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_0.MultiHeadDotProductAttention_0.query.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_0.MultiHeadDotProductAttention_0.value.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_0.MultiHeadDotProductAttention_0.value.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_1.Dense_0.bias", "write_shape": [768], "chunk_shape": [768], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_1.Dense_0.kernel", "write_shape": [384, 768], "chunk_shape": [384, 768], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_1.Dense_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_1.Dense_1.kernel", "write_shape": [768, 384], "chunk_shape": [768, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_1.LayerNorm_0.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_1.LayerNorm_0.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_1.LayerNorm_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_1.LayerNorm_1.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_1.MultiHeadDotProductAttention_0.key.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_1.MultiHeadDotProductAttention_0.key.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_1.MultiHeadDotProductAttention_0.out.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_1.MultiHeadDotProductAttention_0.out.kernel", "write_shape": [8, 48, 384], "chunk_shape": [8, 48, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_1.MultiHeadDotProductAttention_0.query.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_1.MultiHeadDotProductAttention_0.query.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_1.MultiHeadDotProductAttention_0.value.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_1.MultiHeadDotProductAttention_0.value.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_2.Dense_0.bias", "write_shape": [768], "chunk_shape": [768], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_2.Dense_0.kernel", "write_shape": [384, 768], "chunk_shape": [384, 768], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_2.Dense_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_2.Dense_1.kernel", "write_shape": [768, 384], "chunk_shape": [768, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_2.LayerNorm_0.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_2.LayerNorm_0.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_2.LayerNorm_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_2.LayerNorm_1.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_2.MultiHeadDotProductAttention_0.key.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_2.MultiHeadDotProductAttention_0.key.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_2.MultiHeadDotProductAttention_0.out.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_2.MultiHeadDotProductAttention_0.out.kernel", "write_shape": [8, 48, 384], "chunk_shape": [8, 48, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_2.MultiHeadDotProductAttention_0.query.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_2.MultiHeadDotProductAttention_0.query.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_2.MultiHeadDotProductAttention_0.value.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_2.MultiHeadDotProductAttention_0.value.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_3.Dense_0.bias", "write_shape": [768], "chunk_shape": [768], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_3.Dense_0.kernel", "write_shape": [384, 768], "chunk_shape": [384, 768], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_3.Dense_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_3.Dense_1.kernel", "write_shape": [768, 384], "chunk_shape": [768, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_3.LayerNorm_0.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_3.LayerNorm_0.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_3.LayerNorm_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_3.LayerNorm_1.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_3.MultiHeadDotProductAttention_0.key.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_3.MultiHeadDotProductAttention_0.key.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_3.MultiHeadDotProductAttention_0.out.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_3.MultiHeadDotProductAttention_0.out.kernel", "write_shape": [8, 48, 384], "chunk_shape": [8, 48, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_3.MultiHeadDotProductAttention_0.query.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_3.MultiHeadDotProductAttention_0.query.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_3.MultiHeadDotProductAttention_0.value.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_3.MultiHeadDotProductAttention_0.value.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_4.Dense_0.bias", "write_shape": [768], "chunk_shape": [768], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_4.Dense_0.kernel", "write_shape": [384, 768], "chunk_shape": [384, 768], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_4.Dense_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_4.Dense_1.kernel", "write_shape": [768, 384], "chunk_shape": [768, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_4.LayerNorm_0.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_4.LayerNorm_0.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_4.LayerNorm_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_4.LayerNorm_1.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_4.MultiHeadDotProductAttention_0.key.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_4.MultiHeadDotProductAttention_0.key.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_4.MultiHeadDotProductAttention_0.out.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_4.MultiHeadDotProductAttention_0.out.kernel", "write_shape": [8, 48, 384], "chunk_shape": [8, 48, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_4.MultiHeadDotProductAttention_0.query.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_4.MultiHeadDotProductAttention_0.query.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_4.MultiHeadDotProductAttention_0.value.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_4.MultiHeadDotProductAttention_0.value.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_5.Dense_0.bias", "write_shape": [768], "chunk_shape": [768], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_5.Dense_0.kernel", "write_shape": [384, 768], "chunk_shape": [384, 768], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_5.Dense_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_5.Dense_1.kernel", "write_shape": [768, 384], "chunk_shape": [768, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_5.LayerNorm_0.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_5.LayerNorm_0.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_5.LayerNorm_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_5.LayerNorm_1.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_5.MultiHeadDotProductAttention_0.key.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_5.MultiHeadDotProductAttention_0.key.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_5.MultiHeadDotProductAttention_0.out.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_5.MultiHeadDotProductAttention_0.out.kernel", "write_shape": [8, 48, 384], "chunk_shape": [8, 48, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_5.MultiHeadDotProductAttention_0.query.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_5.MultiHeadDotProductAttention_0.query.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_5.MultiHeadDotProductAttention_0.value.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "params.params.TransformerBlock_5.MultiHeadDotProductAttention_0.value.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.count", "write_shape": [], "chunk_shape": [], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.Dense_0.bias", "write_shape": [768], "chunk_shape": [768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.Dense_0.kernel", "write_shape": [1345, 768], "chunk_shape": [1345, 768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.Dense_1.bias", "write_shape": [768], "chunk_shape": [768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.Dense_1.kernel", "write_shape": [768, 768], "chunk_shape": [768, 768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.Dense_2.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.Dense_2.kernel", "write_shape": [768, 384], "chunk_shape": [768, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.Dense_3.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.Dense_3.kernel", "write_shape": [384, 384], "chunk_shape": [384, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.Dense_4.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.Dense_4.kernel", "write_shape": [384, 384], "chunk_shape": [384, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.Dense_5.bias", "write_shape": [17], "chunk_shape": [17], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.Dense_5.kernel", "write_shape": [384, 17], "chunk_shape": [384, 17], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.Embed_0.embedding", "write_shape": [18, 384], "chunk_shape": [18, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.LayerNorm_0.bias", "write_shape": [768], "chunk_shape": [768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.LayerNorm_0.scale", "write_shape": [768], "chunk_shape": [768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.LayerNorm_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.LayerNorm_1.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_0.Dense_0.bias", "write_shape": [768], "chunk_shape": [768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_0.Dense_0.kernel", "write_shape": [384, 768], "chunk_shape": [384, 768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_0.Dense_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_0.Dense_1.kernel", "write_shape": [768, 384], "chunk_shape": [768, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_0.LayerNorm_0.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_0.LayerNorm_0.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_0.LayerNorm_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_0.LayerNorm_1.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_0.MultiHeadDotProductAttention_0.key.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_0.MultiHeadDotProductAttention_0.key.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_0.MultiHeadDotProductAttention_0.out.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_0.MultiHeadDotProductAttention_0.out.kernel", "write_shape": [8, 48, 384], "chunk_shape": [8, 48, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_0.MultiHeadDotProductAttention_0.query.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_0.MultiHeadDotProductAttention_0.query.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_0.MultiHeadDotProductAttention_0.value.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_0.MultiHeadDotProductAttention_0.value.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_1.Dense_0.bias", "write_shape": [768], "chunk_shape": [768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_1.Dense_0.kernel", "write_shape": [384, 768], "chunk_shape": [384, 768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_1.Dense_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_1.Dense_1.kernel", "write_shape": [768, 384], "chunk_shape": [768, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_1.LayerNorm_0.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_1.LayerNorm_0.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_1.LayerNorm_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_1.LayerNorm_1.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_1.MultiHeadDotProductAttention_0.key.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_1.MultiHeadDotProductAttention_0.key.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_1.MultiHeadDotProductAttention_0.out.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_1.MultiHeadDotProductAttention_0.out.kernel", "write_shape": [8, 48, 384], "chunk_shape": [8, 48, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_1.MultiHeadDotProductAttention_0.query.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_1.MultiHeadDotProductAttention_0.query.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_1.MultiHeadDotProductAttention_0.value.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_1.MultiHeadDotProductAttention_0.value.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_2.Dense_0.bias", "write_shape": [768], "chunk_shape": [768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_2.Dense_0.kernel", "write_shape": [384, 768], "chunk_shape": [384, 768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_2.Dense_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_2.Dense_1.kernel", "write_shape": [768, 384], "chunk_shape": [768, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_2.LayerNorm_0.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_2.LayerNorm_0.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_2.LayerNorm_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_2.LayerNorm_1.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_2.MultiHeadDotProductAttention_0.key.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_2.MultiHeadDotProductAttention_0.key.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_2.MultiHeadDotProductAttention_0.out.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_2.MultiHeadDotProductAttention_0.out.kernel", "write_shape": [8, 48, 384], "chunk_shape": [8, 48, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_2.MultiHeadDotProductAttention_0.query.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_2.MultiHeadDotProductAttention_0.query.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_2.MultiHeadDotProductAttention_0.value.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_2.MultiHeadDotProductAttention_0.value.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_3.Dense_0.bias", "write_shape": [768], "chunk_shape": [768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_3.Dense_0.kernel", "write_shape": [384, 768], "chunk_shape": [384, 768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_3.Dense_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_3.Dense_1.kernel", "write_shape": [768, 384], "chunk_shape": [768, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_3.LayerNorm_0.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_3.LayerNorm_0.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_3.LayerNorm_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_3.LayerNorm_1.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_3.MultiHeadDotProductAttention_0.key.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_3.MultiHeadDotProductAttention_0.key.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_3.MultiHeadDotProductAttention_0.out.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_3.MultiHeadDotProductAttention_0.out.kernel", "write_shape": [8, 48, 384], "chunk_shape": [8, 48, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_3.MultiHeadDotProductAttention_0.query.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_3.MultiHeadDotProductAttention_0.query.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_3.MultiHeadDotProductAttention_0.value.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_3.MultiHeadDotProductAttention_0.value.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_4.Dense_0.bias", "write_shape": [768], "chunk_shape": [768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_4.Dense_0.kernel", "write_shape": [384, 768], "chunk_shape": [384, 768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_4.Dense_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_4.Dense_1.kernel", "write_shape": [768, 384], "chunk_shape": [768, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_4.LayerNorm_0.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_4.LayerNorm_0.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_4.LayerNorm_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_4.LayerNorm_1.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_4.MultiHeadDotProductAttention_0.key.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_4.MultiHeadDotProductAttention_0.key.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_4.MultiHeadDotProductAttention_0.out.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_4.MultiHeadDotProductAttention_0.out.kernel", "write_shape": [8, 48, 384], "chunk_shape": [8, 48, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_4.MultiHeadDotProductAttention_0.query.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_4.MultiHeadDotProductAttention_0.query.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_4.MultiHeadDotProductAttention_0.value.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_4.MultiHeadDotProductAttention_0.value.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_5.Dense_0.bias", "write_shape": [768], "chunk_shape": [768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_5.Dense_0.kernel", "write_shape": [384, 768], "chunk_shape": [384, 768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_5.Dense_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_5.Dense_1.kernel", "write_shape": [768, 384], "chunk_shape": [768, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_5.LayerNorm_0.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_5.LayerNorm_0.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_5.LayerNorm_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_5.LayerNorm_1.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_5.MultiHeadDotProductAttention_0.key.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_5.MultiHeadDotProductAttention_0.key.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_5.MultiHeadDotProductAttention_0.out.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_5.MultiHeadDotProductAttention_0.out.kernel", "write_shape": [8, 48, 384], "chunk_shape": [8, 48, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_5.MultiHeadDotProductAttention_0.query.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_5.MultiHeadDotProductAttention_0.query.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_5.MultiHeadDotProductAttention_0.value.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.mu.params.TransformerBlock_5.MultiHeadDotProductAttention_0.value.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.Dense_0.bias", "write_shape": [768], "chunk_shape": [768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.Dense_0.kernel", "write_shape": [1345, 768], "chunk_shape": [1345, 768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.Dense_1.bias", "write_shape": [768], "chunk_shape": [768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.Dense_1.kernel", "write_shape": [768, 768], "chunk_shape": [768, 768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.Dense_2.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.Dense_2.kernel", "write_shape": [768, 384], "chunk_shape": [768, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.Dense_3.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.Dense_3.kernel", "write_shape": [384, 384], "chunk_shape": [384, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.Dense_4.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.Dense_4.kernel", "write_shape": [384, 384], "chunk_shape": [384, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.Dense_5.bias", "write_shape": [17], "chunk_shape": [17], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.Dense_5.kernel", "write_shape": [384, 17], "chunk_shape": [384, 17], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.Embed_0.embedding", "write_shape": [18, 384], "chunk_shape": [18, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.LayerNorm_0.bias", "write_shape": [768], "chunk_shape": [768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.LayerNorm_0.scale", "write_shape": [768], "chunk_shape": [768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.LayerNorm_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.LayerNorm_1.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_0.Dense_0.bias", "write_shape": [768], "chunk_shape": [768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_0.Dense_0.kernel", "write_shape": [384, 768], "chunk_shape": [384, 768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_0.Dense_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_0.Dense_1.kernel", "write_shape": [768, 384], "chunk_shape": [768, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_0.LayerNorm_0.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_0.LayerNorm_0.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_0.LayerNorm_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_0.LayerNorm_1.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_0.MultiHeadDotProductAttention_0.key.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_0.MultiHeadDotProductAttention_0.key.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_0.MultiHeadDotProductAttention_0.out.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_0.MultiHeadDotProductAttention_0.out.kernel", "write_shape": [8, 48, 384], "chunk_shape": [8, 48, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_0.MultiHeadDotProductAttention_0.query.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_0.MultiHeadDotProductAttention_0.query.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_0.MultiHeadDotProductAttention_0.value.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_0.MultiHeadDotProductAttention_0.value.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_1.Dense_0.bias", "write_shape": [768], "chunk_shape": [768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_1.Dense_0.kernel", "write_shape": [384, 768], "chunk_shape": [384, 768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_1.Dense_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_1.Dense_1.kernel", "write_shape": [768, 384], "chunk_shape": [768, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_1.LayerNorm_0.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_1.LayerNorm_0.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_1.LayerNorm_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_1.LayerNorm_1.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_1.MultiHeadDotProductAttention_0.key.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_1.MultiHeadDotProductAttention_0.key.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_1.MultiHeadDotProductAttention_0.out.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_1.MultiHeadDotProductAttention_0.out.kernel", "write_shape": [8, 48, 384], "chunk_shape": [8, 48, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_1.MultiHeadDotProductAttention_0.query.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_1.MultiHeadDotProductAttention_0.query.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_1.MultiHeadDotProductAttention_0.value.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_1.MultiHeadDotProductAttention_0.value.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_2.Dense_0.bias", "write_shape": [768], "chunk_shape": [768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_2.Dense_0.kernel", "write_shape": [384, 768], "chunk_shape": [384, 768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_2.Dense_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_2.Dense_1.kernel", "write_shape": [768, 384], "chunk_shape": [768, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_2.LayerNorm_0.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_2.LayerNorm_0.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_2.LayerNorm_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_2.LayerNorm_1.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_2.MultiHeadDotProductAttention_0.key.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_2.MultiHeadDotProductAttention_0.key.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_2.MultiHeadDotProductAttention_0.out.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_2.MultiHeadDotProductAttention_0.out.kernel", "write_shape": [8, 48, 384], "chunk_shape": [8, 48, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_2.MultiHeadDotProductAttention_0.query.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_2.MultiHeadDotProductAttention_0.query.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_2.MultiHeadDotProductAttention_0.value.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_2.MultiHeadDotProductAttention_0.value.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_3.Dense_0.bias", "write_shape": [768], "chunk_shape": [768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_3.Dense_0.kernel", "write_shape": [384, 768], "chunk_shape": [384, 768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_3.Dense_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_3.Dense_1.kernel", "write_shape": [768, 384], "chunk_shape": [768, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_3.LayerNorm_0.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_3.LayerNorm_0.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_3.LayerNorm_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_3.LayerNorm_1.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_3.MultiHeadDotProductAttention_0.key.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_3.MultiHeadDotProductAttention_0.key.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_3.MultiHeadDotProductAttention_0.out.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_3.MultiHeadDotProductAttention_0.out.kernel", "write_shape": [8, 48, 384], "chunk_shape": [8, 48, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_3.MultiHeadDotProductAttention_0.query.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_3.MultiHeadDotProductAttention_0.query.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_3.MultiHeadDotProductAttention_0.value.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_3.MultiHeadDotProductAttention_0.value.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_4.Dense_0.bias", "write_shape": [768], "chunk_shape": [768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_4.Dense_0.kernel", "write_shape": [384, 768], "chunk_shape": [384, 768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_4.Dense_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_4.Dense_1.kernel", "write_shape": [768, 384], "chunk_shape": [768, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_4.LayerNorm_0.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_4.LayerNorm_0.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_4.LayerNorm_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_4.LayerNorm_1.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_4.MultiHeadDotProductAttention_0.key.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_4.MultiHeadDotProductAttention_0.key.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_4.MultiHeadDotProductAttention_0.out.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_4.MultiHeadDotProductAttention_0.out.kernel", "write_shape": [8, 48, 384], "chunk_shape": [8, 48, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_4.MultiHeadDotProductAttention_0.query.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_4.MultiHeadDotProductAttention_0.query.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_4.MultiHeadDotProductAttention_0.value.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_4.MultiHeadDotProductAttention_0.value.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_5.Dense_0.bias", "write_shape": [768], "chunk_shape": [768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_5.Dense_0.kernel", "write_shape": [384, 768], "chunk_shape": [384, 768], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_5.Dense_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_5.Dense_1.kernel", "write_shape": [768, 384], "chunk_shape": [768, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_5.LayerNorm_0.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_5.LayerNorm_0.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_5.LayerNorm_1.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_5.LayerNorm_1.scale", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_5.MultiHeadDotProductAttention_0.key.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_5.MultiHeadDotProductAttention_0.key.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_5.MultiHeadDotProductAttention_0.out.bias", "write_shape": [384], "chunk_shape": [384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_5.MultiHeadDotProductAttention_0.out.kernel", "write_shape": [8, 48, 384], "chunk_shape": [8, 48, 384], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_5.MultiHeadDotProductAttention_0.query.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_5.MultiHeadDotProductAttention_0.query.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_5.MultiHeadDotProductAttention_0.value.bias", "write_shape": [8, 48], "chunk_shape": [8, 48], "ext_metadata": null}}, {"array_metadata": {"param_name": "opt_state.1.0.nu.params.TransformerBlock_5.MultiHeadDotProductAttention_0.value.kernel", "write_shape": [384, 8, 48], "chunk_shape": [384, 8, 48], "ext_metadata": null}}]}
checkpoints/online/Craftax-Classic-Symbolic-v1-OnlineDiffusion-DAgger-50M/50000000/default/d/63ff4b6b75238977cfc360704c224d86 ADDED
Binary file (2.22 kB). View file
 
checkpoints/online/Craftax-Classic-Symbolic-v1-OnlineDiffusion-DAgger-50M/50000000/default/manifest.ocdbt ADDED
Binary file (117 Bytes). View file
 
checkpoints/online/Craftax-Classic-Symbolic-v1-OnlineDiffusion-DAgger-50M/50000000/default/ocdbt.process_0/d/63a3ac9c870d5f7eb2b74967758ee043 ADDED
Binary file (171 Bytes). View file
 
checkpoints/online/Craftax-Classic-Symbolic-v1-OnlineDiffusion-DAgger-50M/50000000/default/ocdbt.process_0/d/c3b086748e0ea04233c3638a3994fa30 ADDED
Binary file (3.77 kB). View file
 
checkpoints/online/Craftax-Classic-Symbolic-v1-OnlineDiffusion-DAgger-50M/50000000/default/ocdbt.process_0/d/db1925df3ee2d3c92bea0a9878efa5fc ADDED
Binary file (832 Bytes). View file
 
checkpoints/online/Craftax-Classic-Symbolic-v1-OnlineDiffusion-DAgger-50M/50000000/default/ocdbt.process_0/d/e25dea5d414404e637e55db20175c620 ADDED
Binary file (214 Bytes). View file
 
checkpoints/online/Craftax-Classic-Symbolic-v1-OnlineDiffusion-DAgger-50M/50000000/default/ocdbt.process_0/d/fd4cdc0c7be44d4f518c6ccca7ad654a ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1c27dc63cbdd625c2b62fac311fd37e14406b411ac848847ca4bd4e99f333419
3
+ size 34631680
checkpoints/online/Craftax-Classic-Symbolic-v1-OnlineDiffusion-DAgger-50M/50000000/default/ocdbt.process_0/manifest.ocdbt ADDED
Binary file (302 Bytes). View file
 
checkpoints/ppo_agents/Craftax-Classic-Symbolic-v1-PPO_RNN-1000M/1000000000/_CHECKPOINT_METADATA ADDED
@@ -0,0 +1 @@
 
 
1
+ {"item_handlers": {"default": "orbax.checkpoint._src.handlers.standard_checkpoint_handler.StandardCheckpointHandler"}, "metrics": {}, "performance_metrics": {}, "init_timestamp_nsecs": 1773173340517772966, "commit_timestamp_nsecs": 1773173340998852009, "custom_metadata": {}}
checkpoints/ppo_agents/Craftax-Classic-Symbolic-v1-PPO_RNN-1000M/1000000000/default/_METADATA ADDED
@@ -0,0 +1 @@
 
 
1
+ {"tree_metadata": {"('step',)": {"key_metadata": [{"key": "step", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": []}}, "('params', 'params', 'Dense_0', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('params', 'params', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1345, 512]}}, "('params', 'params', 'Dense_1', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('params', 'params', 'Dense_1', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512, 512]}}, "('params', 'params', 'Dense_2', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('params', 'params', 'Dense_2', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_2", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512, 512]}}, "('params', 'params', 'Dense_3', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_3", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [17]}}, "('params', 'params', 'Dense_3', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_3", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512, 17]}}, "('params', 'params', 'Dense_4', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_4", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('params', 'params', 'Dense_4', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_4", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512, 512]}}, "('params', 'params', 'Dense_5', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_5", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('params', 'params', 'Dense_5', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_5", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512, 512]}}, "('params', 'params', 'Dense_6', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_6", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1]}}, "('params', 'params', 'Dense_6', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_6", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512, 1]}}, "('params', 'params', 'ScannedRNN_0', 'GRUCell_1', 'hn', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "ScannedRNN_0", "key_type": 2}, {"key": "GRUCell_1", "key_type": 2}, {"key": "hn", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('params', 'params', 'ScannedRNN_0', 'GRUCell_1', 'hn', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "ScannedRNN_0", "key_type": 2}, {"key": "GRUCell_1", "key_type": 2}, {"key": "hn", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512, 512]}}, "('params', 'params', 'ScannedRNN_0', 'GRUCell_1', 'hr', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "ScannedRNN_0", "key_type": 2}, {"key": "GRUCell_1", "key_type": 2}, {"key": "hr", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512, 512]}}, "('params', 'params', 'ScannedRNN_0', 'GRUCell_1', 'hz', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "ScannedRNN_0", "key_type": 2}, {"key": "GRUCell_1", "key_type": 2}, {"key": "hz", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512, 512]}}, "('params', 'params', 'ScannedRNN_0', 'GRUCell_1', 'in', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "ScannedRNN_0", "key_type": 2}, {"key": "GRUCell_1", "key_type": 2}, {"key": "in", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('params', 'params', 'ScannedRNN_0', 'GRUCell_1', 'in', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "ScannedRNN_0", "key_type": 2}, {"key": "GRUCell_1", "key_type": 2}, {"key": "in", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512, 512]}}, "('params', 'params', 'ScannedRNN_0', 'GRUCell_1', 'ir', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "ScannedRNN_0", "key_type": 2}, {"key": "GRUCell_1", "key_type": 2}, {"key": "ir", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('params', 'params', 'ScannedRNN_0', 'GRUCell_1', 'ir', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "ScannedRNN_0", "key_type": 2}, {"key": "GRUCell_1", "key_type": 2}, {"key": "ir", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512, 512]}}, "('params', 'params', 'ScannedRNN_0', 'GRUCell_1', 'iz', 'bias')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "ScannedRNN_0", "key_type": 2}, {"key": "GRUCell_1", "key_type": 2}, {"key": "iz", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('params', 'params', 'ScannedRNN_0', 'GRUCell_1', 'iz', 'kernel')": {"key_metadata": [{"key": "params", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "ScannedRNN_0", "key_type": 2}, {"key": "GRUCell_1", "key_type": 2}, {"key": "iz", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512, 512]}}, "('opt_state', '0')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "0", "key_type": 1}], "value_metadata": {"value_type": "None", "skip_deserialize": true}}, "('opt_state', '1', '0', 'count')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "count", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": []}}, "('opt_state', '1', '0', 'mu', 'params', 'Dense_0', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('opt_state', '1', '0', 'mu', 'params', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1345, 512]}}, "('opt_state', '1', '0', 'mu', 'params', 'Dense_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('opt_state', '1', '0', 'mu', 'params', 'Dense_1', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512, 512]}}, "('opt_state', '1', '0', 'mu', 'params', 'Dense_2', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('opt_state', '1', '0', 'mu', 'params', 'Dense_2', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_2", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512, 512]}}, "('opt_state', '1', '0', 'mu', 'params', 'Dense_3', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_3", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [17]}}, "('opt_state', '1', '0', 'mu', 'params', 'Dense_3', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_3", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512, 17]}}, "('opt_state', '1', '0', 'mu', 'params', 'Dense_4', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_4", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('opt_state', '1', '0', 'mu', 'params', 'Dense_4', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_4", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512, 512]}}, "('opt_state', '1', '0', 'mu', 'params', 'Dense_5', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_5", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('opt_state', '1', '0', 'mu', 'params', 'Dense_5', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_5", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512, 512]}}, "('opt_state', '1', '0', 'mu', 'params', 'Dense_6', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_6", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1]}}, "('opt_state', '1', '0', 'mu', 'params', 'Dense_6', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_6", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512, 1]}}, "('opt_state', '1', '0', 'mu', 'params', 'ScannedRNN_0', 'GRUCell_1', 'hn', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "ScannedRNN_0", "key_type": 2}, {"key": "GRUCell_1", "key_type": 2}, {"key": "hn", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('opt_state', '1', '0', 'mu', 'params', 'ScannedRNN_0', 'GRUCell_1', 'hn', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "ScannedRNN_0", "key_type": 2}, {"key": "GRUCell_1", "key_type": 2}, {"key": "hn", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512, 512]}}, "('opt_state', '1', '0', 'mu', 'params', 'ScannedRNN_0', 'GRUCell_1', 'hr', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "ScannedRNN_0", "key_type": 2}, {"key": "GRUCell_1", "key_type": 2}, {"key": "hr", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512, 512]}}, "('opt_state', '1', '0', 'mu', 'params', 'ScannedRNN_0', 'GRUCell_1', 'hz', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "ScannedRNN_0", "key_type": 2}, {"key": "GRUCell_1", "key_type": 2}, {"key": "hz", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512, 512]}}, "('opt_state', '1', '0', 'mu', 'params', 'ScannedRNN_0', 'GRUCell_1', 'in', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "ScannedRNN_0", "key_type": 2}, {"key": "GRUCell_1", "key_type": 2}, {"key": "in", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('opt_state', '1', '0', 'mu', 'params', 'ScannedRNN_0', 'GRUCell_1', 'in', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "ScannedRNN_0", "key_type": 2}, {"key": "GRUCell_1", "key_type": 2}, {"key": "in", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512, 512]}}, "('opt_state', '1', '0', 'mu', 'params', 'ScannedRNN_0', 'GRUCell_1', 'ir', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "ScannedRNN_0", "key_type": 2}, {"key": "GRUCell_1", "key_type": 2}, {"key": "ir", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('opt_state', '1', '0', 'mu', 'params', 'ScannedRNN_0', 'GRUCell_1', 'ir', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "ScannedRNN_0", "key_type": 2}, {"key": "GRUCell_1", "key_type": 2}, {"key": "ir", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512, 512]}}, "('opt_state', '1', '0', 'mu', 'params', 'ScannedRNN_0', 'GRUCell_1', 'iz', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "ScannedRNN_0", "key_type": 2}, {"key": "GRUCell_1", "key_type": 2}, {"key": "iz", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('opt_state', '1', '0', 'mu', 'params', 'ScannedRNN_0', 'GRUCell_1', 'iz', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "mu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "ScannedRNN_0", "key_type": 2}, {"key": "GRUCell_1", "key_type": 2}, {"key": "iz", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512, 512]}}, "('opt_state', '1', '0', 'nu', 'params', 'Dense_0', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('opt_state', '1', '0', 'nu', 'params', 'Dense_0', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_0", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1345, 512]}}, "('opt_state', '1', '0', 'nu', 'params', 'Dense_1', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('opt_state', '1', '0', 'nu', 'params', 'Dense_1', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_1", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512, 512]}}, "('opt_state', '1', '0', 'nu', 'params', 'Dense_2', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_2", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('opt_state', '1', '0', 'nu', 'params', 'Dense_2', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_2", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512, 512]}}, "('opt_state', '1', '0', 'nu', 'params', 'Dense_3', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_3", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [17]}}, "('opt_state', '1', '0', 'nu', 'params', 'Dense_3', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_3", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512, 17]}}, "('opt_state', '1', '0', 'nu', 'params', 'Dense_4', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_4", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('opt_state', '1', '0', 'nu', 'params', 'Dense_4', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_4", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512, 512]}}, "('opt_state', '1', '0', 'nu', 'params', 'Dense_5', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_5", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('opt_state', '1', '0', 'nu', 'params', 'Dense_5', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_5", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512, 512]}}, "('opt_state', '1', '0', 'nu', 'params', 'Dense_6', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_6", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [1]}}, "('opt_state', '1', '0', 'nu', 'params', 'Dense_6', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "Dense_6", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512, 1]}}, "('opt_state', '1', '0', 'nu', 'params', 'ScannedRNN_0', 'GRUCell_1', 'hn', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "ScannedRNN_0", "key_type": 2}, {"key": "GRUCell_1", "key_type": 2}, {"key": "hn", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('opt_state', '1', '0', 'nu', 'params', 'ScannedRNN_0', 'GRUCell_1', 'hn', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "ScannedRNN_0", "key_type": 2}, {"key": "GRUCell_1", "key_type": 2}, {"key": "hn", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512, 512]}}, "('opt_state', '1', '0', 'nu', 'params', 'ScannedRNN_0', 'GRUCell_1', 'hr', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "ScannedRNN_0", "key_type": 2}, {"key": "GRUCell_1", "key_type": 2}, {"key": "hr", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512, 512]}}, "('opt_state', '1', '0', 'nu', 'params', 'ScannedRNN_0', 'GRUCell_1', 'hz', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "ScannedRNN_0", "key_type": 2}, {"key": "GRUCell_1", "key_type": 2}, {"key": "hz", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512, 512]}}, "('opt_state', '1', '0', 'nu', 'params', 'ScannedRNN_0', 'GRUCell_1', 'in', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "ScannedRNN_0", "key_type": 2}, {"key": "GRUCell_1", "key_type": 2}, {"key": "in", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('opt_state', '1', '0', 'nu', 'params', 'ScannedRNN_0', 'GRUCell_1', 'in', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "ScannedRNN_0", "key_type": 2}, {"key": "GRUCell_1", "key_type": 2}, {"key": "in", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512, 512]}}, "('opt_state', '1', '0', 'nu', 'params', 'ScannedRNN_0', 'GRUCell_1', 'ir', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "ScannedRNN_0", "key_type": 2}, {"key": "GRUCell_1", "key_type": 2}, {"key": "ir", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('opt_state', '1', '0', 'nu', 'params', 'ScannedRNN_0', 'GRUCell_1', 'ir', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "ScannedRNN_0", "key_type": 2}, {"key": "GRUCell_1", "key_type": 2}, {"key": "ir", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512, 512]}}, "('opt_state', '1', '0', 'nu', 'params', 'ScannedRNN_0', 'GRUCell_1', 'iz', 'bias')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "ScannedRNN_0", "key_type": 2}, {"key": "GRUCell_1", "key_type": 2}, {"key": "iz", "key_type": 2}, {"key": "bias", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512]}}, "('opt_state', '1', '0', 'nu', 'params', 'ScannedRNN_0', 'GRUCell_1', 'iz', 'kernel')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "0", "key_type": 1}, {"key": "nu", "key_type": 2}, {"key": "params", "key_type": 2}, {"key": "ScannedRNN_0", "key_type": 2}, {"key": "GRUCell_1", "key_type": 2}, {"key": "iz", "key_type": 2}, {"key": "kernel", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": [512, 512]}}, "('opt_state', '1', '1', 'count')": {"key_metadata": [{"key": "opt_state", "key_type": 2}, {"key": "1", "key_type": 1}, {"key": "1", "key_type": 1}, {"key": "count", "key_type": 2}], "value_metadata": {"value_type": "jax.Array", "skip_deserialize": false, "write_shape": []}}}, "use_ocdbt": true, "use_zarr3": false, "store_array_data_equal_to_fill_value": true, "custom_metadata": null}
checkpoints/ppo_agents/Craftax-Classic-Symbolic-v1-PPO_RNN-1000M/1000000000/default/_sharding ADDED
@@ -0,0 +1 @@
 
 
1
+ {"b3B0X3N0YXRlLjEuMC5jb3VudA==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuRGVuc2VfMC5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuRGVuc2VfMC5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuRGVuc2VfMS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuRGVuc2VfMS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuRGVuc2VfMi5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuRGVuc2VfMi5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuRGVuc2VfMy5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuRGVuc2VfMy5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuRGVuc2VfNC5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuRGVuc2VfNC5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuRGVuc2VfNS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuRGVuc2VfNS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuRGVuc2VfNi5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuRGVuc2VfNi5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuU2Nhbm5lZFJOTl8wLkdSVUNlbGxfMS5obi5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuU2Nhbm5lZFJOTl8wLkdSVUNlbGxfMS5obi5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuU2Nhbm5lZFJOTl8wLkdSVUNlbGxfMS5oci5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuU2Nhbm5lZFJOTl8wLkdSVUNlbGxfMS5oei5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuU2Nhbm5lZFJOTl8wLkdSVUNlbGxfMS5pbi5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuU2Nhbm5lZFJOTl8wLkdSVUNlbGxfMS5pbi5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuU2Nhbm5lZFJOTl8wLkdSVUNlbGxfMS5pci5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuU2Nhbm5lZFJOTl8wLkdSVUNlbGxfMS5pci5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuU2Nhbm5lZFJOTl8wLkdSVUNlbGxfMS5pei5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5tdS5wYXJhbXMuU2Nhbm5lZFJOTl8wLkdSVUNlbGxfMS5pei5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuRGVuc2VfMC5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuRGVuc2VfMC5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuRGVuc2VfMS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuRGVuc2VfMS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuRGVuc2VfMi5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuRGVuc2VfMi5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuRGVuc2VfMy5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuRGVuc2VfMy5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuRGVuc2VfNC5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuRGVuc2VfNC5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuRGVuc2VfNS5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuRGVuc2VfNS5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuRGVuc2VfNi5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuRGVuc2VfNi5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuU2Nhbm5lZFJOTl8wLkdSVUNlbGxfMS5obi5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuU2Nhbm5lZFJOTl8wLkdSVUNlbGxfMS5obi5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuU2Nhbm5lZFJOTl8wLkdSVUNlbGxfMS5oci5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuU2Nhbm5lZFJOTl8wLkdSVUNlbGxfMS5oei5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuU2Nhbm5lZFJOTl8wLkdSVUNlbGxfMS5pbi5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuU2Nhbm5lZFJOTl8wLkdSVUNlbGxfMS5pbi5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuU2Nhbm5lZFJOTl8wLkdSVUNlbGxfMS5pci5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuU2Nhbm5lZFJOTl8wLkdSVUNlbGxfMS5pci5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuU2Nhbm5lZFJOTl8wLkdSVUNlbGxfMS5pei5iaWFz":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMC5udS5wYXJhbXMuU2Nhbm5lZFJOTl8wLkdSVUNlbGxfMS5pei5rZXJuZWw=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","b3B0X3N0YXRlLjEuMS5jb3VudA==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","c3RlcA==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5EZW5zZV80LmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5EZW5zZV80Lmtlcm5lbA==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5EZW5zZV81LmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5EZW5zZV81Lmtlcm5lbA==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5EZW5zZV82LmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5EZW5zZV82Lmtlcm5lbA==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5EZW5zZV8wLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5EZW5zZV8wLmtlcm5lbA==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5EZW5zZV8xLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5EZW5zZV8xLmtlcm5lbA==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5EZW5zZV8yLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5EZW5zZV8yLmtlcm5lbA==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5EZW5zZV8zLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5EZW5zZV8zLmtlcm5lbA==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5TY2FubmVkUk5OXzAuR1JVQ2VsbF8xLmh6Lmtlcm5lbA==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5TY2FubmVkUk5OXzAuR1JVQ2VsbF8xLmhuLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5TY2FubmVkUk5OXzAuR1JVQ2VsbF8xLmhuLmtlcm5lbA==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5TY2FubmVkUk5OXzAuR1JVQ2VsbF8xLmhyLmtlcm5lbA==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5TY2FubmVkUk5OXzAuR1JVQ2VsbF8xLml6LmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5TY2FubmVkUk5OXzAuR1JVQ2VsbF8xLml6Lmtlcm5lbA==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5TY2FubmVkUk5OXzAuR1JVQ2VsbF8xLmluLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5TY2FubmVkUk5OXzAuR1JVQ2VsbF8xLmluLmtlcm5lbA==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5TY2FubmVkUk5OXzAuR1JVQ2VsbF8xLmlyLmJpYXM=":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}","cGFyYW1zLnBhcmFtcy5TY2FubmVkUk5OXzAuR1JVQ2VsbF8xLmlyLmtlcm5lbA==":"{\"sharding_type\": \"SingleDeviceSharding\", \"device_str\": \"cuda:0\"}"}