| from functools import partial |
|
|
| import equinox as eqx |
| import jax.numpy as jnp |
| import jax.random as jrn |
| import jax.tree_util as jtu |
| from jax import vmap |
| from tqdm import tqdm |
|
|
| from neural_fdm.models import AutoEncoderPiggy |
|
|
|
|
| def train_step_piggy(model, structure, optimizer, generator, opt_state, *, loss_fn, batch_size, key): |
| """ |
| Update the parameters of an autoencoder piggy model on a batch of data for one step. |
| |
| Parameters |
| ---------- |
| model: `eqx.Module` |
| The model to train. |
| structure: `jax_fdm.EquilibriumStructure` |
| A structure with the discretization of the shape. |
| optimizer: `optax.GradientTransformation` |
| The optimizer to use for training. |
| generator: `PointGenerator` |
| The data generator. |
| opt_state: `optax.GradientTransformationExtraArgs` |
| The current optimizer state. |
| loss_fn: `Callable` |
| The loss function. |
| batch_size: `int` |
| The number of samples to generate in each batch. |
| key: `jax.random.PRNGKey` |
| The random key. |
| |
| Returns |
| ------- |
| loss_vals: `dict` of `float` |
| The values of the loss terms. |
| model: `eqx.Module` |
| The updated model. |
| opt_state: `optax.GradientTransformationExtraArgs` |
| The updated optimizer state. |
| """ |
| |
| keys = jrn.split(key, batch_size) |
| x = vmap(generator)(keys) |
|
|
| |
| val_grad_fn = eqx.filter_value_and_grad(loss_fn, has_aux=True) |
| (loss, loss_vals), grads_main = val_grad_fn( |
| model, |
| structure, |
| x, |
| True, |
| False |
| ) |
|
|
| |
| val_grad_fn = eqx.filter_value_and_grad(loss_fn, has_aux=True) |
| (loss, loss_vals), grads_piggy = val_grad_fn( |
| model, |
| structure, |
| x, |
| True, |
| True |
| ) |
|
|
| |
| grads = jtu.tree_map(lambda x, y: x + y, grads_main, grads_piggy) |
|
|
| |
| updates, opt_state = optimizer.update(grads, opt_state) |
| model = eqx.apply_updates(model, updates) |
|
|
| return loss_vals, model, opt_state |
|
|
|
|
| def train_step(model, structure, optimizer, generator, opt_state, *, loss_fn, batch_size, key): |
| """ |
| Update the parameters of an autoencoder model on a batch of data for one step. |
| |
| Parameters |
| ---------- |
| model: `eqx.Module` |
| The model to train. |
| structure: `jax_fdm.EquilibriumStructure` |
| A structure with the discretization of the shape. |
| optimizer: `optax.GradientTransformation` |
| The optimizer to use for training. |
| generator: `PointGenerator` |
| The data generator. |
| opt_state: `optax.GradientTransformationExtraArgs` |
| The current optimizer state. |
| loss_fn: `Callable` |
| The loss function. |
| batch_size: `int` |
| The number of samples to generate in each batch. |
| key: `jax.random.PRNGKey` |
| The random key. |
| |
| Returns |
| ------- |
| loss_vals: `dict` of `float` |
| The values of the loss terms. |
| model: `eqx.Module` |
| The updated model. |
| opt_state: `optax.GradientTransformationExtraArgs` |
| The updated optimizer state. |
| """ |
| |
| keys = jrn.split(key, batch_size) |
| x = vmap(generator)(keys) |
|
|
| |
| val_grad_fn = eqx.filter_value_and_grad(loss_fn, has_aux=True) |
| (loss, loss_vals), grads = val_grad_fn(model, structure, x, aux_data=True) |
|
|
| |
| updates, opt_state = optimizer.update(grads, opt_state) |
| model = eqx.apply_updates(model, updates) |
|
|
| return loss_vals, model, opt_state |
|
|
|
|
| def train_step_vae(model, structure, optimizer, generator, opt_state, *, |
| loss_fn, batch_size, key, beta): |
| """Training step for VAE models with reparameterization sampling. |
| |
| Differs from train_step in two ways: |
| 1. Passes a PRNG key through the loss function for epsilon sampling. |
| 2. Accepts beta as a traced JAX float (not Python float) to avoid |
| JIT recompilation at each step. |
| |
| Parameters |
| ---------- |
| beta : jnp.float32 |
| Current KL weight. Must be a JAX array (traced), not a Python float, |
| to prevent JIT recompilation at every training step. |
| |
| References |
| ---------- |
| Kingma & Welling (2014): Reparameterization trick requires PRNG in forward pass. |
| Fu et al. (2019): Beta varies per step via cyclical annealing. |
| """ |
| |
| data_key, model_key = jrn.split(key) |
|
|
| |
| keys = jrn.split(data_key, batch_size) |
| x = vmap(generator)(keys) |
|
|
| |
| def vae_loss_wrapper(model, structure, x, aux_data=True): |
| from neural_fdm.losses import compute_loss_vae |
| |
| _loss_params = dict(loss_fn.keywords.get("loss_params", {})) |
| _loss_params.setdefault("vae", {}) |
| _loss_params["vae"]["beta"] = beta |
| _loss_fn = loss_fn.keywords.get("loss_fn", None) |
| return compute_loss_vae( |
| model, structure, x, _loss_fn, _loss_params, |
| aux_data=aux_data, key=model_key |
| ) |
|
|
| val_grad_fn = eqx.filter_value_and_grad(vae_loss_wrapper, has_aux=True) |
| (loss, loss_vals), grads = val_grad_fn(model, structure, x, aux_data=True) |
|
|
| updates, opt_state = optimizer.update(grads, opt_state) |
| model = eqx.apply_updates(model, updates) |
|
|
| return loss_vals, model, opt_state |
|
|
|
|
| def train_model(model, structure, optimizer, generator, *, loss_fn, num_steps, batch_size, key, callback=None): |
| """ |
| Train a model over a number of steps. |
| |
| Parameters |
| ---------- |
| model: `eqx.Module` |
| The model to train. |
| structure: `jax_fdm.EquilibriumStructure` |
| A structure with the discretization of the shape. |
| optimizer: `optax.GradientTransformation` |
| The optimizer to use for training. |
| generator: `PointGenerator` |
| The data generator. |
| loss_fn: `Callable` |
| The loss function. |
| num_steps: `int` |
| The number of steps to train for (number of parameter updates). |
| batch_size: `int` |
| The number of samples to generate per batch. |
| key: `jax.random.PRNGKey` |
| The random key. |
| callback: `Callable`, optional |
| A callback function to call after each step. |
| The callback function should take the following arguments: |
| - model: `eqx.Module` |
| - opt_state: `optax.GradientTransformationExtraArgs` |
| - loss_vals: `dict` of `float` |
| - step: `int` |
| """ |
| |
| opt_state = optimizer.init(eqx.filter(model, eqx.is_array)) |
|
|
| |
| from neural_fdm.variational import VariationalAutoEncoder |
| is_vae = isinstance(model, VariationalAutoEncoder) |
|
|
| |
| if is_vae: |
| from neural_fdm.variational import compute_beta_schedule |
| |
| _lp = loss_fn.keywords.get("loss_params", {}) |
| vae_cfg = _lp.get("vae", {}) |
| beta_max = vae_cfg.get("beta_max", 1.0) |
| cycle_length = vae_cfg.get("cycle_length", num_steps) |
| warmup_ratio = vae_cfg.get("warmup_ratio", 0.5) |
|
|
| train_step_fn = partial(train_step_vae, loss_fn=loss_fn) |
| train_step_fn = eqx.filter_jit(train_step_fn) |
| else: |
| train_step_fn = train_step |
| if isinstance(model, AutoEncoderPiggy): |
| train_step_fn = train_step_piggy |
| train_step_fn = partial(train_step_fn, loss_fn=loss_fn) |
| train_step_fn = eqx.filter_jit(train_step_fn) |
|
|
| |
| loss_history = [] |
| for step in tqdm(range(num_steps)): |
|
|
| |
| key, _ = jrn.split(key) |
|
|
| if is_vae: |
| |
| beta = jnp.float32(compute_beta_schedule( |
| step, beta_max, cycle_length, warmup_ratio |
| )) |
| loss_vals, model, opt_state = train_step_fn( |
| model, structure, optimizer, generator, opt_state, |
| batch_size=batch_size, key=key, beta=beta, |
| ) |
| else: |
| |
| loss_vals, model, opt_state = train_step_fn( |
| model, structure, optimizer, generator, opt_state, |
| batch_size=batch_size, key=key, |
| ) |
|
|
| |
| loss_history.append(loss_vals) |
|
|
| |
| if callback: |
| callback(model, opt_state, loss_vals, step) |
|
|
| return model, loss_history |
|
|