Efradeca's picture
Upload folder using huggingface_hub
fc7d689 verified
"""Pure-numpy API for external tool integration.
This module provides a clean interface that accepts and returns only
numpy arrays and Python dicts, with no JAX objects in the public API.
Note: The mesh topology, loads, and boundary conditions are defined
by the YAML config file, not by the caller. The vertices argument
provides the target shape that the model maps to force densities.
Example
-------
>>> from neural_fdm.interop import predict_equilibrium
>>> result = predict_equilibrium(
... vertices=target_xyz, # (N, 3) target positions
... model_path="data/formfinder_bezier.eqx",
... config_path="scripts/bezier.yml",
... )
>>> result["vertices"] # (N, 3) predicted equilibrium positions
>>> result["force_densities"] # (E,) per-edge force densities
"""
from __future__ import annotations
import numpy as np
def predict_equilibrium(
vertices: np.ndarray,
model_path: str,
config_path: str,
model_name: str | None = None,
) -> dict[str, np.ndarray]:
"""Predict equilibrium shape using a trained neural FDM model.
The mesh topology and loads are determined by the config file.
The vertices provide the target shape for the encoder.
Model type is auto-detected from config (VAE if loss.vae section
exists) or can be overridden via model_name.
Parameters
----------
vertices : ndarray (N, 3)
Target vertex positions.
model_path : str
Path to trained model file (.eqx).
config_path : str
Path to YAML configuration file.
model_name : str, optional
Model type ("formfinder" or "variational_formfinder").
Auto-detected from config if not specified.
Returns
-------
result : dict
Dictionary with keys:
- "vertices": ndarray (N, 3) - predicted equilibrium positions
- "force_densities": ndarray (E,) - force density per edge
- "forces": ndarray (E,) - axial force per edge
- "lengths": ndarray (E,) - member lengths
- "residuals": ndarray (N, 3) - force residuals at vertices
- "inference_time_ms": float - prediction time in milliseconds
"""
import time
import jax.numpy as jnp
import jax.random as jrn
import yaml
from neural_fdm.builders import (
build_connectivity_structure_from_generator,
build_data_generator,
build_neural_model,
)
from neural_fdm.helpers import (
edges_lengths,
edges_vectors,
vertices_residuals_from_xyz,
)
from neural_fdm.serialization import load_model
# Load config
with open(config_path) as f:
config = yaml.load(f, Loader=yaml.FullLoader)
# Build infrastructure from config
key = jrn.PRNGKey(config.get("seed", 0))
generator = build_data_generator(config)
structure = build_connectivity_structure_from_generator(config, generator)
# Auto-detect model type from config, or use explicit override
if model_name is None:
is_vae = "vae" in config.get("loss", {})
model_name = "variational_formfinder" if is_vae else "formfinder"
model_skeleton = build_neural_model(model_name, config, generator, key)
model = load_model(model_path, model_skeleton)
# Prepare input
xyz_flat = jnp.array(vertices.flatten())
# Predict
t0 = time.perf_counter()
x_hat, aux = model(xyz_flat, structure, aux_data=True)
x_hat.block_until_ready()
t1 = time.perf_counter()
# Unpack aux_data: VAE returns ((q, xyz_fixed, loads), mu, log_sigma)
# Deterministic returns (q, xyz_fixed, loads)
from neural_fdm.variational import VariationalAutoEncoder
if isinstance(model, VariationalAutoEncoder):
(q, xyz_fixed, loads_jax), _mu, _log_sigma = aux
else:
q, xyz_fixed, loads_jax = aux
# Post-process
xyz_pred = jnp.reshape(x_hat, (-1, 3))
vectors = edges_vectors(xyz_pred, structure.connectivity)
lengths_arr = edges_lengths(vectors)
forces_arr = q * jnp.ravel(lengths_arr)
residuals = vertices_residuals_from_xyz(q, loads_jax, xyz_pred, structure)
return {
"vertices": np.array(xyz_pred),
"force_densities": np.array(q),
"forces": np.array(forces_arr),
"lengths": np.array(jnp.ravel(lengths_arr)),
"residuals": np.array(residuals),
"inference_time_ms": (t1 - t0) * 1000,
}