vae-fdm / src /neural_fdm /models.py
Efradeca's picture
Upload folder using huggingface_hub
fc7d689 verified
import equinox as eqx
import jax.numpy as jnp
from jax.lax import stop_gradient
from jax_fdm.equilibrium import EquilibriumModel
from jaxtyping import Array, Bool, Float
from neural_fdm.helpers import (
calculate_area_loads,
calculate_equilibrium_state,
calculate_fd_params_state,
)
# ===============================================================================
# Autoencoders
# ===============================================================================
class AutoEncoder(eqx.Module):
"""
A model that pipes an encoder to a decoder.
Parameters
----------
encoder: `eqx.Module`
The encoder.
decoder: `eqx.Module`
The decoder.
"""
encoder: eqx.Module
decoder: eqx.Module
def __init__(self, encoder, decoder):
self.encoder = encoder
self.decoder = decoder
def __call__(self, x, structure, aux_data=False, *args, **kwargs):
"""
Predict a shape that approximates the target shape.
Parameters
----------
x: `jax.Array`
The target shape.
structure: `jax_fdm.EquilibriumStructure`
A structure with the discretization of the shape.
aux_data: `bool`, optional
Whether to return auxiliary data. The auxiliary data is a tuple of the
force density parameters, the fixed node positions, and the applied loads.
Returns
-------
x_hat: `jax.Array`
The predicted shape.
data: `tuple` of `jax.Array`
The auxiliary data if `aux_data` is `True`.
"""
# NOTE: x must be a flat vector
# GNN encoders need structure for edge connectivity
from neural_fdm.gnn import GNNEncoder
if isinstance(self.encoder, GNNEncoder):
q = self.encoder(x, structure=structure)
else:
q = self.encoder(x)
x_hat = self.decoder(q, x, structure, aux_data)
return x_hat
def encode(self, x):
"""
Generate the latent representation of a target shape.
Parameters
----------
x: `jax.Array`
The target shape.
Returns
-------
q: `jax.Array`
The latent representation.
"""
return self.encoder(x)
def decode(self, q, *args, **kwargs):
"""
Map a latent representation back to shape space.
Parameters
----------
q: `jax.Array`
The latent representation.
Returns
-------
x_hat: `jax.Array`
The predicted shape.
"""
return self.decoder(q, *args, **kwargs)
def predict_states(self, x, structure):
"""
Predict equilibrium and parameter states for visualization.
Parameters
----------
x: `jax.Array`
The target shape.
structure: `jax_fdm.EquilibriumStructure`
A structure with the discretization of the shape.
Returns
-------
eq_state: `jax_fdm.EquilibriumState`
The current equilibrium state of the structure.
fd_params_state: `jax_fdm.EquilibriumParametersState`
The current state of simulation parameters.
"""
# Predict shape
x_hat, params = self(x, structure, True)
return build_states(x_hat, params, structure)
class AutoEncoderPiggy(AutoEncoder):
"""
An autoencoder with a piggybacking decoder.
Parameters
----------
encoder: `eqx.Module`
The encoder.
decoder: `eqx.Module`
The decoder.
decoder_piggy: `eqx.Module`
The piggybacking decoder.
"""
decoder_piggy: eqx.Module
def __init__(self, encoder, decoder, decoder_piggy):
super().__init__(encoder, decoder)
self.decoder_piggy = decoder_piggy
def __call__(self, x, structure, aux_data=False, piggy_mode=True):
"""
Predict a shape that approximates the target shape.
Parameters
----------
x: `jax.Array`
The target shape.
structure: `jax_fdm.EquilibriumStructure`
A structure with the discretization of the shape.
aux_data: `bool`, optional
Whether to return auxiliary data. The auxiliary data is a tuple of the
force density parameters, the fixed node positions, and the applied loads.
piggy_mode: `bool`, optional
Whether to use the piggybacking decoder. If `True`, gradients are not backpropagated
from the piggybacking decoder into the encoder.
Returns
-------
x_hat: `jax.Array` or `tuple` of `jax.Array`
The predicted shape. If `aux_data` is `True`, this is a tuple of the
predicted shape and the auxiliary data.
y_hat: `jax.Array` or `tuple` of `jax.Array`
The predicted shape from the piggybacking decoder. If `aux_data` is `True`,
this is a tuple of the predicted shape and the auxiliary data from
the piggybacking decoder.
"""
q = self.encoder(x)
x_hat = self.decoder(q, x, structure, aux_data)
if piggy_mode:
q = stop_gradient(q)
x_hat = stop_gradient(x_hat)
y_hat = self.decoder_piggy(q, x, structure, aux_data)
return x_hat, y_hat
def decode(self, q, *args, **kwargs):
"""
Map a latent representation back to shape space.
Parameters
----------
q: `jax.Array`
The latent representation.
Returns
-------
x_hat: `jax.Array`
The predicted shape.
"""
return self.decoder_piggy(q, *args, **kwargs)
def predict_states(self, x, structure):
"""
Predict equilibrium and parameter states for visualization.
Parameters
----------
x: `jax.Array`
The target shape.
structure: `jax_fdm.EquilibriumStructure`
A structure with the discretization of the shape.
Returns
-------
eq_state: `jax_fdm.EquilibriumState`
The current equilibrium state of the structure.
fd_params_state: `jax_fdm.EquilibriumParametersState`
The current state of simulation parameters.
"""
# Predict shape
_, pred_piggy = self(x, structure, True)
x_hat, params = pred_piggy
return build_states(x_hat, params, structure)
# ===============================================================================
# Encoders
# ===============================================================================
class Encoder(eqx.Module):
"""
An encoder.
Parameters
----------
edges_signs: `jax.Array`
An array of +1s to denote tension and -1s to denote compression on the edges.
q_shift: `float`, optional
The minimum value of the latent representation.
slice_out: `bool`, optional
Whether to slice the output of the encoder to learn a mapping only
w.r.t. a slice of the target shape.
slice_indices: `jax.Array`, optional
The indices of the points to slice from the target shape.
"""
edges_signs: Array
q_shift: Float
slice_out: Bool
slice_indices: Array
def __init__(
self,
edges_signs,
q_shift=0.0,
slice_out=False,
slice_indices=None,
*args,
**kwargs
):
super().__init__(*args, **kwargs)
self.edges_signs = edges_signs
self.q_shift = q_shift
self.slice_out = slice_out
self.slice_indices = slice_indices
def __call__(self, x):
"""
Map a target shape to a latent representation.
Parameters
----------
x: `jax.Array`
The target shape.
Returns
-------
q: `jax.Array`
The latent representation.
"""
if self.slice_out:
x = jnp.reshape(x, (-1, 3))
x = x[self.slice_indices, :]
x = jnp.ravel(x)
return super().__call__(x)
class MLPEncoder(Encoder, eqx.nn.MLP):
"""
A MLP encoder.
Parameters
----------
edges_signs: `jax.Array`
An array of +1s to denote tension and -1s to denote compression on the edges.
q_shift: `float`, optional
The minimum value of the latent representation.
slice_out: `bool`, optional
Whether to slice the output of the encoder to learn a mapping only
w.r.t. a slice of the target shape.
slice_indices: `jax.Array`, optional
The indices of the points to slice from the target shape.
in_size: `int`
The dimension of the input.
out_size: `int`
The dimension of the output latents.
width_size: `int`
The size of the hidden layers.
depth: `int`
The number of hidden layers, including the output layer.
activation: `Callable`
The activation function for the hidden layers.
final_activation: `Callable`
The activation function for the output layer.
key: `jax.random.PRNGKey`
The random key.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def __call__(self, x):
"""
Map a target shape to a latent representation.
Parameters
----------
x: `jax.Array`
The target shape.
Returns
-------
q: `jax.Array`
The latent representation.
"""
# MLP prediction (must be positive due to softplus activation)
q_hat = super().__call__(x)
# NOTE: negative q denotes compression, positive tension.
return (q_hat + self.q_shift) * self.edges_signs
# ===============================================================================
# Decoders
# ===============================================================================
class Decoder(eqx.Module):
"""
A decoder.
Parameters
----------
load: `float`
The area load applied to the structure.
mask_edges: `jax.Array`
A mask vector for the latent values to zero out.
"""
load: Float
mask_edges: Array
def __init__(self, load, mask_edges, *args, **kwargs):
super().__init__(*args, **kwargs)
self.load = load
self.mask_edges = mask_edges
def __call__(self, q, x, structure, aux_data=False):
"""
Map a latent representation to a target shape.
Parameters
----------
q: `jax.Array`
The latent representation.
x: `jax.Array`
The target shape.
structure: `jax_fdm.EquilibriumStructure`
A structure with the discretization of the shape.
aux_data: `bool`, optional
Whether to return auxiliary data. The auxiliary data is a tuple of the
force density parameters, the fixed node positions, and the applied loads.
Returns
-------
x_hat: `jax.Array`
The predicted shape.
data: `tuple` of `jax.Array`
The auxiliary data if `aux_data` is `True`.
"""
# gather parameters
q = self.get_q(q)
xyz_fixed = self.get_xyz_fixed(x, structure)
loads = self.get_loads(x, structure)
# predict x
x_hat = self.get_xyz((q, xyz_fixed, loads), structure)
if aux_data:
data = (q, xyz_fixed, loads)
return x_hat, data
return x_hat
def get_q(self, q_hat):
"""
Mask the latent values to zero out.
Parameters
----------
q_hat: `jax.Array`
The latent representation.
Returns
-------
q: `jax.Array`
The masked latent representation.
"""
return q_hat * self.mask_edges
def get_xyz_fixed(self, x, structure):
"""
Calculate the fixed vertex positions.
Parameters
----------
x: `jax.Array`
The target shape.
structure: `jax_fdm.EquilibriumStructure`
A structure with the discretization of the shape.
Returns
-------
xyz_fixed: `jax.Array`
The fixed vertex positions.
"""
indices = structure.indices_fixed
x = jnp.reshape(x, (-1, 3))
return x[indices, :]
def get_loads(self, x, structure):
"""
Calculate the applied vertex loads from a global area load.
Parameters
----------
x: `jax.Array`
The target shape.
structure: `jax_fdm.EquilibriumStructure`
A structure with the discretization of the shape.
Returns
-------
loads: `jax.Array`
The applied vertex loads.
"""
if self.load:
return calculate_area_loads(x, structure, self.load)
return jnp.zeros((structure.num_vertices, 3))
def get_xyz(self, params, structure):
"""
Lower level method to predict the target shape. It must be implemented by the subclasses.
Parameters
----------
params: `tuple` of `jax.Array`
The parameters to predict the target shape from.
structure: `jax_fdm.EquilibriumStructure`
A structure with the discretization of the shape.
Returns
-------
x_hat: `jax.Array`
The predicted shape.
"""
raise NotImplementedError
# ===============================================================================
# Physics-based decoders
# ===============================================================================
class FDDecoder(Decoder):
"""
A physics-based force density decoder.
Parameters
----------
model: `jax_fdm.EquilibriumModel`
The force density model.
load: `float`
The area load applied to the structure.
mask_edges: `jax.Array`
A mask vector for the latent values to zero out.
"""
model: EquilibriumModel
def __init__(self, model, *args, **kwargs):
self.model = model
super().__init__(*args, **kwargs)
def get_xyz(self, params, structure):
"""
Predict the target shape from the simulation parameters.
Parameters
----------
params: `tuple` of `jax.Array`
The parameters to predict the target shape from.
structure: `jax_fdm.EquilibriumStructure`
A structure with the discretization of the shape.
Returns
-------
x_hat: `jax.Array`
The predicted shape.
"""
q, xyz_fixed, loads = params
# NOTE: to predict only free vertices, use instead
# self.model.nodes_free_positions(q, xyz_fixed, loads_nodes, structure)
x_hat = self.model.equilibrium(q,
xyz_fixed,
loads,
structure)
return jnp.ravel(x_hat)
class FDDecoderParametrized(FDDecoder):
"""
A physics-based force density decoder that is directly optimizable.
Parameters
----------
q: `jax.Array`
The initial force densities.
model: `jax_fdm.EquilibriumModel`
The force density model.
load: `float`
The area load applied to the structure.
mask_edges: `jax.Array`
A mask vector for the latent values to zero out.
"""
q: Array
def __init__(self, q, *args, **kwargs):
self.q = q
super().__init__(*args, **kwargs)
def __call__(self, x, structure, aux_data=False, *args, **kwargs):
"""
Solve equilibrium using stored force densities.
Parameters
----------
x: `jax.Array`
The target shape (unused; force densities come from self.q).
structure: `jax_fdm.EquilibriumStructure`
A structure with the discretization of the shape.
aux_data: `bool`, optional
Whether to return auxiliary data. The auxiliary data is a tuple of the
force density parameters, the fixed node positions, and the applied loads.
Returns
-------
x_hat: `jax.Array`
The predicted shape.
data: `tuple` of `jax.Array`
The auxiliary data if `aux_data` is `True`.
"""
return super().__call__(self.q, x, structure, aux_data)
def predict_states(self, x, structure):
"""
Predict equilibrium and parameter states for visualization.
Parameters
----------
x: `jax.Array`
The target shape.
structure: `jax_fdm.EquilibriumStructure`
A structure with the discretization of the shape.
Returns
-------
eq_state: `jax_fdm.EquilibriumState`
The current equilibrium state of the structure.
fd_params_state: `jax_fdm.EquilibriumParametersState`
The current state of simulation parameters.
"""
# Predict shape
x_hat, params = self(x, structure, True)
return build_states(x_hat, params, structure)
# ===============================================================================
# Neural decoders
# ===============================================================================
class MLPDecoder(Decoder, eqx.nn.MLP):
"""
A MLP decoder.
Parameters
----------
load: `float`
The area load applied to the structure.
mask_edges: `jax.Array`
A mask vector for the latent values to zero out.
in_size: `int`
The dimension of the input.
out_size: `int`
The dimension of the output.
width_size: `int`
The size of the hidden layers.
depth: `int`
The number of hidden layers, including the output layer.
activation: `Callable`
The activation function for the hidden layers.
key: `jax.random.PRNGKey`
The random key.
"""
# NOTE: Should the inheritance order be reversed?
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def get_xyz(self, params, structure):
"""
Map a latent representation to a target shape.
Parameters
----------
params: `tuple` of `jax.Array`
The parameters to predict the target shape from. The parameters are
the force density parameters, the fixed node positions, and the applied loads.
structure: `jax_fdm.EquilibriumStructure`
A structure with the discretization of the shape.
Returns
-------
x_hat: `jax.Array`
The predicted shape.
"""
# unpack parameters
q, x_fixed, loads = params
# predict x
x_free = self._get_xyz(params)
# Concatenate the position of the free and the fixed nodes
indices = structure.indices_freefixed
x_free = jnp.reshape(x_free, (-1, 3))
x_hat = jnp.concatenate((x_free, x_fixed))[indices, :]
return jnp.ravel(x_hat)
def _get_xyz(self, params):
"""
Map a latent representation to a target shape.
Parameters
----------
params: `tuple` of `jax.Array`
The parameters to predict the target shape from. The parameters are
the force density parameters, the fixed node positions, and the applied loads.
Returns
-------
x_hat: `jax.Array`
The predicted shape.
"""
# unpack parameters
q, x_fixed, loads = params
# NOTE: using this exotic way to call __call__ to map q to x due to multiple inheritance
return eqx.nn.MLP.__call__(self, q)
class MLPDecoderXL(MLPDecoder):
"""
A MLP decoder that maps latents and the boundary conditions (fixed positions and loads) to a shape.
It assumes that the load has only a z-component, while x and y are always 0.
Parameters
----------
load: `float`
The area load applied to the structure.
mask_edges: `jax.Array`
A mask vector for the latent values to zero out.
in_size: `int`
The dimension of the input.
out_size: `int`
The dimension of the output.
width_size: `int`
The size of the hidden layers.
depth: `int`
The number of hidden layers, including the output layer.
activation: `Callable`
The activation function for the hidden layers.
key: `jax.random.PRNGKey`
The random key.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def _get_xyz(self, params):
"""
Map a latent representation and the boundary conditions to a target shape.
Parameters
----------
params: `tuple` of `jax.Array`
The parameters to predict the target shape from. The parameters are
the force density parameters, the fixed node positions, and the applied loads.
Returns
-------
x_hat: `jax.Array`
The predicted shape.
"""
# unpack parameters
q, x_fixed, loads = params
# concatenate long array
x_fixed = jnp.ravel(x_fixed)
loads_z = loads[:, 2] # only z component, x and y are always 0
params = jnp.concatenate((q, x_fixed, loads_z))
return eqx.nn.MLP.__call__(self, params)
# ===============================================================================
# Helpers
# ===============================================================================
def build_states(x_hat, params, structure):
"""
Assemble equilibrium and parameter states for visualization.
Parameters
----------
xyz_hat: `jax.Array`
The predicted shape.
params: `tuple` of `jax.Array`
The parameters to predict the target shape from. The parameters are
the force density parameters, the fixed node positions, and the applied loads.
structure: `jax_fdm.EquilibriumStructure`
A structure with the discretization of the shape.
Returns
-------
eq_state: `jax_fdm.EquilibriumState`
The current equilibrium state of the structure.
fd_params_state: `jax_fdm.EquilibriumParametersState`
The current state of simulation parameters.
"""
# Unpack aux data
q, xyz_fixed, loads = params
# Equilibrium parameters
fd_params_state = calculate_fd_params_state(
q,
xyz_fixed,
loads
)
# Equilibrium state
x_hat = jnp.reshape(x_hat, (-1, 3))
eq_state = calculate_equilibrium_state(
q,
x_hat, # xyz_free | xyz_fixed
loads,
structure
)
return eq_state, fd_params_state