| """Graph data structures for GNN-based form-finding.""" |
|
|
| from __future__ import annotations |
|
|
| from typing import NamedTuple |
|
|
| import jax.numpy as jnp |
| from jaxtyping import Array, Float, Int |
|
|
| |
| |
| |
|
|
| class GraphData(NamedTuple): |
| """Lightweight graph container for a single mesh graph. |
| |
| Parameters |
| ---------- |
| node_features : Array |
| Vertex positions of shape ``(N, 3)``. |
| edge_index : Array |
| Sender/receiver indices in COO format of shape ``(2, E)``. |
| num_nodes : int |
| Number of nodes (vertices) in the graph. |
| num_edges : int |
| Number of directed edges in the graph. |
| """ |
| node_features: Float[Array, "N 3"] |
| edge_index: Int[Array, "2 E"] |
| num_nodes: int |
| num_edges: int |
|
|
|
|
| |
| |
| |
|
|
| def structure_to_graph( |
| structure, |
| xyz_flat: Float[Array, "N3"], |
| ) -> GraphData: |
| """Convert an EquilibriumMeshStructure and flat xyz array to a GraphData. |
| |
| Parameters |
| ---------- |
| structure : EquilibriumMeshStructure |
| The equilibrium structure (carries topology information). |
| xyz_flat : Array |
| Flat vertex positions of length ``num_vertices * 3``. |
| |
| Returns |
| ------- |
| GraphData |
| A graph whose nodes are the mesh vertices and whose edges follow the |
| mesh connectivity. |
| """ |
| node_features = jnp.reshape(xyz_flat, (-1, 3)) |
| num_nodes = int(node_features.shape[0]) |
|
|
| |
| |
| |
| connectivity = jnp.array(structure.connectivity) |
| num_edges = int(connectivity.shape[0]) |
|
|
| senders = jnp.argmax(connectivity, axis=1) |
| receivers = jnp.argmax(-connectivity, axis=1) |
|
|
| edge_index = jnp.stack([senders, receivers], axis=0) |
|
|
| return GraphData( |
| node_features=node_features, |
| edge_index=edge_index, |
| num_nodes=num_nodes, |
| num_edges=num_edges, |
| ) |
|
|
|
|
| def edge_index_from_mesh(mesh) -> Int[Array, "2 E"]: |
| """Extract a COO edge index from an FDMesh. |
| |
| Iterates ``mesh.edges()`` to collect ``(u, v)`` pairs and returns a JAX |
| array of shape ``(2, num_edges)``. |
| |
| Parameters |
| ---------- |
| mesh : FDMesh |
| A COMPAS / jax_fdm mesh datastructure. |
| |
| Returns |
| ------- |
| edge_index : Array |
| Integer array of shape ``(2, num_edges)``. |
| """ |
| senders = [] |
| receivers = [] |
| for u, v in mesh.edges(): |
| senders.append(u) |
| receivers.append(v) |
|
|
| return jnp.array([senders, receivers], dtype=jnp.int32) |
|
|
|
|
| |
| |
| |
|
|
| def compute_edge_features( |
| node_features: Float[Array, "N 3"], |
| edge_index: Int[Array, "2 E"], |
| ) -> tuple[Float[Array, "E 3"], Float[Array, "E 1"]]: |
| """Compute edge features from node positions and edge connectivity. |
| |
| For each edge ``(i, j)`` the features are: |
| |
| * **relative_pos** -- ``node_features[j] - node_features[i]`` (shape ``(E, 3)``) |
| * **distance** -- Euclidean length of the relative position vector (shape ``(E, 1)``) |
| |
| Parameters |
| ---------- |
| node_features : Array |
| Vertex positions of shape ``(N, 3)``. |
| edge_index : Array |
| COO edge index of shape ``(2, E)``. |
| |
| Returns |
| ------- |
| relative_positions : Array |
| Relative position vectors of shape ``(E, 3)``. |
| distances : Array |
| Euclidean distances of shape ``(E, 1)``. |
| """ |
| senders = edge_index[0] |
| receivers = edge_index[1] |
|
|
| relative_positions = node_features[receivers] - node_features[senders] |
| distances = jnp.linalg.norm(relative_positions, axis=-1, keepdims=True) |
|
|
| return relative_positions, distances |
|
|