vae-fdm / src /neural_fdm /builders.py
Efradeca's picture
Upload folder using huggingface_hub
fc7d689 verified
from functools import partial
import jax
import jax.numpy as jnp
import optax
from jax_fdm.equilibrium import EquilibriumMeshStructure, EquilibriumModel
from neural_fdm.generators import (
BezierSurfaceAsymmetricPointGenerator,
BezierSurfaceLerpPointGenerator,
BezierSurfacePointGenerator,
BezierSurfaceSymmetricDoublePointGenerator,
BezierSurfaceSymmetricPointGenerator,
CircularTubePointGenerator,
EllipticalTubePointGenerator,
TubePointGenerator,
)
from neural_fdm.losses import (
compute_loss,
compute_loss_shell,
compute_loss_shell_vae,
compute_loss_tower,
compute_loss_tower_vae,
)
from neural_fdm.mesh import create_mesh_from_bezier_generator, create_mesh_from_tube_generator
from neural_fdm.models import (
AutoEncoder,
AutoEncoderPiggy,
FDDecoder,
FDDecoderParametrized,
MLPDecoder,
MLPDecoderXL,
MLPEncoder,
)
# ===============================================================================
# Tower shape generator bounds
# ===============================================================================
def ellipse_minmax_values():
"""
The boundary values (radius 1, radius 2, rotation) for a family of ellipses.
Returns
-------
minval: `list` of `float`
The minimum values for the ellipse.
maxval: `list` of `float`
The maximum values for the ellipse.
Notes
-----
The radii are scale factors relative to the reference radius of a tower.
"""
minval = [0.5, 0.5, 0.0]
maxval = [1.5, 1.5, 0.0]
return minval, maxval
def ellipse_rotated_minmax_values():
"""
The boundary values (radius 1, radius 2, rotation) for a family of rotating ellipses.
Returns
-------
minval: `list` of `float`
The minimum values for the ellipse.
maxval: `list` of `float`
The maximum values for the ellipse.
Notes
-----
The radii are scale factors relative to the reference radius of a tower.
"""
minval = [0.5, 0.5, -15.0]
maxval = [1.5, 1.5, 15.0]
return minval, maxval
def get_tower_generator_minmax_values(name, bounds):
"""
Get the minimum and maximum radii and rotation values for a tower generator.
Parameters
----------
name: `str`
The name of the tower generator.
bounds: `str`
The name of the bounds to use.
Returns
-------
minval: `jax.Array`
The minimum values for the generator.
maxval: `jax.Array`
The maximum values for the generator.
"""
experiments = {
"straight": ellipse_minmax_values,
"twisted": ellipse_rotated_minmax_values,
}
values_fn = experiments.get(bounds)
if not values_fn:
raise KeyError(f"Experiment bounds: {bounds} is currently unsupported!")
# generate values on a quarter tile
minval, maxval = values_fn()
# array-ify
minval = jnp.array(minval)
maxval = jnp.array(maxval)
return minval, maxval
# ===============================================================================
# Bezier shape generator bounds
# ===============================================================================
def pillow_minmax_values():
"""
The boundary 3D coordinates for a family of pillow shapes.
Returns
-------
minval: `list` of `list` of `float`
The minimum values for the pillow on a quarter tile.
maxval: `list` of `list` of `float`
The maximum values for the pillow on a quarter tile.
"""
minval = [
[0.0, 0.0, 1.0],
[0.0, 0.0, 0.0],
[0.0, 0.0, 0.0],
[0.0, 0.0, 0.0]
]
maxval = [
[0.0, 0.0, 10.0],
[0.0, 0.0, 0.0],
[0.0, 0.0, 0.0],
[0.0, 0.0, 0.0]
]
return minval, maxval
def dome_minmax_values():
"""
The boundary 3D coordinates for a family of dome shapes.
Returns
-------
minval: `list` of `list` of `float`
The minimum values for the dome on a quarter tile.
maxval: `list` of `list` of `float`
The maximum values for the dome on a quarter tile.
"""
minval = [
[0.0, 0.0, 1.0],
[-5.0, 0.0, 0.0],
[0.0, -5.0, 0.0],
[0.0, 0.0, 0.0]
]
maxval = [
[0.0, 0.0, 10.0],
[5.0, 0.0, 0.0],
[0.0, 5.0, 0.0],
[0.0, 0.0, 0.0]
]
return minval, maxval
def saddle_minmax_values():
"""
The boundary 3D coordinates for a family of saddle shapes.
Returns
-------
minval: `list` of `list` of `float`
The minimum values for the saddle on a quarter tile.
maxval: `list` of `list` of `float`
The maximum values for the saddle on a quarter tile.
"""
minval = [
[0.0, 0.0, 1.0],
[-5.0, 0.0, 0.0],
[0.0, -5.0, 0.0],
[0.0, 0.0, 0.0]
]
maxval = [
[0.0, 0.0, 10.0],
[5.0, 0.0, 10.0],
[0.0, 5.0, 0.0],
[0.0, 0.0, 0.0]
]
return minval, maxval
def get_bezier_generator_minmax_values(name, bounds):
"""
The boundary 3D coordinates for a family of Bezier shapes.
Parameters
----------
name: `str`
The name of the Bezier generator.
bounds: `str`
The name of the bounds to use.
Returns
-------
minval: `jax.Array`
The minimum values for the generator.
maxval: `jax.Array`
The maximum values for the generator.
"""
experiments = {
"pillow": pillow_minmax_values,
"dome": dome_minmax_values,
"saddle": saddle_minmax_values
}
values_fn = experiments.get(bounds)
if not values_fn:
raise KeyError(f"Experiment bounds: {bounds} is currently unsupported!")
# generate values on a quarter tile (assumes double symmetry)
minval, maxval = values_fn()
# concatenate bounds based on generator type and symmetry
name_parts = name.split("_")
# generator that blends between a symmetry and asymmetric surfaces
if "lerp" in name_parts:
return _get_bezier_generator_minmax_values_blend(minval, maxval)
# generators with symmetry
if "symmetric" in name_parts:
if "double" not in name_parts:
minval, maxval = _get_bezier_generator_minmax_values_symmetric(minval, maxval)
elif "asymmetric" in name_parts:
minval, maxval = _get_bezier_generator_minmax_values_asymmetric(minval, maxval)
# array-ify
minval = jnp.array(minval)
maxval = jnp.array(maxval)
return minval, maxval
def _get_bezier_generator_minmax_values_symmetric(minval, maxval):
minval = minval + minval
maxval = maxval + maxval
return minval, maxval
def _get_bezier_generator_minmax_values_asymmetric(minval, maxval):
minval = minval + minval + minval + minval
maxval = maxval + maxval + maxval + maxval
return minval, maxval
def _get_bezier_generator_minmax_values_blend(minval, maxval):
minval_b, maxval_b = _get_bezier_generator_minmax_values_asymmetric(minval, maxval)
minval_a = jnp.array(minval)
maxval_a = jnp.array(maxval)
minval_b = jnp.array(minval_b)
maxval_b = jnp.array(maxval_b)
return (minval_a, minval_b), (maxval_a, maxval_b)
# ===============================================================================
# Data generators
# ===============================================================================
def build_tube_point_generator(generator_params):
"""
Build a generator that samples random points on a tube.
Parameters
----------
generator_params: `dict`
The hyperparameters for the generator.
Returns
-------
generator: `TubePointGenerator`
The generator.
"""
# unpack parameters
name = generator_params["name"]
bounds = generator_params["bounds"]
height = generator_params["height"]
radius = generator_params["radius"]
num_sides = generator_params["num_sides"]
num_levels = generator_params["num_levels"]
num_rings = generator_params["num_rings"]
# wiggle bounds for task
minval, maxval = get_tower_generator_minmax_values(name, bounds)
# Create data generator
generators = {
"ellipse": EllipticalTubePointGenerator,
"circle": CircularTubePointGenerator,
}
name = generator_params["name"].split("_")[-1]
generator = generators.get(name)
if not generator:
raise ValueError(f"Generator {name} is not supported yet!")
return generator(height, radius, num_sides, num_levels, num_rings, minval, maxval)
def build_bezier_point_generator(generator_params):
"""
Build a generator that samples random points on a Bezier surface.
Parameters
----------
generator_params: `dict`
The hyperparameters for the generator.
Returns
-------
generator: `BezierSurfacePointGenerator`
The generator.
"""
# unpack parameters
name = generator_params["name"]
num_u = generator_params["num_uv"]
num_v = generator_params["num_uv"]
size = generator_params["size"]
num_pts = generator_params["num_points"]
bounds_name = generator_params["bounds"]
lerp_factor = generator_params.get("lerp_factor")
# wiggle bounds for task
minval, maxval = get_bezier_generator_minmax_values(name, bounds_name)
# Create data generator
u = jnp.linspace(0.0, 1.0, num_u)
v = jnp.linspace(0.0, 1.0, num_v)
# Create data generator
generators = {
"bezier_symmetric": BezierSurfaceSymmetricPointGenerator,
"bezier_symmetric_double": BezierSurfaceSymmetricDoublePointGenerator,
"bezier_asymmetric": BezierSurfaceAsymmetricPointGenerator,
"bezier_lerp": BezierSurfaceLerpPointGenerator
}
generator = generators.get(name)
if not generator:
raise ValueError(f"Generator {name} is not supported yet!")
return generator(size, num_pts, u, v, minval, maxval, lerp_factor)
def build_data_generator(config):
"""
Build a generator that samples random points on a target shape.
Parameters
----------
config: `dict`
The configuration for the generator.
Returns
-------
generator: `PointGenerator`
The generator.
"""
# unpack parameters
generator_params = config["generator"]
# pick generator function
generator_builders = {
"bezier": build_bezier_point_generator,
"tower": build_tube_point_generator
}
name = generator_params["name"].split("_")[0]
generator_builder = generator_builders.get(name)
if not generator_builder:
raise ValueError(f"Generator {name} is not supported yet!")
# build bezier generator
return generator_builder(generator_params)
# ===============================================================================
# Mesh
# ===============================================================================
def build_mesh_from_generator(config, generator):
"""
Generate a JAX FDM mesh according to the generator type.
Parameters
----------
config: `dict`
The configuration for the generator.
generator: `PointGenerator`
The generator.
Returns
-------
mesh: `jax_fdm.FDMesh`
The mesh.
"""
if isinstance(generator, BezierSurfacePointGenerator):
mesh_builder = create_mesh_from_bezier_generator
elif isinstance(generator, TubePointGenerator):
mesh_builder = create_mesh_from_tube_generator
else:
raise ValueError(f"Cannot make meshes with generator {generator}!")
return mesh_builder(generator, config)
# ===============================================================================
# Structure (Graph)
# ===============================================================================
def build_connectivity_structure_from_generator(config, generator):
"""
Build a structure from a generator.
Parameters
----------
config: `dict`
The configuration for the generator.
generator: `PointGenerator`
The generator.
Returns
-------
structure: `jax_fdm.EquilibriumStructure`
A structure with the discretization of the shape.
"""
# generate base FD mesh
mesh = build_mesh_from_generator(config, generator)
return EquilibriumMeshStructure.from_mesh(mesh)
# ===============================================================================
# Activation functions
# ===============================================================================
def get_activation_fn(name):
"""
Fetch the activation function.
Parameters
----------
name: `str`
The name of the activation function.
Returns
-------
activation_fn: `Callable`
The activation function.
"""
functions = {
"elu": jax.nn.elu,
"relu": jax.nn.relu,
"softplus": jax.nn.softplus
}
activation_fn = functions.get(name)
if not activation_fn:
raise KeyError(f"Activation name: {name} is currently unsupported!")
return activation_fn
# ===============================================================================
# Optimizers
# ===============================================================================
def get_optimizer_fn(name):
"""
Fetch the optimizer function.
Parameters
----------
name: `str`
The name of the optimizer.
Returns
-------
optimizer_fn: `Callable`
The optimizer function.
"""
optimizers = {
"adam": optax.adam,
"sgd": optax.sgd,
}
optimizer_fn = optimizers.get(name)
if not optimizer_fn:
raise KeyError(f"Optimize name: {name} is currently unsupported!")
return optimizer_fn
def build_optimizer(config):
"""
Construct an optimizer.
Parameters
----------
config: `dict`
The configuration for the optimizer.
Returns
-------
optimizer: `optax.GradientTransformation`
The optimizer.
"""
params = config["optimizer"]
name = params["name"]
learning_rate = params["learning_rate"]
assert isinstance(learning_rate, float)
optimizer_fn = get_optimizer_fn(name)
optimizer = optimizer_fn(learning_rate=learning_rate)
clip_norm = float(params["clip_norm"])
if clip_norm:
print(f"Optimizing with {name} with learning rate {learning_rate} and gradient clipping to global max norm of {clip_norm}")
optimizer = optax.chain(
optax.clip_by_global_norm(clip_norm),
optimizer
)
else:
print(f"Optimizing with {name} with learning rate {learning_rate}")
return optimizer
# ===============================================================================
# Loss functions
# ===============================================================================
def build_loss_function(config, generator):
"""
Build a loss function.
Parameters
----------
config: `dict`
The configuration for the loss function.
generator: `PointGenerator`
A generator.
Returns
-------
loss_fn: `Callable`
The loss function.
"""
task_name = config["generator"]["name"]
loss_params = config["loss"]
is_vae = "vae" in loss_params
if "bezier" in task_name:
_loss_fn = compute_loss_shell_vae if is_vae else compute_loss_shell
elif "tower" in task_name:
# Store the shape and height dimensions for the loss evaluation
loss_params["shape"]["dims"] = generator.shape_tube
loss_params["shape"]["levels_compression"] = generator.levels_rings_comp
loss_params["shape"]["levels_tension"] = generator.levels_rings_tension
_loss_fn = compute_loss_tower_vae if is_vae else compute_loss_tower
loss_fn = partial(
compute_loss,
loss_fn=_loss_fn,
loss_params=loss_params
)
return loss_fn
# ===============================================================================
# Force density solver
# ===============================================================================
def build_fd_model():
"""
Build a force density model.
Returns
-------
fd_model: `jax_fdm.EquilibriumModel`
The force density model.
Notes
-----
This is a dense model, because batching rule is undefined to vectorize a sparse one.
"""
fd_model = EquilibriumModel(
tmax=1,
eta=1e-6,
is_load_local=False,
itersolve_fn=None,
implicit_diff=True,
verbose=False
)
return fd_model
def calculate_edges_mask(mesh):
"""
A mask to indicate what mesh edges are fully supported.
Parameters
----------
mesh: `jax_fdm.FDMesh`
The mesh.
Returns
-------
mask_edges: `jax.Array`
The mask array with 1s for supported edges and 0s for unsupported edges.
"""
mask_edges = []
for edge in mesh.edges():
mask_val = 1.0
if mesh.is_edge_fully_supported(edge):
mask_val = 0.0
# NOTE: for tower task
if mesh.edge_attribute(edge, "tag") == "ring":
if not mesh.is_edge_on_boundary(*edge):
mask_val = 1.0
mask_edges.append(mask_val)
return jnp.array(mask_edges, dtype=jnp.int64)
def calculate_edges_stress_signs(mesh):
"""
Calculate an integer array to indicate what mesh edges are in compression and in tension.
Parameters
----------
mesh: `jax_fdm.FDMesh`
The mesh.
Returns
-------
signs: `jax.Array`
The array with -1s for compression and 1s for tension.
"""
signs = []
for edge in mesh.edges():
sign = -1 # compression by default
# NOTE: for tower task
if mesh.edge_attribute(edge, "tag") == "cable":
sign = 1
signs.append(sign)
return jnp.array(signs, dtype=jnp.int64)
# ===============================================================================
# Decoders
# ===============================================================================
def build_fd_decoder_parametrized(q0, mesh, params):
"""
Build a force density decoder for direct optimization.
Parameters
----------
q0: `jax.Array`
The initial force densities.
mesh: `jax_fdm.FDMesh`
The mesh.
params: `dict`
The hyperparameters for the decoder.
Returns
-------
decoder: `eqx.Module`
The force density decoder.
"""
# unpack hyperparams
load = params["load"]
# create FD model
fd_model = build_fd_model()
# get mask of supported edges
mask_edges = calculate_edges_mask(mesh)
# instantiate FD decoder
decoder = FDDecoderParametrized(
q0,
fd_model,
load,
mask_edges
)
return decoder
def build_fd_decoder(mesh, params):
"""
Build a force density decoder to connect with a neural encoder.
Parameters
----------
mesh: `jax_fdm.FDMesh`
The mesh.
params: `dict`
The hyperparameters for the decoder.
Returns
-------
decoder: `eqx.Module`
The force density decoder.
"""
# unpack hyperparams
load = params["load"]
# create FD model
fd_model = build_fd_model()
# get mask of supported edges
mask_edges = calculate_edges_mask(mesh)
# instantiate FD decoder
decoder = FDDecoder(
fd_model,
load,
mask_edges
)
return decoder
def build_neural_decoder(mesh, key, params):
"""
Build an MLP decoder.
Parameters
----------
mesh: `jax_fdm.FDMesh`
The mesh.
key: `jax.random.PRNGKey`
The random key.
params: `dict`
The hyperparameters for the decoder.
Returns
-------
decoder: `eqx.Module`
The decoder.
"""
# unpack hyperparameters
nn_params, fd_params = params
# get neural network params
include_xl = nn_params["include_params_xl"]
hidden_layer_size = nn_params["hidden_layer_size"]
hidden_layer_num = nn_params["hidden_layer_num"]
activation_name = nn_params["activation_fn_name"]
# get load
load = fd_params["load"]
# mesh quantities
num_vertices = mesh.number_of_vertices()
num_edges = mesh.number_of_edges()
num_vertices_free = len(list(mesh.vertices_free()))
num_vertices_fixed = len(list(mesh.vertices_fixed()))
# get mask of supported edges
mask_edges = calculate_edges_mask(mesh)
# define size of input layer
in_size = num_edges
decoder_cls = MLPDecoder
if include_xl:
in_size += num_vertices_fixed * 3
in_size += num_vertices
decoder_cls = MLPDecoderXL
# instantiate MLP
decoder = decoder_cls(
load=load,
mask_edges=mask_edges,
in_size=in_size,
out_size=num_vertices_free * 3,
width_size=hidden_layer_size,
depth=hidden_layer_num,
activation=get_activation_fn(activation_name),
key=key
)
return decoder
# ===============================================================================
# Encoders
# ===============================================================================
def build_neural_encoder(mesh, key, params, generator):
"""
Build an encoder (MLP or GNN based on config).
Parameters
----------
mesh: `jax_fdm.FDMesh`
The mesh.
key: `jax.random.PRNGKey`
The random key.
params: `dict`
The hyperparameters for the encoder.
generator: `PointGenerator`
The generator.
Returns
-------
encoder: `eqx.Module`
The encoder.
"""
# unpack hyperparameters
nn_params, generator_params = params
# Dispatch to GNN encoder if configured
encoder_type = nn_params.get("encoder_type", "mlp")
if encoder_type == "gnn":
return build_gnn_encoder(mesh, key, params, generator)
hidden_layer_size = nn_params["hidden_layer_size"]
hidden_layer_num = nn_params["hidden_layer_num"]
activation_name = nn_params["activation_fn_name"]
final_activation_name = nn_params["final_activation_fn_name"]
# mesh quantities
num_vertices = mesh.number_of_vertices()
num_edges = mesh.number_of_edges()
# get edges stress signs
edges_signs = calculate_edges_stress_signs(mesh)
# define input size
in_size = num_vertices
is_tower_task = "tower" in generator_params["name"]
if is_tower_task:
in_size = generator_params["num_rings"] * generator_params["num_sides"]
in_size *= 3
# define slices
slice_out = False
slice_indices = None
if is_tower_task:
slice_out = True
slice_indices = generator.indices_rings_comp_ravel
# q shift
q_shift = nn_params["shift"]
# instantiate MLP
encoder = MLPEncoder(
edges_signs=edges_signs,
q_shift=q_shift,
slice_out=slice_out,
slice_indices=slice_indices,
in_size=in_size,
out_size=num_edges,
width_size=hidden_layer_size,
depth=hidden_layer_num,
activation=get_activation_fn(activation_name),
final_activation=get_activation_fn(final_activation_name),
key=key
)
return encoder
def build_gnn_encoder(mesh, key, params, generator):
"""
Build a GNN encoder for variable-topology form-finding.
Parameters
----------
mesh: `jax_fdm.FDMesh`
The mesh.
key: `jax.random.PRNGKey`
The random key.
params: `dict`
The hyperparameters for the encoder.
generator: `PointGenerator`
The generator.
Returns
-------
encoder: `GNNEncoder`
The GNN encoder.
"""
from neural_fdm.gnn import GNNEncoder
from neural_fdm.graph import edge_index_from_mesh
nn_params, generator_params = params
hidden_layer_size = nn_params.get("hidden_layer_size", 128)
num_layers = nn_params.get("hidden_layer_num", 4)
q_shift = nn_params.get("shift", 0.0)
mesh.number_of_edges()
edges_signs = calculate_edges_stress_signs(mesh)
edge_index = edge_index_from_mesh(mesh)
is_tower_task = "tower" in generator_params["name"]
slice_out = False
slice_indices = None
if is_tower_task:
slice_out = True
slice_indices = generator.indices_rings_comp_ravel
encoder = GNNEncoder(
edges_signs=edges_signs,
q_shift=q_shift,
slice_out=slice_out,
slice_indices=slice_indices,
node_feat_dim=3,
edge_feat_dim=4,
hidden_dim=hidden_layer_size,
num_layers=num_layers,
edge_index=edge_index,
key=key,
)
return encoder
def build_variational_gnn_encoder(mesh, key, params, generator):
"""Build a variational GNN encoder for VAE form-finding."""
from neural_fdm.gnn import VariationalGNNEncoder
from neural_fdm.graph import edge_index_from_mesh
nn_params, generator_params = params
hidden_layer_size = nn_params.get("hidden_layer_size", 128)
num_layers = nn_params.get("hidden_layer_num", 4)
q_shift = nn_params.get("shift", 0.0)
edges_signs = calculate_edges_stress_signs(mesh)
edge_index = edge_index_from_mesh(mesh)
is_tower_task = "tower" in generator_params["name"]
slice_out = False
slice_indices = None
if is_tower_task:
slice_out = True
slice_indices = generator.indices_rings_comp_ravel
return VariationalGNNEncoder(
edges_signs=edges_signs,
q_shift=q_shift,
slice_out=slice_out,
slice_indices=slice_indices,
node_feat_dim=3,
edge_feat_dim=4,
hidden_dim=hidden_layer_size,
num_layers=num_layers,
edge_index=edge_index,
key=key,
)
def build_variational_encoder(mesh, key, params, generator):
"""Build a variational encoder for VAE form-finding.
Dispatches to MLP or GNN based on encoder_type in config.
"""
from neural_fdm.variational import VariationalMLPEncoder
nn_params, generator_params = params
encoder_type = nn_params.get("encoder_type", "mlp")
if encoder_type == "gnn":
return build_variational_gnn_encoder(mesh, key, params, generator)
hidden_layer_size = nn_params["hidden_layer_size"]
hidden_layer_num = nn_params["hidden_layer_num"]
activation_name = nn_params["activation_fn_name"]
num_vertices = mesh.number_of_vertices()
num_edges = mesh.number_of_edges()
edges_signs = calculate_edges_stress_signs(mesh)
in_size = num_vertices
is_tower_task = "tower" in generator_params["name"]
if is_tower_task:
in_size = generator_params["num_rings"] * generator_params["num_sides"]
in_size *= 3
slice_out = False
slice_indices = None
if is_tower_task:
slice_out = True
slice_indices = generator.indices_rings_comp_ravel
q_shift = nn_params["shift"]
return VariationalMLPEncoder(
edges_signs=edges_signs,
q_shift=q_shift,
slice_out=slice_out,
slice_indices=slice_indices,
in_size=in_size,
out_size=num_edges,
width_size=hidden_layer_size,
depth=hidden_layer_num,
activation=get_activation_fn(activation_name),
key=key,
)
def build_variational_formfinder(mesh, key, params, generator):
"""Build a VAE formfinder: variational encoder + FDM decoder.
The FDM decoder is identical to the deterministic formfinder.
Only the encoder changes to produce distribution parameters.
Reference: Pastrana et al. (ICLR 2025), Section 6.1.
"""
from neural_fdm.variational import VariationalAutoEncoder
nn_params, fd_params = params
k1, k2 = jax.random.split(key)
encoder = build_variational_encoder(mesh, k1, nn_params, generator)
decoder = build_fd_decoder(mesh, fd_params)
return VariationalAutoEncoder(encoder, decoder)
# ===============================================================================
# Autoencoder models
# ===============================================================================
def build_neural_formfinder(mesh, key, params, generator):
"""
Instantiate an autoencoder model with a neural encoder and a mechanical decoder.
Parameters
----------
mesh: `jax_fdm.FDMesh`
The mesh.
key: `jax.random.PRNGKey`
The random key.
params: `dict`
The hyperparameters for the model.
generator: `PointGenerator`
The generator.
Returns
-------
model: `eqx.Module`
The autoencoder model.
"""
# Unpack hyperparams
nn_params, fd_params = params
# Create MLP encoder
encoder = build_neural_encoder(mesh, key, nn_params, generator)
# Build FD decoder
decoder = build_fd_decoder(mesh, fd_params)
# Assemble autoencoder
model = AutoEncoder(encoder, decoder)
return model
def build_neural_autoencoder(mesh, key, params, generator):
"""
Instantiate a fully neural autoencoder model.
Parameters
----------
mesh: `jax_fdm.FDMesh`
The mesh.
key: `jax.random.PRNGKey`
The random key.
params: `dict`
The hyperparameters for the model.
generator: `PointGenerator`
The generator.
Returns
-------
model: `eqx.Module`
The autoencoder model.
"""
# Unpack hyperparams
enc_params, dec_params = params
# Create MLP encoder
encoder = build_neural_encoder(mesh, key, enc_params, generator)
# Build MLP decoder
decoder = build_neural_decoder(mesh, key, dec_params)
# Assemble autoencoder
model = AutoEncoder(encoder, decoder)
return model
def build_neural_autoencoder_piggy(mesh, key, params, generator):
"""
Instantiate an autoencoder model with a piggybacking neural decoder.
Parameters
----------
mesh: `jax_fdm.FDMesh`
The mesh.
key: `jax.random.PRNGKey`
The random key.
params: `dict`
The hyperparameters for the model.
generator: `PointGenerator`
The generator.
Returns
-------
model: `eqx.Module`
The autoencoder model.
"""
# Unpack hyperparams
enc_params, dec_params, fd_params = params
# Create MLP encoder
encoder = build_neural_encoder(mesh, key, enc_params, generator)
# Build FD decoder
decoder = build_fd_decoder(mesh, fd_params)
# Build MLP decoder
decoder_piggy = build_neural_decoder(mesh, key, dec_params)
# Assemble autoencoder
model = AutoEncoderPiggy(encoder, decoder, decoder_piggy)
return model
def build_neural_model(name, config, generator, model_key):
"""
Build a neural model.
Parameters
----------
name: `str`
The name of the model.
config: `dict`
The configuration for the model.
generator: `PointGenerator`
The generator.
model_key: `jax.random.PRNGKey`
The random key.
Returns
-------
model: `eqx.Module`
The autoencoder model.
"""
# generate base FD mesh
mesh = build_mesh_from_generator(config, generator)
# build model
fd_params = config["fdm"]
decoder_params = config["decoder"]
encoder_params = config["encoder"]
generator_params = config["generator"]
encoder_params = (encoder_params, generator_params)
# select model
if name == "formfinder":
build_fn = build_neural_formfinder
params = (encoder_params, fd_params)
elif name == "autoencoder":
build_fn = build_neural_autoencoder
params = (encoder_params, (decoder_params, fd_params))
elif name == "piggy":
build_fn = build_neural_autoencoder_piggy
params = (encoder_params, (decoder_params, fd_params), fd_params)
elif name == "variational_formfinder":
build_fn = build_variational_formfinder
params = (encoder_params, fd_params)
else:
raise ValueError(f"Model name {name} is unsupported")
return build_fn(mesh, model_key, params, generator)