File size: 4,314 Bytes
fc7d689 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 | """Graph data structures for GNN-based form-finding."""
from __future__ import annotations
from typing import NamedTuple
import jax.numpy as jnp
from jaxtyping import Array, Float, Int
# ===============================================================================
# Graph data container
# ===============================================================================
class GraphData(NamedTuple):
"""Lightweight graph container for a single mesh graph.
Parameters
----------
node_features : Array
Vertex positions of shape ``(N, 3)``.
edge_index : Array
Sender/receiver indices in COO format of shape ``(2, E)``.
num_nodes : int
Number of nodes (vertices) in the graph.
num_edges : int
Number of directed edges in the graph.
"""
node_features: Float[Array, "N 3"]
edge_index: Int[Array, "2 E"]
num_nodes: int
num_edges: int
# ===============================================================================
# Construction helpers
# ===============================================================================
def structure_to_graph(
structure,
xyz_flat: Float[Array, "N3"],
) -> GraphData:
"""Convert an EquilibriumMeshStructure and flat xyz array to a GraphData.
Parameters
----------
structure : EquilibriumMeshStructure
The equilibrium structure (carries topology information).
xyz_flat : Array
Flat vertex positions of length ``num_vertices * 3``.
Returns
-------
GraphData
A graph whose nodes are the mesh vertices and whose edges follow the
mesh connectivity.
"""
node_features = jnp.reshape(xyz_flat, (-1, 3))
num_nodes = int(node_features.shape[0])
# The structure stores a connectivity matrix; extract the edge list from it.
# ``structure.connectivity`` is a (num_edges, num_nodes) matrix where each
# row has exactly one +1 (sender) and one -1 (receiver).
connectivity = jnp.array(structure.connectivity)
num_edges = int(connectivity.shape[0])
senders = jnp.argmax(connectivity, axis=1) # +1 entries
receivers = jnp.argmax(-connectivity, axis=1) # -1 entries
edge_index = jnp.stack([senders, receivers], axis=0) # (2, E)
return GraphData(
node_features=node_features,
edge_index=edge_index,
num_nodes=num_nodes,
num_edges=num_edges,
)
def edge_index_from_mesh(mesh) -> Int[Array, "2 E"]:
"""Extract a COO edge index from an FDMesh.
Iterates ``mesh.edges()`` to collect ``(u, v)`` pairs and returns a JAX
array of shape ``(2, num_edges)``.
Parameters
----------
mesh : FDMesh
A COMPAS / jax_fdm mesh datastructure.
Returns
-------
edge_index : Array
Integer array of shape ``(2, num_edges)``.
"""
senders = []
receivers = []
for u, v in mesh.edges():
senders.append(u)
receivers.append(v)
return jnp.array([senders, receivers], dtype=jnp.int32)
# ===============================================================================
# Edge feature computation
# ===============================================================================
def compute_edge_features(
node_features: Float[Array, "N 3"],
edge_index: Int[Array, "2 E"],
) -> tuple[Float[Array, "E 3"], Float[Array, "E 1"]]:
"""Compute edge features from node positions and edge connectivity.
For each edge ``(i, j)`` the features are:
* **relative_pos** -- ``node_features[j] - node_features[i]`` (shape ``(E, 3)``)
* **distance** -- Euclidean length of the relative position vector (shape ``(E, 1)``)
Parameters
----------
node_features : Array
Vertex positions of shape ``(N, 3)``.
edge_index : Array
COO edge index of shape ``(2, E)``.
Returns
-------
relative_positions : Array
Relative position vectors of shape ``(E, 3)``.
distances : Array
Euclidean distances of shape ``(E, 1)``.
"""
senders = edge_index[0] # (E,)
receivers = edge_index[1] # (E,)
relative_positions = node_features[receivers] - node_features[senders] # (E, 3)
distances = jnp.linalg.norm(relative_positions, axis=-1, keepdims=True) # (E, 1)
return relative_positions, distances
|