Spaces:
Sleeping
Sleeping
Update practicality_axioms.py
Browse files- practicality_axioms.py +119 -5
practicality_axioms.py
CHANGED
|
@@ -325,6 +325,124 @@ def generate_templates(problem: core.Problem, sketch: Dict[str, Any]) -> List[Hy
|
|
| 325 |
)
|
| 326 |
return templates
|
| 327 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 328 |
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 329 |
# THE DECOUPLED SEQUENCE RAY TRACER
|
| 330 |
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
|
@@ -466,8 +584,4 @@ class UCB1BanditSeeder:
|
|
| 466 |
if child:
|
| 467 |
child.baton = out_baton
|
| 468 |
children.append(child)
|
| 469 |
-
return children
|
| 470 |
-
|
| 471 |
-
# Re-expose batch optimization so axioms engine remains self-contained
|
| 472 |
-
_batched_deduce_and_evaluate = axioms._batched_deduce_and_evaluate
|
| 473 |
-
_mprt_sample = axioms._mprt_sample
|
|
|
|
| 325 |
)
|
| 326 |
return templates
|
| 327 |
|
| 328 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 329 |
+
# MATH ENGINE METHODS (Explicitly implemented inside local scope)
|
| 330 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 331 |
+
def _batched_deduce_and_evaluate(problem: core.Problem, hyps: List[Hypothesis], steps: int=80) -> List[Tuple[Dict, float, List[str], str]]:
|
| 332 |
+
if not hyps: return []
|
| 333 |
+
|
| 334 |
+
skip_indices = [i for i, h in enumerate(hyps) if getattr(h, 'is_fully_determined', False) or len(h.free_vars) == 0]
|
| 335 |
+
solve_indices = [i for i, h in enumerate(hyps) if i not in skip_indices]
|
| 336 |
+
results = [None] * len(hyps)
|
| 337 |
+
|
| 338 |
+
for i in skip_indices:
|
| 339 |
+
hyp = hyps[i]
|
| 340 |
+
try:
|
| 341 |
+
ce = problem.scalar_energy(hyp.binding)
|
| 342 |
+
dom_vars = list(hyp.pinned_vars.keys())[:3]
|
| 343 |
+
results[i] = (hyp.binding, ce, dom_vars, "algebraic")
|
| 344 |
+
except: results[i] = (hyp.binding, float('inf'), [], "algebraic_error")
|
| 345 |
+
|
| 346 |
+
if not solve_indices: return [r for r in results if r is not None]
|
| 347 |
+
|
| 348 |
+
adam_hyps = [hyps[i] for i in solve_indices]
|
| 349 |
+
V = len(problem.variables)
|
| 350 |
+
|
| 351 |
+
log_mask = []
|
| 352 |
+
log_lo, log_hi = [], []
|
| 353 |
+
for v in problem.variables:
|
| 354 |
+
lo, hi = core._c15(problem.bounds[v][0]), core._c15(problem.bounds[v][1])
|
| 355 |
+
if v in problem.log_space_vars and lo > 0:
|
| 356 |
+
log_mask.append(True)
|
| 357 |
+
log_lo.append(math.log10(max(lo, 1e-30)))
|
| 358 |
+
log_hi.append(math.log10(max(hi, 1e-30)))
|
| 359 |
+
else:
|
| 360 |
+
log_mask.append(False)
|
| 361 |
+
log_lo.append(lo)
|
| 362 |
+
log_hi.append(hi)
|
| 363 |
+
|
| 364 |
+
log_mask_t = torch.tensor(log_mask, device=core.DEVICE, dtype=torch.bool)
|
| 365 |
+
lo_param_t = torch.tensor(log_lo, device=core.DEVICE, dtype=torch.float32)
|
| 366 |
+
hi_param_t = torch.tensor(log_hi, device=core.DEVICE, dtype=torch.float32)
|
| 367 |
+
|
| 368 |
+
def _param_to_orig(P):
|
| 369 |
+
orig = P.clone()
|
| 370 |
+
if log_mask_t.any(): orig[:, log_mask_t] = torch.pow(10.0, P[:, log_mask_t])
|
| 371 |
+
return orig
|
| 372 |
+
|
| 373 |
+
def _orig_to_param(x_val, j):
|
| 374 |
+
if log_mask[j] and x_val > 0: return math.log10(max(x_val, 1e-30))
|
| 375 |
+
return x_val
|
| 376 |
+
|
| 377 |
+
x_data_p, mask_data, target_data_p = [], [], []
|
| 378 |
+
for hyp in adam_hyps:
|
| 379 |
+
xr, mr, tr = [], [], []
|
| 380 |
+
active_vars = problem.get_markov_blanket(set(hyp.pinned_vars.keys()), depth=2)
|
| 381 |
+
for j, v in enumerate(problem.variables):
|
| 382 |
+
lo, hi = core._c15(problem.bounds[v][0]), core._c15(problem.bounds[v][1])
|
| 383 |
+
if v in hyp.pinned_vars:
|
| 384 |
+
p_val = _orig_to_param(core._c15(hyp.pinned_vars[v]), j)
|
| 385 |
+
xr.append(p_val); mr.append(0.0); tr.append(p_val)
|
| 386 |
+
else:
|
| 387 |
+
p_val = _orig_to_param(core._c15(hyp.binding.get(v, (lo+hi)/2)), j)
|
| 388 |
+
is_active = (v in active_vars) or (len(hyp.pinned_vars) == 0)
|
| 389 |
+
xr.append(p_val); mr.append(1.0 if is_active else 0.0); tr.append(0.0)
|
| 390 |
+
x_data_p.append(xr); mask_data.append(mr); target_data_p.append(tr)
|
| 391 |
+
|
| 392 |
+
P = torch.tensor(x_data_p, device=core.DEVICE, dtype=torch.float32, requires_grad=True)
|
| 393 |
+
mask = torch.tensor(mask_data, device=core.DEVICE, dtype=torch.float32)
|
| 394 |
+
target = torch.tensor(target_data_p, device=core.DEVICE, dtype=torch.float32)
|
| 395 |
+
|
| 396 |
+
optimizer = torch.optim.Adam([P], lr=0.01)
|
| 397 |
+
|
| 398 |
+
for step in range(steps):
|
| 399 |
+
optimizer.zero_grad()
|
| 400 |
+
step_ratio = min(1.0, step / (steps * 0.8))
|
| 401 |
+
X_orig = _param_to_orig(P)
|
| 402 |
+
ce = problem.tensor_energy(X_orig, step_ratio, is_optimizing=True)
|
| 403 |
+
if isinstance(ce, torch.Tensor) and (ce < core.SOLVE_THRESHOLD).all() and step_ratio == 1.0: break
|
| 404 |
+
ce.sum().backward()
|
| 405 |
+
with torch.no_grad():
|
| 406 |
+
P.grad.clamp_(-10.0, 10.0)
|
| 407 |
+
P.grad *= mask
|
| 408 |
+
optimizer.step()
|
| 409 |
+
P.data = torch.where(mask == 0.0, target, P.data)
|
| 410 |
+
margin = 0.1 * (1.0 - step_ratio)
|
| 411 |
+
lo_m = lo_param_t - (hi_param_t - lo_param_t) * margin
|
| 412 |
+
hi_m = hi_param_t + (hi_param_t - lo_param_t) * margin
|
| 413 |
+
P.data = torch.clamp(P.data, lo_m.unsqueeze(0), hi_m.unsqueeze(0))
|
| 414 |
+
|
| 415 |
+
X_orig_final = _param_to_orig(P)
|
| 416 |
+
final_ce = problem.tensor_energy(X_orig_final, 1.0, is_optimizing=False).view(-1)
|
| 417 |
+
ce_vals = final_ce.detach().cpu().numpy()
|
| 418 |
+
X_vals = X_orig_final.detach().cpu().numpy()
|
| 419 |
+
|
| 420 |
+
for b_idx, orig_idx in enumerate(solve_indices):
|
| 421 |
+
final_b = {problem.variables[j]: float(X_vals[b_idx, j]) for j in range(V)}
|
| 422 |
+
results[orig_idx] = (final_b, float(ce_vals[b_idx]), [], "systemic")
|
| 423 |
+
|
| 424 |
+
return [r for r in results if r is not None]
|
| 425 |
+
|
| 426 |
+
def _mprt_sample(problem: core.Problem, N: int):
|
| 427 |
+
var_list = problem.variables; V = len(var_list)
|
| 428 |
+
lo_t = torch.tensor([core._c15(problem.bounds.get(v, (-10.0, 10.0))[0]) for v in var_list], device=core.DEVICE, dtype=torch.float32)
|
| 429 |
+
hi_t = torch.tensor([core._c15(problem.bounds.get(v, (-10.0, 10.0))[1]) for v in var_list], device=core.DEVICE, dtype=torch.float32)
|
| 430 |
+
for i in range(V):
|
| 431 |
+
if lo_t[i] >= hi_t[i]: m = (lo_t[i]+hi_t[i])/2; lo_t[i] = m - 1e-6; hi_t[i] = m + 1e-6
|
| 432 |
+
|
| 433 |
+
rand_base = torch.rand((N, V), device=core.DEVICE)
|
| 434 |
+
lsv_indices = [problem.var_idx[v] for v in problem.log_space_vars if v in problem.var_idx]
|
| 435 |
+
for idx in lsv_indices:
|
| 436 |
+
lo_v, hi_v = lo_t[idx].item(), hi_t[idx].item()
|
| 437 |
+
if lo_v > 0 and hi_v > lo_v:
|
| 438 |
+
log_lo, log_hi = math.log10(max(lo_v, 1e-30)), math.log10(max(hi_v, 1e-30))
|
| 439 |
+
rand_base[:, idx] = torch.pow(10.0, torch.rand(N, device=core.DEVICE)*(log_hi-log_lo)+log_lo) / hi_v
|
| 440 |
+
|
| 441 |
+
X = lo_t.unsqueeze(0) + (hi_t - lo_t).unsqueeze(0) * rand_base
|
| 442 |
+
ce_batch = problem.tensor_energy(X, 1.0, is_optimizing=False).view(-1)
|
| 443 |
+
best_idx = torch.argmin(ce_batch).item()
|
| 444 |
+
return {v: float(X[best_idx, i].item()) for i, v in enumerate(var_list)}, ce_batch[best_idx].item()
|
| 445 |
+
|
| 446 |
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 447 |
# THE DECOUPLED SEQUENCE RAY TRACER
|
| 448 |
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
|
|
|
| 584 |
if child:
|
| 585 |
child.baton = out_baton
|
| 586 |
children.append(child)
|
| 587 |
+
return children
|
|
|
|
|
|
|
|
|
|
|
|