import apebench import jax.numpy as jnp for scenario in [ "diff_disp", "diff_burgers", "diff_kdv", "diff_ks_cons", "diff_ks", ]: scene = apebench.scenarios.scenario_dict[scenario]() ref_trj = scene.get_ref_sample_data() jnp.save(f"ref_sample_rollouts/{scene.get_scenario_name()}.npy", ref_trj)