everydaytok commited on
Commit
ae2bf63
Β·
verified Β·
1 Parent(s): 0a0ada9

Update practicality_axioms.py

Browse files
Files changed (1) hide show
  1. practicality_axioms.py +119 -5
practicality_axioms.py CHANGED
@@ -325,6 +325,124 @@ def generate_templates(problem: core.Problem, sketch: Dict[str, Any]) -> List[Hy
325
  )
326
  return templates
327
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
328
  # ══════════════════════════════════════════════════════════════════════
329
  # THE DECOUPLED SEQUENCE RAY TRACER
330
  # ══════════════════════════════════════════════════════════════════════
@@ -466,8 +584,4 @@ class UCB1BanditSeeder:
466
  if child:
467
  child.baton = out_baton
468
  children.append(child)
469
- return children
470
-
471
- # Re-expose batch optimization so axioms engine remains self-contained
472
- _batched_deduce_and_evaluate = axioms._batched_deduce_and_evaluate
473
- _mprt_sample = axioms._mprt_sample
 
325
  )
326
  return templates
327
 
328
+ # ══════════════════════════════════════════════════════════════════════
329
+ # MATH ENGINE METHODS (Explicitly implemented inside local scope)
330
+ # ══════════════════════════════════════════════════════════════════════
331
+ def _batched_deduce_and_evaluate(problem: core.Problem, hyps: List[Hypothesis], steps: int=80) -> List[Tuple[Dict, float, List[str], str]]:
332
+ if not hyps: return []
333
+
334
+ skip_indices = [i for i, h in enumerate(hyps) if getattr(h, 'is_fully_determined', False) or len(h.free_vars) == 0]
335
+ solve_indices = [i for i, h in enumerate(hyps) if i not in skip_indices]
336
+ results = [None] * len(hyps)
337
+
338
+ for i in skip_indices:
339
+ hyp = hyps[i]
340
+ try:
341
+ ce = problem.scalar_energy(hyp.binding)
342
+ dom_vars = list(hyp.pinned_vars.keys())[:3]
343
+ results[i] = (hyp.binding, ce, dom_vars, "algebraic")
344
+ except: results[i] = (hyp.binding, float('inf'), [], "algebraic_error")
345
+
346
+ if not solve_indices: return [r for r in results if r is not None]
347
+
348
+ adam_hyps = [hyps[i] for i in solve_indices]
349
+ V = len(problem.variables)
350
+
351
+ log_mask = []
352
+ log_lo, log_hi = [], []
353
+ for v in problem.variables:
354
+ lo, hi = core._c15(problem.bounds[v][0]), core._c15(problem.bounds[v][1])
355
+ if v in problem.log_space_vars and lo > 0:
356
+ log_mask.append(True)
357
+ log_lo.append(math.log10(max(lo, 1e-30)))
358
+ log_hi.append(math.log10(max(hi, 1e-30)))
359
+ else:
360
+ log_mask.append(False)
361
+ log_lo.append(lo)
362
+ log_hi.append(hi)
363
+
364
+ log_mask_t = torch.tensor(log_mask, device=core.DEVICE, dtype=torch.bool)
365
+ lo_param_t = torch.tensor(log_lo, device=core.DEVICE, dtype=torch.float32)
366
+ hi_param_t = torch.tensor(log_hi, device=core.DEVICE, dtype=torch.float32)
367
+
368
+ def _param_to_orig(P):
369
+ orig = P.clone()
370
+ if log_mask_t.any(): orig[:, log_mask_t] = torch.pow(10.0, P[:, log_mask_t])
371
+ return orig
372
+
373
+ def _orig_to_param(x_val, j):
374
+ if log_mask[j] and x_val > 0: return math.log10(max(x_val, 1e-30))
375
+ return x_val
376
+
377
+ x_data_p, mask_data, target_data_p = [], [], []
378
+ for hyp in adam_hyps:
379
+ xr, mr, tr = [], [], []
380
+ active_vars = problem.get_markov_blanket(set(hyp.pinned_vars.keys()), depth=2)
381
+ for j, v in enumerate(problem.variables):
382
+ lo, hi = core._c15(problem.bounds[v][0]), core._c15(problem.bounds[v][1])
383
+ if v in hyp.pinned_vars:
384
+ p_val = _orig_to_param(core._c15(hyp.pinned_vars[v]), j)
385
+ xr.append(p_val); mr.append(0.0); tr.append(p_val)
386
+ else:
387
+ p_val = _orig_to_param(core._c15(hyp.binding.get(v, (lo+hi)/2)), j)
388
+ is_active = (v in active_vars) or (len(hyp.pinned_vars) == 0)
389
+ xr.append(p_val); mr.append(1.0 if is_active else 0.0); tr.append(0.0)
390
+ x_data_p.append(xr); mask_data.append(mr); target_data_p.append(tr)
391
+
392
+ P = torch.tensor(x_data_p, device=core.DEVICE, dtype=torch.float32, requires_grad=True)
393
+ mask = torch.tensor(mask_data, device=core.DEVICE, dtype=torch.float32)
394
+ target = torch.tensor(target_data_p, device=core.DEVICE, dtype=torch.float32)
395
+
396
+ optimizer = torch.optim.Adam([P], lr=0.01)
397
+
398
+ for step in range(steps):
399
+ optimizer.zero_grad()
400
+ step_ratio = min(1.0, step / (steps * 0.8))
401
+ X_orig = _param_to_orig(P)
402
+ ce = problem.tensor_energy(X_orig, step_ratio, is_optimizing=True)
403
+ if isinstance(ce, torch.Tensor) and (ce < core.SOLVE_THRESHOLD).all() and step_ratio == 1.0: break
404
+ ce.sum().backward()
405
+ with torch.no_grad():
406
+ P.grad.clamp_(-10.0, 10.0)
407
+ P.grad *= mask
408
+ optimizer.step()
409
+ P.data = torch.where(mask == 0.0, target, P.data)
410
+ margin = 0.1 * (1.0 - step_ratio)
411
+ lo_m = lo_param_t - (hi_param_t - lo_param_t) * margin
412
+ hi_m = hi_param_t + (hi_param_t - lo_param_t) * margin
413
+ P.data = torch.clamp(P.data, lo_m.unsqueeze(0), hi_m.unsqueeze(0))
414
+
415
+ X_orig_final = _param_to_orig(P)
416
+ final_ce = problem.tensor_energy(X_orig_final, 1.0, is_optimizing=False).view(-1)
417
+ ce_vals = final_ce.detach().cpu().numpy()
418
+ X_vals = X_orig_final.detach().cpu().numpy()
419
+
420
+ for b_idx, orig_idx in enumerate(solve_indices):
421
+ final_b = {problem.variables[j]: float(X_vals[b_idx, j]) for j in range(V)}
422
+ results[orig_idx] = (final_b, float(ce_vals[b_idx]), [], "systemic")
423
+
424
+ return [r for r in results if r is not None]
425
+
426
+ def _mprt_sample(problem: core.Problem, N: int):
427
+ var_list = problem.variables; V = len(var_list)
428
+ lo_t = torch.tensor([core._c15(problem.bounds.get(v, (-10.0, 10.0))[0]) for v in var_list], device=core.DEVICE, dtype=torch.float32)
429
+ hi_t = torch.tensor([core._c15(problem.bounds.get(v, (-10.0, 10.0))[1]) for v in var_list], device=core.DEVICE, dtype=torch.float32)
430
+ for i in range(V):
431
+ if lo_t[i] >= hi_t[i]: m = (lo_t[i]+hi_t[i])/2; lo_t[i] = m - 1e-6; hi_t[i] = m + 1e-6
432
+
433
+ rand_base = torch.rand((N, V), device=core.DEVICE)
434
+ lsv_indices = [problem.var_idx[v] for v in problem.log_space_vars if v in problem.var_idx]
435
+ for idx in lsv_indices:
436
+ lo_v, hi_v = lo_t[idx].item(), hi_t[idx].item()
437
+ if lo_v > 0 and hi_v > lo_v:
438
+ log_lo, log_hi = math.log10(max(lo_v, 1e-30)), math.log10(max(hi_v, 1e-30))
439
+ rand_base[:, idx] = torch.pow(10.0, torch.rand(N, device=core.DEVICE)*(log_hi-log_lo)+log_lo) / hi_v
440
+
441
+ X = lo_t.unsqueeze(0) + (hi_t - lo_t).unsqueeze(0) * rand_base
442
+ ce_batch = problem.tensor_energy(X, 1.0, is_optimizing=False).view(-1)
443
+ best_idx = torch.argmin(ce_batch).item()
444
+ return {v: float(X[best_idx, i].item()) for i, v in enumerate(var_list)}, ce_batch[best_idx].item()
445
+
446
  # ══════════════════════════════════════════════════════════════════════
447
  # THE DECOUPLED SEQUENCE RAY TRACER
448
  # ══════════════════════════════════════════════════════════════════════
 
584
  if child:
585
  child.baton = out_baton
586
  children.append(child)
587
+ return children