vae-fdm / src /neural_fdm /classical.py
Efradeca's picture
Upload folder using huggingface_hub
fc7d689 verified
"""Classical baseline: direct optimization of force densities.
Solves the inverse form-finding problem by optimizing force densities q
to minimize shape error, using the same differentiable FDM decoder as
the neural approach. This provides the classical comparison baseline
for evaluating neural encoder performance.
The optimization uses bounded L-BFGS-B (via jaxopt) with box constraints
that respect the compression/tension signs of each edge.
References
----------
[1] Schek, H.J. (1974). The force density method for form finding and
computation of general networks. CMAME, 3(1):115-134.
[2] Pastrana, R. et al. (2025). Neural FDM. ICLR 2025. Section 4.
"""
from __future__ import annotations
from time import perf_counter
import equinox as eqx
import jax
import jax.numpy as jnp
import jax.random as jrn
from jaxopt import ScipyBoundedMinimize
from neural_fdm.builders import (
build_fd_decoder_parametrized,
build_loss_function,
)
from neural_fdm.helpers import compute_l_physics
def _build_bounds(mesh, q0, qmin, qmax):
"""Construct box constraints respecting edge stress signs."""
bound_low, bound_up = [], []
for edge in mesh.edges():
if mesh.edge_attribute(edge, "tag") == "cable":
bound_low.append(qmin)
bound_up.append(qmax)
else:
bound_low.append(-qmax)
bound_up.append(-qmin)
return jnp.array(bound_low), jnp.array(bound_up)
def _init_q(mesh, key, qmin=0.1, qmax=10.0, initial_q=None):
"""Initialize force densities with correct signs."""
num_edges = mesh.number_of_edges()
signs = []
for edge in mesh.edges():
sign = -1.0
if mesh.edge_attribute(edge, "tag") == "cable":
sign = 1.0
signs.append(sign)
signs = jnp.array(signs)
if initial_q is not None:
return initial_q
return jrn.uniform(key, shape=(num_edges,), minval=qmin, maxval=qmax) * signs
def solve_classical(
xyz_target,
config,
mesh,
structure,
method="L-BFGS-B",
maxiter=2000,
tol=1e-6,
initial_q=None,
key=None,
qmin=0.1,
qmax=10.0,
):
"""Solve form-finding via direct optimization of force densities.
Parameters
----------
xyz_target : jax.Array
Target vertex positions (flat, shape N*3).
config : dict
YAML configuration (same as used for neural training).
mesh : FDMesh
The FDM mesh.
structure : EquilibriumMeshStructure
The connectivity structure.
method : str
Scipy optimizer method ("L-BFGS-B" or "SLSQP").
maxiter : int
Maximum optimization iterations.
tol : float
Convergence tolerance.
initial_q : jax.Array, optional
Initial force densities (warm-start). If None, random init.
key : PRNGKey, optional
Random key for q initialization. Required if initial_q is None.
qmin : float
Minimum absolute q value for bounds and init.
qmax : float
Maximum absolute q value for bounds and init.
Returns
-------
result : dict
q_opt, xyz_opt, l_shape, l_physics, loss_total,
success, n_iters, runtime_ms
"""
if key is None and initial_q is None:
key = jrn.PRNGKey(0)
fd_params = config["fdm"]
q0 = _init_q(mesh, key, qmin, qmax, initial_q)
# Build parametrized decoder (q is the only trainable parameter)
decoder = build_fd_decoder_parametrized(q0, mesh, fd_params)
# Build loss function (same as neural training)
from neural_fdm.builders import build_data_generator
generator = build_data_generator(config)
compute_loss = build_loss_function(config, generator)
# Split: only q is trainable, everything else is static
filter_spec = jax.tree_util.tree_map(lambda _: False, decoder)
filter_spec = eqx.tree_at(lambda t: t.q, filter_spec, replace=True)
diff_model, static_model = eqx.partition(decoder, filter_spec)
# Loss wrapper for jaxopt — only q is differentiated
@eqx.filter_jit
@eqx.filter_value_and_grad
def loss_fn(diff_model, xyz):
model = eqx.combine(diff_model, static_model)
return compute_loss(model, structure, xyz, aux_data=False)
# Bounds (only for q)
bound_low, bound_up = _build_bounds(mesh, q0, qmin, qmax)
bl_tree = eqx.tree_at(lambda t: t.q, diff_model, replace=bound_low)
bu_tree = eqx.tree_at(lambda t: t.q, diff_model, replace=bound_up)
bounds = (bl_tree, bu_tree)
# Optimizer
opt = ScipyBoundedMinimize(
fun=loss_fn,
method=method,
jit=True,
tol=tol,
maxiter=maxiter,
options={"disp": False},
value_and_grad=True,
)
# JIT warmup
xyz_in = xyz_target[None, :]
_ = loss_fn(diff_model, xyz_in)
# Solve
t0 = perf_counter()
diff_opt, opt_res = opt.run(diff_model, bounds, xyz_in)
runtime_ms = (perf_counter() - t0) * 1000
# Extract results
model_opt = eqx.combine(diff_opt, static_model)
x_hat, (q_opt, xyz_fixed, loads) = model_opt(
xyz_target, structure, aux_data=True
)
# Metrics (same as neural evaluation)
xyz_pred = jnp.reshape(x_hat, (-1, 3))
xyz_tgt = jnp.reshape(xyz_target, (-1, 3))
l_shape = float(jnp.sum(jnp.abs(xyz_pred - xyz_tgt)))
free_indices = sorted(list(mesh.vertices_free()))
l_physics = compute_l_physics(x_hat, q_opt, loads, structure, free_indices)
return {
"q_opt": q_opt,
"xyz_opt": xyz_pred,
"l_shape": l_shape,
"l_physics": l_physics,
"loss_total": float(opt_res.fun_val) if hasattr(opt_res, "fun_val") else l_shape,
"success": bool(opt_res.success) if hasattr(opt_res, "success") else True,
"n_iters": int(opt_res.iter_num) if hasattr(opt_res, "iter_num") else -1,
"runtime_ms": runtime_ms,
}