import glob import os import pickle import jax import jax.numpy as jnp import matplotlib matplotlib.use("Agg") import matplotlib.pyplot as plt import numpy as np from flax import serialization from model import OrbitMLP from physics_engine import energy as compute_energy, rk4_step from train import angular_momentum, make_predict_trajectory def load_latest_model(models_dir: str = "models"): flax_files = glob.glob(os.path.join(models_dir, "*.flax")) if not flax_files: raise FileNotFoundError(f"No .flax files found in {models_dir}/") latest = max(flax_files, key=os.path.getmtime) print(f"Loading model: {latest}") with open(latest, "rb") as f: bytes_params = f.read() return serialization.from_bytes(jax.random.PRNGKey(0), bytes_params) def generate_random_initial_state(rng, gm: float = 1.0): rng_pos, rng_vel = jax.random.split(rng) angle = jax.random.uniform(rng_pos, (), minval=0.0, maxval=2.0 * jnp.pi) radius = jax.random.uniform(rng_pos, (), minval=1.0, maxval=2.0) x0 = radius * jnp.cos(angle) y0 = radius * jnp.sin(angle) v_mag = jax.random.uniform(rng_vel, (), minval=0.5, maxval=1.1) v_angle = angle + jnp.pi / 2.0 vx0 = v_mag * jnp.cos(v_angle) vy0 = v_mag * jnp.sin(v_angle) return jnp.array([x0, y0, vx0, vy0], dtype=jnp.float32) def main(): os.makedirs("results", exist_ok=True) gm = 1.0 dt = 0.05 num_steps = 500 params = load_latest_model() model = OrbitMLP() rng = jax.random.PRNGKey(0) init_state = generate_random_initial_state(rng, gm=gm) print(f"Initial state: {init_state}") predict_trajectory = make_predict_trajectory(model) nn_traj = np.array(predict_trajectory(params, init_state, num_steps)) rk4_traj = np.zeros((num_steps + 1, 4), dtype=np.float32) rk4_traj[0] = np.array(init_state) s = init_state for i in range(num_steps): s, _ = rk4_step(s, dt, gm) rk4_traj[i + 1] = np.array(s) rk4_energy = np.array([compute_energy(rk4_traj[i], gm) for i in range(num_steps + 1)]) nn_energy = np.array([compute_energy(nn_traj[i], gm) for i in range(num_steps + 1)]) rk4_L = np.array([angular_momentum(rk4_traj[i]) for i in range(num_steps + 1)]) nn_L = np.array([angular_momentum(nn_traj[i]) for i in range(num_steps + 1)]) fig, axes = plt.subplots(1, 3, figsize=(18, 5)) ax = axes[0] ax.plot(rk4_traj[:, 0], rk4_traj[:, 1], "b-", label="RK4 (truth)", alpha=0.7) ax.plot(nn_traj[:, 0], nn_traj[:, 1], "r--", label="OrbitMLP", alpha=0.7) ax.scatter(0, 0, c="k", marker="o", s=80, label="Central mass") ax.scatter(rk4_traj[0, 0], rk4_traj[0, 1], c="g", marker="x", s=100, label="Start") ax.set_xlabel("x") ax.set_ylabel("y") ax.set_title("Orbit Comparison (500 steps)") ax.legend() ax.set_aspect("equal") ax = axes[1] ax.plot(range(num_steps + 1), rk4_energy, "b-", label="RK4 energy", alpha=0.7) ax.plot(range(num_steps + 1), nn_energy, "r--", label="NN energy", alpha=0.7) ax.set_xlabel("Step") ax.set_ylabel("Total Energy") ax.set_title("Energy Conservation") ax.legend() ax = axes[2] ax.plot(range(num_steps + 1), rk4_L, "b-", label="RK4 angular momentum", alpha=0.7) ax.plot(range(num_steps + 1), nn_L, "r--", label="NN angular momentum", alpha=0.7) ax.set_xlabel("Step") ax.set_ylabel("Angular Momentum L") ax.set_title("Angular Momentum Conservation") ax.legend() plt.tight_layout() out_path = "results/inference_orbit.png" plt.savefig(out_path, dpi=150) plt.close() print(f"Inference plot saved to {out_path}") mse_pos = np.mean((nn_traj - rk4_traj) ** 2) print(f"Position MSE vs RK4: {mse_pos:.6e}") print(f"NN energy drift (final - initial): {nn_energy[-1] - nn_energy[0]:.6e}") print(f"RK4 energy drift (final - initial): {rk4_energy[-1] - rk4_energy[0]:.6e}") print(f"NN L variance: {np.var(nn_L):.6e}") print(f"RK4 L variance: {np.var(rk4_L):.6e}") if __name__ == "__main__": main()