| from typing import Tuple | |
| import jax | |
| import jax.numpy as jnp | |
| import optax | |
| from flax.core import FrozenDict | |
| from flax.training.train_state import TrainState | |
| from physics_engine import energy, rk4_step | |
| vmap_rk4 = jax.vmap(rk4_step, in_axes=(0, None, None)) | |
| vmap_energy = jax.vmap(energy, in_axes=(0, None)) | |
| def generate_trajectories( | |
| rng: jax.random.PRNGKey, | |
| num_trajs: int = 256, | |
| num_steps: int = 200, | |
| dt: float = 0.05, | |
| gm: float = 1.0, | |
| ) -> Tuple[jnp.ndarray, jnp.ndarray]: | |
| rng_pos, rng_vel = jax.random.split(rng) | |
| angles = jax.random.uniform(rng_pos, (num_trajs,), minval=0.0, maxval=2.0 * jnp.pi) | |
| radii = jax.random.uniform(rng_pos, (num_trajs,), minval=0.8, maxval=2.0) | |
| x0 = radii * jnp.cos(angles) | |
| y0 = radii * jnp.sin(angles) | |
| v_mags = jax.random.uniform(rng_vel, (num_trajs,), minval=0.4, maxval=1.2) | |
| v_angles = angles + jnp.pi / 2.0 | |
| vx0 = v_mags * jnp.cos(v_angles) | |
| vy0 = v_mags * jnp.sin(v_angles) | |
| init_states = jnp.stack([x0, y0, vx0, vy0], axis=-1) | |
| def scan_step(state, _): | |
| next_state, _ = vmap_rk4(state, dt, gm) | |
| return next_state, state | |
| _, trajectory = jax.lax.scan(scan_step, init_states, length=num_steps) | |
| trajectory = jnp.swapaxes(trajectory, 0, 1) | |
| init_states_expanded = init_states[:, None, :] | |
| states = jnp.concatenate([init_states_expanded, trajectory], axis=1) | |
| targets = states[:, 1:, :] | |
| return states, targets | |
| def angular_momentum(state: jnp.ndarray) -> jnp.ndarray: | |
| x, y, vx, vy = state[..., 0], state[..., 1], state[..., 2], state[..., 3] | |
| return x * vy - y * vx | |
| def loss_fn( | |
| params: FrozenDict, | |
| state_in: jnp.ndarray, | |
| state_target: jnp.ndarray, | |
| apply_fn: callable, | |
| lambda_energy: float = 0.1, | |
| gm: float = 1.0, | |
| lambda_angular: float = 0.1, | |
| ) -> Tuple[jnp.ndarray, dict]: | |
| state_pred = apply_fn({"params": params}, state_in) | |
| mse_loss = jnp.mean((state_pred - state_target) ** 2) | |
| e_pred = vmap_energy(state_pred.reshape(-1, 4), gm) | |
| e_target = vmap_energy(state_target.reshape(-1, 4), gm) | |
| energy_loss = jnp.mean(jnp.abs(e_pred - e_target)) | |
| L_pred = angular_momentum(state_pred) | |
| angular_loss = jnp.var(L_pred) | |
| total_loss = mse_loss + lambda_energy * energy_loss + lambda_angular * angular_loss | |
| metrics = { | |
| "loss": total_loss, | |
| "mse": mse_loss, | |
| "energy_loss": energy_loss, | |
| "angular_loss": angular_loss, | |
| } | |
| return total_loss, metrics | |
| def create_train_state( | |
| rng: jax.random.PRNGKey, | |
| model: "OrbitMLP", | |
| learning_rate: float = 1e-3, | |
| ) -> TrainState: | |
| dummy_input = jnp.zeros((1, 4), dtype=jnp.float32) | |
| variables = model.init(rng, dummy_input) | |
| params = variables["params"] | |
| schedule = optax.cosine_decay_schedule(init_value=learning_rate, decay_steps=2000, alpha=1e-4) | |
| tx = optax.adamw(schedule) | |
| return TrainState.create(apply_fn=model.apply, params=params, tx=tx) | |
| def make_train_step( | |
| model: "OrbitMLP", | |
| lambda_energy: float = 0.1, | |
| gm: float = 1.0, | |
| lambda_angular: float = 0.1, | |
| ): | |
| def train_step( | |
| state: TrainState, | |
| batch_in: jnp.ndarray, | |
| batch_target: jnp.ndarray, | |
| ) -> Tuple[TrainState, dict]: | |
| def _loss(params): | |
| return loss_fn(params, batch_in, batch_target, model.apply, lambda_energy, gm, lambda_angular) | |
| (_, metrics), grads = jax.value_and_grad(_loss, has_aux=True)(state.params) | |
| new_state = state.apply_gradients(grads=grads) | |
| return new_state, metrics | |
| return train_step | |
| def make_predict_trajectory(model: "OrbitMLP"): | |
| def predict_trajectory( | |
| params: FrozenDict, | |
| init_state: jnp.ndarray, | |
| num_steps: int, | |
| ) -> jnp.ndarray: | |
| def scan_predict(s, _): | |
| next_s = model.apply({"params": params}, s[None, :])[0] | |
| return next_s, next_s | |
| _, trajectory = jax.lax.scan(scan_predict, init_state, length=num_steps) | |
| return jnp.concatenate([init_state[None, :], trajectory], axis=0) | |
| return predict_trajectory |