File size: 4,371 Bytes
fc7d689
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
"""Pure-numpy API for external tool integration.

This module provides a clean interface that accepts and returns only
numpy arrays and Python dicts, with no JAX objects in the public API.

Note: The mesh topology, loads, and boundary conditions are defined
by the YAML config file, not by the caller. The vertices argument
provides the target shape that the model maps to force densities.

Example
-------
>>> from neural_fdm.interop import predict_equilibrium
>>> result = predict_equilibrium(
...     vertices=target_xyz,  # (N, 3) target positions
...     model_path="data/formfinder_bezier.eqx",
...     config_path="scripts/bezier.yml",
... )
>>> result["vertices"]      # (N, 3) predicted equilibrium positions
>>> result["force_densities"]  # (E,) per-edge force densities
"""

from __future__ import annotations

import numpy as np


def predict_equilibrium(
    vertices: np.ndarray,
    model_path: str,
    config_path: str,
    model_name: str | None = None,
) -> dict[str, np.ndarray]:
    """Predict equilibrium shape using a trained neural FDM model.

    The mesh topology and loads are determined by the config file.
    The vertices provide the target shape for the encoder.
    Model type is auto-detected from config (VAE if loss.vae section
    exists) or can be overridden via model_name.

    Parameters
    ----------
    vertices : ndarray (N, 3)
        Target vertex positions.
    model_path : str
        Path to trained model file (.eqx).
    config_path : str
        Path to YAML configuration file.
    model_name : str, optional
        Model type ("formfinder" or "variational_formfinder").
        Auto-detected from config if not specified.

    Returns
    -------
    result : dict
        Dictionary with keys:
        - "vertices": ndarray (N, 3) - predicted equilibrium positions
        - "force_densities": ndarray (E,) - force density per edge
        - "forces": ndarray (E,) - axial force per edge
        - "lengths": ndarray (E,) - member lengths
        - "residuals": ndarray (N, 3) - force residuals at vertices
        - "inference_time_ms": float - prediction time in milliseconds
    """
    import time

    import jax.numpy as jnp
    import jax.random as jrn
    import yaml

    from neural_fdm.builders import (
        build_connectivity_structure_from_generator,
        build_data_generator,
        build_neural_model,
    )
    from neural_fdm.helpers import (
        edges_lengths,
        edges_vectors,
        vertices_residuals_from_xyz,
    )
    from neural_fdm.serialization import load_model

    # Load config
    with open(config_path) as f:
        config = yaml.load(f, Loader=yaml.FullLoader)

    # Build infrastructure from config
    key = jrn.PRNGKey(config.get("seed", 0))
    generator = build_data_generator(config)
    structure = build_connectivity_structure_from_generator(config, generator)

    # Auto-detect model type from config, or use explicit override
    if model_name is None:
        is_vae = "vae" in config.get("loss", {})
        model_name = "variational_formfinder" if is_vae else "formfinder"

    model_skeleton = build_neural_model(model_name, config, generator, key)
    model = load_model(model_path, model_skeleton)

    # Prepare input
    xyz_flat = jnp.array(vertices.flatten())

    # Predict
    t0 = time.perf_counter()
    x_hat, aux = model(xyz_flat, structure, aux_data=True)
    x_hat.block_until_ready()
    t1 = time.perf_counter()

    # Unpack aux_data: VAE returns ((q, xyz_fixed, loads), mu, log_sigma)
    # Deterministic returns (q, xyz_fixed, loads)
    from neural_fdm.variational import VariationalAutoEncoder

    if isinstance(model, VariationalAutoEncoder):
        (q, xyz_fixed, loads_jax), _mu, _log_sigma = aux
    else:
        q, xyz_fixed, loads_jax = aux

    # Post-process
    xyz_pred = jnp.reshape(x_hat, (-1, 3))
    vectors = edges_vectors(xyz_pred, structure.connectivity)
    lengths_arr = edges_lengths(vectors)
    forces_arr = q * jnp.ravel(lengths_arr)
    residuals = vertices_residuals_from_xyz(q, loads_jax, xyz_pred, structure)

    return {
        "vertices": np.array(xyz_pred),
        "force_densities": np.array(q),
        "forces": np.array(forces_arr),
        "lengths": np.array(jnp.ravel(lengths_arr)),
        "residuals": np.array(residuals),
        "inference_time_ms": (t1 - t0) * 1000,
    }