File size: 433 Bytes
b2cc771
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
import apebench
import jax.numpy as jnp

for scenario in [
    "phy_unbal_adv",
    "diff_burgers",  # Uses the three-channel version in 3D
    "diff_ks",
    "phy_gs",
    "phy_sh",
]:
    scene = apebench.scenarios.scenario_dict[scenario](
        num_spatial_dims=3, num_points=32, num_test_samples=10
    )

    ref_trj = scene.get_ref_sample_data()

    jnp.save(f"ref_sample_rollouts/{scene.get_scenario_name()}.npy", ref_trj)