Spaces:
Sleeping
Sleeping
Rename data_gen.py to practicality_core.py
Browse files- data_gen.py +0 -243
- practicality_core.py +307 -0
data_gen.py
DELETED
|
@@ -1,243 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
data_gen.py — Training / test data for the elastic mesh.
|
| 3 |
-
|
| 4 |
-
Each sample is a triple (A, B, C) where:
|
| 5 |
-
A ∈ ℝ^DIM encodes constraints ("what must be true")
|
| 6 |
-
B ∈ ℝ^DIM encodes objectives ("what we want")
|
| 7 |
-
C ∈ ℝ^DIM is the analytic solution — the feasibility center the mesh must learn to produce
|
| 8 |
-
|
| 9 |
-
Five problem families, each with a geometrically distinct C:
|
| 10 |
-
|
| 11 |
-
1. box_proj — clamp B into axis-aligned box defined by A
|
| 12 |
-
2. halfspace — project B onto hyperplane defined by A
|
| 13 |
-
3. sphere — project B onto sphere surface defined by A
|
| 14 |
-
4. simplex — project B onto probability simplex (A = uniform prior signal)
|
| 15 |
-
5. elastic_bal — per-dimension weighted balance between A-center and B
|
| 16 |
-
|
| 17 |
-
These cover:
|
| 18 |
-
- Bounded feasibility (box)
|
| 19 |
-
- Equality constraints (halfspace)
|
| 20 |
-
- Norm constraints (sphere)
|
| 21 |
-
- Probability/sum=1 (simplex)
|
| 22 |
-
- Soft trade-offs (elastic)
|
| 23 |
-
|
| 24 |
-
The mesh sees ONLY (A, B) during inference; C is what it must reconstruct.
|
| 25 |
-
"""
|
| 26 |
-
|
| 27 |
-
import numpy as np
|
| 28 |
-
import json, pathlib, argparse
|
| 29 |
-
from typing import List, Dict
|
| 30 |
-
|
| 31 |
-
DIM = 32 # embedding dimension (set to 768 for LLM-scale)
|
| 32 |
-
SAMPLES_PER_TYPE = 1000 # × 5 types = 5 000 total
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
# ── UTILITIES ─────────────────────────────────────────────────────────────────
|
| 36 |
-
|
| 37 |
-
def normalize(v: np.ndarray) -> np.ndarray:
|
| 38 |
-
n = np.linalg.norm(v)
|
| 39 |
-
return v / (n + 1e-12)
|
| 40 |
-
|
| 41 |
-
def pack(*arrays: np.ndarray, dim: int) -> np.ndarray:
|
| 42 |
-
"""Concatenate + trim/pad to `dim`."""
|
| 43 |
-
v = np.concatenate(arrays)
|
| 44 |
-
if len(v) >= dim:
|
| 45 |
-
return v[:dim]
|
| 46 |
-
return np.pad(v, (0, dim - len(v)))
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
# ── PROBLEM TYPE 1: BOX PROJECTION ────────────────────────────────────────────
|
| 50 |
-
#
|
| 51 |
-
# Constraint A : encodes per-dimension box [lo, hi]
|
| 52 |
-
# A[:D/2] = lo[:D/2], A[D/2:] = hi[:D/2]
|
| 53 |
-
# Objective B : unconstrained target point in ℝ^D
|
| 54 |
-
# Solution C : clip(B, lo, hi) — nearest point in box to B
|
| 55 |
-
#
|
| 56 |
-
# Meaning: "stay within resource/capacity bounds while aiming for B"
|
| 57 |
-
|
| 58 |
-
def gen_box(n: int, dim: int, rng: np.random.Generator) -> List[Dict]:
|
| 59 |
-
data = []
|
| 60 |
-
for _ in range(n):
|
| 61 |
-
center = rng.uniform(-2, 2, dim)
|
| 62 |
-
half = rng.uniform(0.3, 2.0, dim)
|
| 63 |
-
lo, hi = center - half, center + half
|
| 64 |
-
B = rng.uniform(-4, 4, dim)
|
| 65 |
-
C = np.clip(B, lo, hi)
|
| 66 |
-
A = pack(lo[:dim//2], hi[:dim//2], dim=dim)
|
| 67 |
-
data.append({'A': A.tolist(), 'B': B.tolist(), 'C': C.tolist(), 'type': 'box_proj'})
|
| 68 |
-
return data
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
# ── PROBLEM TYPE 2: HALFSPACE PROJECTION ──────────────────────────────────────
|
| 72 |
-
#
|
| 73 |
-
# Constraint A : encodes a hyperplane nᵀx = b
|
| 74 |
-
# A = normal vector, A[0] carries the offset b
|
| 75 |
-
# Objective B : unconstrained point in ℝ^D
|
| 76 |
-
# Solution C : projection of B onto the hyperplane
|
| 77 |
-
# C = B − (nᵀB − b) · n
|
| 78 |
-
#
|
| 79 |
-
# Meaning: "satisfy one hard equality constraint at minimum cost to B"
|
| 80 |
-
|
| 81 |
-
def gen_halfspace(n: int, dim: int, rng: np.random.Generator) -> List[Dict]:
|
| 82 |
-
data = []
|
| 83 |
-
for _ in range(n):
|
| 84 |
-
normal = normalize(rng.standard_normal(dim))
|
| 85 |
-
b = float(rng.uniform(-1, 1))
|
| 86 |
-
B = rng.uniform(-3, 3, dim)
|
| 87 |
-
C = B - (float(np.dot(normal, B)) - b) * normal
|
| 88 |
-
A = normal.copy()
|
| 89 |
-
A[0] = b # offset embedded in first slot
|
| 90 |
-
data.append({'A': A.tolist(), 'B': B.tolist(), 'C': C.tolist(), 'type': 'halfspace'})
|
| 91 |
-
return data
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
# ── PROBLEM TYPE 3: SPHERE SURFACE ────────────────────────────────────────────
|
| 95 |
-
#
|
| 96 |
-
# Constraint A : encodes a sphere (center, radius)
|
| 97 |
-
# A = center vector, A[0] overwritten with radius r
|
| 98 |
-
# Objective B : external point
|
| 99 |
-
# Solution C : point on sphere surface nearest to B
|
| 100 |
-
# C = center + r · (B − center) / ‖B − center‖
|
| 101 |
-
#
|
| 102 |
-
# Meaning: "satisfy a norm/budget constraint, move toward B as far as allowed"
|
| 103 |
-
|
| 104 |
-
def gen_sphere(n: int, dim: int, rng: np.random.Generator) -> List[Dict]:
|
| 105 |
-
data = []
|
| 106 |
-
for _ in range(n):
|
| 107 |
-
center = rng.uniform(-1.5, 1.5, dim)
|
| 108 |
-
r = float(rng.uniform(1.0, 3.0))
|
| 109 |
-
B = rng.uniform(-4, 4, dim)
|
| 110 |
-
diff = B - center
|
| 111 |
-
nd = np.linalg.norm(diff)
|
| 112 |
-
if nd < 1e-10:
|
| 113 |
-
diff = np.ones(dim) / np.sqrt(dim)
|
| 114 |
-
nd = 1.0
|
| 115 |
-
C = center + r * diff / nd
|
| 116 |
-
A = center.copy()
|
| 117 |
-
A[0] = r # radius in first slot
|
| 118 |
-
data.append({'A': A.tolist(), 'B': B.tolist(), 'C': C.tolist(), 'type': 'sphere'})
|
| 119 |
-
return data
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
# ── PROBLEM TYPE 4: SIMPLEX PROJECTION ────────────────────────────────────────
|
| 123 |
-
#
|
| 124 |
-
# Constraint A : uniform-prior signal (all ones) → encodes simplex constraint Σxᵢ=1, xᵢ≥0
|
| 125 |
-
# Objective B : unconstrained "belief" vector
|
| 126 |
-
# Solution C : nearest point on probability simplex to B
|
| 127 |
-
#
|
| 128 |
-
# Meaning: "find a valid probability distribution closest to unconstrained belief B"
|
| 129 |
-
# Useful for softmax-like problems.
|
| 130 |
-
|
| 131 |
-
def _proj_simplex(v: np.ndarray) -> np.ndarray:
|
| 132 |
-
n = len(v)
|
| 133 |
-
u = np.sort(v)[::-1]
|
| 134 |
-
cs = np.cumsum(u) - 1.0
|
| 135 |
-
rho = int(np.where(u * np.arange(1, n + 1) > cs)[0][-1])
|
| 136 |
-
theta = cs[rho] / (rho + 1.0)
|
| 137 |
-
return np.maximum(v - theta, 0.0)
|
| 138 |
-
|
| 139 |
-
def gen_simplex(n: int, dim: int, rng: np.random.Generator) -> List[Dict]:
|
| 140 |
-
data = []
|
| 141 |
-
for _ in range(n):
|
| 142 |
-
A = np.ones(dim) # simplex constraint signal
|
| 143 |
-
B = rng.uniform(-1.0, 3.0, dim) # unconstrained belief
|
| 144 |
-
C = _proj_simplex(B)
|
| 145 |
-
data.append({'A': A.tolist(), 'B': B.tolist(), 'C': C.tolist(), 'type': 'simplex'})
|
| 146 |
-
return data
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
# ── PROBLEM TYPE 5: ELASTIC BALANCE ───────────────────────────────────────────
|
| 150 |
-
#
|
| 151 |
-
# Constraint A : encodes soft constraint center + per-dimension tightness weight w ∈ [0,1]
|
| 152 |
-
# A[:D/2] = constraint centers, A[D/2:] = tightness weights
|
| 153 |
-
# Objective B : desired goal point
|
| 154 |
-
# Solution C : per-dimension elastic balance
|
| 155 |
-
# C[j] = w[j] · a_center[j] + (1 − w[j]) · B[j]
|
| 156 |
-
#
|
| 157 |
-
# Meaning: "each dimension is pulled between constraint center and objective,
|
| 158 |
-
# with w[j] controlling how hard the constraint is in that dimension"
|
| 159 |
-
# This is the natural problem for the elastic mesh.
|
| 160 |
-
|
| 161 |
-
def gen_elastic(n: int, dim: int, rng: np.random.Generator) -> List[Dict]:
|
| 162 |
-
data = []
|
| 163 |
-
for _ in range(n):
|
| 164 |
-
a_center = rng.uniform(-2, 2, dim)
|
| 165 |
-
w = rng.uniform(0.05, 0.95, dim) # per-dim tightness
|
| 166 |
-
B = rng.uniform(-3, 3, dim)
|
| 167 |
-
C = w * a_center + (1.0 - w) * B
|
| 168 |
-
A = pack(a_center[:dim//2], w[:dim//2], dim=dim)
|
| 169 |
-
data.append({'A': A.tolist(), 'B': B.tolist(), 'C': C.tolist(), 'type': 'elastic'})
|
| 170 |
-
return data
|
| 171 |
-
|
| 172 |
-
|
| 173 |
-
# ── ASSEMBLY ──────────────────────────────────────────────────────────────────
|
| 174 |
-
|
| 175 |
-
GENERATORS = {
|
| 176 |
-
'box_proj': gen_box,
|
| 177 |
-
'halfspace': gen_halfspace,
|
| 178 |
-
'sphere': gen_sphere,
|
| 179 |
-
'simplex': gen_simplex,
|
| 180 |
-
'elastic': gen_elastic,
|
| 181 |
-
}
|
| 182 |
-
|
| 183 |
-
def generate_all(n_per_type: int = SAMPLES_PER_TYPE,
|
| 184 |
-
dim: int = DIM,
|
| 185 |
-
seed: int = 42) -> List[Dict]:
|
| 186 |
-
rng = np.random.default_rng(seed)
|
| 187 |
-
data = []
|
| 188 |
-
for fn in GENERATORS.values():
|
| 189 |
-
data.extend(fn(n_per_type, dim, rng))
|
| 190 |
-
idx = rng.permutation(len(data))
|
| 191 |
-
return [data[i] for i in idx]
|
| 192 |
-
|
| 193 |
-
|
| 194 |
-
# ── MAIN ──────────────────────────────────────────────────────────────────────
|
| 195 |
-
|
| 196 |
-
if __name__ == '__main__':
|
| 197 |
-
parser = argparse.ArgumentParser(description='Generate elastic mesh training data')
|
| 198 |
-
parser.add_argument('--dim', type=int, default=DIM, help='embedding dimension')
|
| 199 |
-
parser.add_argument('--n', type=int, default=SAMPLES_PER_TYPE, help='samples per problem type')
|
| 200 |
-
parser.add_argument('--out', type=str, default='data', help='output directory')
|
| 201 |
-
args = parser.parse_args()
|
| 202 |
-
|
| 203 |
-
print(f"\n{'─'*50}")
|
| 204 |
-
print(f" Generating {5 * args.n} samples | dim={args.dim}")
|
| 205 |
-
print(f"{'─'*50}")
|
| 206 |
-
|
| 207 |
-
data = generate_all(args.n, args.dim)
|
| 208 |
-
split = int(len(data) * 0.9)
|
| 209 |
-
train, test = data[:split], data[split:]
|
| 210 |
-
|
| 211 |
-
out = pathlib.Path(args.out)
|
| 212 |
-
out.mkdir(exist_ok=True)
|
| 213 |
-
with open(out / 'train.json', 'w') as f: json.dump(train, f)
|
| 214 |
-
with open(out / 'test.json', 'w') as f: json.dump(test, f)
|
| 215 |
-
|
| 216 |
-
# Per-type statistics
|
| 217 |
-
from collections import Counter
|
| 218 |
-
train_types = Counter(d['type'] for d in train)
|
| 219 |
-
test_types = Counter(d['type'] for d in test)
|
| 220 |
-
|
| 221 |
-
print(f"\n Train : {len(train)}")
|
| 222 |
-
print(f" Test : {len(test)}\n")
|
| 223 |
-
print(f" {'Type':<14} {'Train':>8} {'Test':>7} C-norm (mean)")
|
| 224 |
-
print(f" {'─'*14} {'─'*8} {'─'*7} {'─'*14}")
|
| 225 |
-
for t in GENERATORS:
|
| 226 |
-
subset = [d for d in data if d['type'] == t]
|
| 227 |
-
norms = [np.linalg.norm(d['C']) for d in subset]
|
| 228 |
-
print(f" {t:<14} {train_types[t]:>8} {test_types[t]:>7} "
|
| 229 |
-
f"{np.mean(norms):.3f} ± {np.std(norms):.3f}")
|
| 230 |
-
|
| 231 |
-
# Sanity check one sample per type
|
| 232 |
-
print(f"\n Sanity check (first sample per type):")
|
| 233 |
-
seen = set()
|
| 234 |
-
for d in data:
|
| 235 |
-
if d['type'] in seen: continue
|
| 236 |
-
seen.add(d['type'])
|
| 237 |
-
A, B, C = map(np.array, [d['A'], d['B'], d['C']])
|
| 238 |
-
err = np.linalg.norm(A - B)
|
| 239 |
-
print(f" [{d['type']:<12}] "
|
| 240 |
-
f"‖A‖={np.linalg.norm(A):.2f} ‖B‖={np.linalg.norm(B):.2f} "
|
| 241 |
-
f"‖C‖={np.linalg.norm(C):.2f} ‖A-B‖={err:.2f}")
|
| 242 |
-
|
| 243 |
-
print(f"\n Saved → {out}/train.json {out}/test.json\n")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
practicality_core.py
ADDED
|
@@ -0,0 +1,307 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
import torch
|
| 3 |
+
import sympy as sp
|
| 4 |
+
from sympy.parsing.sympy_parser import parse_expr
|
| 5 |
+
from typing import Dict, List, Tuple, Set, Optional, Callable
|
| 6 |
+
from dataclasses import dataclass, field
|
| 7 |
+
from functools import reduce
|
| 8 |
+
from collections import defaultdict, deque
|
| 9 |
+
|
| 10 |
+
USE_GPU = torch.cuda.is_available()
|
| 11 |
+
DEVICE = torch.device("cuda" if USE_GPU else "cpu")
|
| 12 |
+
SOLVE_THRESHOLD = 0.001
|
| 13 |
+
LOG_SPACE_THRESHOLD = 1000.0
|
| 14 |
+
|
| 15 |
+
def safe_round(val, ndigits=8):
|
| 16 |
+
try:
|
| 17 |
+
return round(val, ndigits) if math.isfinite(val) else val
|
| 18 |
+
except: return val
|
| 19 |
+
|
| 20 |
+
def _c15(v):
|
| 21 |
+
try:
|
| 22 |
+
if not math.isfinite(v): return 1e15 if v > 0 else -1e15
|
| 23 |
+
return max(-1e15, min(1e15, float(v)))
|
| 24 |
+
except: return 0.0
|
| 25 |
+
|
| 26 |
+
def safe_log(x):
|
| 27 |
+
if not isinstance(x, torch.Tensor): x=torch.tensor(float(x),device=DEVICE,dtype=torch.float32)
|
| 28 |
+
return torch.log(torch.clamp(x,min=1e-7))
|
| 29 |
+
|
| 30 |
+
def safe_sqrt(x):
|
| 31 |
+
if not isinstance(x, torch.Tensor): x=torch.tensor(float(x),device=DEVICE,dtype=torch.float32)
|
| 32 |
+
return torch.sqrt(torch.clamp(x,min=0.0))
|
| 33 |
+
|
| 34 |
+
class IV:
|
| 35 |
+
__slots__ = ("lo", "hi")
|
| 36 |
+
def __init__(self, lo, hi): self.lo = float(lo); self.hi = float(hi)
|
| 37 |
+
def __add__(self, o): return IV(self.lo+o, self.hi+o) if isinstance(o,(int,float)) else IV(self.lo+o.lo, self.hi+o.hi)
|
| 38 |
+
__radd__ = __add__
|
| 39 |
+
def __sub__(self, o): return IV(self.lo-o, self.hi-o) if isinstance(o,(int,float)) else IV(self.lo-o.hi, self.hi-o.lo)
|
| 40 |
+
def __rsub__(self, o): return IV(o-self.hi, o-self.lo) if isinstance(o,(int,float)) else o.__sub__(self)
|
| 41 |
+
def __mul__(self, o):
|
| 42 |
+
if isinstance(o,(int,float)): a,b=self.lo*o,self.hi*o; return IV(min(a,b),max(a,b))
|
| 43 |
+
p=(self.lo*o.lo,self.lo*o.hi,self.hi*o.lo,self.hi*o.hi); return IV(min(p),max(p))
|
| 44 |
+
__rmul__ = __mul__
|
| 45 |
+
def __truediv__(self, o):
|
| 46 |
+
if isinstance(o,(int,float)):
|
| 47 |
+
if abs(o)<1e-15: return IV(-1e18,1e18)
|
| 48 |
+
a,b=self.lo/o,self.hi/o; return IV(min(a,b),max(a,b))
|
| 49 |
+
if o.lo<=0<=o.hi: return IV(-1e18,1e18)
|
| 50 |
+
return self*IV(1.0/o.hi,1.0/o.lo)
|
| 51 |
+
def __neg__(self): return IV(-self.hi,-self.lo)
|
| 52 |
+
def __pow__(self, n):
|
| 53 |
+
if isinstance(n,int):
|
| 54 |
+
if n==0: return IV(1.0,1.0)
|
| 55 |
+
if n%2==0:
|
| 56 |
+
if self.lo>=0: return IV(self.lo**n,self.hi**n)
|
| 57 |
+
if self.hi<=0: return IV(self.hi**n,self.lo**n)
|
| 58 |
+
return IV(0.0,max(abs(self.lo)**n,abs(self.hi)**n))
|
| 59 |
+
return IV(self.lo**n if self.lo>=0 else -((-self.lo)**n),
|
| 60 |
+
self.hi**n if self.hi>=0 else -((-self.hi)**n))
|
| 61 |
+
if self.lo<0: return IV(0.0,max(abs(self.lo)**n,self.hi**n))
|
| 62 |
+
return IV(self.lo**n,self.hi**n)
|
| 63 |
+
def contains_zero(self): return self.lo<=0.0<=self.hi
|
| 64 |
+
def width(self): return max(0.0,self.hi-self.lo)
|
| 65 |
+
def mid(self): return (self.lo+self.hi)*0.5
|
| 66 |
+
|
| 67 |
+
def compile_iv(expr, variables):
|
| 68 |
+
def _c(e):
|
| 69 |
+
if e.is_Number: v=float(e); return lambda box,_v=v: IV(_v,_v)
|
| 70 |
+
if e.is_Symbol: n=str(e); return lambda box,_n=n: box.get(_n,IV(-1e18,1e18))
|
| 71 |
+
if e.is_Add: fs=[_c(a) for a in e.args]; return lambda box,_fs=fs: reduce(lambda a,b:a+b,(_f(box) for _f in _fs))
|
| 72 |
+
if e.is_Mul: fs=[_c(a) for a in e.args]; return lambda box,_fs=fs: reduce(lambda a,b:a*b,(_f(box) for _f in _fs))
|
| 73 |
+
if e.is_Pow:
|
| 74 |
+
bc=_c(e.args[0]); ex=e.args[1]
|
| 75 |
+
if ex.is_Number: return lambda box,_bc=bc,_ex=float(ex): _bc(box)**_ex
|
| 76 |
+
exc=_c(ex); return lambda box,_bc=bc,_exc=exc: _bc(box)**_exc(box).mid()
|
| 77 |
+
return lambda box: IV(-1e18,1e18)
|
| 78 |
+
return _c(expr)
|
| 79 |
+
|
| 80 |
+
def _hc4(box, constraints):
|
| 81 |
+
cur = dict(box)
|
| 82 |
+
for mc in constraints:
|
| 83 |
+
if getattr(mc,'weight',1.0)==0.0: continue
|
| 84 |
+
if mc.kind=="or_eq":
|
| 85 |
+
valid=False
|
| 86 |
+
for bmc in mc.branches:
|
| 87 |
+
if bmc.fast_iv is None: valid=True; break
|
| 88 |
+
try:
|
| 89 |
+
if bmc.fast_iv(cur).contains_zero(): valid=True; break
|
| 90 |
+
except: valid=True; break
|
| 91 |
+
if not valid: return None
|
| 92 |
+
else:
|
| 93 |
+
if mc.fast_iv is None: continue
|
| 94 |
+
try:
|
| 95 |
+
riv=mc.fast_iv(cur)
|
| 96 |
+
if ((mc.kind=="equality" and not riv.contains_zero()) or
|
| 97 |
+
(mc.kind=="inequality" and ((mc.direction=="geq" and riv.hi<-1e-10) or
|
| 98 |
+
(mc.direction=="leq" and riv.lo>1e-10)))): return None
|
| 99 |
+
except: pass
|
| 100 |
+
return cur
|
| 101 |
+
|
| 102 |
+
@dataclass
|
| 103 |
+
class MathConstraint:
|
| 104 |
+
kind:str; expr_str:str; direction:str; weight:float=1.0
|
| 105 |
+
fast_iv:Optional[Callable]=field(default=None,repr=False)
|
| 106 |
+
torch_func:Optional[Callable]=field(default=None,repr=False)
|
| 107 |
+
syms_used:List[str]=field(default_factory=list)
|
| 108 |
+
parsed:Optional[sp.Expr]=field(default=None,repr=False)
|
| 109 |
+
scope:str="root"; branches:List['MathConstraint']=field(default_factory=list)
|
| 110 |
+
projections:Dict[str,List[Dict]]=field(default_factory=dict)
|
| 111 |
+
|
| 112 |
+
PROJECTION_CACHE = {}
|
| 113 |
+
|
| 114 |
+
def compile_mc(kind, expr_str, direction, variables, weight=1.0, scope="root", branches=None):
|
| 115 |
+
expr_str = expr_str.replace("^","**")
|
| 116 |
+
mc = MathConstraint(kind=kind, expr_str=expr_str, direction=direction, weight=weight, scope=scope)
|
| 117 |
+
|
| 118 |
+
if kind == "or_eq" and branches:
|
| 119 |
+
for b_str in branches:
|
| 120 |
+
b_mc = compile_mc("equality", b_str, "eq", variables, weight, scope)
|
| 121 |
+
mc.branches.append(b_mc)
|
| 122 |
+
mc.syms_used.extend(b_mc.syms_used)
|
| 123 |
+
mc.syms_used = list(dict.fromkeys(mc.syms_used))
|
| 124 |
+
def _or_iv(box, _mcs=mc.branches):
|
| 125 |
+
rivs = []
|
| 126 |
+
for b in _mcs:
|
| 127 |
+
if b.fast_iv:
|
| 128 |
+
try: rivs.append(b.fast_iv(box))
|
| 129 |
+
except: pass
|
| 130 |
+
if not rivs: return IV(-1e18, 1e18)
|
| 131 |
+
return IV(min(r.lo for r in rivs), max(r.hi for r in rivs))
|
| 132 |
+
mc.fast_iv = _or_iv
|
| 133 |
+
return mc
|
| 134 |
+
|
| 135 |
+
syms = {v: sp.Symbol(v) for v in variables}
|
| 136 |
+
try:
|
| 137 |
+
parsed = parse_expr(expr_str, local_dict=syms) if kind != "or_eq" else None
|
| 138 |
+
if parsed:
|
| 139 |
+
if getattr(parsed,'is_Equality',False) or getattr(parsed,'is_Relational',False):
|
| 140 |
+
parsed = parsed.lhs - parsed.rhs
|
| 141 |
+
for s in list(parsed.free_symbols):
|
| 142 |
+
if str(s) not in variables: parsed = parsed.subs(s, 1.0)
|
| 143 |
+
mc.parsed = parsed
|
| 144 |
+
mc.syms_used = [v for v in variables if sp.Symbol(v) in parsed.free_symbols]
|
| 145 |
+
mc.fast_iv = compile_iv(parsed, variables)
|
| 146 |
+
|
| 147 |
+
pt_map = {'sin':torch.sin, 'cos':torch.cos, 'tan':torch.tan, 'exp':torch.exp,
|
| 148 |
+
'log':safe_log, 'sqrt':safe_sqrt, 'Abs':torch.abs, 'pi':math.pi, 'E':math.e}
|
| 149 |
+
t_func_raw = sp.lambdify([sp.Symbol(v) for v in mc.syms_used], parsed, modules=[pt_map, "math"])
|
| 150 |
+
|
| 151 |
+
def _t_wrapper(*args):
|
| 152 |
+
try:
|
| 153 |
+
val = t_func_raw(*args)
|
| 154 |
+
if not isinstance(val, torch.Tensor):
|
| 155 |
+
val = torch.tensor(float(val), device=DEVICE, dtype=torch.float32)
|
| 156 |
+
except:
|
| 157 |
+
val = torch.tensor(1e6, device=DEVICE, dtype=torch.float32)
|
| 158 |
+
return torch.nan_to_num(val, posinf=1e6, neginf=-1e6, nan=1e6)
|
| 159 |
+
|
| 160 |
+
mc.torch_func = _t_wrapper
|
| 161 |
+
|
| 162 |
+
if kind == "equality":
|
| 163 |
+
if expr_str not in PROJECTION_CACHE:
|
| 164 |
+
pm = {}
|
| 165 |
+
for sym in parsed.free_symbols:
|
| 166 |
+
v_str = str(sym)
|
| 167 |
+
try:
|
| 168 |
+
sols = sp.solve(parsed, sym)
|
| 169 |
+
pm[v_str] = []
|
| 170 |
+
for sol in sols:
|
| 171 |
+
fs = list(sol.free_symbols)
|
| 172 |
+
pm[v_str].append({"syms": [str(s) for s in fs], "func": sp.lambdify(fs, sol, modules="math")})
|
| 173 |
+
except: pass
|
| 174 |
+
PROJECTION_CACHE[expr_str] = pm
|
| 175 |
+
mc.projections = PROJECTION_CACHE.get(expr_str, {})
|
| 176 |
+
except: pass
|
| 177 |
+
return mc
|
| 178 |
+
|
| 179 |
+
@dataclass
|
| 180 |
+
class Problem:
|
| 181 |
+
pid:str; variables:List[str]; bounds:Dict[str,Tuple[float,float]]
|
| 182 |
+
compiled_constraints:List[MathConstraint]
|
| 183 |
+
int_vars:Set[str]=field(default_factory=set)
|
| 184 |
+
minimize_var:str=""
|
| 185 |
+
log_space_vars:Set[str]=field(default_factory=set)
|
| 186 |
+
|
| 187 |
+
def __post_init__(self):
|
| 188 |
+
self.var_idx = {v: i for i, v in enumerate(self.variables)}
|
| 189 |
+
self.adjacency_list = defaultdict(set)
|
| 190 |
+
for mc in self.compiled_constraints:
|
| 191 |
+
for v1 in mc.syms_used:
|
| 192 |
+
for v2 in mc.syms_used:
|
| 193 |
+
if v1 != v2: self.adjacency_list[v1].add(v2)
|
| 194 |
+
|
| 195 |
+
self.log_space_vars = set()
|
| 196 |
+
for v in self.variables:
|
| 197 |
+
if v in self.int_vars: continue
|
| 198 |
+
lo, hi = self.bounds.get(v, (0, 1))
|
| 199 |
+
if lo > 0 and hi > 0 and math.isfinite(lo) and math.isfinite(hi):
|
| 200 |
+
if hi / lo > LOG_SPACE_THRESHOLD:
|
| 201 |
+
self.log_space_vars.add(v)
|
| 202 |
+
|
| 203 |
+
def get_markov_blanket(self, pinned_vars: Set[str], depth: int=2) -> Set[str]:
|
| 204 |
+
if not pinned_vars: return set(self.variables)
|
| 205 |
+
visited = set(pinned_vars)
|
| 206 |
+
queue = deque([(v, 0) for v in pinned_vars])
|
| 207 |
+
while queue:
|
| 208 |
+
curr, d = queue.popleft()
|
| 209 |
+
if d < depth:
|
| 210 |
+
for neighbor in self.adjacency_list.get(curr, []):
|
| 211 |
+
if neighbor not in visited:
|
| 212 |
+
visited.add(neighbor)
|
| 213 |
+
queue.append((neighbor, d+1))
|
| 214 |
+
return visited
|
| 215 |
+
|
| 216 |
+
def tensor_energy(self, X: torch.Tensor, step_ratio: float=1.0, is_optimizing: bool=False) -> torch.Tensor:
|
| 217 |
+
is_batched = (X.dim() == 2)
|
| 218 |
+
batch_size = X.shape[0] if is_batched else 1
|
| 219 |
+
total = torch.zeros(batch_size, device=DEVICE, dtype=torch.float32)
|
| 220 |
+
|
| 221 |
+
for mc in self.compiled_constraints:
|
| 222 |
+
if getattr(mc, 'weight', 1.0) == 0.0: continue
|
| 223 |
+
eff_weight = float(mc.weight)
|
| 224 |
+
if step_ratio < 1.0 and any(f in mc.expr_str for f in ["sin", "cos", "exp"]):
|
| 225 |
+
eff_weight *= (0.1 + 0.9 * step_ratio)
|
| 226 |
+
|
| 227 |
+
if mc.kind == "or_eq":
|
| 228 |
+
b_vals = []
|
| 229 |
+
for bmc in mc.branches:
|
| 230 |
+
if bmc.torch_func:
|
| 231 |
+
args = [X[:, self.var_idx[v]] if is_batched else X[self.var_idx[v]] for v in bmc.syms_used]
|
| 232 |
+
b_vals.append(torch.abs(bmc.torch_func(*args)))
|
| 233 |
+
if b_vals: total += (torch.stack(b_vals, dim=0).min(dim=0)[0]**2) * eff_weight
|
| 234 |
+
else:
|
| 235 |
+
if mc.torch_func is None: continue
|
| 236 |
+
args = [X[:, self.var_idx[v]] if is_batched else X[self.var_idx[v]] for v in mc.syms_used]
|
| 237 |
+
val = mc.torch_func(*args)
|
| 238 |
+
if mc.kind == "equality": total += (val**2) * eff_weight
|
| 239 |
+
elif mc.direction == "geq": total += (torch.relu(-val)**2) * eff_weight
|
| 240 |
+
else: total += (torch.relu(val)**2) * eff_weight
|
| 241 |
+
|
| 242 |
+
for i, v in enumerate(self.variables):
|
| 243 |
+
lo, hi = _c15(self.bounds[v][0]), _c15(self.bounds[v][1])
|
| 244 |
+
col = X[:, i] if is_batched else X[i]
|
| 245 |
+
margin = (hi - lo) * 0.1 * (1.0 - step_ratio)
|
| 246 |
+
out_of_bounds = torch.relu(lo - margin - col) + torch.relu(col - (hi + margin))
|
| 247 |
+
total += (out_of_bounds**2) * 10.0
|
| 248 |
+
|
| 249 |
+
if is_optimizing and self.minimize_var and self.minimize_var in self.var_idx:
|
| 250 |
+
midx = self.var_idx[self.minimize_var]
|
| 251 |
+
lo, hi = _c15(self.bounds[self.minimize_var][0]), _c15(self.bounds[self.minimize_var][1])
|
| 252 |
+
rng = max(hi - lo, 1e-8)
|
| 253 |
+
col = X[:, midx] if is_batched else X[midx]
|
| 254 |
+
normalized = (col - lo) / rng
|
| 255 |
+
total += normalized * 0.05 * step_ratio
|
| 256 |
+
|
| 257 |
+
return total.view(batch_size, -1).sum(dim=1)
|
| 258 |
+
|
| 259 |
+
def scalar_energy(self, b: Dict[str, float]) -> float:
|
| 260 |
+
x_arr = [b.get(v, (_c15(self.bounds.get(v,(-1,1))[0]) + _c15(self.bounds.get(v,(-1,1))[1]))/2) for v in self.variables]
|
| 261 |
+
X_t = torch.tensor(x_arr, device=DEVICE, dtype=torch.float32).unsqueeze(0)
|
| 262 |
+
with torch.no_grad():
|
| 263 |
+
return float(self.tensor_energy(X_t, step_ratio=1.0, is_optimizing=False).item())
|
| 264 |
+
|
| 265 |
+
def algebraic_propagate_pinned(problem: Problem, pinned_vars: Dict[str, float], timeout_secs: float=2.0) -> Tuple[Dict[str, float], List[str]]:
|
| 266 |
+
resolved = dict(pinned_vars)
|
| 267 |
+
log = []
|
| 268 |
+
changed = True
|
| 269 |
+
max_passes = len(problem.variables) + 1
|
| 270 |
+
passes = 0
|
| 271 |
+
while changed and passes < max_passes:
|
| 272 |
+
changed = False
|
| 273 |
+
passes += 1
|
| 274 |
+
for mc in problem.compiled_constraints:
|
| 275 |
+
if mc.kind != "equality" or mc.parsed is None: continue
|
| 276 |
+
expr = mc.parsed
|
| 277 |
+
for v, val in resolved.items():
|
| 278 |
+
try: expr = expr.subs(sp.Symbol(v), sp.Float(val))
|
| 279 |
+
except: pass
|
| 280 |
+
try: expr = sp.simplify(expr)
|
| 281 |
+
except: pass
|
| 282 |
+
free = [str(s) for s in expr.free_symbols if str(s) in problem.variables and str(s) not in resolved]
|
| 283 |
+
if len(free) == 1:
|
| 284 |
+
target_sym = sp.Symbol(free[0])
|
| 285 |
+
try:
|
| 286 |
+
solutions = sp.solve(expr, target_sym)
|
| 287 |
+
if not solutions: continue
|
| 288 |
+
lo, hi = problem.bounds.get(free[0], (-1e9, 1e9))
|
| 289 |
+
mid = (lo + hi) / 2.0
|
| 290 |
+
valid_sols = []
|
| 291 |
+
for sol in solutions:
|
| 292 |
+
try:
|
| 293 |
+
val = complex(sol.evalf())
|
| 294 |
+
if abs(val.imag) < 1e-8:
|
| 295 |
+
rval = val.real
|
| 296 |
+
if lo - 1.0 <= rval <= hi + 1.0 and math.isfinite(rval):
|
| 297 |
+
valid_sols.append(rval)
|
| 298 |
+
except: pass
|
| 299 |
+
if valid_sols:
|
| 300 |
+
best = min(valid_sols, key=lambda v: abs(v - mid))
|
| 301 |
+
resolved[free[0]] = best
|
| 302 |
+
log.append(f" PROP [{free[0]}] = {best:.6g} <- [{mc.expr_str[:50]}]")
|
| 303 |
+
changed = True
|
| 304 |
+
except: pass
|
| 305 |
+
n_new = len(resolved) - len(pinned_vars)
|
| 306 |
+
if n_new > 0: log.insert(0, f"ALGEBRAIC PROPAGATOR: resolved {n_new} vars in {passes} pass(es)")
|
| 307 |
+
return resolved, log
|