| from typing import Tuple |
|
|
| import jax |
| import jax.numpy as jnp |
|
|
| State = jax.Array |
|
|
|
|
| def kepler_dynamics(state: State, gm: float = 1.0) -> State: |
| x, y, vx, vy = state |
| r2 = x * x + y * y |
| r3 = r2 * jnp.sqrt(r2) |
| inv_r3 = jnp.where(r3 > 1e-12, 1.0 / r3, 0.0) |
| ax = -gm * x * inv_r3 |
| ay = -gm * y * inv_r3 |
| return jnp.stack([vx, vy, ax, ay]) |
|
|
|
|
| def rk4_step(state: State, dt: float = 0.01, gm: float = 1.0) -> Tuple[State, State]: |
| k1 = kepler_dynamics(state, gm) |
| k2 = kepler_dynamics(state + 0.5 * dt * k1, gm) |
| k3 = kepler_dynamics(state + 0.5 * dt * k2, gm) |
| k4 = kepler_dynamics(state + dt * k3, gm) |
|
|
| next_state = state + (dt / 6.0) * (k1 + 2.0 * k2 + 2.0 * k3 + k4) |
| return next_state, k1 |
|
|
|
|
| def energy(state: State, gm: float = 1.0) -> jax.Array: |
| x, y, vx, vy = state |
| kinetic = 0.5 * (vx * vx + vy * vy) |
| r = jnp.sqrt(x * x + y * y) |
| potential = -gm / jnp.where(r > 1e-12, r, 1e12) |
| return kinetic + potential |