Spaces:
Sleeping
Sleeping
| import math | |
| import torch | |
| import sympy as sp | |
| from sympy.parsing.sympy_parser import parse_expr | |
| from typing import Dict, List, Tuple, Set, Optional, Callable | |
| from dataclasses import dataclass, field | |
| from functools import reduce | |
| from collections import defaultdict, deque | |
| USE_GPU = torch.cuda.is_available() | |
| DEVICE = torch.device("cuda" if USE_GPU else "cpu") | |
| SOLVE_THRESHOLD = 0.001 | |
| LOG_SPACE_THRESHOLD = 1000.0 | |
| def safe_round(val, ndigits=8): | |
| try: | |
| return round(val, ndigits) if math.isfinite(val) else val | |
| except: return val | |
| def _c15(v): | |
| try: | |
| if not math.isfinite(v): return 1e15 if v > 0 else -1e15 | |
| return max(-1e15, min(1e15, float(v))) | |
| except: return 0.0 | |
| def safe_log(x): | |
| if not isinstance(x, torch.Tensor): x = torch.tensor(float(x), device=DEVICE, dtype=torch.float32) | |
| return torch.log(torch.clamp(x, min=1e-7)) | |
| def safe_sqrt(x): | |
| if not isinstance(x, torch.Tensor): x = torch.tensor(float(x), device=DEVICE, dtype=torch.float32) | |
| return torch.sqrt(torch.clamp(x, min=0.0)) | |
| class IV: | |
| __slots__ = ("lo", "hi") | |
| def __init__(self, lo, hi): self.lo = float(lo); self.hi = float(hi) | |
| def __add__(self, o): return IV(self.lo+o, self.hi+o) if isinstance(o, (int, float)) else IV(self.lo+o.lo, self.hi+o.hi) | |
| __radd__ = __add__ | |
| def __sub__(self, o): return IV(self.lo-o, self.hi-o) if isinstance(o, (int, float)) else IV(self.lo-o.hi, self.hi-o.lo) | |
| def __rsub__(self, o): return IV(o-self.hi, o-self.lo) if isinstance(o, (int, float)) else o.__sub__(self) | |
| def __mul__(self, o): | |
| if isinstance(o, (int, float)): a, b = self.lo*o, self.hi*o; return IV(min(a, b), max(a, b)) | |
| p = (self.lo*o.lo, self.lo*o.hi, self.hi*o.lo, self.hi*o.hi); return IV(min(p), max(p)) | |
| __rmul__ = __mul__ | |
| def __truediv__(self, o): | |
| if isinstance(o, (int, float)): | |
| if abs(o) < 1e-15: return IV(-1e18, 1e18) | |
| a, b = self.lo/o, self.hi/o; return IV(min(a, b), max(a, b)) | |
| if o.lo <= 0 <= o.hi: return IV(-1e18, 1e18) | |
| return self * IV(1.0/o.hi, 1.0/o.lo) | |
| def __neg__(self): return IV(-self.hi, -self.lo) | |
| def __pow__(self, n): | |
| if isinstance(n, int): | |
| if n == 0: return IV(1.0, 1.0) | |
| if n % 2 == 0: | |
| if self.lo >= 0: return IV(self.lo**n, self.hi**n) | |
| if self.hi <= 0: return IV(self.hi**n, self.lo**n) | |
| return IV(0.0, max(abs(self.lo)**n, abs(self.hi)**n)) | |
| return IV(self.lo**n if self.lo >= 0 else -((-self.lo)**n), | |
| self.hi**n if self.hi >= 0 else -((-self.hi)**n)) | |
| if self.lo < 0: return IV(0.0, max(abs(self.lo)**n, self.hi**n)) | |
| return IV(self.lo**n, self.hi**n) | |
| def contains_zero(self): return self.lo <= 0.0 <= self.hi | |
| def width(self): return max(0.0, self.hi-self.lo) | |
| def mid(self): return (self.lo+self.hi)*0.5 | |
| def compile_iv(expr, variables): | |
| def _c(e): | |
| if e.is_Number: v = float(e); return lambda box, _v=v: IV(_v, _v) | |
| if e.is_Symbol: n = str(e); return lambda box, _n=n: box.get(_n, IV(-1e18, 1e18)) | |
| if e.is_Add: fs = [_c(a) for a in e.args]; return lambda box, _fs=fs: reduce(lambda a, b: a+b, (_f(box) for _f in _fs)) | |
| if e.is_Mul: fs = [_c(a) for a in e.args]; return lambda box, _fs=fs: reduce(lambda a, b: a*b, (_f(box) for _f in _fs)) | |
| if e.is_Pow: | |
| bc = _c(e.args[0]); ex = e.args[1] | |
| if ex.is_Number: return lambda box, _bc=bc, _ex=float(ex): _bc(box)**_ex | |
| exc = _c(ex); return lambda box, _bc=bc, _exc=exc: _bc(box)**_exc(box).mid() | |
| return lambda box: IV(-1e18, 1e18) | |
| return _c(expr) | |
| def _hc4(box, constraints): | |
| cur = dict(box) | |
| for mc in constraints: | |
| if getattr(mc, 'weight', 1.0) == 0.0: continue | |
| if mc.kind == "or_eq": | |
| valid = False | |
| for bmc in mc.branches: | |
| if bmc.fast_iv is None: valid = True; break | |
| try: | |
| if bmc.fast_iv(cur).contains_zero(): valid = True; break | |
| except: valid = True; break | |
| if not valid: return None | |
| else: | |
| if mc.fast_iv is None: continue | |
| try: | |
| riv = mc.fast_iv(cur) | |
| if ((mc.kind == "equality" and not riv.contains_zero()) or | |
| (mc.kind == "inequality" and ((mc.direction == "geq" and riv.hi < -1e-10) or | |
| (mc.direction == "leq" and riv.lo > 1e-10)))): return None | |
| except: pass | |
| return cur | |
| class MathConstraint: | |
| kind: str; expr_str: str; direction: str; weight: float = 1.0 | |
| fast_iv: Optional[Callable] = field(default=None, repr=False) | |
| torch_func: Optional[Callable] = field(default=None, repr=False) | |
| syms_used: List[str] = field(default_factory=list) | |
| parsed: Optional[sp.Expr] = field(default=None, repr=False) | |
| scope: str = "root"; branches: List['MathConstraint'] = field(default_factory=list) | |
| projections: Dict[str, List[Dict]] = field(default_factory=dict) | |
| PROJECTION_CACHE = {} | |
| def compile_mc(kind, expr_str, direction, variables, weight=1.0, scope="root", branches=None): | |
| expr_str = expr_str.replace("^", "**") | |
| mc = MathConstraint(kind=kind, expr_str=expr_str, direction=direction, weight=weight, scope=scope) | |
| if kind == "or_eq" and branches: | |
| for b_str in branches: | |
| b_mc = compile_mc("equality", b_str, "eq", variables, weight, scope) | |
| mc.branches.append(b_mc) | |
| mc.syms_used.extend(b_mc.syms_used) | |
| mc.syms_used = list(dict.fromkeys(mc.syms_used)) | |
| def _or_iv(box, _mcs=mc.branches): | |
| rivs = [] | |
| for b in _mcs: | |
| if b.fast_iv: | |
| try: rivs.append(b.fast_iv(box)) | |
| except: pass | |
| if not rivs: return IV(-1e18, 1e18) | |
| return IV(min(r.lo for r in rivs), max(r.hi for r in rivs)) | |
| mc.fast_iv = _or_iv | |
| return mc | |
| syms = {v: sp.Symbol(v) for v in variables} | |
| try: | |
| parsed = parse_expr(expr_str, local_dict=syms) if kind != "or_eq" else None | |
| if parsed: | |
| if getattr(parsed, 'is_Equality', False) or getattr(parsed, 'is_Relational', False): | |
| parsed = parsed.lhs - parsed.rhs | |
| for s in list(parsed.free_symbols): | |
| if str(s) not in variables: parsed = parsed.subs(s, 1.0) | |
| mc.parsed = parsed | |
| mc.syms_used = [v for v in variables if sp.Symbol(v) in parsed.free_symbols] | |
| mc.fast_iv = compile_iv(parsed, variables) | |
| pt_map = {'sin': torch.sin, 'cos': torch.cos, 'tan': torch.tan, 'exp': torch.exp, | |
| 'log': safe_log, 'sqrt': safe_sqrt, 'Abs': torch.abs, 'pi': math.pi, 'E': math.e} | |
| t_func_raw = sp.lambdify([sp.Symbol(v) for v in mc.syms_used], parsed, modules=[pt_map, "math"]) | |
| def _t_wrapper(*args): | |
| try: | |
| val = t_func_raw(*args) | |
| if not isinstance(val, torch.Tensor): | |
| val = torch.tensor(float(val), device=DEVICE, dtype=torch.float32) | |
| except: | |
| val = torch.tensor(1e6, device=DEVICE, dtype=torch.float32) | |
| return torch.nan_to_num(val, posinf=1e6, neginf=-1e6, nan=1e6) | |
| mc.torch_func = _t_wrapper | |
| if kind == "equality": | |
| if expr_str not in PROJECTION_CACHE: | |
| pm = {} | |
| for sym in parsed.free_symbols: | |
| v_str = str(sym) | |
| try: | |
| sols = sp.solve(parsed, sym) | |
| pm[v_str] = [] | |
| for sol in sols: | |
| fs = list(sol.free_symbols) | |
| pm[v_str].append({"syms": [str(s) for s in fs], "func": sp.lambdify(fs, sol, modules="math")}) | |
| except: pass | |
| PROJECTION_CACHE[expr_str] = pm | |
| mc.projections = PROJECTION_CACHE.get(expr_str, {}) | |
| except: pass | |
| return mc | |
| class Problem: | |
| pid: str; variables: List[str]; bounds: Dict[str, Tuple[float, float]] | |
| compiled_constraints: List[MathConstraint] | |
| int_vars: Set[str] = field(default_factory=set) | |
| minimize_var: str = "" | |
| log_space_vars: Set[str] = field(default_factory=set) | |
| adjacency_list: Dict[str, Set[str]] = field(default_factory=lambda: defaultdict(set)) | |
| bilinear_pairs: List[Tuple[str, str]] = field(default_factory=list) | |
| ordering_pairs: List[Tuple[str, str]] = field(default_factory=list) | |
| monotone_targets: List[Tuple[str, str, float]] = field(default_factory=list) | |
| def __post_init__(self): | |
| self.var_idx = {v: i for i, v in enumerate(self.variables)} | |
| self.adjacency_list = defaultdict(set) | |
| for mc in self.compiled_constraints: | |
| for v1 in mc.syms_used: | |
| for v2 in mc.syms_used: | |
| if v1 != v2: self.adjacency_list[v1].add(v2) | |
| self.log_space_vars = set() | |
| for v in self.variables: | |
| if v in self.int_vars: continue | |
| lo, hi = self.bounds.get(v, (0, 1)) | |
| if lo > 0 and hi > 0 and math.isfinite(lo) and math.isfinite(hi): | |
| if hi / lo > LOG_SPACE_THRESHOLD: | |
| self.log_space_vars.add(v) | |
| for mc in self.compiled_constraints: | |
| if mc.parsed and mc.parsed.is_Add: | |
| for term in mc.parsed.args: | |
| if term.is_Mul: | |
| syms_in = [str(s) for s in term.free_symbols if str(s) in self.variables] | |
| if len(syms_in) == 2: | |
| va, vb = syms_in | |
| if (va, vb) not in self.bilinear_pairs and (vb, va) not in self.bilinear_pairs: | |
| self.bilinear_pairs.append((va, vb)) | |
| for mc in self.compiled_constraints: | |
| if mc.kind == "equality" and mc.parsed is not None: | |
| for va, vb in self.bilinear_pairs: | |
| sa, sb = sp.Symbol(va), sp.Symbol(vb) | |
| if sa in mc.parsed.free_symbols and sb in mc.parsed.free_symbols: | |
| try: | |
| k_expr = -(mc.parsed - sa*sb); k = float(k_expr.evalf()) | |
| if k > 0: | |
| target = (va, vb, k) | |
| if target not in self.monotone_targets: self.monotone_targets.append(target) | |
| except: pass | |
| def get_markov_blanket(self, pinned_vars: Set[str], depth: int = 2) -> Set[str]: | |
| if not pinned_vars: return set(self.variables) | |
| visited = set(pinned_vars) | |
| queue = deque([(v, 0) for v in pinned_vars]) | |
| while queue: | |
| curr, d = queue.popleft() | |
| if d < depth: | |
| for neighbor in self.adjacency_list.get(curr, []): | |
| if neighbor not in visited: | |
| visited.add(neighbor) | |
| queue.append((neighbor, d+1)) | |
| return visited | |
| def tensor_energy(self, X: torch.Tensor, step_ratio: float = 1.0, is_optimizing: bool = False) -> torch.Tensor: | |
| is_batched = (X.dim() == 2) | |
| batch_size = X.shape[0] if is_batched else 1 | |
| total = torch.zeros(batch_size, device=DEVICE, dtype=torch.float32) | |
| for mc in self.compiled_constraints: | |
| if getattr(mc, 'weight', 1.0) == 0.0: continue | |
| eff_weight = float(mc.weight) | |
| if step_ratio < 1.0 and any(f in mc.expr_str for f in ["sin", "cos", "exp"]): | |
| eff_weight *= (0.1 + 0.9 * step_ratio) | |
| if mc.kind == "or_eq": | |
| b_vals = [] | |
| for bmc in mc.branches: | |
| if bmc.torch_func: | |
| args = [X[:, self.var_idx[v]] if is_batched else X[self.var_idx[v]] for v in bmc.syms_used] | |
| b_vals.append(torch.abs(bmc.torch_func(*args))) | |
| if b_vals: total += (torch.stack(b_vals, dim=0).min(dim=0)[0]**2) * eff_weight | |
| else: | |
| if mc.torch_func is None: continue | |
| args = [X[:, self.var_idx[v]] if is_batched else X[self.var_idx[v]] for v in mc.syms_used] | |
| val = mc.torch_func(*args) | |
| if mc.kind == "equality": total += (val**2) * eff_weight | |
| elif mc.direction == "geq": total += (torch.relu(-val)**2) * eff_weight | |
| else: total += (torch.relu(val)**2) * eff_weight | |
| for i, v in enumerate(self.variables): | |
| lo, hi = _c15(self.bounds[v][0]), _c15(self.bounds[v][1]) | |
| col = X[:, i] if is_batched else X[i] | |
| margin = (hi - lo) * 0.1 * (1.0 - step_ratio) | |
| out_of_bounds = torch.relu(lo - margin - col) + torch.relu(col - (hi + margin)) | |
| total += (out_of_bounds**2) * 10.0 | |
| if is_optimizing and self.minimize_var and self.minimize_var in self.var_idx: | |
| midx = self.var_idx[self.minimize_var] | |
| lo, hi = _c15(self.bounds[self.minimize_var][0]), _c15(self.bounds[self.minimize_var][1]) | |
| rng = max(hi - lo, 1e-8) | |
| col = X[:, midx] if is_batched else X[midx] | |
| normalized = (col - lo) / rng | |
| total += normalized * 0.05 * step_ratio | |
| return total.view(batch_size, -1).sum(dim=1) | |
| def scalar_energy(self, b: Dict[str, float]) -> float: | |
| x_arr = [b.get(v, (_c15(self.bounds.get(v,(-1,1))[0]) + _c15(self.bounds.get(v,(-1,1))[1]))/2) for v in self.variables] | |
| X_t = torch.tensor(x_arr, device=DEVICE, dtype=torch.float32).unsqueeze(0) | |
| with torch.no_grad(): | |
| return float(self.tensor_energy(X_t, step_ratio=1.0, is_optimizing=False).item()) | |
| def algebraic_propagate_pinned(problem: Problem, pinned_vars: Dict[str, float], timeout_secs: float = 2.0) -> Tuple[Dict[str, float], List[str]]: | |
| resolved = dict(pinned_vars) | |
| log = [] | |
| changed = True | |
| max_passes = len(problem.variables) + 1 | |
| passes = 0 | |
| while changed and passes < max_passes: | |
| changed = False | |
| passes += 1 | |
| for mc in problem.compiled_constraints: | |
| if mc.kind != "equality" or mc.parsed is None: continue | |
| expr = mc.parsed | |
| for v, val in resolved.items(): | |
| try: expr = expr.subs(sp.Symbol(v), sp.Float(val)) | |
| except: pass | |
| try: expr = sp.simplify(expr) | |
| except: pass | |
| free = [str(s) for s in expr.free_symbols if str(s) in problem.variables and str(s) not in resolved] | |
| if len(free) == 1: | |
| target_sym = sp.Symbol(free[0]) | |
| try: | |
| solutions = sp.solve(expr, target_sym) | |
| if not solutions: continue | |
| lo, hi = problem.bounds.get(free[0], (-1e9, 1e9)) | |
| mid = (lo + hi) / 2.0 | |
| valid_sols = [] | |
| for sol in solutions: | |
| try: | |
| val = complex(sol.evalf()) | |
| if abs(val.imag) < 1e-8: | |
| rval = val.real | |
| if lo - 1.0 <= rval <= hi + 1.0 and math.isfinite(rval): | |
| valid_sols.append(rval) | |
| except: pass | |
| if valid_sols: | |
| best = min(valid_sols, key=lambda v: abs(v - mid)) | |
| resolved[free[0]] = best | |
| log.append(f" PROP [{free[0]}] = {best:.6g} <- [{mc.expr_str[:50]}]") | |
| changed = True | |
| except: pass | |
| n_new = len(resolved) - len(pinned_vars) | |
| if n_new > 0: log.insert(0, f"ALGEBRAIC PROPAGATOR: resolved {n_new} vars in {passes} pass(es)") | |
| return resolved, log | |
| def _global_hc4_tighten_bounds(problem: Problem) -> Dict[str, Tuple[float, float]]: | |
| gb = {v: IV(lo, hi) for v, (lo, hi) in problem.bounds.items()} | |
| c = _hc4(gb, problem.compiled_constraints) | |
| if c is None: return dict(problem.bounds) | |
| return {v: (max(problem.bounds[v][0], c[v].lo), min(problem.bounds[v][1], c[v].hi)) for v in problem.bounds if v in c} |