everydaytok commited on
Commit
d9d9304
·
verified ·
1 Parent(s): 78bf05b

Create practicality_axioms.py

Browse files
Files changed (1) hide show
  1. practicality_axioms.py +191 -0
practicality_axioms.py ADDED
@@ -0,0 +1,191 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import random
3
+ import torch
4
+ from typing import Dict, List, Tuple, Set, Optional, Any
5
+ from dataclasses import dataclass, field
6
+ from practicality_core import Problem, DEVICE, SOLVE_THRESHOLD, _c15, IV
7
+
8
+ @dataclass
9
+ class Hypothesis:
10
+ hid: str
11
+ binding: Dict[str, float]
12
+ h_type: str
13
+ claim: str
14
+ derivation: List[str]
15
+ pinned_vars: Dict[str, float]
16
+ free_vars: List[str]
17
+ confidence: float
18
+ ce: float = float('inf')
19
+ is_fully_determined: bool = False
20
+
21
+ @dataclass
22
+ class L9Certificate:
23
+ residual_ce: float; dominant_vars: List[str]; dominant_exprs: List[str]; tension_class: str="unknown"
24
+
25
+ @dataclass
26
+ class Baton:
27
+ binding: Dict[str, float]; ce: float; ray_id: str=""; depth: int=0; l9: Optional[L9Certificate]=None
28
+
29
+ @dataclass
30
+ class AxiomRay:
31
+ sequence: List[str]; ray_id: str; ray_type: str="seed"
32
+ depth: int=0; parent_id: Optional[str]=None
33
+ baton: Optional[Baton]=None; ce_prior: float=float('inf')
34
+ @property
35
+ def name(self) -> str: return "→".join(a[:3] for a in self.sequence)
36
+ def extend(self, axiom, rtype, new_prior=float('inf')):
37
+ return AxiomRay(sequence=self.sequence+[axiom], ray_id=f"{self.ray_id}+{axiom[:3]}",
38
+ ray_type=rtype, depth=self.depth+1, parent_id=self.ray_id, ce_prior=new_prior)
39
+
40
+ class Axiom:
41
+ CONTINUOUS="CONTINUOUS"; DISCRETE="DISCRETE"; QUADRATIC="QUADRATIC"
42
+ BILINEAR="BILINEAR"; METRIC="METRIC"; SYMMETRIC="SYMMETRIC"
43
+ MUTABLE="MUTABLE"; EXTREMAL="EXTREMAL"; ENTROPY="ENTROPY"
44
+ ATOMIC="ATOMIC"; PARSIMONY="PARSIMONY"; DUALITY="DUALITY"
45
+
46
+ ALL_AXIOMS = [getattr(Axiom, a) for a in dir(Axiom) if not a.startswith("__")]
47
+
48
+ class UCB1BanditSeeder:
49
+ def __init__(self):
50
+ self.stats = {a: {"tries": 0, "reward": 0.0} for a in ALL_AXIOMS}
51
+ self.total_tries = 0
52
+
53
+ def record_reward(self, axiom, ce_before, ce_after):
54
+ reward = max(0.0, ce_before - ce_after)
55
+ self.stats[axiom]["tries"] += 1
56
+ self.stats[axiom]["reward"] += reward
57
+ self.total_tries += 1
58
+
59
+ def intelligent_branch(self, ray, out_baton, remaining_axioms, branch_width):
60
+ scored = []
61
+ for ax in remaining_axioms:
62
+ tries = self.stats[ax]["tries"]
63
+ if tries == 0: ucb = 999.0
64
+ else:
65
+ avg_reward = self.stats[ax]["reward"] / tries
66
+ exploration = math.sqrt(math.log(self.total_tries + 1) / tries)
67
+ ucb = avg_reward + 0.5 * exploration
68
+ scored.append((ucb, ax))
69
+ scored.sort(key=lambda x: x[0] * random.random(), reverse=True)
70
+ children = []
71
+ for _, axiom in scored[:branch_width]:
72
+ child = ray.extend(axiom, "branch", new_prior=out_baton.ce)
73
+ if child:
74
+ child.baton = out_baton
75
+ children.append(child)
76
+ return children
77
+
78
+ def _batched_deduce_and_evaluate(problem: Problem, hyps: List[Hypothesis], steps: int=80) -> List[Tuple[Dict, float, List[str], str]]:
79
+ if not hyps: return []
80
+
81
+ skip_indices = [i for i, h in enumerate(hyps) if getattr(h, 'is_fully_determined', False) or len(h.free_vars) == 0]
82
+ solve_indices = [i for i, h in enumerate(hyps) if i not in skip_indices]
83
+ results = [None] * len(hyps)
84
+
85
+ for i in skip_indices:
86
+ hyp = hyps[i]
87
+ try:
88
+ ce = problem.scalar_energy(hyp.binding)
89
+ dom_vars = list(hyp.pinned_vars.keys())[:3]
90
+ results[i] = (hyp.binding, ce, dom_vars, "algebraic")
91
+ except: results[i] = (hyp.binding, float('inf'), [], "algebraic_error")
92
+
93
+ if not solve_indices: return [r for r in results if r is not None]
94
+
95
+ adam_hyps = [hyps[i] for i in solve_indices]
96
+ V = len(problem.variables)
97
+
98
+ log_mask = []
99
+ log_lo, log_hi = [], []
100
+ for v in problem.variables:
101
+ lo, hi = _c15(problem.bounds[v][0]), _c15(problem.bounds[v][1])
102
+ if v in problem.log_space_vars and lo > 0:
103
+ log_mask.append(True)
104
+ log_lo.append(math.log10(max(lo, 1e-30)))
105
+ log_hi.append(math.log10(max(hi, 1e-30)))
106
+ else:
107
+ log_mask.append(False)
108
+ log_lo.append(lo)
109
+ log_hi.append(hi)
110
+
111
+ log_mask_t = torch.tensor(log_mask, device=DEVICE, dtype=torch.bool)
112
+ lo_param_t = torch.tensor(log_lo, device=DEVICE, dtype=torch.float32)
113
+ hi_param_t = torch.tensor(log_hi, device=DEVICE, dtype=torch.float32)
114
+
115
+ def _param_to_orig(P):
116
+ orig = P.clone()
117
+ if log_mask_t.any(): orig[:, log_mask_t] = torch.pow(10.0, P[:, log_mask_t])
118
+ return orig
119
+
120
+ def _orig_to_param(x_val, j):
121
+ if log_mask[j] and x_val > 0: return math.log10(max(x_val, 1e-30))
122
+ return x_val
123
+
124
+ x_data_p, mask_data, target_data_p = [], [], []
125
+ for hyp in adam_hyps:
126
+ xr, mr, tr = [], [], []
127
+ active_vars = problem.get_markov_blanket(set(hyp.pinned_vars.keys()), depth=2)
128
+ for j, v in enumerate(problem.variables):
129
+ lo, hi = _c15(problem.bounds[v][0]), _c15(problem.bounds[v][1])
130
+ if v in hyp.pinned_vars:
131
+ p_val = _orig_to_param(_c15(hyp.pinned_vars[v]), j)
132
+ xr.append(p_val); mr.append(0.0); tr.append(p_val)
133
+ else:
134
+ p_val = _orig_to_param(_c15(hyp.binding.get(v, (lo+hi)/2)), j)
135
+ is_active = (v in active_vars) or (len(hyp.pinned_vars) == 0)
136
+ xr.append(p_val); mr.append(1.0 if is_active else 0.0); tr.append(0.0)
137
+ x_data_p.append(xr); mask_data.append(mr); target_data_p.append(tr)
138
+
139
+ P = torch.tensor(x_data_p, device=DEVICE, dtype=torch.float32, requires_grad=True)
140
+ mask = torch.tensor(mask_data, device=DEVICE, dtype=torch.float32)
141
+ target = torch.tensor(target_data_p, device=DEVICE, dtype=torch.float32)
142
+
143
+ optimizer = torch.optim.Adam([P], lr=0.01)
144
+
145
+ for step in range(steps):
146
+ optimizer.zero_grad()
147
+ step_ratio = min(1.0, step / (steps * 0.8))
148
+ X_orig = _param_to_orig(P)
149
+ ce = problem.tensor_energy(X_orig, step_ratio, is_optimizing=True)
150
+ if isinstance(ce, torch.Tensor) and (ce < SOLVE_THRESHOLD).all() and step_ratio == 1.0: break
151
+ ce.sum().backward()
152
+ with torch.no_grad():
153
+ P.grad.clamp_(-10.0, 10.0)
154
+ P.grad *= mask
155
+ optimizer.step()
156
+ P.data = torch.where(mask == 0.0, target, P.data)
157
+ margin = 0.1 * (1.0 - step_ratio)
158
+ lo_m = lo_param_t - (hi_param_t - lo_param_t) * margin
159
+ hi_m = hi_param_t + (hi_param_t - lo_param_t) * margin
160
+ P.data = torch.clamp(P.data, lo_m.unsqueeze(0), hi_m.unsqueeze(0))
161
+
162
+ X_orig_final = _param_to_orig(P)
163
+ final_ce = problem.tensor_energy(X_orig_final, 1.0, is_optimizing=False).view(-1)
164
+ ce_vals = final_ce.detach().cpu().numpy()
165
+ X_vals = X_orig_final.detach().cpu().numpy()
166
+
167
+ for b_idx, orig_idx in enumerate(solve_indices):
168
+ final_b = {problem.variables[j]: float(X_vals[b_idx, j]) for j in range(V)}
169
+ results[orig_idx] = (final_b, float(ce_vals[b_idx]), [], "systemic")
170
+
171
+ return [r for r in results if r is not None]
172
+
173
+ def _mprt_sample(problem: Problem, N: int):
174
+ var_list = problem.variables; V = len(var_list)
175
+ lo_t = torch.tensor([_c15(problem.bounds.get(v, (-10.0, 10.0))[0]) for v in var_list], device=DEVICE, dtype=torch.float32)
176
+ hi_t = torch.tensor([_c15(problem.bounds.get(v, (-10.0, 10.0))[1]) for v in var_list], device=DEVICE, dtype=torch.float32)
177
+ for i in range(V):
178
+ if lo_t[i] >= hi_t[i]: m = (lo_t[i]+hi_t[i])/2; lo_t[i] = m - 1e-6; hi_t[i] = m + 1e-6
179
+
180
+ rand_base = torch.rand((N, V), device=DEVICE)
181
+ lsv_indices = [problem.var_idx[v] for v in problem.log_space_vars if v in problem.var_idx]
182
+ for idx in lsv_indices:
183
+ lo_v, hi_v = lo_t[idx].item(), hi_t[idx].item()
184
+ if lo_v > 0 and hi_v > lo_v:
185
+ log_lo, log_hi = math.log10(max(lo_v, 1e-30)), math.log10(max(hi_v, 1e-30))
186
+ rand_base[:, idx] = torch.pow(10.0, torch.rand(N, device=DEVICE)*(log_hi-log_lo)+log_lo) / hi_v
187
+
188
+ X = lo_t.unsqueeze(0) + (hi_t - lo_t).unsqueeze(0) * rand_base
189
+ ce_batch = problem.tensor_energy(X, 1.0, is_optimizing=False).view(-1)
190
+ best_idx = torch.argmin(ce_batch).item()
191
+ return {v: float(X[best_idx, i].item()) for i, v in enumerate(var_list)}, ce_batch[best_idx].item()