vae-fdm / src /neural_fdm /training.py
Efradeca's picture
Upload folder using huggingface_hub
fc7d689 verified
from functools import partial
import equinox as eqx
import jax.numpy as jnp
import jax.random as jrn
import jax.tree_util as jtu
from jax import vmap
from tqdm import tqdm
from neural_fdm.models import AutoEncoderPiggy
def train_step_piggy(model, structure, optimizer, generator, opt_state, *, loss_fn, batch_size, key):
"""
Update the parameters of an autoencoder piggy model on a batch of data for one step.
Parameters
----------
model: `eqx.Module`
The model to train.
structure: `jax_fdm.EquilibriumStructure`
A structure with the discretization of the shape.
optimizer: `optax.GradientTransformation`
The optimizer to use for training.
generator: `PointGenerator`
The data generator.
opt_state: `optax.GradientTransformationExtraArgs`
The current optimizer state.
loss_fn: `Callable`
The loss function.
batch_size: `int`
The number of samples to generate in each batch.
key: `jax.random.PRNGKey`
The random key.
Returns
-------
loss_vals: `dict` of `float`
The values of the loss terms.
model: `eqx.Module`
The updated model.
opt_state: `optax.GradientTransformationExtraArgs`
The updated optimizer state.
"""
# sample fresh data
keys = jrn.split(key, batch_size)
x = vmap(generator)(keys)
# calculate updates for main
val_grad_fn = eqx.filter_value_and_grad(loss_fn, has_aux=True)
(loss, loss_vals), grads_main = val_grad_fn(
model,
structure,
x,
True,
False
)
# calculate updates for piggy
val_grad_fn = eqx.filter_value_and_grad(loss_fn, has_aux=True)
(loss, loss_vals), grads_piggy = val_grad_fn(
model,
structure,
x,
True,
True
)
# combine gradients
grads = jtu.tree_map(lambda x, y: x + y, grads_main, grads_piggy)
# apply updates
updates, opt_state = optimizer.update(grads, opt_state)
model = eqx.apply_updates(model, updates)
return loss_vals, model, opt_state
def train_step(model, structure, optimizer, generator, opt_state, *, loss_fn, batch_size, key):
"""
Update the parameters of an autoencoder model on a batch of data for one step.
Parameters
----------
model: `eqx.Module`
The model to train.
structure: `jax_fdm.EquilibriumStructure`
A structure with the discretization of the shape.
optimizer: `optax.GradientTransformation`
The optimizer to use for training.
generator: `PointGenerator`
The data generator.
opt_state: `optax.GradientTransformationExtraArgs`
The current optimizer state.
loss_fn: `Callable`
The loss function.
batch_size: `int`
The number of samples to generate in each batch.
key: `jax.random.PRNGKey`
The random key.
Returns
-------
loss_vals: `dict` of `float`
The values of the loss terms.
model: `eqx.Module`
The updated model.
opt_state: `optax.GradientTransformationExtraArgs`
The updated optimizer state.
"""
# sample fresh data
keys = jrn.split(key, batch_size)
x = vmap(generator)(keys)
# calculate updates
val_grad_fn = eqx.filter_value_and_grad(loss_fn, has_aux=True)
(loss, loss_vals), grads = val_grad_fn(model, structure, x, aux_data=True)
# apply updates
updates, opt_state = optimizer.update(grads, opt_state)
model = eqx.apply_updates(model, updates)
return loss_vals, model, opt_state
def train_step_vae(model, structure, optimizer, generator, opt_state, *,
loss_fn, batch_size, key, beta):
"""Training step for VAE models with reparameterization sampling.
Differs from train_step in two ways:
1. Passes a PRNG key through the loss function for epsilon sampling.
2. Accepts beta as a traced JAX float (not Python float) to avoid
JIT recompilation at each step.
Parameters
----------
beta : jnp.float32
Current KL weight. Must be a JAX array (traced), not a Python float,
to prevent JIT recompilation at every training step.
References
----------
Kingma & Welling (2014): Reparameterization trick requires PRNG in forward pass.
Fu et al. (2019): Beta varies per step via cyclical annealing.
"""
# Split key: one for data generation, one for VAE sampling
data_key, model_key = jrn.split(key)
# Sample fresh data
keys = jrn.split(data_key, batch_size)
x = vmap(generator)(keys)
# Create loss wrapper that includes the VAE key and beta
def vae_loss_wrapper(model, structure, x, aux_data=True):
from neural_fdm.losses import compute_loss_vae
# Inject beta into loss_params dynamically
_loss_params = dict(loss_fn.keywords.get("loss_params", {}))
_loss_params.setdefault("vae", {})
_loss_params["vae"]["beta"] = beta
_loss_fn = loss_fn.keywords.get("loss_fn", None)
return compute_loss_vae(
model, structure, x, _loss_fn, _loss_params,
aux_data=aux_data, key=model_key
)
val_grad_fn = eqx.filter_value_and_grad(vae_loss_wrapper, has_aux=True)
(loss, loss_vals), grads = val_grad_fn(model, structure, x, aux_data=True)
updates, opt_state = optimizer.update(grads, opt_state)
model = eqx.apply_updates(model, updates)
return loss_vals, model, opt_state
def train_model(model, structure, optimizer, generator, *, loss_fn, num_steps, batch_size, key, callback=None):
"""
Train a model over a number of steps.
Parameters
----------
model: `eqx.Module`
The model to train.
structure: `jax_fdm.EquilibriumStructure`
A structure with the discretization of the shape.
optimizer: `optax.GradientTransformation`
The optimizer to use for training.
generator: `PointGenerator`
The data generator.
loss_fn: `Callable`
The loss function.
num_steps: `int`
The number of steps to train for (number of parameter updates).
batch_size: `int`
The number of samples to generate per batch.
key: `jax.random.PRNGKey`
The random key.
callback: `Callable`, optional
A callback function to call after each step.
The callback function should take the following arguments:
- model: `eqx.Module`
- opt_state: `optax.GradientTransformationExtraArgs`
- loss_vals: `dict` of `float`
- step: `int`
"""
# initial optimization step
opt_state = optimizer.init(eqx.filter(model, eqx.is_array))
# detect VAE model
from neural_fdm.variational import VariationalAutoEncoder
is_vae = isinstance(model, VariationalAutoEncoder)
# assemble train step
if is_vae:
from neural_fdm.variational import compute_beta_schedule
# Extract VAE config from loss_fn partial keywords
_lp = loss_fn.keywords.get("loss_params", {})
vae_cfg = _lp.get("vae", {})
beta_max = vae_cfg.get("beta_max", 1.0)
cycle_length = vae_cfg.get("cycle_length", num_steps)
warmup_ratio = vae_cfg.get("warmup_ratio", 0.5)
train_step_fn = partial(train_step_vae, loss_fn=loss_fn)
train_step_fn = eqx.filter_jit(train_step_fn)
else:
train_step_fn = train_step
if isinstance(model, AutoEncoderPiggy):
train_step_fn = train_step_piggy
train_step_fn = partial(train_step_fn, loss_fn=loss_fn)
train_step_fn = eqx.filter_jit(train_step_fn)
# train
loss_history = []
for step in tqdm(range(num_steps)):
# randomnesss
key, _ = jrn.split(key)
if is_vae:
# Compute beta as JAX array to avoid JIT recompilation
beta = jnp.float32(compute_beta_schedule(
step, beta_max, cycle_length, warmup_ratio
))
loss_vals, model, opt_state = train_step_fn(
model, structure, optimizer, generator, opt_state,
batch_size=batch_size, key=key, beta=beta,
)
else:
# train step
loss_vals, model, opt_state = train_step_fn(
model, structure, optimizer, generator, opt_state,
batch_size=batch_size, key=key,
)
# store loss values
loss_history.append(loss_vals)
# callback
if callback:
callback(model, opt_state, loss_vals, step)
return model, loss_history