everydaytok commited on
Commit
78bf05b
·
verified ·
1 Parent(s): b556146

Rename data_gen.py to practicality_core.py

Browse files
Files changed (2) hide show
  1. data_gen.py +0 -243
  2. practicality_core.py +307 -0
data_gen.py DELETED
@@ -1,243 +0,0 @@
1
- """
2
- data_gen.py — Training / test data for the elastic mesh.
3
-
4
- Each sample is a triple (A, B, C) where:
5
- A ∈ ℝ^DIM encodes constraints ("what must be true")
6
- B ∈ ℝ^DIM encodes objectives ("what we want")
7
- C ∈ ℝ^DIM is the analytic solution — the feasibility center the mesh must learn to produce
8
-
9
- Five problem families, each with a geometrically distinct C:
10
-
11
- 1. box_proj — clamp B into axis-aligned box defined by A
12
- 2. halfspace — project B onto hyperplane defined by A
13
- 3. sphere — project B onto sphere surface defined by A
14
- 4. simplex — project B onto probability simplex (A = uniform prior signal)
15
- 5. elastic_bal — per-dimension weighted balance between A-center and B
16
-
17
- These cover:
18
- - Bounded feasibility (box)
19
- - Equality constraints (halfspace)
20
- - Norm constraints (sphere)
21
- - Probability/sum=1 (simplex)
22
- - Soft trade-offs (elastic)
23
-
24
- The mesh sees ONLY (A, B) during inference; C is what it must reconstruct.
25
- """
26
-
27
- import numpy as np
28
- import json, pathlib, argparse
29
- from typing import List, Dict
30
-
31
- DIM = 32 # embedding dimension (set to 768 for LLM-scale)
32
- SAMPLES_PER_TYPE = 1000 # × 5 types = 5 000 total
33
-
34
-
35
- # ── UTILITIES ─────────────────────────────────────────────────────────────────
36
-
37
- def normalize(v: np.ndarray) -> np.ndarray:
38
- n = np.linalg.norm(v)
39
- return v / (n + 1e-12)
40
-
41
- def pack(*arrays: np.ndarray, dim: int) -> np.ndarray:
42
- """Concatenate + trim/pad to `dim`."""
43
- v = np.concatenate(arrays)
44
- if len(v) >= dim:
45
- return v[:dim]
46
- return np.pad(v, (0, dim - len(v)))
47
-
48
-
49
- # ── PROBLEM TYPE 1: BOX PROJECTION ────────────────────────────────────────────
50
- #
51
- # Constraint A : encodes per-dimension box [lo, hi]
52
- # A[:D/2] = lo[:D/2], A[D/2:] = hi[:D/2]
53
- # Objective B : unconstrained target point in ℝ^D
54
- # Solution C : clip(B, lo, hi) — nearest point in box to B
55
- #
56
- # Meaning: "stay within resource/capacity bounds while aiming for B"
57
-
58
- def gen_box(n: int, dim: int, rng: np.random.Generator) -> List[Dict]:
59
- data = []
60
- for _ in range(n):
61
- center = rng.uniform(-2, 2, dim)
62
- half = rng.uniform(0.3, 2.0, dim)
63
- lo, hi = center - half, center + half
64
- B = rng.uniform(-4, 4, dim)
65
- C = np.clip(B, lo, hi)
66
- A = pack(lo[:dim//2], hi[:dim//2], dim=dim)
67
- data.append({'A': A.tolist(), 'B': B.tolist(), 'C': C.tolist(), 'type': 'box_proj'})
68
- return data
69
-
70
-
71
- # ── PROBLEM TYPE 2: HALFSPACE PROJECTION ──────────────────────────────────────
72
- #
73
- # Constraint A : encodes a hyperplane nᵀx = b
74
- # A = normal vector, A[0] carries the offset b
75
- # Objective B : unconstrained point in ℝ^D
76
- # Solution C : projection of B onto the hyperplane
77
- # C = B − (nᵀB − b) · n
78
- #
79
- # Meaning: "satisfy one hard equality constraint at minimum cost to B"
80
-
81
- def gen_halfspace(n: int, dim: int, rng: np.random.Generator) -> List[Dict]:
82
- data = []
83
- for _ in range(n):
84
- normal = normalize(rng.standard_normal(dim))
85
- b = float(rng.uniform(-1, 1))
86
- B = rng.uniform(-3, 3, dim)
87
- C = B - (float(np.dot(normal, B)) - b) * normal
88
- A = normal.copy()
89
- A[0] = b # offset embedded in first slot
90
- data.append({'A': A.tolist(), 'B': B.tolist(), 'C': C.tolist(), 'type': 'halfspace'})
91
- return data
92
-
93
-
94
- # ── PROBLEM TYPE 3: SPHERE SURFACE ────────────────────────────────────────────
95
- #
96
- # Constraint A : encodes a sphere (center, radius)
97
- # A = center vector, A[0] overwritten with radius r
98
- # Objective B : external point
99
- # Solution C : point on sphere surface nearest to B
100
- # C = center + r · (B − center) / ‖B − center‖
101
- #
102
- # Meaning: "satisfy a norm/budget constraint, move toward B as far as allowed"
103
-
104
- def gen_sphere(n: int, dim: int, rng: np.random.Generator) -> List[Dict]:
105
- data = []
106
- for _ in range(n):
107
- center = rng.uniform(-1.5, 1.5, dim)
108
- r = float(rng.uniform(1.0, 3.0))
109
- B = rng.uniform(-4, 4, dim)
110
- diff = B - center
111
- nd = np.linalg.norm(diff)
112
- if nd < 1e-10:
113
- diff = np.ones(dim) / np.sqrt(dim)
114
- nd = 1.0
115
- C = center + r * diff / nd
116
- A = center.copy()
117
- A[0] = r # radius in first slot
118
- data.append({'A': A.tolist(), 'B': B.tolist(), 'C': C.tolist(), 'type': 'sphere'})
119
- return data
120
-
121
-
122
- # ── PROBLEM TYPE 4: SIMPLEX PROJECTION ────────────────────────────────────────
123
- #
124
- # Constraint A : uniform-prior signal (all ones) → encodes simplex constraint Σxᵢ=1, xᵢ≥0
125
- # Objective B : unconstrained "belief" vector
126
- # Solution C : nearest point on probability simplex to B
127
- #
128
- # Meaning: "find a valid probability distribution closest to unconstrained belief B"
129
- # Useful for softmax-like problems.
130
-
131
- def _proj_simplex(v: np.ndarray) -> np.ndarray:
132
- n = len(v)
133
- u = np.sort(v)[::-1]
134
- cs = np.cumsum(u) - 1.0
135
- rho = int(np.where(u * np.arange(1, n + 1) > cs)[0][-1])
136
- theta = cs[rho] / (rho + 1.0)
137
- return np.maximum(v - theta, 0.0)
138
-
139
- def gen_simplex(n: int, dim: int, rng: np.random.Generator) -> List[Dict]:
140
- data = []
141
- for _ in range(n):
142
- A = np.ones(dim) # simplex constraint signal
143
- B = rng.uniform(-1.0, 3.0, dim) # unconstrained belief
144
- C = _proj_simplex(B)
145
- data.append({'A': A.tolist(), 'B': B.tolist(), 'C': C.tolist(), 'type': 'simplex'})
146
- return data
147
-
148
-
149
- # ── PROBLEM TYPE 5: ELASTIC BALANCE ───────────────────────────────────────────
150
- #
151
- # Constraint A : encodes soft constraint center + per-dimension tightness weight w ∈ [0,1]
152
- # A[:D/2] = constraint centers, A[D/2:] = tightness weights
153
- # Objective B : desired goal point
154
- # Solution C : per-dimension elastic balance
155
- # C[j] = w[j] · a_center[j] + (1 − w[j]) · B[j]
156
- #
157
- # Meaning: "each dimension is pulled between constraint center and objective,
158
- # with w[j] controlling how hard the constraint is in that dimension"
159
- # This is the natural problem for the elastic mesh.
160
-
161
- def gen_elastic(n: int, dim: int, rng: np.random.Generator) -> List[Dict]:
162
- data = []
163
- for _ in range(n):
164
- a_center = rng.uniform(-2, 2, dim)
165
- w = rng.uniform(0.05, 0.95, dim) # per-dim tightness
166
- B = rng.uniform(-3, 3, dim)
167
- C = w * a_center + (1.0 - w) * B
168
- A = pack(a_center[:dim//2], w[:dim//2], dim=dim)
169
- data.append({'A': A.tolist(), 'B': B.tolist(), 'C': C.tolist(), 'type': 'elastic'})
170
- return data
171
-
172
-
173
- # ── ASSEMBLY ──────────────────────────────────────────────────────────────────
174
-
175
- GENERATORS = {
176
- 'box_proj': gen_box,
177
- 'halfspace': gen_halfspace,
178
- 'sphere': gen_sphere,
179
- 'simplex': gen_simplex,
180
- 'elastic': gen_elastic,
181
- }
182
-
183
- def generate_all(n_per_type: int = SAMPLES_PER_TYPE,
184
- dim: int = DIM,
185
- seed: int = 42) -> List[Dict]:
186
- rng = np.random.default_rng(seed)
187
- data = []
188
- for fn in GENERATORS.values():
189
- data.extend(fn(n_per_type, dim, rng))
190
- idx = rng.permutation(len(data))
191
- return [data[i] for i in idx]
192
-
193
-
194
- # ── MAIN ──────────────────────────────────────────────────────────────────────
195
-
196
- if __name__ == '__main__':
197
- parser = argparse.ArgumentParser(description='Generate elastic mesh training data')
198
- parser.add_argument('--dim', type=int, default=DIM, help='embedding dimension')
199
- parser.add_argument('--n', type=int, default=SAMPLES_PER_TYPE, help='samples per problem type')
200
- parser.add_argument('--out', type=str, default='data', help='output directory')
201
- args = parser.parse_args()
202
-
203
- print(f"\n{'─'*50}")
204
- print(f" Generating {5 * args.n} samples | dim={args.dim}")
205
- print(f"{'─'*50}")
206
-
207
- data = generate_all(args.n, args.dim)
208
- split = int(len(data) * 0.9)
209
- train, test = data[:split], data[split:]
210
-
211
- out = pathlib.Path(args.out)
212
- out.mkdir(exist_ok=True)
213
- with open(out / 'train.json', 'w') as f: json.dump(train, f)
214
- with open(out / 'test.json', 'w') as f: json.dump(test, f)
215
-
216
- # Per-type statistics
217
- from collections import Counter
218
- train_types = Counter(d['type'] for d in train)
219
- test_types = Counter(d['type'] for d in test)
220
-
221
- print(f"\n Train : {len(train)}")
222
- print(f" Test : {len(test)}\n")
223
- print(f" {'Type':<14} {'Train':>8} {'Test':>7} C-norm (mean)")
224
- print(f" {'─'*14} {'─'*8} {'─'*7} {'─'*14}")
225
- for t in GENERATORS:
226
- subset = [d for d in data if d['type'] == t]
227
- norms = [np.linalg.norm(d['C']) for d in subset]
228
- print(f" {t:<14} {train_types[t]:>8} {test_types[t]:>7} "
229
- f"{np.mean(norms):.3f} ± {np.std(norms):.3f}")
230
-
231
- # Sanity check one sample per type
232
- print(f"\n Sanity check (first sample per type):")
233
- seen = set()
234
- for d in data:
235
- if d['type'] in seen: continue
236
- seen.add(d['type'])
237
- A, B, C = map(np.array, [d['A'], d['B'], d['C']])
238
- err = np.linalg.norm(A - B)
239
- print(f" [{d['type']:<12}] "
240
- f"‖A‖={np.linalg.norm(A):.2f} ‖B‖={np.linalg.norm(B):.2f} "
241
- f"‖C‖={np.linalg.norm(C):.2f} ‖A-B‖={err:.2f}")
242
-
243
- print(f"\n Saved → {out}/train.json {out}/test.json\n")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
practicality_core.py ADDED
@@ -0,0 +1,307 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ import sympy as sp
4
+ from sympy.parsing.sympy_parser import parse_expr
5
+ from typing import Dict, List, Tuple, Set, Optional, Callable
6
+ from dataclasses import dataclass, field
7
+ from functools import reduce
8
+ from collections import defaultdict, deque
9
+
10
+ USE_GPU = torch.cuda.is_available()
11
+ DEVICE = torch.device("cuda" if USE_GPU else "cpu")
12
+ SOLVE_THRESHOLD = 0.001
13
+ LOG_SPACE_THRESHOLD = 1000.0
14
+
15
+ def safe_round(val, ndigits=8):
16
+ try:
17
+ return round(val, ndigits) if math.isfinite(val) else val
18
+ except: return val
19
+
20
+ def _c15(v):
21
+ try:
22
+ if not math.isfinite(v): return 1e15 if v > 0 else -1e15
23
+ return max(-1e15, min(1e15, float(v)))
24
+ except: return 0.0
25
+
26
+ def safe_log(x):
27
+ if not isinstance(x, torch.Tensor): x=torch.tensor(float(x),device=DEVICE,dtype=torch.float32)
28
+ return torch.log(torch.clamp(x,min=1e-7))
29
+
30
+ def safe_sqrt(x):
31
+ if not isinstance(x, torch.Tensor): x=torch.tensor(float(x),device=DEVICE,dtype=torch.float32)
32
+ return torch.sqrt(torch.clamp(x,min=0.0))
33
+
34
+ class IV:
35
+ __slots__ = ("lo", "hi")
36
+ def __init__(self, lo, hi): self.lo = float(lo); self.hi = float(hi)
37
+ def __add__(self, o): return IV(self.lo+o, self.hi+o) if isinstance(o,(int,float)) else IV(self.lo+o.lo, self.hi+o.hi)
38
+ __radd__ = __add__
39
+ def __sub__(self, o): return IV(self.lo-o, self.hi-o) if isinstance(o,(int,float)) else IV(self.lo-o.hi, self.hi-o.lo)
40
+ def __rsub__(self, o): return IV(o-self.hi, o-self.lo) if isinstance(o,(int,float)) else o.__sub__(self)
41
+ def __mul__(self, o):
42
+ if isinstance(o,(int,float)): a,b=self.lo*o,self.hi*o; return IV(min(a,b),max(a,b))
43
+ p=(self.lo*o.lo,self.lo*o.hi,self.hi*o.lo,self.hi*o.hi); return IV(min(p),max(p))
44
+ __rmul__ = __mul__
45
+ def __truediv__(self, o):
46
+ if isinstance(o,(int,float)):
47
+ if abs(o)<1e-15: return IV(-1e18,1e18)
48
+ a,b=self.lo/o,self.hi/o; return IV(min(a,b),max(a,b))
49
+ if o.lo<=0<=o.hi: return IV(-1e18,1e18)
50
+ return self*IV(1.0/o.hi,1.0/o.lo)
51
+ def __neg__(self): return IV(-self.hi,-self.lo)
52
+ def __pow__(self, n):
53
+ if isinstance(n,int):
54
+ if n==0: return IV(1.0,1.0)
55
+ if n%2==0:
56
+ if self.lo>=0: return IV(self.lo**n,self.hi**n)
57
+ if self.hi<=0: return IV(self.hi**n,self.lo**n)
58
+ return IV(0.0,max(abs(self.lo)**n,abs(self.hi)**n))
59
+ return IV(self.lo**n if self.lo>=0 else -((-self.lo)**n),
60
+ self.hi**n if self.hi>=0 else -((-self.hi)**n))
61
+ if self.lo<0: return IV(0.0,max(abs(self.lo)**n,self.hi**n))
62
+ return IV(self.lo**n,self.hi**n)
63
+ def contains_zero(self): return self.lo<=0.0<=self.hi
64
+ def width(self): return max(0.0,self.hi-self.lo)
65
+ def mid(self): return (self.lo+self.hi)*0.5
66
+
67
+ def compile_iv(expr, variables):
68
+ def _c(e):
69
+ if e.is_Number: v=float(e); return lambda box,_v=v: IV(_v,_v)
70
+ if e.is_Symbol: n=str(e); return lambda box,_n=n: box.get(_n,IV(-1e18,1e18))
71
+ if e.is_Add: fs=[_c(a) for a in e.args]; return lambda box,_fs=fs: reduce(lambda a,b:a+b,(_f(box) for _f in _fs))
72
+ if e.is_Mul: fs=[_c(a) for a in e.args]; return lambda box,_fs=fs: reduce(lambda a,b:a*b,(_f(box) for _f in _fs))
73
+ if e.is_Pow:
74
+ bc=_c(e.args[0]); ex=e.args[1]
75
+ if ex.is_Number: return lambda box,_bc=bc,_ex=float(ex): _bc(box)**_ex
76
+ exc=_c(ex); return lambda box,_bc=bc,_exc=exc: _bc(box)**_exc(box).mid()
77
+ return lambda box: IV(-1e18,1e18)
78
+ return _c(expr)
79
+
80
+ def _hc4(box, constraints):
81
+ cur = dict(box)
82
+ for mc in constraints:
83
+ if getattr(mc,'weight',1.0)==0.0: continue
84
+ if mc.kind=="or_eq":
85
+ valid=False
86
+ for bmc in mc.branches:
87
+ if bmc.fast_iv is None: valid=True; break
88
+ try:
89
+ if bmc.fast_iv(cur).contains_zero(): valid=True; break
90
+ except: valid=True; break
91
+ if not valid: return None
92
+ else:
93
+ if mc.fast_iv is None: continue
94
+ try:
95
+ riv=mc.fast_iv(cur)
96
+ if ((mc.kind=="equality" and not riv.contains_zero()) or
97
+ (mc.kind=="inequality" and ((mc.direction=="geq" and riv.hi<-1e-10) or
98
+ (mc.direction=="leq" and riv.lo>1e-10)))): return None
99
+ except: pass
100
+ return cur
101
+
102
+ @dataclass
103
+ class MathConstraint:
104
+ kind:str; expr_str:str; direction:str; weight:float=1.0
105
+ fast_iv:Optional[Callable]=field(default=None,repr=False)
106
+ torch_func:Optional[Callable]=field(default=None,repr=False)
107
+ syms_used:List[str]=field(default_factory=list)
108
+ parsed:Optional[sp.Expr]=field(default=None,repr=False)
109
+ scope:str="root"; branches:List['MathConstraint']=field(default_factory=list)
110
+ projections:Dict[str,List[Dict]]=field(default_factory=dict)
111
+
112
+ PROJECTION_CACHE = {}
113
+
114
+ def compile_mc(kind, expr_str, direction, variables, weight=1.0, scope="root", branches=None):
115
+ expr_str = expr_str.replace("^","**")
116
+ mc = MathConstraint(kind=kind, expr_str=expr_str, direction=direction, weight=weight, scope=scope)
117
+
118
+ if kind == "or_eq" and branches:
119
+ for b_str in branches:
120
+ b_mc = compile_mc("equality", b_str, "eq", variables, weight, scope)
121
+ mc.branches.append(b_mc)
122
+ mc.syms_used.extend(b_mc.syms_used)
123
+ mc.syms_used = list(dict.fromkeys(mc.syms_used))
124
+ def _or_iv(box, _mcs=mc.branches):
125
+ rivs = []
126
+ for b in _mcs:
127
+ if b.fast_iv:
128
+ try: rivs.append(b.fast_iv(box))
129
+ except: pass
130
+ if not rivs: return IV(-1e18, 1e18)
131
+ return IV(min(r.lo for r in rivs), max(r.hi for r in rivs))
132
+ mc.fast_iv = _or_iv
133
+ return mc
134
+
135
+ syms = {v: sp.Symbol(v) for v in variables}
136
+ try:
137
+ parsed = parse_expr(expr_str, local_dict=syms) if kind != "or_eq" else None
138
+ if parsed:
139
+ if getattr(parsed,'is_Equality',False) or getattr(parsed,'is_Relational',False):
140
+ parsed = parsed.lhs - parsed.rhs
141
+ for s in list(parsed.free_symbols):
142
+ if str(s) not in variables: parsed = parsed.subs(s, 1.0)
143
+ mc.parsed = parsed
144
+ mc.syms_used = [v for v in variables if sp.Symbol(v) in parsed.free_symbols]
145
+ mc.fast_iv = compile_iv(parsed, variables)
146
+
147
+ pt_map = {'sin':torch.sin, 'cos':torch.cos, 'tan':torch.tan, 'exp':torch.exp,
148
+ 'log':safe_log, 'sqrt':safe_sqrt, 'Abs':torch.abs, 'pi':math.pi, 'E':math.e}
149
+ t_func_raw = sp.lambdify([sp.Symbol(v) for v in mc.syms_used], parsed, modules=[pt_map, "math"])
150
+
151
+ def _t_wrapper(*args):
152
+ try:
153
+ val = t_func_raw(*args)
154
+ if not isinstance(val, torch.Tensor):
155
+ val = torch.tensor(float(val), device=DEVICE, dtype=torch.float32)
156
+ except:
157
+ val = torch.tensor(1e6, device=DEVICE, dtype=torch.float32)
158
+ return torch.nan_to_num(val, posinf=1e6, neginf=-1e6, nan=1e6)
159
+
160
+ mc.torch_func = _t_wrapper
161
+
162
+ if kind == "equality":
163
+ if expr_str not in PROJECTION_CACHE:
164
+ pm = {}
165
+ for sym in parsed.free_symbols:
166
+ v_str = str(sym)
167
+ try:
168
+ sols = sp.solve(parsed, sym)
169
+ pm[v_str] = []
170
+ for sol in sols:
171
+ fs = list(sol.free_symbols)
172
+ pm[v_str].append({"syms": [str(s) for s in fs], "func": sp.lambdify(fs, sol, modules="math")})
173
+ except: pass
174
+ PROJECTION_CACHE[expr_str] = pm
175
+ mc.projections = PROJECTION_CACHE.get(expr_str, {})
176
+ except: pass
177
+ return mc
178
+
179
+ @dataclass
180
+ class Problem:
181
+ pid:str; variables:List[str]; bounds:Dict[str,Tuple[float,float]]
182
+ compiled_constraints:List[MathConstraint]
183
+ int_vars:Set[str]=field(default_factory=set)
184
+ minimize_var:str=""
185
+ log_space_vars:Set[str]=field(default_factory=set)
186
+
187
+ def __post_init__(self):
188
+ self.var_idx = {v: i for i, v in enumerate(self.variables)}
189
+ self.adjacency_list = defaultdict(set)
190
+ for mc in self.compiled_constraints:
191
+ for v1 in mc.syms_used:
192
+ for v2 in mc.syms_used:
193
+ if v1 != v2: self.adjacency_list[v1].add(v2)
194
+
195
+ self.log_space_vars = set()
196
+ for v in self.variables:
197
+ if v in self.int_vars: continue
198
+ lo, hi = self.bounds.get(v, (0, 1))
199
+ if lo > 0 and hi > 0 and math.isfinite(lo) and math.isfinite(hi):
200
+ if hi / lo > LOG_SPACE_THRESHOLD:
201
+ self.log_space_vars.add(v)
202
+
203
+ def get_markov_blanket(self, pinned_vars: Set[str], depth: int=2) -> Set[str]:
204
+ if not pinned_vars: return set(self.variables)
205
+ visited = set(pinned_vars)
206
+ queue = deque([(v, 0) for v in pinned_vars])
207
+ while queue:
208
+ curr, d = queue.popleft()
209
+ if d < depth:
210
+ for neighbor in self.adjacency_list.get(curr, []):
211
+ if neighbor not in visited:
212
+ visited.add(neighbor)
213
+ queue.append((neighbor, d+1))
214
+ return visited
215
+
216
+ def tensor_energy(self, X: torch.Tensor, step_ratio: float=1.0, is_optimizing: bool=False) -> torch.Tensor:
217
+ is_batched = (X.dim() == 2)
218
+ batch_size = X.shape[0] if is_batched else 1
219
+ total = torch.zeros(batch_size, device=DEVICE, dtype=torch.float32)
220
+
221
+ for mc in self.compiled_constraints:
222
+ if getattr(mc, 'weight', 1.0) == 0.0: continue
223
+ eff_weight = float(mc.weight)
224
+ if step_ratio < 1.0 and any(f in mc.expr_str for f in ["sin", "cos", "exp"]):
225
+ eff_weight *= (0.1 + 0.9 * step_ratio)
226
+
227
+ if mc.kind == "or_eq":
228
+ b_vals = []
229
+ for bmc in mc.branches:
230
+ if bmc.torch_func:
231
+ args = [X[:, self.var_idx[v]] if is_batched else X[self.var_idx[v]] for v in bmc.syms_used]
232
+ b_vals.append(torch.abs(bmc.torch_func(*args)))
233
+ if b_vals: total += (torch.stack(b_vals, dim=0).min(dim=0)[0]**2) * eff_weight
234
+ else:
235
+ if mc.torch_func is None: continue
236
+ args = [X[:, self.var_idx[v]] if is_batched else X[self.var_idx[v]] for v in mc.syms_used]
237
+ val = mc.torch_func(*args)
238
+ if mc.kind == "equality": total += (val**2) * eff_weight
239
+ elif mc.direction == "geq": total += (torch.relu(-val)**2) * eff_weight
240
+ else: total += (torch.relu(val)**2) * eff_weight
241
+
242
+ for i, v in enumerate(self.variables):
243
+ lo, hi = _c15(self.bounds[v][0]), _c15(self.bounds[v][1])
244
+ col = X[:, i] if is_batched else X[i]
245
+ margin = (hi - lo) * 0.1 * (1.0 - step_ratio)
246
+ out_of_bounds = torch.relu(lo - margin - col) + torch.relu(col - (hi + margin))
247
+ total += (out_of_bounds**2) * 10.0
248
+
249
+ if is_optimizing and self.minimize_var and self.minimize_var in self.var_idx:
250
+ midx = self.var_idx[self.minimize_var]
251
+ lo, hi = _c15(self.bounds[self.minimize_var][0]), _c15(self.bounds[self.minimize_var][1])
252
+ rng = max(hi - lo, 1e-8)
253
+ col = X[:, midx] if is_batched else X[midx]
254
+ normalized = (col - lo) / rng
255
+ total += normalized * 0.05 * step_ratio
256
+
257
+ return total.view(batch_size, -1).sum(dim=1)
258
+
259
+ def scalar_energy(self, b: Dict[str, float]) -> float:
260
+ x_arr = [b.get(v, (_c15(self.bounds.get(v,(-1,1))[0]) + _c15(self.bounds.get(v,(-1,1))[1]))/2) for v in self.variables]
261
+ X_t = torch.tensor(x_arr, device=DEVICE, dtype=torch.float32).unsqueeze(0)
262
+ with torch.no_grad():
263
+ return float(self.tensor_energy(X_t, step_ratio=1.0, is_optimizing=False).item())
264
+
265
+ def algebraic_propagate_pinned(problem: Problem, pinned_vars: Dict[str, float], timeout_secs: float=2.0) -> Tuple[Dict[str, float], List[str]]:
266
+ resolved = dict(pinned_vars)
267
+ log = []
268
+ changed = True
269
+ max_passes = len(problem.variables) + 1
270
+ passes = 0
271
+ while changed and passes < max_passes:
272
+ changed = False
273
+ passes += 1
274
+ for mc in problem.compiled_constraints:
275
+ if mc.kind != "equality" or mc.parsed is None: continue
276
+ expr = mc.parsed
277
+ for v, val in resolved.items():
278
+ try: expr = expr.subs(sp.Symbol(v), sp.Float(val))
279
+ except: pass
280
+ try: expr = sp.simplify(expr)
281
+ except: pass
282
+ free = [str(s) for s in expr.free_symbols if str(s) in problem.variables and str(s) not in resolved]
283
+ if len(free) == 1:
284
+ target_sym = sp.Symbol(free[0])
285
+ try:
286
+ solutions = sp.solve(expr, target_sym)
287
+ if not solutions: continue
288
+ lo, hi = problem.bounds.get(free[0], (-1e9, 1e9))
289
+ mid = (lo + hi) / 2.0
290
+ valid_sols = []
291
+ for sol in solutions:
292
+ try:
293
+ val = complex(sol.evalf())
294
+ if abs(val.imag) < 1e-8:
295
+ rval = val.real
296
+ if lo - 1.0 <= rval <= hi + 1.0 and math.isfinite(rval):
297
+ valid_sols.append(rval)
298
+ except: pass
299
+ if valid_sols:
300
+ best = min(valid_sols, key=lambda v: abs(v - mid))
301
+ resolved[free[0]] = best
302
+ log.append(f" PROP [{free[0]}] = {best:.6g} <- [{mc.expr_str[:50]}]")
303
+ changed = True
304
+ except: pass
305
+ n_new = len(resolved) - len(pinned_vars)
306
+ if n_new > 0: log.insert(0, f"ALGEBRAIC PROPAGATOR: resolved {n_new} vars in {passes} pass(es)")
307
+ return resolved, log