import math import torch import sympy as sp from sympy.parsing.sympy_parser import parse_expr from typing import Dict, List, Tuple, Set, Optional, Callable from dataclasses import dataclass, field from functools import reduce from collections import defaultdict, deque USE_GPU = torch.cuda.is_available() DEVICE = torch.device("cuda" if USE_GPU else "cpu") SOLVE_THRESHOLD = 0.001 LOG_SPACE_THRESHOLD = 1000.0 def safe_round(val, ndigits=8): try: return round(val, ndigits) if math.isfinite(val) else val except: return val def _c15(v): try: if not math.isfinite(v): return 1e15 if v > 0 else -1e15 return max(-1e15, min(1e15, float(v))) except: return 0.0 def safe_log(x): if not isinstance(x, torch.Tensor): x = torch.tensor(float(x), device=DEVICE, dtype=torch.float32) return torch.log(torch.clamp(x, min=1e-7)) def safe_sqrt(x): if not isinstance(x, torch.Tensor): x = torch.tensor(float(x), device=DEVICE, dtype=torch.float32) return torch.sqrt(torch.clamp(x, min=0.0)) class IV: __slots__ = ("lo", "hi") def __init__(self, lo, hi): self.lo = float(lo); self.hi = float(hi) def __add__(self, o): return IV(self.lo+o, self.hi+o) if isinstance(o, (int, float)) else IV(self.lo+o.lo, self.hi+o.hi) __radd__ = __add__ def __sub__(self, o): return IV(self.lo-o, self.hi-o) if isinstance(o, (int, float)) else IV(self.lo-o.hi, self.hi-o.lo) def __rsub__(self, o): return IV(o-self.hi, o-self.lo) if isinstance(o, (int, float)) else o.__sub__(self) def __mul__(self, o): if isinstance(o, (int, float)): a, b = self.lo*o, self.hi*o; return IV(min(a, b), max(a, b)) p = (self.lo*o.lo, self.lo*o.hi, self.hi*o.lo, self.hi*o.hi); return IV(min(p), max(p)) __rmul__ = __mul__ def __truediv__(self, o): if isinstance(o, (int, float)): if abs(o) < 1e-15: return IV(-1e18, 1e18) a, b = self.lo/o, self.hi/o; return IV(min(a, b), max(a, b)) if o.lo <= 0 <= o.hi: return IV(-1e18, 1e18) return self * IV(1.0/o.hi, 1.0/o.lo) def __neg__(self): return IV(-self.hi, -self.lo) def __pow__(self, n): if isinstance(n, int): if n == 0: return IV(1.0, 1.0) if n % 2 == 0: if self.lo >= 0: return IV(self.lo**n, self.hi**n) if self.hi <= 0: return IV(self.hi**n, self.lo**n) return IV(0.0, max(abs(self.lo)**n, abs(self.hi)**n)) return IV(self.lo**n if self.lo >= 0 else -((-self.lo)**n), self.hi**n if self.hi >= 0 else -((-self.hi)**n)) if self.lo < 0: return IV(0.0, max(abs(self.lo)**n, self.hi**n)) return IV(self.lo**n, self.hi**n) def contains_zero(self): return self.lo <= 0.0 <= self.hi def width(self): return max(0.0, self.hi-self.lo) def mid(self): return (self.lo+self.hi)*0.5 def compile_iv(expr, variables): def _c(e): if e.is_Number: v = float(e); return lambda box, _v=v: IV(_v, _v) if e.is_Symbol: n = str(e); return lambda box, _n=n: box.get(_n, IV(-1e18, 1e18)) if e.is_Add: fs = [_c(a) for a in e.args]; return lambda box, _fs=fs: reduce(lambda a, b: a+b, (_f(box) for _f in _fs)) if e.is_Mul: fs = [_c(a) for a in e.args]; return lambda box, _fs=fs: reduce(lambda a, b: a*b, (_f(box) for _f in _fs)) if e.is_Pow: bc = _c(e.args[0]); ex = e.args[1] if ex.is_Number: return lambda box, _bc=bc, _ex=float(ex): _bc(box)**_ex exc = _c(ex); return lambda box, _bc=bc, _exc=exc: _bc(box)**_exc(box).mid() return lambda box: IV(-1e18, 1e18) return _c(expr) def _hc4(box, constraints): cur = dict(box) for mc in constraints: if getattr(mc, 'weight', 1.0) == 0.0: continue if mc.kind == "or_eq": valid = False for bmc in mc.branches: if bmc.fast_iv is None: valid = True; break try: if bmc.fast_iv(cur).contains_zero(): valid = True; break except: valid = True; break if not valid: return None else: if mc.fast_iv is None: continue try: riv = mc.fast_iv(cur) if ((mc.kind == "equality" and not riv.contains_zero()) or (mc.kind == "inequality" and ((mc.direction == "geq" and riv.hi < -1e-10) or (mc.direction == "leq" and riv.lo > 1e-10)))): return None except: pass return cur @dataclass class MathConstraint: kind: str; expr_str: str; direction: str; weight: float = 1.0 fast_iv: Optional[Callable] = field(default=None, repr=False) torch_func: Optional[Callable] = field(default=None, repr=False) syms_used: List[str] = field(default_factory=list) parsed: Optional[sp.Expr] = field(default=None, repr=False) scope: str = "root"; branches: List['MathConstraint'] = field(default_factory=list) projections: Dict[str, List[Dict]] = field(default_factory=dict) PROJECTION_CACHE = {} def compile_mc(kind, expr_str, direction, variables, weight=1.0, scope="root", branches=None): expr_str = expr_str.replace("^", "**") mc = MathConstraint(kind=kind, expr_str=expr_str, direction=direction, weight=weight, scope=scope) if kind == "or_eq" and branches: for b_str in branches: b_mc = compile_mc("equality", b_str, "eq", variables, weight, scope) mc.branches.append(b_mc) mc.syms_used.extend(b_mc.syms_used) mc.syms_used = list(dict.fromkeys(mc.syms_used)) def _or_iv(box, _mcs=mc.branches): rivs = [] for b in _mcs: if b.fast_iv: try: rivs.append(b.fast_iv(box)) except: pass if not rivs: return IV(-1e18, 1e18) return IV(min(r.lo for r in rivs), max(r.hi for r in rivs)) mc.fast_iv = _or_iv return mc syms = {v: sp.Symbol(v) for v in variables} try: parsed = parse_expr(expr_str, local_dict=syms) if kind != "or_eq" else None if parsed: if getattr(parsed, 'is_Equality', False) or getattr(parsed, 'is_Relational', False): parsed = parsed.lhs - parsed.rhs for s in list(parsed.free_symbols): if str(s) not in variables: parsed = parsed.subs(s, 1.0) mc.parsed = parsed mc.syms_used = [v for v in variables if sp.Symbol(v) in parsed.free_symbols] mc.fast_iv = compile_iv(parsed, variables) pt_map = {'sin': torch.sin, 'cos': torch.cos, 'tan': torch.tan, 'exp': torch.exp, 'log': safe_log, 'sqrt': safe_sqrt, 'Abs': torch.abs, 'pi': math.pi, 'E': math.e} t_func_raw = sp.lambdify([sp.Symbol(v) for v in mc.syms_used], parsed, modules=[pt_map, "math"]) def _t_wrapper(*args): try: val = t_func_raw(*args) if not isinstance(val, torch.Tensor): val = torch.tensor(float(val), device=DEVICE, dtype=torch.float32) except: val = torch.tensor(1e6, device=DEVICE, dtype=torch.float32) return torch.nan_to_num(val, posinf=1e6, neginf=-1e6, nan=1e6) mc.torch_func = _t_wrapper if kind == "equality": if expr_str not in PROJECTION_CACHE: pm = {} for sym in parsed.free_symbols: v_str = str(sym) try: sols = sp.solve(parsed, sym) pm[v_str] = [] for sol in sols: fs = list(sol.free_symbols) pm[v_str].append({"syms": [str(s) for s in fs], "func": sp.lambdify(fs, sol, modules="math")}) except: pass PROJECTION_CACHE[expr_str] = pm mc.projections = PROJECTION_CACHE.get(expr_str, {}) except: pass return mc @dataclass class Problem: pid: str; variables: List[str]; bounds: Dict[str, Tuple[float, float]] compiled_constraints: List[MathConstraint] int_vars: Set[str] = field(default_factory=set) minimize_var: str = "" log_space_vars: Set[str] = field(default_factory=set) adjacency_list: Dict[str, Set[str]] = field(default_factory=lambda: defaultdict(set)) bilinear_pairs: List[Tuple[str, str]] = field(default_factory=list) ordering_pairs: List[Tuple[str, str]] = field(default_factory=list) monotone_targets: List[Tuple[str, str, float]] = field(default_factory=list) def __post_init__(self): self.var_idx = {v: i for i, v in enumerate(self.variables)} self.adjacency_list = defaultdict(set) for mc in self.compiled_constraints: for v1 in mc.syms_used: for v2 in mc.syms_used: if v1 != v2: self.adjacency_list[v1].add(v2) self.log_space_vars = set() for v in self.variables: if v in self.int_vars: continue lo, hi = self.bounds.get(v, (0, 1)) if lo > 0 and hi > 0 and math.isfinite(lo) and math.isfinite(hi): if hi / lo > LOG_SPACE_THRESHOLD: self.log_space_vars.add(v) for mc in self.compiled_constraints: if mc.parsed and mc.parsed.is_Add: for term in mc.parsed.args: if term.is_Mul: syms_in = [str(s) for s in term.free_symbols if str(s) in self.variables] if len(syms_in) == 2: va, vb = syms_in if (va, vb) not in self.bilinear_pairs and (vb, va) not in self.bilinear_pairs: self.bilinear_pairs.append((va, vb)) for mc in self.compiled_constraints: if mc.kind == "equality" and mc.parsed is not None: for va, vb in self.bilinear_pairs: sa, sb = sp.Symbol(va), sp.Symbol(vb) if sa in mc.parsed.free_symbols and sb in mc.parsed.free_symbols: try: k_expr = -(mc.parsed - sa*sb); k = float(k_expr.evalf()) if k > 0: target = (va, vb, k) if target not in self.monotone_targets: self.monotone_targets.append(target) except: pass def get_markov_blanket(self, pinned_vars: Set[str], depth: int = 2) -> Set[str]: if not pinned_vars: return set(self.variables) visited = set(pinned_vars) queue = deque([(v, 0) for v in pinned_vars]) while queue: curr, d = queue.popleft() if d < depth: for neighbor in self.adjacency_list.get(curr, []): if neighbor not in visited: visited.add(neighbor) queue.append((neighbor, d+1)) return visited def tensor_energy(self, X: torch.Tensor, step_ratio: float = 1.0, is_optimizing: bool = False) -> torch.Tensor: is_batched = (X.dim() == 2) batch_size = X.shape[0] if is_batched else 1 total = torch.zeros(batch_size, device=DEVICE, dtype=torch.float32) for mc in self.compiled_constraints: if getattr(mc, 'weight', 1.0) == 0.0: continue eff_weight = float(mc.weight) if step_ratio < 1.0 and any(f in mc.expr_str for f in ["sin", "cos", "exp"]): eff_weight *= (0.1 + 0.9 * step_ratio) if mc.kind == "or_eq": b_vals = [] for bmc in mc.branches: if bmc.torch_func: args = [X[:, self.var_idx[v]] if is_batched else X[self.var_idx[v]] for v in bmc.syms_used] b_vals.append(torch.abs(bmc.torch_func(*args))) if b_vals: total += (torch.stack(b_vals, dim=0).min(dim=0)[0]**2) * eff_weight else: if mc.torch_func is None: continue args = [X[:, self.var_idx[v]] if is_batched else X[self.var_idx[v]] for v in mc.syms_used] val = mc.torch_func(*args) if mc.kind == "equality": total += (val**2) * eff_weight elif mc.direction == "geq": total += (torch.relu(-val)**2) * eff_weight else: total += (torch.relu(val)**2) * eff_weight for i, v in enumerate(self.variables): lo, hi = _c15(self.bounds[v][0]), _c15(self.bounds[v][1]) col = X[:, i] if is_batched else X[i] margin = (hi - lo) * 0.1 * (1.0 - step_ratio) out_of_bounds = torch.relu(lo - margin - col) + torch.relu(col - (hi + margin)) total += (out_of_bounds**2) * 10.0 if is_optimizing and self.minimize_var and self.minimize_var in self.var_idx: midx = self.var_idx[self.minimize_var] lo, hi = _c15(self.bounds[self.minimize_var][0]), _c15(self.bounds[self.minimize_var][1]) rng = max(hi - lo, 1e-8) col = X[:, midx] if is_batched else X[midx] normalized = (col - lo) / rng total += normalized * 0.05 * step_ratio return total.view(batch_size, -1).sum(dim=1) def scalar_energy(self, b: Dict[str, float]) -> float: x_arr = [b.get(v, (_c15(self.bounds.get(v,(-1,1))[0]) + _c15(self.bounds.get(v,(-1,1))[1]))/2) for v in self.variables] X_t = torch.tensor(x_arr, device=DEVICE, dtype=torch.float32).unsqueeze(0) with torch.no_grad(): return float(self.tensor_energy(X_t, step_ratio=1.0, is_optimizing=False).item()) def algebraic_propagate_pinned(problem: Problem, pinned_vars: Dict[str, float], timeout_secs: float = 2.0) -> Tuple[Dict[str, float], List[str]]: resolved = dict(pinned_vars) log = [] changed = True max_passes = len(problem.variables) + 1 passes = 0 while changed and passes < max_passes: changed = False passes += 1 for mc in problem.compiled_constraints: if mc.kind != "equality" or mc.parsed is None: continue expr = mc.parsed for v, val in resolved.items(): try: expr = expr.subs(sp.Symbol(v), sp.Float(val)) except: pass try: expr = sp.simplify(expr) except: pass free = [str(s) for s in expr.free_symbols if str(s) in problem.variables and str(s) not in resolved] if len(free) == 1: target_sym = sp.Symbol(free[0]) try: solutions = sp.solve(expr, target_sym) if not solutions: continue lo, hi = problem.bounds.get(free[0], (-1e9, 1e9)) mid = (lo + hi) / 2.0 valid_sols = [] for sol in solutions: try: val = complex(sol.evalf()) if abs(val.imag) < 1e-8: rval = val.real if lo - 1.0 <= rval <= hi + 1.0 and math.isfinite(rval): valid_sols.append(rval) except: pass if valid_sols: best = min(valid_sols, key=lambda v: abs(v - mid)) resolved[free[0]] = best log.append(f" PROP [{free[0]}] = {best:.6g} <- [{mc.expr_str[:50]}]") changed = True except: pass n_new = len(resolved) - len(pinned_vars) if n_new > 0: log.insert(0, f"ALGEBRAIC PROPAGATOR: resolved {n_new} vars in {passes} pass(es)") return resolved, log def _global_hc4_tighten_bounds(problem: Problem) -> Dict[str, Tuple[float, float]]: gb = {v: IV(lo, hi) for v, (lo, hi) in problem.bounds.items()} c = _hc4(gb, problem.compiled_constraints) if c is None: return dict(problem.bounds) return {v: (max(problem.bounds[v][0], c[v].lo), min(problem.bounds[v][1], c[v].hi)) for v in problem.bounds if v in c}