| from functools import partial |
|
|
| import jax |
| import jax.numpy as jnp |
| import optax |
| from jax_fdm.equilibrium import EquilibriumMeshStructure, EquilibriumModel |
|
|
| from neural_fdm.generators import ( |
| BezierSurfaceAsymmetricPointGenerator, |
| BezierSurfaceLerpPointGenerator, |
| BezierSurfacePointGenerator, |
| BezierSurfaceSymmetricDoublePointGenerator, |
| BezierSurfaceSymmetricPointGenerator, |
| CircularTubePointGenerator, |
| EllipticalTubePointGenerator, |
| TubePointGenerator, |
| ) |
| from neural_fdm.losses import ( |
| compute_loss, |
| compute_loss_shell, |
| compute_loss_shell_vae, |
| compute_loss_tower, |
| compute_loss_tower_vae, |
| ) |
| from neural_fdm.mesh import create_mesh_from_bezier_generator, create_mesh_from_tube_generator |
| from neural_fdm.models import ( |
| AutoEncoder, |
| AutoEncoderPiggy, |
| FDDecoder, |
| FDDecoderParametrized, |
| MLPDecoder, |
| MLPDecoderXL, |
| MLPEncoder, |
| ) |
|
|
| |
| |
| |
|
|
| def ellipse_minmax_values(): |
| """ |
| The boundary values (radius 1, radius 2, rotation) for a family of ellipses. |
| |
| Returns |
| ------- |
| minval: `list` of `float` |
| The minimum values for the ellipse. |
| maxval: `list` of `float` |
| The maximum values for the ellipse. |
| |
| Notes |
| ----- |
| The radii are scale factors relative to the reference radius of a tower. |
| """ |
| minval = [0.5, 0.5, 0.0] |
| maxval = [1.5, 1.5, 0.0] |
|
|
| return minval, maxval |
|
|
|
|
| def ellipse_rotated_minmax_values(): |
| """ |
| The boundary values (radius 1, radius 2, rotation) for a family of rotating ellipses. |
| |
| Returns |
| ------- |
| minval: `list` of `float` |
| The minimum values for the ellipse. |
| maxval: `list` of `float` |
| The maximum values for the ellipse. |
| |
| Notes |
| ----- |
| The radii are scale factors relative to the reference radius of a tower. |
| """ |
| minval = [0.5, 0.5, -15.0] |
| maxval = [1.5, 1.5, 15.0] |
|
|
| return minval, maxval |
|
|
|
|
| def get_tower_generator_minmax_values(name, bounds): |
| """ |
| Get the minimum and maximum radii and rotation values for a tower generator. |
| |
| Parameters |
| ---------- |
| name: `str` |
| The name of the tower generator. |
| bounds: `str` |
| The name of the bounds to use. |
| |
| Returns |
| ------- |
| minval: `jax.Array` |
| The minimum values for the generator. |
| maxval: `jax.Array` |
| The maximum values for the generator. |
| """ |
| experiments = { |
| "straight": ellipse_minmax_values, |
| "twisted": ellipse_rotated_minmax_values, |
| } |
|
|
| values_fn = experiments.get(bounds) |
| if not values_fn: |
| raise KeyError(f"Experiment bounds: {bounds} is currently unsupported!") |
|
|
| |
| minval, maxval = values_fn() |
|
|
| |
| minval = jnp.array(minval) |
| maxval = jnp.array(maxval) |
|
|
| return minval, maxval |
|
|
|
|
| |
| |
| |
|
|
| def pillow_minmax_values(): |
| """ |
| The boundary 3D coordinates for a family of pillow shapes. |
| |
| Returns |
| ------- |
| minval: `list` of `list` of `float` |
| The minimum values for the pillow on a quarter tile. |
| maxval: `list` of `list` of `float` |
| The maximum values for the pillow on a quarter tile. |
| """ |
| minval = [ |
| [0.0, 0.0, 1.0], |
| [0.0, 0.0, 0.0], |
| [0.0, 0.0, 0.0], |
| [0.0, 0.0, 0.0] |
| ] |
|
|
| maxval = [ |
| [0.0, 0.0, 10.0], |
| [0.0, 0.0, 0.0], |
| [0.0, 0.0, 0.0], |
| [0.0, 0.0, 0.0] |
| ] |
|
|
| return minval, maxval |
|
|
|
|
| def dome_minmax_values(): |
| """ |
| The boundary 3D coordinates for a family of dome shapes. |
| |
| Returns |
| ------- |
| minval: `list` of `list` of `float` |
| The minimum values for the dome on a quarter tile. |
| maxval: `list` of `list` of `float` |
| The maximum values for the dome on a quarter tile. |
| """ |
| minval = [ |
| [0.0, 0.0, 1.0], |
| [-5.0, 0.0, 0.0], |
| [0.0, -5.0, 0.0], |
| [0.0, 0.0, 0.0] |
| ] |
|
|
| maxval = [ |
| [0.0, 0.0, 10.0], |
| [5.0, 0.0, 0.0], |
| [0.0, 5.0, 0.0], |
| [0.0, 0.0, 0.0] |
| ] |
|
|
| return minval, maxval |
|
|
|
|
| def saddle_minmax_values(): |
| """ |
| The boundary 3D coordinates for a family of saddle shapes. |
| |
| Returns |
| ------- |
| minval: `list` of `list` of `float` |
| The minimum values for the saddle on a quarter tile. |
| maxval: `list` of `list` of `float` |
| The maximum values for the saddle on a quarter tile. |
| """ |
| minval = [ |
| [0.0, 0.0, 1.0], |
| [-5.0, 0.0, 0.0], |
| [0.0, -5.0, 0.0], |
| [0.0, 0.0, 0.0] |
| ] |
|
|
| maxval = [ |
| [0.0, 0.0, 10.0], |
| [5.0, 0.0, 10.0], |
| [0.0, 5.0, 0.0], |
| [0.0, 0.0, 0.0] |
| ] |
|
|
| return minval, maxval |
|
|
|
|
| def get_bezier_generator_minmax_values(name, bounds): |
| """ |
| The boundary 3D coordinates for a family of Bezier shapes. |
| |
| Parameters |
| ---------- |
| name: `str` |
| The name of the Bezier generator. |
| bounds: `str` |
| The name of the bounds to use. |
| |
| Returns |
| ------- |
| minval: `jax.Array` |
| The minimum values for the generator. |
| maxval: `jax.Array` |
| The maximum values for the generator. |
| """ |
| experiments = { |
| "pillow": pillow_minmax_values, |
| "dome": dome_minmax_values, |
| "saddle": saddle_minmax_values |
| } |
|
|
| values_fn = experiments.get(bounds) |
| if not values_fn: |
| raise KeyError(f"Experiment bounds: {bounds} is currently unsupported!") |
|
|
| |
| minval, maxval = values_fn() |
|
|
| |
| name_parts = name.split("_") |
|
|
| |
| if "lerp" in name_parts: |
| return _get_bezier_generator_minmax_values_blend(minval, maxval) |
|
|
| |
| if "symmetric" in name_parts: |
| if "double" not in name_parts: |
| minval, maxval = _get_bezier_generator_minmax_values_symmetric(minval, maxval) |
| elif "asymmetric" in name_parts: |
| minval, maxval = _get_bezier_generator_minmax_values_asymmetric(minval, maxval) |
|
|
| |
| minval = jnp.array(minval) |
| maxval = jnp.array(maxval) |
|
|
| return minval, maxval |
|
|
|
|
| def _get_bezier_generator_minmax_values_symmetric(minval, maxval): |
| minval = minval + minval |
| maxval = maxval + maxval |
|
|
| return minval, maxval |
|
|
|
|
| def _get_bezier_generator_minmax_values_asymmetric(minval, maxval): |
| minval = minval + minval + minval + minval |
| maxval = maxval + maxval + maxval + maxval |
|
|
| return minval, maxval |
|
|
|
|
| def _get_bezier_generator_minmax_values_blend(minval, maxval): |
| minval_b, maxval_b = _get_bezier_generator_minmax_values_asymmetric(minval, maxval) |
|
|
| minval_a = jnp.array(minval) |
| maxval_a = jnp.array(maxval) |
|
|
| minval_b = jnp.array(minval_b) |
| maxval_b = jnp.array(maxval_b) |
|
|
| return (minval_a, minval_b), (maxval_a, maxval_b) |
|
|
|
|
| |
| |
| |
|
|
| def build_tube_point_generator(generator_params): |
| """ |
| Build a generator that samples random points on a tube. |
| |
| Parameters |
| ---------- |
| generator_params: `dict` |
| The hyperparameters for the generator. |
| |
| Returns |
| ------- |
| generator: `TubePointGenerator` |
| The generator. |
| """ |
| |
| name = generator_params["name"] |
| bounds = generator_params["bounds"] |
|
|
| height = generator_params["height"] |
| radius = generator_params["radius"] |
| num_sides = generator_params["num_sides"] |
| num_levels = generator_params["num_levels"] |
| num_rings = generator_params["num_rings"] |
|
|
| |
| minval, maxval = get_tower_generator_minmax_values(name, bounds) |
|
|
| |
| generators = { |
| "ellipse": EllipticalTubePointGenerator, |
| "circle": CircularTubePointGenerator, |
| } |
|
|
| name = generator_params["name"].split("_")[-1] |
| generator = generators.get(name) |
| if not generator: |
| raise ValueError(f"Generator {name} is not supported yet!") |
|
|
| return generator(height, radius, num_sides, num_levels, num_rings, minval, maxval) |
|
|
|
|
| def build_bezier_point_generator(generator_params): |
| """ |
| Build a generator that samples random points on a Bezier surface. |
| |
| Parameters |
| ---------- |
| generator_params: `dict` |
| The hyperparameters for the generator. |
| |
| Returns |
| ------- |
| generator: `BezierSurfacePointGenerator` |
| The generator. |
| """ |
| |
| name = generator_params["name"] |
| num_u = generator_params["num_uv"] |
| num_v = generator_params["num_uv"] |
| size = generator_params["size"] |
| num_pts = generator_params["num_points"] |
| bounds_name = generator_params["bounds"] |
| lerp_factor = generator_params.get("lerp_factor") |
|
|
| |
| minval, maxval = get_bezier_generator_minmax_values(name, bounds_name) |
|
|
| |
| u = jnp.linspace(0.0, 1.0, num_u) |
| v = jnp.linspace(0.0, 1.0, num_v) |
|
|
| |
| generators = { |
| "bezier_symmetric": BezierSurfaceSymmetricPointGenerator, |
| "bezier_symmetric_double": BezierSurfaceSymmetricDoublePointGenerator, |
| "bezier_asymmetric": BezierSurfaceAsymmetricPointGenerator, |
| "bezier_lerp": BezierSurfaceLerpPointGenerator |
| } |
|
|
| generator = generators.get(name) |
| if not generator: |
| raise ValueError(f"Generator {name} is not supported yet!") |
|
|
| return generator(size, num_pts, u, v, minval, maxval, lerp_factor) |
|
|
|
|
| def build_data_generator(config): |
| """ |
| Build a generator that samples random points on a target shape. |
| |
| Parameters |
| ---------- |
| config: `dict` |
| The configuration for the generator. |
| |
| Returns |
| ------- |
| generator: `PointGenerator` |
| The generator. |
| """ |
| |
| generator_params = config["generator"] |
|
|
| |
| generator_builders = { |
| "bezier": build_bezier_point_generator, |
| "tower": build_tube_point_generator |
| } |
|
|
| name = generator_params["name"].split("_")[0] |
|
|
| generator_builder = generator_builders.get(name) |
| if not generator_builder: |
| raise ValueError(f"Generator {name} is not supported yet!") |
|
|
| |
| return generator_builder(generator_params) |
|
|
|
|
| |
| |
| |
|
|
| def build_mesh_from_generator(config, generator): |
| """ |
| Generate a JAX FDM mesh according to the generator type. |
| |
| Parameters |
| ---------- |
| config: `dict` |
| The configuration for the generator. |
| generator: `PointGenerator` |
| The generator. |
| |
| Returns |
| ------- |
| mesh: `jax_fdm.FDMesh` |
| The mesh. |
| """ |
| if isinstance(generator, BezierSurfacePointGenerator): |
| mesh_builder = create_mesh_from_bezier_generator |
| elif isinstance(generator, TubePointGenerator): |
| mesh_builder = create_mesh_from_tube_generator |
| else: |
| raise ValueError(f"Cannot make meshes with generator {generator}!") |
|
|
| return mesh_builder(generator, config) |
|
|
|
|
| |
| |
| |
|
|
| def build_connectivity_structure_from_generator(config, generator): |
| """ |
| Build a structure from a generator. |
| |
| Parameters |
| ---------- |
| config: `dict` |
| The configuration for the generator. |
| generator: `PointGenerator` |
| The generator. |
| |
| Returns |
| ------- |
| structure: `jax_fdm.EquilibriumStructure` |
| A structure with the discretization of the shape. |
| """ |
| |
| mesh = build_mesh_from_generator(config, generator) |
|
|
| return EquilibriumMeshStructure.from_mesh(mesh) |
|
|
|
|
| |
| |
| |
|
|
| def get_activation_fn(name): |
| """ |
| Fetch the activation function. |
| |
| Parameters |
| ---------- |
| name: `str` |
| The name of the activation function. |
| |
| Returns |
| ------- |
| activation_fn: `Callable` |
| The activation function. |
| """ |
| functions = { |
| "elu": jax.nn.elu, |
| "relu": jax.nn.relu, |
| "softplus": jax.nn.softplus |
| } |
|
|
| activation_fn = functions.get(name) |
| if not activation_fn: |
| raise KeyError(f"Activation name: {name} is currently unsupported!") |
|
|
| return activation_fn |
|
|
|
|
| |
| |
| |
|
|
| def get_optimizer_fn(name): |
| """ |
| Fetch the optimizer function. |
| |
| Parameters |
| ---------- |
| name: `str` |
| The name of the optimizer. |
| |
| Returns |
| ------- |
| optimizer_fn: `Callable` |
| The optimizer function. |
| """ |
| optimizers = { |
| "adam": optax.adam, |
| "sgd": optax.sgd, |
| } |
|
|
| optimizer_fn = optimizers.get(name) |
| if not optimizer_fn: |
| raise KeyError(f"Optimize name: {name} is currently unsupported!") |
|
|
| return optimizer_fn |
|
|
|
|
| def build_optimizer(config): |
| """ |
| Construct an optimizer. |
| |
| Parameters |
| ---------- |
| config: `dict` |
| The configuration for the optimizer. |
| |
| Returns |
| ------- |
| optimizer: `optax.GradientTransformation` |
| The optimizer. |
| """ |
| params = config["optimizer"] |
|
|
| name = params["name"] |
| learning_rate = params["learning_rate"] |
| assert isinstance(learning_rate, float) |
|
|
| optimizer_fn = get_optimizer_fn(name) |
| optimizer = optimizer_fn(learning_rate=learning_rate) |
|
|
| clip_norm = float(params["clip_norm"]) |
| if clip_norm: |
| print(f"Optimizing with {name} with learning rate {learning_rate} and gradient clipping to global max norm of {clip_norm}") |
| optimizer = optax.chain( |
| optax.clip_by_global_norm(clip_norm), |
| optimizer |
| ) |
| else: |
| print(f"Optimizing with {name} with learning rate {learning_rate}") |
|
|
| return optimizer |
|
|
|
|
| |
| |
| |
|
|
| def build_loss_function(config, generator): |
| """ |
| Build a loss function. |
| |
| Parameters |
| ---------- |
| config: `dict` |
| The configuration for the loss function. |
| generator: `PointGenerator` |
| A generator. |
| |
| Returns |
| ------- |
| loss_fn: `Callable` |
| The loss function. |
| """ |
| task_name = config["generator"]["name"] |
| loss_params = config["loss"] |
| is_vae = "vae" in loss_params |
|
|
| if "bezier" in task_name: |
| _loss_fn = compute_loss_shell_vae if is_vae else compute_loss_shell |
|
|
| elif "tower" in task_name: |
| |
| loss_params["shape"]["dims"] = generator.shape_tube |
| loss_params["shape"]["levels_compression"] = generator.levels_rings_comp |
| loss_params["shape"]["levels_tension"] = generator.levels_rings_tension |
|
|
| _loss_fn = compute_loss_tower_vae if is_vae else compute_loss_tower |
|
|
| loss_fn = partial( |
| compute_loss, |
| loss_fn=_loss_fn, |
| loss_params=loss_params |
| ) |
|
|
| return loss_fn |
|
|
|
|
| |
| |
| |
|
|
| def build_fd_model(): |
| """ |
| Build a force density model. |
| |
| Returns |
| ------- |
| fd_model: `jax_fdm.EquilibriumModel` |
| The force density model. |
| |
| Notes |
| ----- |
| This is a dense model, because batching rule is undefined to vectorize a sparse one. |
| """ |
| fd_model = EquilibriumModel( |
| tmax=1, |
| eta=1e-6, |
| is_load_local=False, |
| itersolve_fn=None, |
| implicit_diff=True, |
| verbose=False |
| ) |
|
|
| return fd_model |
|
|
|
|
| def calculate_edges_mask(mesh): |
| """ |
| A mask to indicate what mesh edges are fully supported. |
| |
| Parameters |
| ---------- |
| mesh: `jax_fdm.FDMesh` |
| The mesh. |
| |
| Returns |
| ------- |
| mask_edges: `jax.Array` |
| The mask array with 1s for supported edges and 0s for unsupported edges. |
| """ |
| mask_edges = [] |
| for edge in mesh.edges(): |
| mask_val = 1.0 |
| if mesh.is_edge_fully_supported(edge): |
| mask_val = 0.0 |
| |
| if mesh.edge_attribute(edge, "tag") == "ring": |
| if not mesh.is_edge_on_boundary(*edge): |
| mask_val = 1.0 |
| mask_edges.append(mask_val) |
|
|
| return jnp.array(mask_edges, dtype=jnp.int64) |
|
|
|
|
| def calculate_edges_stress_signs(mesh): |
| """ |
| Calculate an integer array to indicate what mesh edges are in compression and in tension. |
| |
| Parameters |
| ---------- |
| mesh: `jax_fdm.FDMesh` |
| The mesh. |
| |
| Returns |
| ------- |
| signs: `jax.Array` |
| The array with -1s for compression and 1s for tension. |
| """ |
| signs = [] |
| for edge in mesh.edges(): |
| sign = -1 |
| |
| if mesh.edge_attribute(edge, "tag") == "cable": |
| sign = 1 |
| signs.append(sign) |
|
|
| return jnp.array(signs, dtype=jnp.int64) |
|
|
|
|
| |
| |
| |
|
|
| def build_fd_decoder_parametrized(q0, mesh, params): |
| """ |
| Build a force density decoder for direct optimization. |
| |
| Parameters |
| ---------- |
| q0: `jax.Array` |
| The initial force densities. |
| mesh: `jax_fdm.FDMesh` |
| The mesh. |
| params: `dict` |
| The hyperparameters for the decoder. |
| |
| Returns |
| ------- |
| decoder: `eqx.Module` |
| The force density decoder. |
| """ |
| |
| load = params["load"] |
|
|
| |
| fd_model = build_fd_model() |
|
|
| |
| mask_edges = calculate_edges_mask(mesh) |
|
|
| |
| decoder = FDDecoderParametrized( |
| q0, |
| fd_model, |
| load, |
| mask_edges |
| ) |
|
|
| return decoder |
|
|
|
|
| def build_fd_decoder(mesh, params): |
| """ |
| Build a force density decoder to connect with a neural encoder. |
| |
| Parameters |
| ---------- |
| mesh: `jax_fdm.FDMesh` |
| The mesh. |
| params: `dict` |
| The hyperparameters for the decoder. |
| |
| Returns |
| ------- |
| decoder: `eqx.Module` |
| The force density decoder. |
| """ |
| |
| load = params["load"] |
|
|
| |
| fd_model = build_fd_model() |
|
|
| |
| mask_edges = calculate_edges_mask(mesh) |
|
|
| |
| decoder = FDDecoder( |
| fd_model, |
| load, |
| mask_edges |
| ) |
|
|
| return decoder |
|
|
|
|
| def build_neural_decoder(mesh, key, params): |
| """ |
| Build an MLP decoder. |
| |
| Parameters |
| ---------- |
| mesh: `jax_fdm.FDMesh` |
| The mesh. |
| key: `jax.random.PRNGKey` |
| The random key. |
| params: `dict` |
| The hyperparameters for the decoder. |
| |
| Returns |
| ------- |
| decoder: `eqx.Module` |
| The decoder. |
| """ |
| |
| nn_params, fd_params = params |
|
|
| |
| include_xl = nn_params["include_params_xl"] |
| hidden_layer_size = nn_params["hidden_layer_size"] |
| hidden_layer_num = nn_params["hidden_layer_num"] |
| activation_name = nn_params["activation_fn_name"] |
|
|
| |
| load = fd_params["load"] |
|
|
| |
| num_vertices = mesh.number_of_vertices() |
| num_edges = mesh.number_of_edges() |
| num_vertices_free = len(list(mesh.vertices_free())) |
| num_vertices_fixed = len(list(mesh.vertices_fixed())) |
|
|
| |
| mask_edges = calculate_edges_mask(mesh) |
|
|
| |
| in_size = num_edges |
| decoder_cls = MLPDecoder |
|
|
| if include_xl: |
| in_size += num_vertices_fixed * 3 |
| in_size += num_vertices |
| decoder_cls = MLPDecoderXL |
|
|
| |
| decoder = decoder_cls( |
| load=load, |
| mask_edges=mask_edges, |
| in_size=in_size, |
| out_size=num_vertices_free * 3, |
| width_size=hidden_layer_size, |
| depth=hidden_layer_num, |
| activation=get_activation_fn(activation_name), |
| key=key |
| ) |
|
|
| return decoder |
|
|
|
|
| |
| |
| |
|
|
| def build_neural_encoder(mesh, key, params, generator): |
| """ |
| Build an encoder (MLP or GNN based on config). |
| |
| Parameters |
| ---------- |
| mesh: `jax_fdm.FDMesh` |
| The mesh. |
| key: `jax.random.PRNGKey` |
| The random key. |
| params: `dict` |
| The hyperparameters for the encoder. |
| generator: `PointGenerator` |
| The generator. |
| |
| Returns |
| ------- |
| encoder: `eqx.Module` |
| The encoder. |
| """ |
| |
| nn_params, generator_params = params |
|
|
| |
| encoder_type = nn_params.get("encoder_type", "mlp") |
| if encoder_type == "gnn": |
| return build_gnn_encoder(mesh, key, params, generator) |
|
|
| hidden_layer_size = nn_params["hidden_layer_size"] |
| hidden_layer_num = nn_params["hidden_layer_num"] |
| activation_name = nn_params["activation_fn_name"] |
| final_activation_name = nn_params["final_activation_fn_name"] |
|
|
| |
| num_vertices = mesh.number_of_vertices() |
| num_edges = mesh.number_of_edges() |
|
|
| |
| edges_signs = calculate_edges_stress_signs(mesh) |
|
|
| |
| in_size = num_vertices |
| is_tower_task = "tower" in generator_params["name"] |
| if is_tower_task: |
| in_size = generator_params["num_rings"] * generator_params["num_sides"] |
| in_size *= 3 |
|
|
| |
| slice_out = False |
| slice_indices = None |
| if is_tower_task: |
| slice_out = True |
| slice_indices = generator.indices_rings_comp_ravel |
|
|
| |
| q_shift = nn_params["shift"] |
|
|
| |
| encoder = MLPEncoder( |
| edges_signs=edges_signs, |
| q_shift=q_shift, |
| slice_out=slice_out, |
| slice_indices=slice_indices, |
| in_size=in_size, |
| out_size=num_edges, |
| width_size=hidden_layer_size, |
| depth=hidden_layer_num, |
| activation=get_activation_fn(activation_name), |
| final_activation=get_activation_fn(final_activation_name), |
| key=key |
| ) |
|
|
| return encoder |
|
|
|
|
| def build_gnn_encoder(mesh, key, params, generator): |
| """ |
| Build a GNN encoder for variable-topology form-finding. |
| |
| Parameters |
| ---------- |
| mesh: `jax_fdm.FDMesh` |
| The mesh. |
| key: `jax.random.PRNGKey` |
| The random key. |
| params: `dict` |
| The hyperparameters for the encoder. |
| generator: `PointGenerator` |
| The generator. |
| |
| Returns |
| ------- |
| encoder: `GNNEncoder` |
| The GNN encoder. |
| """ |
| from neural_fdm.gnn import GNNEncoder |
| from neural_fdm.graph import edge_index_from_mesh |
|
|
| nn_params, generator_params = params |
|
|
| hidden_layer_size = nn_params.get("hidden_layer_size", 128) |
| num_layers = nn_params.get("hidden_layer_num", 4) |
| q_shift = nn_params.get("shift", 0.0) |
|
|
| mesh.number_of_edges() |
| edges_signs = calculate_edges_stress_signs(mesh) |
| edge_index = edge_index_from_mesh(mesh) |
|
|
| is_tower_task = "tower" in generator_params["name"] |
| slice_out = False |
| slice_indices = None |
| if is_tower_task: |
| slice_out = True |
| slice_indices = generator.indices_rings_comp_ravel |
|
|
| encoder = GNNEncoder( |
| edges_signs=edges_signs, |
| q_shift=q_shift, |
| slice_out=slice_out, |
| slice_indices=slice_indices, |
| node_feat_dim=3, |
| edge_feat_dim=4, |
| hidden_dim=hidden_layer_size, |
| num_layers=num_layers, |
| edge_index=edge_index, |
| key=key, |
| ) |
|
|
| return encoder |
|
|
|
|
| def build_variational_gnn_encoder(mesh, key, params, generator): |
| """Build a variational GNN encoder for VAE form-finding.""" |
| from neural_fdm.gnn import VariationalGNNEncoder |
| from neural_fdm.graph import edge_index_from_mesh |
|
|
| nn_params, generator_params = params |
| hidden_layer_size = nn_params.get("hidden_layer_size", 128) |
| num_layers = nn_params.get("hidden_layer_num", 4) |
| q_shift = nn_params.get("shift", 0.0) |
| edges_signs = calculate_edges_stress_signs(mesh) |
| edge_index = edge_index_from_mesh(mesh) |
|
|
| is_tower_task = "tower" in generator_params["name"] |
| slice_out = False |
| slice_indices = None |
| if is_tower_task: |
| slice_out = True |
| slice_indices = generator.indices_rings_comp_ravel |
|
|
| return VariationalGNNEncoder( |
| edges_signs=edges_signs, |
| q_shift=q_shift, |
| slice_out=slice_out, |
| slice_indices=slice_indices, |
| node_feat_dim=3, |
| edge_feat_dim=4, |
| hidden_dim=hidden_layer_size, |
| num_layers=num_layers, |
| edge_index=edge_index, |
| key=key, |
| ) |
|
|
|
|
| def build_variational_encoder(mesh, key, params, generator): |
| """Build a variational encoder for VAE form-finding. |
| |
| Dispatches to MLP or GNN based on encoder_type in config. |
| """ |
| from neural_fdm.variational import VariationalMLPEncoder |
|
|
| nn_params, generator_params = params |
|
|
| encoder_type = nn_params.get("encoder_type", "mlp") |
| if encoder_type == "gnn": |
| return build_variational_gnn_encoder(mesh, key, params, generator) |
| hidden_layer_size = nn_params["hidden_layer_size"] |
| hidden_layer_num = nn_params["hidden_layer_num"] |
| activation_name = nn_params["activation_fn_name"] |
|
|
| num_vertices = mesh.number_of_vertices() |
| num_edges = mesh.number_of_edges() |
| edges_signs = calculate_edges_stress_signs(mesh) |
|
|
| in_size = num_vertices |
| is_tower_task = "tower" in generator_params["name"] |
| if is_tower_task: |
| in_size = generator_params["num_rings"] * generator_params["num_sides"] |
| in_size *= 3 |
|
|
| slice_out = False |
| slice_indices = None |
| if is_tower_task: |
| slice_out = True |
| slice_indices = generator.indices_rings_comp_ravel |
|
|
| q_shift = nn_params["shift"] |
|
|
| return VariationalMLPEncoder( |
| edges_signs=edges_signs, |
| q_shift=q_shift, |
| slice_out=slice_out, |
| slice_indices=slice_indices, |
| in_size=in_size, |
| out_size=num_edges, |
| width_size=hidden_layer_size, |
| depth=hidden_layer_num, |
| activation=get_activation_fn(activation_name), |
| key=key, |
| ) |
|
|
|
|
| def build_variational_formfinder(mesh, key, params, generator): |
| """Build a VAE formfinder: variational encoder + FDM decoder. |
| |
| The FDM decoder is identical to the deterministic formfinder. |
| Only the encoder changes to produce distribution parameters. |
| |
| Reference: Pastrana et al. (ICLR 2025), Section 6.1. |
| """ |
| from neural_fdm.variational import VariationalAutoEncoder |
|
|
| nn_params, fd_params = params |
| k1, k2 = jax.random.split(key) |
| encoder = build_variational_encoder(mesh, k1, nn_params, generator) |
| decoder = build_fd_decoder(mesh, fd_params) |
| return VariationalAutoEncoder(encoder, decoder) |
|
|
|
|
| |
| |
| |
|
|
| def build_neural_formfinder(mesh, key, params, generator): |
| """ |
| Instantiate an autoencoder model with a neural encoder and a mechanical decoder. |
| |
| Parameters |
| ---------- |
| mesh: `jax_fdm.FDMesh` |
| The mesh. |
| key: `jax.random.PRNGKey` |
| The random key. |
| params: `dict` |
| The hyperparameters for the model. |
| generator: `PointGenerator` |
| The generator. |
| |
| Returns |
| ------- |
| model: `eqx.Module` |
| The autoencoder model. |
| """ |
| |
| nn_params, fd_params = params |
|
|
| |
| encoder = build_neural_encoder(mesh, key, nn_params, generator) |
|
|
| |
| decoder = build_fd_decoder(mesh, fd_params) |
|
|
| |
| model = AutoEncoder(encoder, decoder) |
|
|
| return model |
|
|
|
|
| def build_neural_autoencoder(mesh, key, params, generator): |
| """ |
| Instantiate a fully neural autoencoder model. |
| |
| Parameters |
| ---------- |
| mesh: `jax_fdm.FDMesh` |
| The mesh. |
| key: `jax.random.PRNGKey` |
| The random key. |
| params: `dict` |
| The hyperparameters for the model. |
| generator: `PointGenerator` |
| The generator. |
| |
| Returns |
| ------- |
| model: `eqx.Module` |
| The autoencoder model. |
| """ |
| |
| enc_params, dec_params = params |
|
|
| |
| encoder = build_neural_encoder(mesh, key, enc_params, generator) |
|
|
| |
| decoder = build_neural_decoder(mesh, key, dec_params) |
|
|
| |
| model = AutoEncoder(encoder, decoder) |
|
|
| return model |
|
|
|
|
| def build_neural_autoencoder_piggy(mesh, key, params, generator): |
| """ |
| Instantiate an autoencoder model with a piggybacking neural decoder. |
| |
| Parameters |
| ---------- |
| mesh: `jax_fdm.FDMesh` |
| The mesh. |
| key: `jax.random.PRNGKey` |
| The random key. |
| params: `dict` |
| The hyperparameters for the model. |
| generator: `PointGenerator` |
| The generator. |
| |
| Returns |
| ------- |
| model: `eqx.Module` |
| The autoencoder model. |
| """ |
| |
|
|
| enc_params, dec_params, fd_params = params |
|
|
| |
| encoder = build_neural_encoder(mesh, key, enc_params, generator) |
|
|
| |
| decoder = build_fd_decoder(mesh, fd_params) |
|
|
| |
| decoder_piggy = build_neural_decoder(mesh, key, dec_params) |
|
|
| |
| model = AutoEncoderPiggy(encoder, decoder, decoder_piggy) |
|
|
| return model |
|
|
|
|
| def build_neural_model(name, config, generator, model_key): |
| """ |
| Build a neural model. |
| |
| Parameters |
| ---------- |
| name: `str` |
| The name of the model. |
| config: `dict` |
| The configuration for the model. |
| generator: `PointGenerator` |
| The generator. |
| model_key: `jax.random.PRNGKey` |
| The random key. |
| |
| Returns |
| ------- |
| model: `eqx.Module` |
| The autoencoder model. |
| """ |
| |
| mesh = build_mesh_from_generator(config, generator) |
|
|
| |
| fd_params = config["fdm"] |
| decoder_params = config["decoder"] |
| encoder_params = config["encoder"] |
| generator_params = config["generator"] |
|
|
| encoder_params = (encoder_params, generator_params) |
|
|
| |
| if name == "formfinder": |
| build_fn = build_neural_formfinder |
| params = (encoder_params, fd_params) |
| elif name == "autoencoder": |
| build_fn = build_neural_autoencoder |
| params = (encoder_params, (decoder_params, fd_params)) |
| elif name == "piggy": |
| build_fn = build_neural_autoencoder_piggy |
| params = (encoder_params, (decoder_params, fd_params), fd_params) |
| elif name == "variational_formfinder": |
| build_fn = build_variational_formfinder |
| params = (encoder_params, fd_params) |
| else: |
| raise ValueError(f"Model name {name} is unsupported") |
|
|
| return build_fn(mesh, model_key, params, generator) |
|
|