File size: 4,314 Bytes
fc7d689
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
"""Graph data structures for GNN-based form-finding."""

from __future__ import annotations

from typing import NamedTuple

import jax.numpy as jnp
from jaxtyping import Array, Float, Int

# ===============================================================================
# Graph data container
# ===============================================================================

class GraphData(NamedTuple):
    """Lightweight graph container for a single mesh graph.

    Parameters
    ----------
    node_features : Array
        Vertex positions of shape ``(N, 3)``.
    edge_index : Array
        Sender/receiver indices in COO format of shape ``(2, E)``.
    num_nodes : int
        Number of nodes (vertices) in the graph.
    num_edges : int
        Number of directed edges in the graph.
    """
    node_features: Float[Array, "N 3"]
    edge_index: Int[Array, "2 E"]
    num_nodes: int
    num_edges: int


# ===============================================================================
# Construction helpers
# ===============================================================================

def structure_to_graph(
    structure,
    xyz_flat: Float[Array, "N3"],
) -> GraphData:
    """Convert an EquilibriumMeshStructure and flat xyz array to a GraphData.

    Parameters
    ----------
    structure : EquilibriumMeshStructure
        The equilibrium structure (carries topology information).
    xyz_flat : Array
        Flat vertex positions of length ``num_vertices * 3``.

    Returns
    -------
    GraphData
        A graph whose nodes are the mesh vertices and whose edges follow the
        mesh connectivity.
    """
    node_features = jnp.reshape(xyz_flat, (-1, 3))
    num_nodes = int(node_features.shape[0])

    # The structure stores a connectivity matrix; extract the edge list from it.
    # ``structure.connectivity`` is a (num_edges, num_nodes) matrix where each
    # row has exactly one +1 (sender) and one -1 (receiver).
    connectivity = jnp.array(structure.connectivity)
    num_edges = int(connectivity.shape[0])

    senders = jnp.argmax(connectivity, axis=1)       # +1 entries
    receivers = jnp.argmax(-connectivity, axis=1)     # -1 entries

    edge_index = jnp.stack([senders, receivers], axis=0)  # (2, E)

    return GraphData(
        node_features=node_features,
        edge_index=edge_index,
        num_nodes=num_nodes,
        num_edges=num_edges,
    )


def edge_index_from_mesh(mesh) -> Int[Array, "2 E"]:
    """Extract a COO edge index from an FDMesh.

    Iterates ``mesh.edges()`` to collect ``(u, v)`` pairs and returns a JAX
    array of shape ``(2, num_edges)``.

    Parameters
    ----------
    mesh : FDMesh
        A COMPAS / jax_fdm mesh datastructure.

    Returns
    -------
    edge_index : Array
        Integer array of shape ``(2, num_edges)``.
    """
    senders = []
    receivers = []
    for u, v in mesh.edges():
        senders.append(u)
        receivers.append(v)

    return jnp.array([senders, receivers], dtype=jnp.int32)


# ===============================================================================
# Edge feature computation
# ===============================================================================

def compute_edge_features(
    node_features: Float[Array, "N 3"],
    edge_index: Int[Array, "2 E"],
) -> tuple[Float[Array, "E 3"], Float[Array, "E 1"]]:
    """Compute edge features from node positions and edge connectivity.

    For each edge ``(i, j)`` the features are:

    * **relative_pos** -- ``node_features[j] - node_features[i]``  (shape ``(E, 3)``)
    * **distance** -- Euclidean length of the relative position vector  (shape ``(E, 1)``)

    Parameters
    ----------
    node_features : Array
        Vertex positions of shape ``(N, 3)``.
    edge_index : Array
        COO edge index of shape ``(2, E)``.

    Returns
    -------
    relative_positions : Array
        Relative position vectors of shape ``(E, 3)``.
    distances : Array
        Euclidean distances of shape ``(E, 1)``.
    """
    senders = edge_index[0]    # (E,)
    receivers = edge_index[1]  # (E,)

    relative_positions = node_features[receivers] - node_features[senders]  # (E, 3)
    distances = jnp.linalg.norm(relative_positions, axis=-1, keepdims=True)  # (E, 1)

    return relative_positions, distances