orbits / train.py
asgeirr89's picture
Upload folder using huggingface_hub
6cd7e16 verified
Raw
History Blame Contribute Delete
4.08 kB
from typing import Tuple
import jax
import jax.numpy as jnp
import optax
from flax.core import FrozenDict
from flax.training.train_state import TrainState
from physics_engine import energy, rk4_step
vmap_rk4 = jax.vmap(rk4_step, in_axes=(0, None, None))
vmap_energy = jax.vmap(energy, in_axes=(0, None))
def generate_trajectories(
rng: jax.random.PRNGKey,
num_trajs: int = 256,
num_steps: int = 200,
dt: float = 0.05,
gm: float = 1.0,
) -> Tuple[jnp.ndarray, jnp.ndarray]:
rng_pos, rng_vel = jax.random.split(rng)
angles = jax.random.uniform(rng_pos, (num_trajs,), minval=0.0, maxval=2.0 * jnp.pi)
radii = jax.random.uniform(rng_pos, (num_trajs,), minval=0.8, maxval=2.0)
x0 = radii * jnp.cos(angles)
y0 = radii * jnp.sin(angles)
v_mags = jax.random.uniform(rng_vel, (num_trajs,), minval=0.4, maxval=1.2)
v_angles = angles + jnp.pi / 2.0
vx0 = v_mags * jnp.cos(v_angles)
vy0 = v_mags * jnp.sin(v_angles)
init_states = jnp.stack([x0, y0, vx0, vy0], axis=-1)
def scan_step(state, _):
next_state, _ = vmap_rk4(state, dt, gm)
return next_state, state
_, trajectory = jax.lax.scan(scan_step, init_states, length=num_steps)
trajectory = jnp.swapaxes(trajectory, 0, 1)
init_states_expanded = init_states[:, None, :]
states = jnp.concatenate([init_states_expanded, trajectory], axis=1)
targets = states[:, 1:, :]
return states, targets
def angular_momentum(state: jnp.ndarray) -> jnp.ndarray:
x, y, vx, vy = state[..., 0], state[..., 1], state[..., 2], state[..., 3]
return x * vy - y * vx
def loss_fn(
params: FrozenDict,
state_in: jnp.ndarray,
state_target: jnp.ndarray,
apply_fn: callable,
lambda_energy: float = 0.1,
gm: float = 1.0,
lambda_angular: float = 0.1,
) -> Tuple[jnp.ndarray, dict]:
state_pred = apply_fn({"params": params}, state_in)
mse_loss = jnp.mean((state_pred - state_target) ** 2)
e_pred = vmap_energy(state_pred.reshape(-1, 4), gm)
e_target = vmap_energy(state_target.reshape(-1, 4), gm)
energy_loss = jnp.mean(jnp.abs(e_pred - e_target))
L_pred = angular_momentum(state_pred)
angular_loss = jnp.var(L_pred)
total_loss = mse_loss + lambda_energy * energy_loss + lambda_angular * angular_loss
metrics = {
"loss": total_loss,
"mse": mse_loss,
"energy_loss": energy_loss,
"angular_loss": angular_loss,
}
return total_loss, metrics
def create_train_state(
rng: jax.random.PRNGKey,
model: "OrbitMLP",
learning_rate: float = 1e-3,
) -> TrainState:
dummy_input = jnp.zeros((1, 4), dtype=jnp.float32)
variables = model.init(rng, dummy_input)
params = variables["params"]
schedule = optax.cosine_decay_schedule(init_value=learning_rate, decay_steps=2000, alpha=1e-4)
tx = optax.adamw(schedule)
return TrainState.create(apply_fn=model.apply, params=params, tx=tx)
def make_train_step(
model: "OrbitMLP",
lambda_energy: float = 0.1,
gm: float = 1.0,
lambda_angular: float = 0.1,
):
@jax.jit
def train_step(
state: TrainState,
batch_in: jnp.ndarray,
batch_target: jnp.ndarray,
) -> Tuple[TrainState, dict]:
def _loss(params):
return loss_fn(params, batch_in, batch_target, model.apply, lambda_energy, gm, lambda_angular)
(_, metrics), grads = jax.value_and_grad(_loss, has_aux=True)(state.params)
new_state = state.apply_gradients(grads=grads)
return new_state, metrics
return train_step
def make_predict_trajectory(model: "OrbitMLP"):
def predict_trajectory(
params: FrozenDict,
init_state: jnp.ndarray,
num_steps: int,
) -> jnp.ndarray:
def scan_predict(s, _):
next_s = model.apply({"params": params}, s[None, :])[0]
return next_s, next_s
_, trajectory = jax.lax.scan(scan_predict, init_state, length=num_steps)
return jnp.concatenate([init_state[None, :], trajectory], axis=0)
return predict_trajectory