File size: 5,853 Bytes
fc7d689 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 | """Classical baseline: direct optimization of force densities.
Solves the inverse form-finding problem by optimizing force densities q
to minimize shape error, using the same differentiable FDM decoder as
the neural approach. This provides the classical comparison baseline
for evaluating neural encoder performance.
The optimization uses bounded L-BFGS-B (via jaxopt) with box constraints
that respect the compression/tension signs of each edge.
References
----------
[1] Schek, H.J. (1974). The force density method for form finding and
computation of general networks. CMAME, 3(1):115-134.
[2] Pastrana, R. et al. (2025). Neural FDM. ICLR 2025. Section 4.
"""
from __future__ import annotations
from time import perf_counter
import equinox as eqx
import jax
import jax.numpy as jnp
import jax.random as jrn
from jaxopt import ScipyBoundedMinimize
from neural_fdm.builders import (
build_fd_decoder_parametrized,
build_loss_function,
)
from neural_fdm.helpers import compute_l_physics
def _build_bounds(mesh, q0, qmin, qmax):
"""Construct box constraints respecting edge stress signs."""
bound_low, bound_up = [], []
for edge in mesh.edges():
if mesh.edge_attribute(edge, "tag") == "cable":
bound_low.append(qmin)
bound_up.append(qmax)
else:
bound_low.append(-qmax)
bound_up.append(-qmin)
return jnp.array(bound_low), jnp.array(bound_up)
def _init_q(mesh, key, qmin=0.1, qmax=10.0, initial_q=None):
"""Initialize force densities with correct signs."""
num_edges = mesh.number_of_edges()
signs = []
for edge in mesh.edges():
sign = -1.0
if mesh.edge_attribute(edge, "tag") == "cable":
sign = 1.0
signs.append(sign)
signs = jnp.array(signs)
if initial_q is not None:
return initial_q
return jrn.uniform(key, shape=(num_edges,), minval=qmin, maxval=qmax) * signs
def solve_classical(
xyz_target,
config,
mesh,
structure,
method="L-BFGS-B",
maxiter=2000,
tol=1e-6,
initial_q=None,
key=None,
qmin=0.1,
qmax=10.0,
):
"""Solve form-finding via direct optimization of force densities.
Parameters
----------
xyz_target : jax.Array
Target vertex positions (flat, shape N*3).
config : dict
YAML configuration (same as used for neural training).
mesh : FDMesh
The FDM mesh.
structure : EquilibriumMeshStructure
The connectivity structure.
method : str
Scipy optimizer method ("L-BFGS-B" or "SLSQP").
maxiter : int
Maximum optimization iterations.
tol : float
Convergence tolerance.
initial_q : jax.Array, optional
Initial force densities (warm-start). If None, random init.
key : PRNGKey, optional
Random key for q initialization. Required if initial_q is None.
qmin : float
Minimum absolute q value for bounds and init.
qmax : float
Maximum absolute q value for bounds and init.
Returns
-------
result : dict
q_opt, xyz_opt, l_shape, l_physics, loss_total,
success, n_iters, runtime_ms
"""
if key is None and initial_q is None:
key = jrn.PRNGKey(0)
fd_params = config["fdm"]
q0 = _init_q(mesh, key, qmin, qmax, initial_q)
# Build parametrized decoder (q is the only trainable parameter)
decoder = build_fd_decoder_parametrized(q0, mesh, fd_params)
# Build loss function (same as neural training)
from neural_fdm.builders import build_data_generator
generator = build_data_generator(config)
compute_loss = build_loss_function(config, generator)
# Split: only q is trainable, everything else is static
filter_spec = jax.tree_util.tree_map(lambda _: False, decoder)
filter_spec = eqx.tree_at(lambda t: t.q, filter_spec, replace=True)
diff_model, static_model = eqx.partition(decoder, filter_spec)
# Loss wrapper for jaxopt — only q is differentiated
@eqx.filter_jit
@eqx.filter_value_and_grad
def loss_fn(diff_model, xyz):
model = eqx.combine(diff_model, static_model)
return compute_loss(model, structure, xyz, aux_data=False)
# Bounds (only for q)
bound_low, bound_up = _build_bounds(mesh, q0, qmin, qmax)
bl_tree = eqx.tree_at(lambda t: t.q, diff_model, replace=bound_low)
bu_tree = eqx.tree_at(lambda t: t.q, diff_model, replace=bound_up)
bounds = (bl_tree, bu_tree)
# Optimizer
opt = ScipyBoundedMinimize(
fun=loss_fn,
method=method,
jit=True,
tol=tol,
maxiter=maxiter,
options={"disp": False},
value_and_grad=True,
)
# JIT warmup
xyz_in = xyz_target[None, :]
_ = loss_fn(diff_model, xyz_in)
# Solve
t0 = perf_counter()
diff_opt, opt_res = opt.run(diff_model, bounds, xyz_in)
runtime_ms = (perf_counter() - t0) * 1000
# Extract results
model_opt = eqx.combine(diff_opt, static_model)
x_hat, (q_opt, xyz_fixed, loads) = model_opt(
xyz_target, structure, aux_data=True
)
# Metrics (same as neural evaluation)
xyz_pred = jnp.reshape(x_hat, (-1, 3))
xyz_tgt = jnp.reshape(xyz_target, (-1, 3))
l_shape = float(jnp.sum(jnp.abs(xyz_pred - xyz_tgt)))
free_indices = sorted(list(mesh.vertices_free()))
l_physics = compute_l_physics(x_hat, q_opt, loads, structure, free_indices)
return {
"q_opt": q_opt,
"xyz_opt": xyz_pred,
"l_shape": l_shape,
"l_physics": l_physics,
"loss_total": float(opt_res.fun_val) if hasattr(opt_res, "fun_val") else l_shape,
"success": bool(opt_res.success) if hasattr(opt_res, "success") else True,
"n_iters": int(opt_res.iter_num) if hasattr(opt_res, "iter_num") else -1,
"runtime_ms": runtime_ms,
}
|