"""Form-finding helpers: FDM equilibrium computations. Implements the Force Density Method (FDM) equilibrium equations for pin-jointed bar systems. References ---------- [1] Schek, H.J. (1974). The Force Density Method for form finding and computation of general networks. CMAME, 3(1):115-134. [2] Pastrana, R. et al. (2025). Neural FDM. ICLR 2025. Eq. 1, 8. """ import jax.numpy as jnp from jax_fdm.equilibrium import EquilibriumParametersState as FDParametersState from jax_fdm.equilibrium import EquilibriumState, LoadState, nodes_load_from_faces # =============================================================================== # Load helpers # =============================================================================== def calculate_area_loads(x, structure, load): """ Convert area loads into vertex loads. Parameters ---------- x: `jax.Array` The 3D coordinates of the vertices. structure: `jax_fdm.EquilibriumStructure` A structure with the discretization of the shape. load: `float` The vertical load per unit area in the `z` direction. Returns ------- vertices_load: `jax.Array` The 3D vertex loads. """ x = jnp.reshape(x, (-1, 3)) # need to convert loads into face loads num_faces = structure.num_faces faces_load_xy = jnp.zeros(shape=(num_faces, 2)) # (num_faces, xy) faces_load_z = jnp.ones(shape=(num_faces, 1)) * load # (num_faces, xy) faces_load = jnp.hstack((faces_load_xy, faces_load_z)) vertices_load = nodes_load_from_faces( x, faces_load, structure, is_local=False ) return vertices_load def calculate_constant_loads(x, structure, load): """ Create constant vertical vertex loads. Parameters ---------- x: `jax.Array` The 3D coordinates of the vertices. structure: `jax_fdm.EquilibriumStructure` A structure with the discretization of the shape. load: `float` The vertical load per vertex in the `z` direction. Returns ------- vertices_load: `jax.Array` The 3D vertex loads. """ num_vertices = structure.num_vertices # (num_vertices, xy) vertices_load_xy = jnp.zeros(shape=(num_vertices, 2)) # (num_vertices, xy) vertices_load_z = jnp.ones(shape=(num_vertices, 1)) * load return jnp.hstack((vertices_load_xy, vertices_load_z)) # =============================================================================== # Form-finding helpers # =============================================================================== def edges_vectors(xyz, connectivity): """ Calculate the unnormalized edge directions (nodal coordinate differences). Parameters ---------- xyz: `jax.Array` The 3D coordinates of the vertices. connectivity: `jax.Array` The connectivity matrix of the structure. Returns ------- vectors: `jax.Array` The edge vectors. """ return connectivity @ xyz def edges_lengths(vectors): """ Compute the length of the edge vectors. Parameters ---------- vectors: `jax.Array` The edge vectors. Returns ------- lengths: `jax.Array` The lengths. """ return jnp.linalg.norm(vectors, axis=1, keepdims=True) def edges_forces(q, lengths): """ Calculate the force in the edges. Parameters ---------- q: `jax.Array` The force densities. lengths: `jax.Array` The edge lengths. Returns ------- forces: `jax.Array` The forces in the edges. """ return jnp.reshape(q, (-1, 1)) * lengths def vertices_residuals(q, loads, vectors, connectivity): """ Compute the residual forces on the vertices of the structure. Parameters ---------- q: `jax.Array` The force densities. loads: `jax.Array` The loads on the vertices. vectors: `jax.Array` The edge vectors. connectivity: `jax.Array` The connectivity matrix of the structure. Returns ------- residuals: `jax.Array` The residual forces on the vertices. """ return loads - connectivity.T @ (q[:, None] * vectors) def vertices_residuals_from_xyz(q, loads, xyz, structure): """ Compute the residual forces on the vertices of the structure. Parameters ---------- q: `jax.Array` The force densities. loads: `jax.Array` The loads on the vertices. xyz: `jax.Array` The 3D coordinates of the vertices. structure: `jax_fdm.EquilibriumStructure` A structure with the discretization of the shape. Returns ------- residuals: `jax.Array` The residual forces on the vertices. """ connectivity = structure.connectivity xyz = jnp.reshape(xyz, (-1, 3)) vectors = edges_vectors(xyz, connectivity) return vertices_residuals(q, loads, vectors, connectivity) def calculate_equilibrium_state(q, xyz, loads_nodes, structure): """ Assembles an equilibrium state object. Parameters ---------- q: `jax.Array` The force densities. xyz: `jax.Array` The 3D coordinates of the vertices. loads_nodes: `jax.Array` The loads on the vertices. structure: `jax_fdm.EquilibriumStructure` A structure with the discretization of the shape. Returns ------- state: `jax_fdm.EquilibriumState` The equilibrium state. """ connectivity = structure.connectivity vectors = edges_vectors(xyz, connectivity) lengths = edges_lengths(vectors) residuals = vertices_residuals(q, loads_nodes, vectors, connectivity) forces = edges_forces(q, lengths) return EquilibriumState( xyz=xyz, residuals=residuals, lengths=lengths, forces=forces, loads=loads_nodes, vectors=vectors ) def calculate_fd_params_state(q, xyz_fixed, loads_nodes): """ Assembles an simulation parameters state. Parameters ---------- q: `jax.Array` The force densities. xyz_fixed: `jax.Array` The 3D coordinates of the fixed vertices. loads_nodes: `jax.Array` The loads on the vertices. Returns ------- state: `jax_fdm.EquilibriumParametersState` The current state of the simulation parameters. """ return FDParametersState(q, xyz_fixed, LoadState(loads_nodes, 0.0, 0.0)) # ============================================================================= # Reaction forces at supports # ============================================================================= def compute_reactions(q, loads, xyz, structure): """ Compute reaction forces at support (fixed) nodes. In equilibrium, the residual at free nodes is zero. At fixed nodes, the residual represents the reaction force the support must provide: R_i = sum_j K_ij (x_j - x_i) * q_j - P_i This is the negative of the residual at fixed nodes. Parameters ---------- q : jax.Array (E,) Force densities per edge. loads : jax.Array (N, 3) Applied loads at all vertices. xyz : jax.Array (N, 3) Vertex positions. structure : EquilibriumStructure The mesh structure with connectivity. Returns ------- reactions : jax.Array (N_fixed, 3) Reaction force vectors at each support node. indices_fixed : jax.Array (N_fixed,) Indices of the fixed (support) nodes. """ # Compute residuals at ALL nodes (not just free) residuals_all = vertices_residuals_from_xyz(q, loads, xyz, structure) # Reactions are the negative residuals at fixed nodes # At equilibrium, free node residuals are ~0 # At fixed nodes, residuals represent the imbalance that supports carry indices_fixed = structure.indices_fixed reactions = -residuals_all[indices_fixed] return reactions, indices_fixed def compute_total_reactions(reactions): """ Compute total reaction forces (sum over all supports). Parameters ---------- reactions : jax.Array (N_fixed, 3) Reaction force vectors per support. Returns ------- total : jax.Array (3,) Sum of all reactions [Rx, Ry, Rz]. """ return jnp.sum(reactions, axis=0) def compute_l_physics(x_hat, q, loads, structure, free_indices=None): """Compute L_physics: Euclidean norm (L2) of residual forces at free nodes. Matches the paper's physics loss definition (losses.py compute_error_residual): sqrt(sum(residual_vectors_free^2)). This is the standard metric used throughout all benchmarks for consistency. Parameters ---------- x_hat : jax.Array Predicted vertex positions (flat or (N,3)). q : jax.Array Force densities per edge. loads : jax.Array Applied loads per vertex (N,3). structure : EquilibriumStructure Mesh connectivity structure. free_indices : array-like, optional Indices of free (non-support) nodes. If None, uses all nodes. Returns ------- l_physics : float Euclidean norm of residual forces at free nodes. """ xyz = jnp.reshape(x_hat, (-1, 3)) residuals = vertices_residuals_from_xyz(q, loads, xyz, structure) if free_indices is not None: residuals = residuals[jnp.array(free_indices), :] return float(jnp.sqrt(jnp.sum(jnp.square(residuals))))