import apebench import jax.numpy as jnp for scenario in [ "phy_unbal_adv", "diff_burgers", # Uses the three-channel version in 3D "diff_ks", "phy_gs", "phy_sh", ]: scene = apebench.scenarios.scenario_dict[scenario]( num_spatial_dims=3, num_points=32, num_test_samples=10 ) ref_trj = scene.get_ref_sample_data() jnp.save(f"ref_sample_rollouts/{scene.get_scenario_name()}.npy", ref_trj)