| import jax |
| import jax.numpy as jnp |
| from jax import vmap |
|
|
| from neural_fdm.helpers import vertices_residuals_from_xyz |
| from neural_fdm.models import AutoEncoderPiggy |
|
|
| |
| |
| |
|
|
| def compute_loss( |
| model, |
| structure, |
| x, |
| loss_fn, |
| loss_params, |
| aux_data=False, |
| piggy_mode=False |
| ): |
| """ |
| Compute the model loss according to the model type. |
| |
| Parameters |
| ---------- |
| model: `eqx.Module` |
| The model. |
| structure: `jax_fdm.EquilibriumStructure` |
| A structure with the discretization of the shape. |
| x: `jax.Array` |
| The target shape. |
| loss_fn: `Callable` |
| The loss function. |
| loss_params: `dict` |
| The scaling parameters to combine the loss' error terms. |
| aux_data: `bool` |
| If true, returns auxiliary data. |
| piggy_mode: `bool` |
| If true, the model is a piggy autoencoder. |
| |
| Returns |
| ------- |
| loss: `float` or `tuple` |
| The loss. If `aux_data` is `True`, returns a tuple of the loss and the loss terms. |
| """ |
| predict_fn = vmap(model, in_axes=(0, None, None, None)) |
| x_hat, data_hat = predict_fn(x, structure, True, piggy_mode) |
|
|
| _loss_fn = _compute_loss |
| if isinstance(model, AutoEncoderPiggy): |
| _loss_fn = _compute_loss_piggy |
|
|
| loss = _loss_fn( |
| loss_fn, |
| loss_params, |
| x, |
| x_hat, |
| data_hat, |
| structure, |
| aux_data, |
| piggy_mode |
| ) |
|
|
| return loss |
|
|
|
|
| def _compute_loss( |
| loss_fn, |
| loss_params, |
| x, |
| x_hat, |
| params_hat, |
| structure, |
| aux_data, |
| piggy_mode=False |
| ): |
| """ |
| Compute the model loss of an autoencoder. |
| |
| Parameters |
| ---------- |
| loss_fn: `Callable` |
| The loss function. |
| loss_params: `dict` |
| The scaling parameters to combine the loss' error terms. |
| x: `jax.Array` |
| The target shape. |
| x_hat: `jax.Array` |
| The predicted shape. |
| params_hat: tuple of `jax.Array` |
| The predicted force densities, loads, and fixed positions. |
| structure: `jax_fdm.EquilibriumStructure` |
| A structure with the discretization of the shape. |
| aux_data: `bool` |
| If true, returns auxiliary data. |
| piggy_mode: `bool` |
| If true, the model is a piggy autoencoder. |
| |
| Returns |
| ------- |
| loss: `float` or `tuple` |
| The loss. If `aux_data` is `True`, returns a tuple of the loss and the loss terms. |
| """ |
| return loss_fn(x, x_hat, params_hat, structure, loss_params, aux_data) |
|
|
|
|
| def _compute_loss_piggy( |
| loss_fn, |
| loss_params, |
| x, |
| x_data_hat, |
| y_data_hat, |
| structure, |
| aux_data, |
| piggy_mode=True, |
| ): |
| """ |
| Compute the loss of a piggy autoencoder. |
| |
| Parameters |
| ---------- |
| loss_fn: `Callable` |
| The loss function. |
| loss_params: `dict` |
| The scaling parameters to combine the loss' error terms. |
| x: `jax.Array` |
| The target shape. |
| x_data_hat: `tuple` |
| The predicted shape and the predicted parameters. |
| y_data_hat: `tuple` |
| The predicted shape and the predicted parameters. |
| structure: `jax_fdm.EquilibriumStructure` |
| A structure with the discretization of the shape. |
| aux_data: `bool` |
| If true, returns auxiliary data. |
| piggy_mode: `bool` |
| If true, the model is a piggy autoencoder. |
| |
| Returns |
| ------- |
| loss: `float` or `tuple` |
| The loss. If `aux_data` is `True`, returns a tuple of the loss and the loss terms. |
| """ |
| x_hat, x_params_hat = x_data_hat |
|
|
| if not piggy_mode: |
| loss_data = loss_fn(x, x_hat, x_params_hat, structure, loss_params, aux_data) |
| else: |
| y_hat, y_params_hat = y_data_hat |
| loss_data = loss_fn(x_hat, y_hat, y_params_hat, structure, loss_params, aux_data) |
|
|
| return loss_data |
|
|
|
|
| |
| |
| |
|
|
| def compute_loss_shell( |
| x, |
| x_hat, |
| params_hat, |
| structure, |
| loss_params, |
| aux_data, |
| *args |
| ): |
| """ |
| Compute the loss for the shell task. |
| |
| Parameters |
| ---------- |
| x: `jax.Array` |
| The target shape. |
| x_hat: `jax.Array` |
| The predicted shape. |
| params_hat: tuple of `jax.Array` |
| The predicted force densities, loads, and fixed positions. |
| structure: `jax_fdm.EquilibriumStructure` |
| A structure with the discretization of the shape. |
| loss_params: `dict` |
| The scaling parameters to combine the loss' error terms. |
| aux_data: `bool` |
| If true, returns auxiliary data. |
| |
| Returns |
| ------- |
| loss: `float` or `tuple` |
| The loss. If `aux_data` is `True`, returns a tuple of the loss and the loss terms. |
| """ |
| shape_params = loss_params["shape"] |
| factor_shape = shape_params["weight"] |
| loss_shape = compute_error_shape_l1(x, x_hat) |
| loss_shape = factor_shape * loss_shape |
|
|
| indices = structure.indices_free |
| residual_params = loss_params["residual"] |
| factor_residual = residual_params["weight"] |
| loss_residual = compute_error_residual(x_hat, params_hat, structure, indices) |
| loss_residual = factor_residual * loss_residual |
|
|
| loss = 0.0 |
| if shape_params["include"]: |
| loss = loss + loss_shape |
| if residual_params["include"]: |
| loss = loss + loss_residual |
|
|
| loss_terms = { |
| "loss": loss, |
| "shape error": loss_shape, |
| "residual error": loss_residual |
| } |
|
|
| if aux_data: |
| return loss, loss_terms |
|
|
| return loss |
|
|
|
|
| def compute_loss_tower( |
| x, |
| x_hat, |
| params_hat, |
| structure, |
| loss_params, |
| aux_data, |
| *args |
| ): |
| """ |
| Compute the loss for the tower task. |
| |
| Parameters |
| ---------- |
| x: `jax.Array` |
| The target shape. |
| x_hat: `jax.Array` |
| The predicted shape. |
| params_hat: tuple of `jax.Array` |
| The predicted force densities, loads, and fixed positions. |
| structure: `jax_fdm.EquilibriumStructure` |
| A structure with the discretization of the shape. |
| loss_params: `dict` |
| The scaling parameters to combine the loss' error terms. |
| aux_data: `bool` |
| If true, returns auxiliary data. |
| |
| Returns |
| ------- |
| loss: `float` or `tuple` |
| The loss. If `aux_data` is `True`, returns a tuple of the loss and the loss terms. |
| """ |
| |
| shape_params = loss_params["shape"] |
| factor_shape = shape_params["weight"] |
| shape_dims = shape_params["dims"] |
| levels_compression = shape_params["levels_compression"] |
|
|
| def slice_xyz_rings(_x, levels): |
| return jnp.reshape(_x, shape_dims)[levels, :, :].ravel() |
|
|
| slice_xyz_vmap = vmap(slice_xyz_rings, in_axes=(0, None)) |
| xyz_slice = slice_xyz_vmap(x, levels_compression) |
| xyz_hat_slice = slice_xyz_vmap(x_hat, levels_compression) |
| assert xyz_slice.shape == xyz_hat_slice.shape |
|
|
| |
| loss_shape = compute_error_shape_l2(xyz_slice, xyz_hat_slice) |
| loss_shape = factor_shape * loss_shape |
|
|
| |
| height_params = loss_params["shape"] |
| factor_height = height_params["weight"] |
| height_dims = height_params["dims"] |
| levels_tension = height_params["levels_tension"] |
|
|
| def slice_z_rings(_x, levels): |
| return jnp.reshape(_x, height_dims)[levels, :, 2].ravel() |
|
|
| slice_z_vmap = vmap(slice_z_rings, in_axes=(0, None)) |
| z_slice = slice_z_vmap(x, levels_tension) |
| z_hat_slice = slice_z_vmap(x_hat, levels_tension) |
| assert z_slice.shape == z_hat_slice.shape |
|
|
| |
| loss_height = compute_error_shape_l2(z_slice, z_hat_slice) |
| loss_height = factor_height * loss_height |
|
|
| |
| loss_shape = loss_shape + loss_height |
|
|
| |
| indices = structure.indices_free |
| residual_params = loss_params["residual"] |
| factor_residual = residual_params["weight"] |
| loss_residual = compute_error_residual(x_hat, params_hat, structure, indices) |
| loss_residual = factor_residual * loss_residual |
|
|
| |
| regularization_params = loss_params["regularization"] |
| factor_regularization = regularization_params["weight"] |
| q = params_hat[0] |
| regularization = compute_q_regularization(q) |
| regularization = factor_regularization * regularization |
|
|
| loss = 0.0 |
| if shape_params["include"]: |
| loss = loss + loss_shape |
| if residual_params["include"]: |
| loss = loss + loss_residual |
| if regularization_params["include"]: |
| loss = loss + regularization |
|
|
| loss_terms = { |
| "loss": loss, |
| "shape error": loss_shape, |
| "residual error": loss_residual, |
| "regularization": regularization |
| } |
|
|
| if aux_data: |
| return loss, loss_terms |
|
|
| return loss |
|
|
|
|
| |
| |
| |
|
|
|
|
| def compute_loss_shell_vae( |
| x, x_hat, params_hat_vae, structure, loss_params, aux_data, *args |
| ): |
| """Loss for shell task with VAE: reconstruction + beta * KL. |
| |
| Wraps compute_loss_shell and adds the KL divergence term from the |
| variational encoder. The KL weight (beta) follows the cyclical |
| annealing schedule of Fu et al. (NAACL 2019). |
| |
| Parameters |
| ---------- |
| params_hat_vae : tuple |
| ((q, xyz_fixed, loads), mu, log_sigma) - VAE extended aux_data. |
| loss_params : dict |
| Must include 'vae' key with 'beta' value. |
| |
| References |
| ---------- |
| Kingma & Welling (2014), Eq. 7: KL divergence formula. |
| Fu et al. (2019): Cyclical beta annealing. |
| """ |
| from neural_fdm.variational import compute_kl_divergence |
|
|
| |
| params_hat, mu, log_sigma = params_hat_vae |
|
|
| |
| recon_result = compute_loss_shell( |
| x, x_hat, params_hat, structure, loss_params, aux_data |
| ) |
|
|
| |
| kl_loss = compute_kl_divergence(mu, log_sigma) |
|
|
| |
| beta = loss_params.get("vae", {}).get("beta", 0.0) |
|
|
| if aux_data: |
| recon_loss, loss_terms = recon_result |
| loss = recon_loss + beta * kl_loss |
| loss_terms["loss"] = loss |
| loss_terms["kl divergence"] = kl_loss |
| loss_terms["beta"] = jnp.float32(beta) |
| return loss, loss_terms |
| else: |
| loss = recon_result + beta * kl_loss |
| return loss |
|
|
|
|
| def compute_loss_tower_vae( |
| x, x_hat, params_hat_vae, structure, loss_params, aux_data, *args |
| ): |
| """Loss for tower task with VAE: reconstruction + beta * KL. |
| |
| Same pattern as compute_loss_shell_vae but delegates reconstruction |
| to compute_loss_tower (which includes shape, height, residual, and |
| regularization terms). |
| """ |
| from neural_fdm.variational import compute_kl_divergence |
|
|
| params_hat, mu, log_sigma = params_hat_vae |
|
|
| recon_result = compute_loss_tower( |
| x, x_hat, params_hat, structure, loss_params, aux_data |
| ) |
|
|
| kl_loss = compute_kl_divergence(mu, log_sigma) |
| beta = loss_params.get("vae", {}).get("beta", 0.0) |
|
|
| if aux_data: |
| recon_loss, loss_terms = recon_result |
| loss = recon_loss + beta * kl_loss |
| loss_terms["loss"] = loss |
| loss_terms["kl divergence"] = kl_loss |
| loss_terms["beta"] = jnp.float32(beta) |
| return loss, loss_terms |
| else: |
| return recon_result + beta * kl_loss |
|
|
|
|
| def compute_loss_vae( |
| model, structure, x, loss_fn, loss_params, aux_data=False, *, key=None |
| ): |
| """Loss assembler for VAE models. |
| |
| Handles PRNG key threading for the reparameterization trick. |
| Each sample in the batch gets a different key for epsilon sampling. |
| |
| Parameters |
| ---------- |
| model : VariationalAutoEncoder |
| structure : EquilibriumStructure |
| x : Array (B, N*3) |
| Batch of target shapes. |
| loss_fn : callable |
| Task-specific loss (compute_loss_shell_vae or compute_loss_tower_vae). |
| loss_params : dict |
| aux_data : bool |
| key : PRNGKey or None |
| If provided, enables stochastic sampling. If None, uses MAP. |
| """ |
| from jax import vmap |
|
|
| batch_size = x.shape[0] |
|
|
| if key is not None: |
| keys = jax.random.split(key, batch_size) |
|
|
| def predict_one(xi, ki): |
| return model(xi, structure, True, key=ki) |
|
|
| predict_fn = vmap(predict_one, in_axes=(0, 0)) |
| x_hat, data_hat = predict_fn(x, keys) |
| else: |
| def predict_one_det(xi): |
| return model(xi, structure, True, key=None) |
|
|
| predict_fn = vmap(predict_one_det) |
| x_hat, data_hat = predict_fn(x) |
|
|
| return loss_fn(x, x_hat, data_hat, structure, loss_params, aux_data) |
|
|
|
|
| |
| |
| |
|
|
| def compute_error_shape_l1(x, x_hat): |
| """ |
| Calculate the L1 shape reconstruction error, averaged over the batch. |
| |
| Parameters |
| ---------- |
| x: `jax.Array` |
| The target shape. |
| x_hat: `jax.Array` |
| The predicted shape. |
| |
| Returns |
| ------- |
| error: `float` |
| The reconstruction error. |
| """ |
| error = jnp.abs(x - x_hat) |
| batch_error = jnp.sum(error, axis=-1) |
|
|
| return jnp.mean(batch_error, axis=-1) |
|
|
|
|
| def compute_error_shape_l2(x, x_hat): |
| """ |
| Calculate the L2 shape reconstruction error, averaged over the batch. |
| |
| Parameters |
| ---------- |
| x: `jax.Array` |
| The target shape. |
| x_hat: `jax.Array` |
| The predicted shape. |
| |
| Returns |
| ------- |
| error: `float` |
| The reconstruction error. |
| """ |
| error = jnp.square(x - x_hat) |
| batch_error = jnp.sum(error, axis=-1) |
|
|
| return jnp.mean(batch_error, axis=-1) |
|
|
|
|
| |
| |
| |
|
|
| def compute_error_residual(x_hat, params_hat, structure, indices): |
| """ |
| Calculate the residual error, averaged over the batch. This is the physics loss. |
| |
| Parameters |
| ---------- |
| x_hat: `jax.Array` |
| The predicted shape. |
| params_hat: tuple of `jax.Array` |
| The predicted force densities, loads, and fixed positions. |
| structure: `jax_fdm.EquilibriumStructure` |
| A structure with the discretization of the shape. |
| indices: `jax.Array` |
| The indices of the free vertices to calculate the residual at. |
| |
| Returns |
| ------- |
| error: `float` |
| The residual error. |
| """ |
| def calculate_residuals(_x_hat, _params_hat): |
| |
| q_hat, xyz_fixed, loads = _params_hat |
| residual_vectors = vertices_residuals_from_xyz(q_hat, loads, _x_hat, structure) |
| residual_vectors_free = jnp.ravel(residual_vectors[indices, :]) |
|
|
| |
| |
| return jnp.square(residual_vectors_free) |
|
|
| residuals = vmap(calculate_residuals)(x_hat, params_hat) |
| shape_residuals = jnp.sqrt(jnp.sum(residuals, axis=-1)) |
| batch_residual = jnp.mean(shape_residuals, axis=-1) |
|
|
| return batch_residual |
|
|
|
|
| |
| |
| |
|
|
| def compute_q_regularization(q): |
| """ |
| Calculate variance of the force densities for compression and tension. |
| |
| Parameters |
| ---------- |
| q: `jax.Array` |
| The force densities. |
| |
| Returns |
| ------- |
| result: `float` |
| The sum of the two variances. |
| """ |
| sign_q = jnp.sign(q) |
| var_q_pos = jnp.var(q, where=sign_q > 0) |
| var_q_neg = jnp.var(q, where=sign_q < 0) |
|
|
| |
| result = jnp.mean(var_q_pos) + jnp.mean(var_q_neg) |
|
|
| return result |
|
|
|
|
|
|
|
|
| |
| |
| |
|
|
| def print_loss_summary(loss_terms, prefix=None): |
| """ |
| Print a summary of the loss terms. |
| |
| Parameters |
| ---------- |
| loss_terms: `dict` |
| The loss terms. |
| prefix: `str` or `None`, optional |
| The prefix to add to the loss terms printed to the console. |
| """ |
| msg_parts = [] |
| if prefix: |
| msg_parts.append(prefix) |
|
|
| for label, term in loss_terms.items(): |
| part = f"{label.capitalize()}: {term.item():.4f}" |
| msg_parts.append(part) |
|
|
| msg = "\t".join(msg_parts) |
| print(msg) |
|
|