| import apebench | |
| import jax.numpy as jnp | |
| for scenario in [ | |
| "phy_unbal_adv", | |
| "diff_burgers", # Uses the three-channel version in 3D | |
| "diff_ks", | |
| "phy_gs", | |
| "phy_sh", | |
| ]: | |
| scene = apebench.scenarios.scenario_dict[scenario]( | |
| num_spatial_dims=3, num_points=32, num_test_samples=10 | |
| ) | |
| ref_trj = scene.get_ref_sample_data() | |
| jnp.save(f"ref_sample_rollouts/{scene.get_scenario_name()}.npy", ref_trj) | |