testing_space / practicality_core.py
everydaytok's picture
Update practicality_core.py
1121f09 verified
import math
import torch
import sympy as sp
from sympy.parsing.sympy_parser import parse_expr
from typing import Dict, List, Tuple, Set, Optional, Callable
from dataclasses import dataclass, field
from functools import reduce
from collections import defaultdict, deque
USE_GPU = torch.cuda.is_available()
DEVICE = torch.device("cuda" if USE_GPU else "cpu")
SOLVE_THRESHOLD = 0.001
LOG_SPACE_THRESHOLD = 1000.0
def safe_round(val, ndigits=8):
try:
return round(val, ndigits) if math.isfinite(val) else val
except: return val
def _c15(v):
try:
if not math.isfinite(v): return 1e15 if v > 0 else -1e15
return max(-1e15, min(1e15, float(v)))
except: return 0.0
def safe_log(x):
if not isinstance(x, torch.Tensor): x = torch.tensor(float(x), device=DEVICE, dtype=torch.float32)
return torch.log(torch.clamp(x, min=1e-7))
def safe_sqrt(x):
if not isinstance(x, torch.Tensor): x = torch.tensor(float(x), device=DEVICE, dtype=torch.float32)
return torch.sqrt(torch.clamp(x, min=0.0))
class IV:
__slots__ = ("lo", "hi")
def __init__(self, lo, hi): self.lo = float(lo); self.hi = float(hi)
def __add__(self, o): return IV(self.lo+o, self.hi+o) if isinstance(o, (int, float)) else IV(self.lo+o.lo, self.hi+o.hi)
__radd__ = __add__
def __sub__(self, o): return IV(self.lo-o, self.hi-o) if isinstance(o, (int, float)) else IV(self.lo-o.hi, self.hi-o.lo)
def __rsub__(self, o): return IV(o-self.hi, o-self.lo) if isinstance(o, (int, float)) else o.__sub__(self)
def __mul__(self, o):
if isinstance(o, (int, float)): a, b = self.lo*o, self.hi*o; return IV(min(a, b), max(a, b))
p = (self.lo*o.lo, self.lo*o.hi, self.hi*o.lo, self.hi*o.hi); return IV(min(p), max(p))
__rmul__ = __mul__
def __truediv__(self, o):
if isinstance(o, (int, float)):
if abs(o) < 1e-15: return IV(-1e18, 1e18)
a, b = self.lo/o, self.hi/o; return IV(min(a, b), max(a, b))
if o.lo <= 0 <= o.hi: return IV(-1e18, 1e18)
return self * IV(1.0/o.hi, 1.0/o.lo)
def __neg__(self): return IV(-self.hi, -self.lo)
def __pow__(self, n):
if isinstance(n, int):
if n == 0: return IV(1.0, 1.0)
if n % 2 == 0:
if self.lo >= 0: return IV(self.lo**n, self.hi**n)
if self.hi <= 0: return IV(self.hi**n, self.lo**n)
return IV(0.0, max(abs(self.lo)**n, abs(self.hi)**n))
return IV(self.lo**n if self.lo >= 0 else -((-self.lo)**n),
self.hi**n if self.hi >= 0 else -((-self.hi)**n))
if self.lo < 0: return IV(0.0, max(abs(self.lo)**n, self.hi**n))
return IV(self.lo**n, self.hi**n)
def contains_zero(self): return self.lo <= 0.0 <= self.hi
def width(self): return max(0.0, self.hi-self.lo)
def mid(self): return (self.lo+self.hi)*0.5
def compile_iv(expr, variables):
def _c(e):
if e.is_Number: v = float(e); return lambda box, _v=v: IV(_v, _v)
if e.is_Symbol: n = str(e); return lambda box, _n=n: box.get(_n, IV(-1e18, 1e18))
if e.is_Add: fs = [_c(a) for a in e.args]; return lambda box, _fs=fs: reduce(lambda a, b: a+b, (_f(box) for _f in _fs))
if e.is_Mul: fs = [_c(a) for a in e.args]; return lambda box, _fs=fs: reduce(lambda a, b: a*b, (_f(box) for _f in _fs))
if e.is_Pow:
bc = _c(e.args[0]); ex = e.args[1]
if ex.is_Number: return lambda box, _bc=bc, _ex=float(ex): _bc(box)**_ex
exc = _c(ex); return lambda box, _bc=bc, _exc=exc: _bc(box)**_exc(box).mid()
return lambda box: IV(-1e18, 1e18)
return _c(expr)
def _hc4(box, constraints):
cur = dict(box)
for mc in constraints:
if getattr(mc, 'weight', 1.0) == 0.0: continue
if mc.kind == "or_eq":
valid = False
for bmc in mc.branches:
if bmc.fast_iv is None: valid = True; break
try:
if bmc.fast_iv(cur).contains_zero(): valid = True; break
except: valid = True; break
if not valid: return None
else:
if mc.fast_iv is None: continue
try:
riv = mc.fast_iv(cur)
if ((mc.kind == "equality" and not riv.contains_zero()) or
(mc.kind == "inequality" and ((mc.direction == "geq" and riv.hi < -1e-10) or
(mc.direction == "leq" and riv.lo > 1e-10)))): return None
except: pass
return cur
@dataclass
class MathConstraint:
kind: str; expr_str: str; direction: str; weight: float = 1.0
fast_iv: Optional[Callable] = field(default=None, repr=False)
torch_func: Optional[Callable] = field(default=None, repr=False)
syms_used: List[str] = field(default_factory=list)
parsed: Optional[sp.Expr] = field(default=None, repr=False)
scope: str = "root"; branches: List['MathConstraint'] = field(default_factory=list)
projections: Dict[str, List[Dict]] = field(default_factory=dict)
PROJECTION_CACHE = {}
def compile_mc(kind, expr_str, direction, variables, weight=1.0, scope="root", branches=None):
expr_str = expr_str.replace("^", "**")
mc = MathConstraint(kind=kind, expr_str=expr_str, direction=direction, weight=weight, scope=scope)
if kind == "or_eq" and branches:
for b_str in branches:
b_mc = compile_mc("equality", b_str, "eq", variables, weight, scope)
mc.branches.append(b_mc)
mc.syms_used.extend(b_mc.syms_used)
mc.syms_used = list(dict.fromkeys(mc.syms_used))
def _or_iv(box, _mcs=mc.branches):
rivs = []
for b in _mcs:
if b.fast_iv:
try: rivs.append(b.fast_iv(box))
except: pass
if not rivs: return IV(-1e18, 1e18)
return IV(min(r.lo for r in rivs), max(r.hi for r in rivs))
mc.fast_iv = _or_iv
return mc
syms = {v: sp.Symbol(v) for v in variables}
try:
parsed = parse_expr(expr_str, local_dict=syms) if kind != "or_eq" else None
if parsed:
if getattr(parsed, 'is_Equality', False) or getattr(parsed, 'is_Relational', False):
parsed = parsed.lhs - parsed.rhs
for s in list(parsed.free_symbols):
if str(s) not in variables: parsed = parsed.subs(s, 1.0)
mc.parsed = parsed
mc.syms_used = [v for v in variables if sp.Symbol(v) in parsed.free_symbols]
mc.fast_iv = compile_iv(parsed, variables)
pt_map = {'sin': torch.sin, 'cos': torch.cos, 'tan': torch.tan, 'exp': torch.exp,
'log': safe_log, 'sqrt': safe_sqrt, 'Abs': torch.abs, 'pi': math.pi, 'E': math.e}
t_func_raw = sp.lambdify([sp.Symbol(v) for v in mc.syms_used], parsed, modules=[pt_map, "math"])
def _t_wrapper(*args):
try:
val = t_func_raw(*args)
if not isinstance(val, torch.Tensor):
val = torch.tensor(float(val), device=DEVICE, dtype=torch.float32)
except:
val = torch.tensor(1e6, device=DEVICE, dtype=torch.float32)
return torch.nan_to_num(val, posinf=1e6, neginf=-1e6, nan=1e6)
mc.torch_func = _t_wrapper
if kind == "equality":
if expr_str not in PROJECTION_CACHE:
pm = {}
for sym in parsed.free_symbols:
v_str = str(sym)
try:
sols = sp.solve(parsed, sym)
pm[v_str] = []
for sol in sols:
fs = list(sol.free_symbols)
pm[v_str].append({"syms": [str(s) for s in fs], "func": sp.lambdify(fs, sol, modules="math")})
except: pass
PROJECTION_CACHE[expr_str] = pm
mc.projections = PROJECTION_CACHE.get(expr_str, {})
except: pass
return mc
@dataclass
class Problem:
pid: str; variables: List[str]; bounds: Dict[str, Tuple[float, float]]
compiled_constraints: List[MathConstraint]
int_vars: Set[str] = field(default_factory=set)
minimize_var: str = ""
log_space_vars: Set[str] = field(default_factory=set)
adjacency_list: Dict[str, Set[str]] = field(default_factory=lambda: defaultdict(set))
bilinear_pairs: List[Tuple[str, str]] = field(default_factory=list)
ordering_pairs: List[Tuple[str, str]] = field(default_factory=list)
monotone_targets: List[Tuple[str, str, float]] = field(default_factory=list)
def __post_init__(self):
self.var_idx = {v: i for i, v in enumerate(self.variables)}
self.adjacency_list = defaultdict(set)
for mc in self.compiled_constraints:
for v1 in mc.syms_used:
for v2 in mc.syms_used:
if v1 != v2: self.adjacency_list[v1].add(v2)
self.log_space_vars = set()
for v in self.variables:
if v in self.int_vars: continue
lo, hi = self.bounds.get(v, (0, 1))
if lo > 0 and hi > 0 and math.isfinite(lo) and math.isfinite(hi):
if hi / lo > LOG_SPACE_THRESHOLD:
self.log_space_vars.add(v)
for mc in self.compiled_constraints:
if mc.parsed and mc.parsed.is_Add:
for term in mc.parsed.args:
if term.is_Mul:
syms_in = [str(s) for s in term.free_symbols if str(s) in self.variables]
if len(syms_in) == 2:
va, vb = syms_in
if (va, vb) not in self.bilinear_pairs and (vb, va) not in self.bilinear_pairs:
self.bilinear_pairs.append((va, vb))
for mc in self.compiled_constraints:
if mc.kind == "equality" and mc.parsed is not None:
for va, vb in self.bilinear_pairs:
sa, sb = sp.Symbol(va), sp.Symbol(vb)
if sa in mc.parsed.free_symbols and sb in mc.parsed.free_symbols:
try:
k_expr = -(mc.parsed - sa*sb); k = float(k_expr.evalf())
if k > 0:
target = (va, vb, k)
if target not in self.monotone_targets: self.monotone_targets.append(target)
except: pass
def get_markov_blanket(self, pinned_vars: Set[str], depth: int = 2) -> Set[str]:
if not pinned_vars: return set(self.variables)
visited = set(pinned_vars)
queue = deque([(v, 0) for v in pinned_vars])
while queue:
curr, d = queue.popleft()
if d < depth:
for neighbor in self.adjacency_list.get(curr, []):
if neighbor not in visited:
visited.add(neighbor)
queue.append((neighbor, d+1))
return visited
def tensor_energy(self, X: torch.Tensor, step_ratio: float = 1.0, is_optimizing: bool = False) -> torch.Tensor:
is_batched = (X.dim() == 2)
batch_size = X.shape[0] if is_batched else 1
total = torch.zeros(batch_size, device=DEVICE, dtype=torch.float32)
for mc in self.compiled_constraints:
if getattr(mc, 'weight', 1.0) == 0.0: continue
eff_weight = float(mc.weight)
if step_ratio < 1.0 and any(f in mc.expr_str for f in ["sin", "cos", "exp"]):
eff_weight *= (0.1 + 0.9 * step_ratio)
if mc.kind == "or_eq":
b_vals = []
for bmc in mc.branches:
if bmc.torch_func:
args = [X[:, self.var_idx[v]] if is_batched else X[self.var_idx[v]] for v in bmc.syms_used]
b_vals.append(torch.abs(bmc.torch_func(*args)))
if b_vals: total += (torch.stack(b_vals, dim=0).min(dim=0)[0]**2) * eff_weight
else:
if mc.torch_func is None: continue
args = [X[:, self.var_idx[v]] if is_batched else X[self.var_idx[v]] for v in mc.syms_used]
val = mc.torch_func(*args)
if mc.kind == "equality": total += (val**2) * eff_weight
elif mc.direction == "geq": total += (torch.relu(-val)**2) * eff_weight
else: total += (torch.relu(val)**2) * eff_weight
for i, v in enumerate(self.variables):
lo, hi = _c15(self.bounds[v][0]), _c15(self.bounds[v][1])
col = X[:, i] if is_batched else X[i]
margin = (hi - lo) * 0.1 * (1.0 - step_ratio)
out_of_bounds = torch.relu(lo - margin - col) + torch.relu(col - (hi + margin))
total += (out_of_bounds**2) * 10.0
if is_optimizing and self.minimize_var and self.minimize_var in self.var_idx:
midx = self.var_idx[self.minimize_var]
lo, hi = _c15(self.bounds[self.minimize_var][0]), _c15(self.bounds[self.minimize_var][1])
rng = max(hi - lo, 1e-8)
col = X[:, midx] if is_batched else X[midx]
normalized = (col - lo) / rng
total += normalized * 0.05 * step_ratio
return total.view(batch_size, -1).sum(dim=1)
def scalar_energy(self, b: Dict[str, float]) -> float:
x_arr = [b.get(v, (_c15(self.bounds.get(v,(-1,1))[0]) + _c15(self.bounds.get(v,(-1,1))[1]))/2) for v in self.variables]
X_t = torch.tensor(x_arr, device=DEVICE, dtype=torch.float32).unsqueeze(0)
with torch.no_grad():
return float(self.tensor_energy(X_t, step_ratio=1.0, is_optimizing=False).item())
def algebraic_propagate_pinned(problem: Problem, pinned_vars: Dict[str, float], timeout_secs: float = 2.0) -> Tuple[Dict[str, float], List[str]]:
resolved = dict(pinned_vars)
log = []
changed = True
max_passes = len(problem.variables) + 1
passes = 0
while changed and passes < max_passes:
changed = False
passes += 1
for mc in problem.compiled_constraints:
if mc.kind != "equality" or mc.parsed is None: continue
expr = mc.parsed
for v, val in resolved.items():
try: expr = expr.subs(sp.Symbol(v), sp.Float(val))
except: pass
try: expr = sp.simplify(expr)
except: pass
free = [str(s) for s in expr.free_symbols if str(s) in problem.variables and str(s) not in resolved]
if len(free) == 1:
target_sym = sp.Symbol(free[0])
try:
solutions = sp.solve(expr, target_sym)
if not solutions: continue
lo, hi = problem.bounds.get(free[0], (-1e9, 1e9))
mid = (lo + hi) / 2.0
valid_sols = []
for sol in solutions:
try:
val = complex(sol.evalf())
if abs(val.imag) < 1e-8:
rval = val.real
if lo - 1.0 <= rval <= hi + 1.0 and math.isfinite(rval):
valid_sols.append(rval)
except: pass
if valid_sols:
best = min(valid_sols, key=lambda v: abs(v - mid))
resolved[free[0]] = best
log.append(f" PROP [{free[0]}] = {best:.6g} <- [{mc.expr_str[:50]}]")
changed = True
except: pass
n_new = len(resolved) - len(pinned_vars)
if n_new > 0: log.insert(0, f"ALGEBRAIC PROPAGATOR: resolved {n_new} vars in {passes} pass(es)")
return resolved, log
def _global_hc4_tighten_bounds(problem: Problem) -> Dict[str, Tuple[float, float]]:
gb = {v: IV(lo, hi) for v, (lo, hi) in problem.bounds.items()}
c = _hc4(gb, problem.compiled_constraints)
if c is None: return dict(problem.bounds)
return {v: (max(problem.bounds[v][0], c[v].lo), min(problem.bounds[v][1], c[v].hi)) for v in problem.bounds if v in c}