orbits / predict.py
asgeirr89's picture
Upload folder using huggingface_hub
6cd7e16 verified
Raw
History Blame Contribute Delete
4.06 kB
import glob
import os
import pickle
import jax
import jax.numpy as jnp
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
import numpy as np
from flax import serialization
from model import OrbitMLP
from physics_engine import energy as compute_energy, rk4_step
from train import angular_momentum, make_predict_trajectory
def load_latest_model(models_dir: str = "models"):
flax_files = glob.glob(os.path.join(models_dir, "*.flax"))
if not flax_files:
raise FileNotFoundError(f"No .flax files found in {models_dir}/")
latest = max(flax_files, key=os.path.getmtime)
print(f"Loading model: {latest}")
with open(latest, "rb") as f:
bytes_params = f.read()
return serialization.from_bytes(jax.random.PRNGKey(0), bytes_params)
def generate_random_initial_state(rng, gm: float = 1.0):
rng_pos, rng_vel = jax.random.split(rng)
angle = jax.random.uniform(rng_pos, (), minval=0.0, maxval=2.0 * jnp.pi)
radius = jax.random.uniform(rng_pos, (), minval=1.0, maxval=2.0)
x0 = radius * jnp.cos(angle)
y0 = radius * jnp.sin(angle)
v_mag = jax.random.uniform(rng_vel, (), minval=0.5, maxval=1.1)
v_angle = angle + jnp.pi / 2.0
vx0 = v_mag * jnp.cos(v_angle)
vy0 = v_mag * jnp.sin(v_angle)
return jnp.array([x0, y0, vx0, vy0], dtype=jnp.float32)
def main():
os.makedirs("results", exist_ok=True)
gm = 1.0
dt = 0.05
num_steps = 500
params = load_latest_model()
model = OrbitMLP()
rng = jax.random.PRNGKey(0)
init_state = generate_random_initial_state(rng, gm=gm)
print(f"Initial state: {init_state}")
predict_trajectory = make_predict_trajectory(model)
nn_traj = np.array(predict_trajectory(params, init_state, num_steps))
rk4_traj = np.zeros((num_steps + 1, 4), dtype=np.float32)
rk4_traj[0] = np.array(init_state)
s = init_state
for i in range(num_steps):
s, _ = rk4_step(s, dt, gm)
rk4_traj[i + 1] = np.array(s)
rk4_energy = np.array([compute_energy(rk4_traj[i], gm) for i in range(num_steps + 1)])
nn_energy = np.array([compute_energy(nn_traj[i], gm) for i in range(num_steps + 1)])
rk4_L = np.array([angular_momentum(rk4_traj[i]) for i in range(num_steps + 1)])
nn_L = np.array([angular_momentum(nn_traj[i]) for i in range(num_steps + 1)])
fig, axes = plt.subplots(1, 3, figsize=(18, 5))
ax = axes[0]
ax.plot(rk4_traj[:, 0], rk4_traj[:, 1], "b-", label="RK4 (truth)", alpha=0.7)
ax.plot(nn_traj[:, 0], nn_traj[:, 1], "r--", label="OrbitMLP", alpha=0.7)
ax.scatter(0, 0, c="k", marker="o", s=80, label="Central mass")
ax.scatter(rk4_traj[0, 0], rk4_traj[0, 1], c="g", marker="x", s=100, label="Start")
ax.set_xlabel("x")
ax.set_ylabel("y")
ax.set_title("Orbit Comparison (500 steps)")
ax.legend()
ax.set_aspect("equal")
ax = axes[1]
ax.plot(range(num_steps + 1), rk4_energy, "b-", label="RK4 energy", alpha=0.7)
ax.plot(range(num_steps + 1), nn_energy, "r--", label="NN energy", alpha=0.7)
ax.set_xlabel("Step")
ax.set_ylabel("Total Energy")
ax.set_title("Energy Conservation")
ax.legend()
ax = axes[2]
ax.plot(range(num_steps + 1), rk4_L, "b-", label="RK4 angular momentum", alpha=0.7)
ax.plot(range(num_steps + 1), nn_L, "r--", label="NN angular momentum", alpha=0.7)
ax.set_xlabel("Step")
ax.set_ylabel("Angular Momentum L")
ax.set_title("Angular Momentum Conservation")
ax.legend()
plt.tight_layout()
out_path = "results/inference_orbit.png"
plt.savefig(out_path, dpi=150)
plt.close()
print(f"Inference plot saved to {out_path}")
mse_pos = np.mean((nn_traj - rk4_traj) ** 2)
print(f"Position MSE vs RK4: {mse_pos:.6e}")
print(f"NN energy drift (final - initial): {nn_energy[-1] - nn_energy[0]:.6e}")
print(f"RK4 energy drift (final - initial): {rk4_energy[-1] - rk4_energy[0]:.6e}")
print(f"NN L variance: {np.var(nn_L):.6e}")
print(f"RK4 L variance: {np.var(rk4_L):.6e}")
if __name__ == "__main__":
main()