everydaytok commited on
Commit
722604c
Β·
verified Β·
1 Parent(s): 3740d38

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +543 -265
app.py CHANGED
@@ -1,5 +1,31 @@
1
- import numpy as np
2
- import time, collections, threading, json, random, math, os, pathlib
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  from fastapi import FastAPI
4
  from fastapi.responses import HTMLResponse, FileResponse
5
  from fastapi.middleware.cors import CORSMiddleware
@@ -7,299 +33,551 @@ from fastapi.middleware.cors import CORSMiddleware
7
  app = FastAPI()
8
  app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"])
9
 
10
- DIM = 8
11
- LR = 0.02 # Slower, highly stable learning rate
12
- EWC_LAMBDA = 0.8
13
- FISHER_DECAY = 0.95
14
- DAMPING = 0.6 # Increased damping for visual stability
15
- DT = 0.2
16
-
17
- # --- AUTO DATA GENERATOR ---
18
- def ensure_data():
19
- out = pathlib.Path('data')
20
- out.mkdir(exist_ok=True)
21
- if os.path.exists('data/train.json') and os.path.exists('data/test.json'):
22
- return
23
-
24
- print("Generating N-Dim Dataset...")
25
- rng = np.random.default_rng(42)
26
- data = []
27
- for _ in range(1000):
28
- a, b = rng.uniform(0.1, 0.9, DIM), rng.uniform(0.1, 0.9, DIM)
29
- data.append({'a': a.tolist(), 'b': b.tolist(), 'c': (0.7 * a + 0.3 * b).tolist(), 'type': 'blend'})
30
- data.append({'a': a.tolist(), 'b': b.tolist(), 'c': (0.5 + 0.4 * (a - b)).tolist(), 'type': 'diff'})
31
- data.append({'a': a.tolist(), 'b': b.tolist(), 'c': (0.5 * np.roll(a, 1) + 0.5 * np.roll(b, -1)).tolist(), 'type': 'route'})
32
-
33
- random.shuffle(data)
34
- split = int(len(data) * 0.9)
35
- with open('data/train.json', 'w') as f: json.dump(data[:split], f)
36
- with open('data/test.json', 'w') as f: json.dump(data[split:], f)
37
-
38
- ensure_data()
39
- # ---------------------------
40
-
41
- class PristineMesh:
42
- def __init__(self, n_dim=DIM):
43
- self.n_dim = n_dim
44
- self.nodes = {}
45
- self.springs = {}
46
- self.fisher = {}
47
- self.anchor_k = {}
48
- self._build_lattice()
49
-
50
- def _build_lattice(self):
51
- self.row_widths = [
52
- self.n_dim, self.n_dim+1, self.n_dim+2, self.n_dim+1,
53
- self.n_dim,
54
- self.n_dim+1, self.n_dim+2, self.n_dim+1, self.n_dim
55
- ]
56
- y_spacing = 0.866
57
-
58
- for r, w in enumerate(self.row_widths):
59
- y = -r * y_spacing
60
- x_offset = -(w - 1) / 2.0
61
-
62
- kind = 'H'
63
- if r == 0: kind = 'A'
64
- elif r == len(self.row_widths)-1: kind = 'B'
65
- elif r == len(self.row_widths)//2: kind = 'C'
66
-
67
- for c in range(w):
68
- nid = f"{kind}_r{r}_c{c}"
69
  self.nodes[nid] = {
70
- 'x': 0.0, 'vel': 0.0, 'kind': kind, 'row': r, 'col': c,
71
- 'pos': (x_offset + c, y), 'anchored': kind in ['A', 'B']
 
 
72
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73
 
74
- node_ids = list(self.nodes.keys())
75
- for i in range(len(node_ids)):
76
- for j in range(i + 1, len(node_ids)):
77
- n1, n2 = node_ids[i], node_ids[j]
78
- p1, p2 = self.nodes[n1]['pos'], self.nodes[n2]['pos']
79
- if math.hypot(p1[0]-p2[0], p1[1]-p2[1]) < 1.05:
80
- key = tuple(sorted([n1, n2]))
81
- self.springs[key] = random.uniform(0.1, 0.4)
82
- self.fisher[key] = 0.0
83
- self.anchor_k[key] = self.springs[key]
84
-
85
- self.c_nodes = sorted([n for n in self.nodes if self.nodes[n]['kind'] == 'C'], key=lambda k: self.nodes[k]['col'])
86
- self.a_nodes = sorted([n for n in self.nodes if self.nodes[n]['kind'] == 'A'], key=lambda k: self.nodes[k]['col'])
87
- self.b_nodes = sorted([n for n in self.nodes if self.nodes[n]['kind'] == 'B'], key=lambda k: self.nodes[k]['col'])
88
-
89
- def set_inputs(self, a_vec, b_vec):
90
- for i, nid in enumerate(self.a_nodes): self.nodes[nid]['x'] = a_vec[i]
91
- for i, nid in enumerate(self.b_nodes): self.nodes[nid]['x'] = b_vec[i]
92
- for nid, data in self.nodes.items():
93
- if data['kind'] not in ['A', 'B']:
94
- data['x'] = 0.0
95
- data['vel'] = 0.0
96
-
97
- def settle(self, steps=30):
98
- for _ in range(steps):
99
- forces = {n: 0.0 for n in self.nodes}
100
- for (u, v), K in self.springs.items():
101
  f = K * (self.nodes[v]['x'] - self.nodes[u]['x'])
102
- forces[u] += f
103
- forces[v] -= f
104
-
105
- for nid, data in self.nodes.items():
106
- if not data['anchored']:
107
- f = forces[nid] - (0.05 * data['x']) # Weak grounding
108
- data['vel'] = data['vel'] * DAMPING + f * DT
109
- data['x'] += data['vel'] * DT
110
-
111
- # MAGICAL FIX: Bounding prevents mathematical explosion
112
- data['x'] = max(-2.0, min(2.0, data['x']))
113
-
114
- def get_predictions(self):
115
- return [self.nodes[n]['x'] for n in self.c_nodes]
116
-
117
- def lms_update(self, target_vec, mode='train'):
118
- errors = {n: 0.0 for n in self.nodes}
119
- for i, nid in enumerate(self.c_nodes):
120
- errors[nid] = self.nodes[nid]['x'] - target_vec[i]
121
-
122
- # Deep Error Diffusion
123
- for _ in range(5):
124
- next_err = dict(errors)
125
- for (u, v), K in self.springs.items():
126
- weight = min(abs(K) * 0.1, 0.4)
127
- next_err[u] += weight * errors[v]
128
- next_err[v] += weight * errors[u]
129
- for n in errors: next_err[n] *= 0.85 # Decay to prevent error explosion
130
- errors = next_err
131
-
132
- for key in self.springs:
133
- u, v = key
134
- # True gradient formulation
135
- err_gradient = (errors[u] - errors[v]) * (self.nodes[u]['x'] - self.nodes[v]['x'])
136
-
137
- if mode == 'train':
138
- # FIXED: MUST BE += TO MINIMIZE ERROR!
139
- self.springs[key] += LR * err_gradient
140
- self.fisher[key] = FISHER_DECAY * self.fisher[key] + (1 - FISHER_DECAY) * (err_gradient ** 2)
141
- elif mode == 'infer':
142
- penalty = EWC_LAMBDA * self.fisher[key] * (self.springs[key] - self.anchor_k[key])
143
- self.springs[key] += (LR * 0.5 * err_gradient) - penalty
144
-
145
- # Stable K bounds
146
- self.springs[key] = max(-1.0, min(3.0, self.springs[key]))
147
-
148
- def save_anchors(self):
149
- self.anchor_k = dict(self.springs)
 
 
150
 
151
  class Engine:
152
  def __init__(self):
153
- self.mesh = PristineMesh()
154
- self.mode = 'idle'
155
- self.running = False
156
- self.queue = collections.deque()
157
- self.logs = []
158
- self.iter = 0
159
- self.train_data = []
160
- self.test_data = []
161
- self.error_hist = []
162
- self.current_err = 0.0
163
- self.current_type = 'β€”'
164
- self.test_results = []
165
-
166
- def add_log(self, msg):
167
- self.logs.insert(0, f"[{self.iter:05d}] {msg}")
168
- if len(self.logs) > 40: self.logs.pop()
169
-
170
- def run_step(self):
171
- if not self.queue:
172
- self.running = False
173
- self.add_log("Queue empty. Standing by.")
174
- return
175
-
176
- sample = self.queue.popleft()
177
- self.current_type = sample['type']
178
-
179
- self.mesh.set_inputs(sample['a'], sample['b'])
180
- self.mesh.settle(steps=25)
181
- preds = self.mesh.get_predictions()
182
-
183
- if sample['type'] != 'manual':
184
- # Clean Mean Absolute Error
185
- err = float(np.mean(np.abs(np.array(preds) - np.array(sample['c']))))
186
- if math.isnan(err): err = 1.0 # Safety catch
187
-
188
- self.current_err = err
189
- self.error_hist.append(err)
190
- if len(self.error_hist) > 100: self.error_hist.pop(0)
191
-
192
- if self.mode == 'infer':
193
- self.test_results.append({'type': self.current_type, 'err': err})
194
-
195
- # Apply correct LMS + EWC
196
- self.mesh.lms_update(sample['c'], mode=self.mode)
197
  else:
198
- self.current_err = 0.0
199
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
200
  self.iter += 1
201
- if self.iter % 5 == 0 or sample['type'] == 'manual':
202
- self.add_log(f"[{self.current_type}] err: {self.current_err:.4f}")
203
 
204
- def train_offline(self, epochs):
 
 
 
 
205
  self.running = False
206
- self.mode = 'train'
207
- self.add_log(f"⚑ Offline Training: {epochs} epochs...")
208
- for ep in range(epochs):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
209
  random.shuffle(self.train_data)
210
- errs = []
211
- for sample in self.train_data:
212
- self.mesh.set_inputs(sample['a'], sample['b'])
213
- self.mesh.settle(20)
214
- self.mesh.lms_update(sample['c'], mode='train')
215
- preds = self.mesh.get_predictions()
216
- e = np.mean(np.abs(np.array(preds) - np.array(sample['c'])))
217
- if not math.isnan(e): errs.append(e)
218
-
219
- avg_e = np.mean(errs) if errs else 0.0
220
- self.add_log(f"Ep {ep+1} | Avg Err: {avg_e:.4f}")
221
- print(f"Epoch {ep+1} | Avg Err: {avg_e:.4f}")
222
-
223
- self.mesh.save_anchors()
224
- self.add_log("βœ“ Training Complete. EWC Anchors saved.")
225
  self.mode = 'idle'
226
 
227
- def get_accuracy_summary(self):
228
- acc = {}
229
- for r in self.test_results:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
230
  t = r['type']
231
- if t not in acc: acc[t] = {'n': 0, 'sum_e': 0.0}
232
- acc[t]['n'] += 1
233
- acc[t]['sum_e'] += r['err']
234
- return {t: {'n': v['n'], 'avg_err': round(v['sum_e']/v['n'], 4)} for t, v in acc.items()}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
235
 
236
  engine = Engine()
237
- try:
238
- with open('data/train.json') as f: engine.train_data = json.load(f)
239
- with open('data/test.json') as f: engine.test_data = json.load(f)
240
- engine.add_log("Data loaded successfully.")
241
- except Exception as e:
242
- engine.add_log(f"Error loading data: {str(e)}")
243
-
244
- def loop():
245
  while True:
246
- if engine.running: engine.run_step()
247
- time.sleep(0.06)
248
- threading.Thread(target=loop, daemon=True).start()
 
 
249
 
250
  @app.get("/", response_class=HTMLResponse)
251
  async def ui(): return FileResponse("index.html")
252
 
253
  @app.get("/state")
254
- async def state():
255
- return {
256
- 'nodes': engine.mesh.nodes,
257
- 'springs': {f"{u}|{v}": k for (u, v), k in engine.mesh.springs.items()},
258
- 'error': engine.current_err,
259
- 'hist': engine.error_hist,
260
- 'mode': engine.mode,
261
- 'running': engine.running,
262
- 'logs': engine.logs,
263
- 'current_type': engine.current_type,
264
- 'queue_size': len(engine.queue),
265
- 'type_acc': engine.get_accuracy_summary(),
266
- 'dim': DIM
267
- }
268
-
269
- @app.post("/train")
270
- async def train(data: dict):
271
- ep = int(data.get('epochs', 5))
272
  threading.Thread(target=engine.train_offline, args=(ep,), daemon=True).start()
273
  return {"ok": True}
274
 
 
 
 
 
275
  @app.post("/infer")
276
- async def infer(data: dict):
277
- n = int(data.get('n', 200))
278
- engine.mode = 'infer'
279
- engine.test_results = []
280
- engine.queue.clear()
281
- engine.queue.extend(engine.test_data[:n])
282
- engine.running = True
283
- return {"ok": True}
284
 
285
- @app.post("/manual")
286
- async def manual(data: dict):
287
- try:
288
- a_vec = [float(x.strip()) for x in data.get('a', '').split(',')]
289
- b_vec = [float(x.strip()) for x in data.get('b', '').split(',')]
290
- if len(a_vec) != DIM or len(b_vec) != DIM:
291
- return {"ok": False, "error": f"Vectors must be exactly length {DIM}"}
292
-
293
- engine.mode = 'manual'
294
- engine.queue.clear()
295
- engine.queue.append({'a': a_vec, 'b': b_vec, 'c': [0]*DIM, 'type': 'manual'})
296
- engine.running = True
297
- return {"ok": True}
298
- except Exception as e:
299
- return {"ok": False, "error": str(e)}
300
 
301
  @app.post("/halt")
302
- async def halt(): engine.running = False; return {"ok": True}
 
 
 
 
 
 
303
 
304
  if __name__ == "__main__":
305
  import uvicorn
 
1
+ """
2
+ main.py v5 β€” Scalar Triangulated Hourglass Mesh
3
+
4
+ TOPOLOGY (for n input dimensions, 9 rows):
5
+
6
+ A0 … An-1 row 0 width n ← anchored inputs
7
+ Β· Β· Β· Β· row 1 width n+1
8
+ Β· Β· Β· Β· Β· row 2 width n+2 ← widest upper bulge
9
+ Β· Β· Β· Β· row 3 width n+1
10
+ C0 … Cn-1 row 4 width n ← output waist
11
+ Β· Β· Β· Β· row 5 width n+1
12
+ Β· Β· Β· Β· Β· row 6 width n+2 ← widest lower bulge
13
+ Β· Β· Β· Β· row 7 width n+1
14
+ B0 … Bn-1 row 8 width n ← anchored inputs
15
+
16
+ Between any two adjacent rows the width changes by exactly Β±1,
17
+ producing a perfectly triangulated grid (no irregular fans).
18
+
19
+ LEARNING (LMS):
20
+ Training : C anchored β†’ settle hidden β†’ backprop error β†’ update inter-row K
21
+ Inference: C free β†’ settle to equilibrium β†’ optionally update K via EWC
22
+
23
+ INFERENCE K-UPDATE:
24
+ On each new problem the mesh adapts its springs in real-time with a smaller
25
+ learning rate, protected by the EWC Fisher diagonal so past knowledge is not lost.
26
+ """
27
+
28
+ import numpy as np, time, collections, threading, json, random
29
  from fastapi import FastAPI
30
  from fastapi.responses import HTMLResponse, FileResponse
31
  from fastapi.middleware.cors import CORSMiddleware
 
33
  app = FastAPI()
34
  app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"])
35
 
36
+ # ── CONSTANTS ─────────────────────────────────────────────────────────────────
37
+ DT = 0.08
38
+ DAMP = 0.62
39
+ GND = 0.012 # soft restore toward 0.5 (keeps display bounded)
40
+ SETTLE = 80 # physics steps per CHL/display phase
41
+ CONV = 0.025 # |error| < this β†’ converged
42
+ MAXS = 500 # hard cap per sample
43
+ LR = 0.018 # training learning rate
44
+ LR_I = 0.004 # inference K update rate (smaller β†’ conservative)
45
+ KCLIP = 12.0
46
+ EWC_L = 0.8 # EWC penalty strength
47
+ FD = 0.95 # Fisher EMA decay
48
+ MICRO = 6 # display physics steps per server tick
49
+
50
+
51
+ # ── MESH ──────────────────────────────────────────────────────────────────────
52
+
53
+ def _widths(n):
54
+ return [n, n+1, n+2, n+1, n, n+1, n+2, n+1, n]
55
+
56
+ def _xpos(w):
57
+ """Half-integer node positions for a row of width w.
58
+ Produces x ∈ {-(w-1)/2, ..., (w-1)/2} with step 1.
59
+ Adjacent rows interleave perfectly β†’ clean triangulation."""
60
+ return [(2*i - (w-1)) / 2.0 for i in range(w)]
61
+
62
+
63
+ class Mesh:
64
+ def __init__(self, n=1):
65
+ self.n = n
66
+ self._build()
67
+
68
+ # ── TOPOLOGY ──────────────────────────────────────────────────────────────
69
+
70
+ def _build(self):
71
+ n = self.n
72
+ W = _widths(n)
73
+ NR = len(W) # 9
74
+ NK = 4 # neck row = C nodes
75
+
76
+ self.W = W; self.NR = NR; self.NK = NK
77
+
78
+ self.nodes = {} # id β†’ attrs
79
+ self.layers = [] # list of rows; each row = list of node ids
80
+
81
+ for ri, w in enumerate(W):
82
+ kind = 'A' if ri==0 else 'B' if ri==NR-1 else 'C' if ri==NK else 'H'
83
+ xs = _xpos(w)
84
+ y = 1.0 - 2.0*ri/(NR-1)
85
+ row = []
86
+ for ci in range(w):
87
+ nid = f"{ri}_{ci}"
 
 
 
 
 
 
 
88
  self.nodes[nid] = {
89
+ 'x':0.5, 'vel':0.0,
90
+ 'anchored': kind in ('A','B'),
91
+ 'row':ri, 'col':ci, 'kind':kind,
92
+ 'px':float(xs[ci]), 'py':float(y),
93
  }
94
+ row.append(nid)
95
+ self.layers.append(row)
96
+
97
+ self.A = self.layers[0]
98
+ self.C = self.layers[NK]
99
+ self.B = self.layers[-1]
100
+
101
+ # Springs ─────────────────────────────────────────────────────────────
102
+ self.K = {} # (u,v) β†’ float (u<v lex)
103
+ self.adj = {nid:[] for nid in self.nodes}
104
+ self.vert = set() # inter-row spring keys (learned)
105
+ self.horiz = set() # same-row spring keys (lateral flow)
106
+
107
+ for ri in range(NR):
108
+ row = self.layers[ri]
109
+ # Horizontal
110
+ for ci in range(len(row)-1):
111
+ k = self._add(row[ci], row[ci+1])
112
+ self.horiz.add(k)
113
+ # Inter-row (width changes by exactly Β±1)
114
+ if ri < NR-1:
115
+ up, dn = self.layers[ri], self.layers[ri+1]
116
+ wu, wd = len(up), len(dn)
117
+ if wd == wu+1: # expanding: up[i]β†’dn[i], up[i]β†’dn[i+1]
118
+ for i in range(wu):
119
+ self.vert.add(self._add(up[i], dn[i]))
120
+ self.vert.add(self._add(up[i], dn[i+1]))
121
+ elif wd == wu-1: # contracting: up[i]β†’dn[i], up[i+1]β†’dn[i]
122
+ for i in range(wd):
123
+ self.vert.add(self._add(up[i], dn[i]))
124
+ self.vert.add(self._add(up[i+1], dn[i]))
125
+ else: # same width (unused in our topology)
126
+ for i in range(min(wu,wd)):
127
+ self.vert.add(self._add(up[i], dn[i]))
128
+
129
+ # Init springs to small positive values
130
+ for k in self.K: self.K[k] = random.uniform(0.20, 0.50)
131
+
132
+ # EWC state
133
+ self.fisher = {k: 0.0 for k in self.K}
134
+ self.K0 = dict(self.K) # anchor for EWC
135
+
136
+ # Triangles (precomputed for visualisation)
137
+ self.tris = self._find_tris()
138
+
139
+ def _ek(self, a, b): return (a,b) if a<b else (b,a)
140
+
141
+ def _add(self, a, b):
142
+ k = self._ek(a,b)
143
+ if k not in self.K:
144
+ self.K[k] = 0.35
145
+ self.adj[a].append(b)
146
+ self.adj[b].append(a)
147
+ return k
148
+
149
+ def _find_tris(self):
150
+ as_ = {n: set(ns) for n, ns in self.adj.items()}
151
+ seen, tris = set(), []
152
+ for u, v in self.K:
153
+ for w in as_[u] & as_[v]:
154
+ key = tuple(sorted([u,v,w]))
155
+ if key not in seen:
156
+ tris.append(key); seen.add(key)
157
+ return tris
158
+
159
+ # ── FORWARD PASS (layer-by-layer, LMS backbone) ───────────────────────────
160
+
161
+ def _fwd(self, a_vals, b_vals):
162
+ """
163
+ Signed weighted-average forward pass.
164
+
165
+ x[v] = Ξ£ K_uv * x[u] / (Ξ£|K_uv| + Ξ΅)
166
+
167
+ Uses signed K so negative springs CAN pull output away from a neighbor
168
+ (needed for the 'diff' dataset). Output clamped to [-0.5, 1.5] for safety.
169
+ """
170
+ x = {}
171
+ for i, nid in enumerate(self.A): x[nid] = float(a_vals[min(i, len(a_vals)-1)])
172
+ for i, nid in enumerate(self.B): x[nid] = float(b_vals[min(i, len(b_vals)-1)])
173
+
174
+ def _agg(nid, upstream_rows):
175
+ nb = [n for n in self.adj[nid]
176
+ if self.nodes[n]['row'] in upstream_rows
177
+ and self._ek(nid,n) in self.vert]
178
+ if not nb: return 0.5
179
+ ws = [self.K[self._ek(nid,n)] for n in nb]
180
+ wab = sum(abs(w) for w in ws) + 1e-8
181
+ val = sum(w*x.get(n, 0.5) for w,n in zip(ws,nb)) / wab
182
+ return max(-0.5, min(1.5, val))
183
+
184
+ # Upper: rows 1β†’2β†’3β†’4(C)
185
+ for ri in range(1, self.NK+1):
186
+ for nid in self.layers[ri]:
187
+ if self.nodes[nid]['kind'] in ('A','B'): continue
188
+ x[nid] = _agg(nid, set(range(ri)))
189
+
190
+ # Lower: rows 7β†’6β†’5, then contributes into C (row 4)
191
+ for ri in range(self.NR-2, self.NK, -1):
192
+ for nid in self.layers[ri]:
193
+ if self.nodes[nid]['kind'] in ('A','B'): continue
194
+ x[nid] = _agg(nid, set(range(ri+1, self.NR)))
195
+
196
+ # C aggregates from row NK-1 (upper) AND row NK+1 (lower)
197
+ for nid in self.C:
198
+ nb = [n for n in self.adj[nid] if self._ek(nid,n) in self.vert]
199
+ ws = [self.K[self._ek(nid,n)] for n in nb]
200
+ wab = sum(abs(w) for w in ws) + 1e-8
201
+ x[nid] = max(-0.5, min(1.5,
202
+ sum(w*x.get(n,0.5) for w,n in zip(ws,nb)) / wab))
203
+
204
+ return x
205
+
206
+ # ── LMS UPDATE ────────────────────────────────────────────────────────────
207
+
208
+ def lms_update(self, a_vals, b_vals, c_target, ewc=False):
209
+ """
210
+ 1. Run forward pass.
211
+ 2. Compute error at C.
212
+ 3. Backprop delta signal through vertical springs only (horizontal = lateral flow, not learned here).
213
+ 4. Update each inter-row spring via Widrow-Hoff rule with optimal step size.
214
+ 5. Accumulate Fisher diagonal for EWC.
215
+ """
216
+ x = self._fwd(a_vals, b_vals)
217
+
218
+ ct = [float(c_target[min(i, len(c_target)-1)]) for i in range(self.n)]
219
+ errs = [x[nid] - ct[i] for i, nid in enumerate(self.C)]
220
+ total_e = float(np.sqrt(sum(e**2 for e in errs)))
221
+
222
+ # ── Backprop deltas ───────────────────────────────────────────────────
223
+ delta = {nid: 0.0 for nid in self.nodes}
224
+ for i, nid in enumerate(self.C): delta[nid] = errs[i]
225
+
226
+ def _prop_up(ri):
227
+ """Propagate delta from row ri+1 back into row ri."""
228
+ for nid in self.layers[ri]:
229
+ dn_nb = [n for n in self.adj[nid]
230
+ if self.nodes[n]['row'] == ri+1
231
+ and self._ek(nid,n) in self.vert]
232
+ for nb in dn_nb:
233
+ # How much does K(nid,nb) contribute to nb's total weight?
234
+ up_of_nb = [n2 for n2 in self.adj[nb]
235
+ if self.nodes[n2]['row'] < self.nodes[nb]['row']
236
+ and self._ek(nb,n2) in self.vert]
237
+ w_self = abs(self.K[self._ek(nid,nb)])
238
+ w_sum = sum(abs(self.K[self._ek(nb,n2)]) for n2 in up_of_nb) + 1e-8
239
+ delta[nid] += delta[nb] * w_self / w_sum
240
+
241
+ def _prop_dn(ri):
242
+ """Propagate delta from row ri-1 back into row ri (lower half)."""
243
+ for nid in self.layers[ri]:
244
+ up_nb = [n for n in self.adj[nid]
245
+ if self.nodes[n]['row'] == ri-1
246
+ and self._ek(nid,n) in self.vert]
247
+ for nb in up_nb:
248
+ dn_of_nb = [n2 for n2 in self.adj[nb]
249
+ if self.nodes[n2]['row'] > self.nodes[nb]['row']
250
+ and self._ek(nb,n2) in self.vert]
251
+ w_self = abs(self.K[self._ek(nid,nb)])
252
+ w_sum = sum(abs(self.K[self._ek(nb,n2)]) for n2 in dn_of_nb) + 1e-8
253
+ delta[nid] += delta[nb] * w_self / w_sum
254
+
255
+ for ri in range(self.NK-1, -1, -1): _prop_up(ri)
256
+ for ri in range(self.NK+1, self.NR): _prop_dn(ri)
257
+
258
+ # ── Widrow-Hoff update on inter-row springs ───────────────────────────
259
+ eps = 1e-8
260
+ for (u,v) in self.vert:
261
+ ru, rv = self.nodes[u]['row'], self.nodes[v]['row']
262
+ up_, dn_ = (u,v) if ru<rv else (v,u)
263
+ x_up = x.get(up_, 0.5)
264
+ d_dn = delta[dn_]
265
+ grad = d_dn * x_up
266
+ lr = LR_I / (1.0 + EWC_L * self.fisher[(u,v)]) if ewc else LR
267
+ new_k = self.K[(u,v)] - lr * grad / (x_up**2 + eps)
268
+ self.K[(u,v)] = max(-KCLIP, min(KCLIP, new_k))
269
+ # Fisher
270
+ g2 = (grad / (x_up**2 + eps))**2
271
+ self.fisher[(u,v)] = FD*self.fisher[(u,v)] + (1-FD)*g2
272
+
273
+ return total_e, [round(x[nid], 4) for nid in self.C]
274
 
275
+ # ── HOOKE DISPLAY PHYSICS ─────────────────────────────────────────────────
276
+
277
+ def set_inputs(self, a_vals, b_vals, c_tgt=None, anchor_c=False):
278
+ for i, nid in enumerate(self.A):
279
+ self.nodes[nid]['x'] = float(a_vals[min(i, len(a_vals)-1)])
280
+ for i, nid in enumerate(self.B):
281
+ self.nodes[nid]['x'] = float(b_vals[min(i, len(b_vals)-1)])
282
+ fixed = set(self.A) | set(self.B) | set(self.C)
283
+ for nid in self.nodes:
284
+ if nid not in fixed:
285
+ self.nodes[nid]['x'] = 0.5; self.nodes[nid]['vel'] = 0.0
286
+ for i, nid in enumerate(self.C):
287
+ self.nodes[nid]['vel'] = 0.0
288
+ if anchor_c and c_tgt is not None:
289
+ self.nodes[nid]['x'] = float(c_tgt[min(i, len(c_tgt)-1)])
290
+ self.nodes[nid]['anchored'] = True
291
+ else:
292
+ self.nodes[nid]['anchored'] = False
293
+ self.nodes[nid]['x'] = 0.5
294
+
295
+ def phys_step(self, ns=MICRO, anchor_c=True):
296
+ cs = set(self.C)
297
+ for _ in range(ns):
298
+ F = {nid: 0.0 for nid in self.nodes}
299
+ for (u,v), K in self.K.items():
 
 
300
  f = K * (self.nodes[v]['x'] - self.nodes[u]['x'])
301
+ F[u] += f; F[v] -= f
302
+ for nid, nd in self.nodes.items():
303
+ if nd['anchored'] or (nid in cs and anchor_c): continue
304
+ nd['vel'] = nd['vel']*DAMP + (F[nid] - GND*(nd['x']-0.5))*DT
305
+ nd['x'] += nd['vel']*DT
306
+
307
+ def c_phys(self): return [self.nodes[nid]['x'] for nid in self.C]
308
+
309
+ # ── STATE ─────────────────────────────────────────────────────────────────
310
+
311
+ def node_state(self):
312
+ return {nid: {
313
+ 'x': round(nd['x'], 4),
314
+ 'vel': round(abs(nd['vel']), 4),
315
+ 'anchored': nd['anchored'],
316
+ 'row': nd['row'],
317
+ 'kind': nd['kind'],
318
+ 'px': nd['px'],
319
+ 'py': nd['py'],
320
+ } for nid, nd in self.nodes.items()}
321
+
322
+ def spring_state(self):
323
+ # Return keyed by "ri_ci|ri_ci" for display
324
+ out = {}
325
+ for (u,v), K in self.K.items():
326
+ out[f"{u}|{v}"] = {
327
+ 'k': round(K, 4),
328
+ 'fish': round(self.fisher[(u,v)], 5),
329
+ 'vert': (u,v) in self.vert,
330
+ 'u_px': self.nodes[u]['px'], 'u_py': self.nodes[u]['py'],
331
+ 'v_px': self.nodes[v]['px'], 'v_py': self.nodes[v]['py'],
332
+ }
333
+ return out
334
+
335
+ def tri_state(self):
336
+ out = []
337
+ for (u,v,w) in self.tris:
338
+ nu, nv, nw = self.nodes[u], self.nodes[v], self.nodes[w]
339
+ ku = self.K.get(self._ek(u,v), 0)
340
+ kv = self.K.get(self._ek(v,w), 0)
341
+ kw = self.K.get(self._ek(u,w), 0)
342
+ out.append({
343
+ 'pos': [[nu['px'],nu['py']], [nv['px'],nv['py']], [nw['px'],nw['py']]],
344
+ 'avg_k': round((ku+kv+kw)/3, 4),
345
+ 'stress': round(abs(nu['x']-nv['x'])+abs(nv['x']-nw['x'])+abs(nu['x']-nw['x']), 4),
346
+ })
347
+ return out
348
+
349
+
350
+ # ── ENGINE ────────────────────────────────────────────────────────────────────
351
 
352
  class Engine:
353
  def __init__(self):
354
+ self.mesh = Mesh(n=1)
355
+ self.mode = 'idle'
356
+ self.running = False
357
+ self.q = collections.deque()
358
+ self.logs = []
359
+ self.iter = 0
360
+ self.step_cnt = 0
361
+ self.error = 0.0
362
+ self.c_pred = [0.5]
363
+ self.c_tgt = None
364
+ self.cur_type = 'β€”'
365
+ self.history = []
366
+ self.train_data = []
367
+ self.test_data = []
368
+ self.test_res = []
369
+
370
+ def log(self, msg):
371
+ self.logs.insert(0, f"[{self.iter:06d}] {msg}")
372
+ if len(self.logs) > 60: self.logs.pop()
373
+
374
+ # ── DATA ──────────────────────────────────────────────────────────────────
375
+
376
+ def load_data(self, tr='data/train.json', te='data/test.json'):
377
+ with open(tr) as f: self.train_data = json.load(f)
378
+ with open(te) as f: self.test_data = json.load(f)
379
+ # Detect n from data
380
+ n = len(self.train_data[0]['A'])
381
+ if n != self.mesh.n:
382
+ self.mesh = Mesh(n=n)
383
+ self.log(f"Mesh rebuilt for n={n}")
384
+ ood = sum(1 for d in self.test_data if d['type']=='heavy_b')
385
+ self.log(f"Data: {len(self.train_data)} train | {len(self.test_data)} test ({ood} OOD) | n={n}")
386
+
387
+ # ── VISUAL STEP ───────────────────────────────────────────────────────────
388
+
389
+ def phys_tick(self):
390
+ anchor = (self.mode == 'training')
391
+ self.mesh.phys_step(MICRO, anchor_c=anchor)
392
+
393
+ if self.c_tgt is not None:
394
+ c_p = self.mesh.c_phys()
395
+ errs = [c_p[i] - float(self.c_tgt[i]) for i in range(self.mesh.n)]
396
+ self.error = float(np.sqrt(sum(e**2 for e in errs)))
397
+ self.c_pred = [round(v,4) for v in c_p]
398
  else:
399
+ self.error = 0.0
400
+
401
+ self.history.append(round(abs(self.error), 5))
402
+ if len(self.history) > 200: self.history.pop(0)
403
+
404
+ self.step_cnt += 1
405
+ conv = abs(self.error) < CONV
406
+ timeout = self.step_cnt >= MAXS
407
+
408
+ if conv or timeout:
409
+ tag = 'βœ“' if conv else '⚠'
410
+ ood = self.cur_type == 'heavy_b'
411
+ self.log(f"{tag}{'[OOD]' if ood else ''} [{self.cur_type}] err={self.error:.4f}")
412
+ if self.mode == 'inference' and self.c_tgt is not None:
413
+ self.test_res.append({
414
+ 'type': self.cur_type, 'abs': round(abs(self.error),5),
415
+ 'ok': conv, 'steps': self.step_cnt, 'ood': ood,
416
+ })
417
+ return self._next()
418
+
419
+ # Inference-time K update (EWC protected)
420
+ if self.mode == 'inference' and self.c_tgt is not None:
421
+ self.mesh.lms_update(
422
+ [self.mesh.nodes[nid]['x'] for nid in self.mesh.A],
423
+ [self.mesh.nodes[nid]['x'] for nid in self.mesh.B],
424
+ self.c_tgt, ewc=True
425
+ )
426
+
427
  self.iter += 1
428
+ return True
 
429
 
430
+ def _next(self):
431
+ if self.q:
432
+ p = self.q.popleft()
433
+ self._load(p)
434
+ return True
435
  self.running = False
436
+ self.log("β—Ό Queue done.")
437
+ return False
438
+
439
+ def _load(self, p):
440
+ self.cur_type = p['type']
441
+ self.c_tgt = p.get('C')
442
+ self.step_cnt = 0
443
+ anchor = (self.mode == 'training')
444
+ self.mesh.set_inputs(p['A'], p['B'], p.get('C'), anchor_c=anchor)
445
+
446
+ def _fill(self, data):
447
+ self.q.clear()
448
+ for d in data: self.q.append(d)
449
+ if self.q: self._load(self.q.popleft())
450
+
451
+ # ── OFFLINE TRAINING (fast, no display) ───────────────────────────────────
452
+
453
+ def train_offline(self, epochs=5):
454
+ self.running = False; self.mode = 'training'
455
+ self.log(f"⚑ Offline LMS: {epochs} epochs…")
456
+ for ep in range(1, epochs+1):
457
  random.shuffle(self.train_data)
458
+ tot, conv = 0.0, 0
459
+ for s in self.train_data:
460
+ for _ in range(MAXS):
461
+ e, _ = self.mesh.lms_update(s['A'], s['B'], s['C'], ewc=False)
462
+ if e < CONV: conv += 1; break
463
+ tot += e
464
+ avg = tot / max(len(self.train_data), 1)
465
+ pct = 100*conv/max(len(self.train_data), 1)
466
+ msg = f" Ep {ep}/{epochs}: avg|e|={avg:.4f} conv={pct:.1f}%"
467
+ self.log(msg); print(msg)
468
+ self.mesh.K0 = dict(self.mesh.K) # save EWC anchors
469
+ self.log("βœ“ Done. EWC anchors saved.")
 
 
 
470
  self.mode = 'idle'
471
 
472
+ # ── START HELPERS ─────────────────────────────────────────────────────────
473
+
474
+ def start_visual(self, n=None):
475
+ data = random.sample(self.train_data, min(n or len(self.train_data), len(self.train_data)))
476
+ self._fill(data); self.mode='training'; self.running=True
477
+ self.log(f"β–Ά Visual train: {len(data)}")
478
+
479
+ def start_infer(self, n=None):
480
+ data = self.test_data[:n] if n else self.test_data
481
+ self.test_res=[]; self._fill(data); self.mode='inference'; self.running=True
482
+ ood = sum(1 for d in data if d['type']=='heavy_b')
483
+ self.log(f"β–Ά Inference: {len(data)} ({ood} OOD)")
484
+
485
+ # ── ACCURACY ──────────────────────────────────────────────────────────────
486
+
487
+ def acc(self):
488
+ out = {}
489
+ for r in self.test_res:
490
  t = r['type']
491
+ if t not in out: out[t] = {'n':0,'ok':0,'se':0.0,'ss':0,'ood':r['ood']}
492
+ out[t]['n']+=1; out[t]['ok']+=int(r['ok'])
493
+ out[t]['se']+=r['abs']; out[t]['ss']+=r['steps']
494
+ return {t:{
495
+ 'n':v['n'], 'acc':round(100*v['ok']/max(v['n'],1),1),
496
+ 'avg_err':round(v['se']/max(v['n'],1),4),
497
+ 'avg_steps':round(v['ss']/max(v['n'],1),1),
498
+ 'ood':v['ood'],
499
+ } for t,v in out.items()}
500
+
501
+ # ── STATE ─────────────────────────────────────────────────────────────────
502
+
503
+ def state(self):
504
+ return {
505
+ 'nodes': self.mesh.node_state(),
506
+ 'springs': self.mesh.spring_state(),
507
+ 'triangles': self.mesh.tri_state(),
508
+ 'widths': self.mesh.W,
509
+ 'n': self.mesh.n,
510
+ 'n_springs': len(self.mesh.K),
511
+ 'n_vert': len(self.mesh.vert),
512
+ 'error': round(self.error, 5),
513
+ 'c_pred': self.c_pred,
514
+ 'c_tgt': [round(v,4) for v in self.c_tgt] if self.c_tgt else None,
515
+ 'iter': self.iter,
516
+ 'step_cnt': self.step_cnt,
517
+ 'logs': self.logs,
518
+ 'history': self.history[-120:],
519
+ 'running': self.running,
520
+ 'mode': self.mode,
521
+ 'cur_type': self.cur_type,
522
+ 'q_size': len(self.q),
523
+ 'train_size': len(self.train_data),
524
+ 'test_size': len(self.test_data),
525
+ 'type_acc': self.acc(),
526
+ 'n_done': len(self.test_res),
527
+ }
528
+
529
+
530
+ # ── SERVER ────────────────────────────────────────────────────────────────────
531
 
532
  engine = Engine()
533
+ try: engine.load_data()
534
+ except Exception as e: engine.log(f"⚠ No data β€” run: python data_gen.py ({e})")
535
+
536
+
537
+ def _loop():
 
 
 
538
  while True:
539
+ if engine.running: engine.phys_tick()
540
+ time.sleep(0.028)
541
+
542
+ threading.Thread(target=_loop, daemon=True).start()
543
+
544
 
545
  @app.get("/", response_class=HTMLResponse)
546
  async def ui(): return FileResponse("index.html")
547
 
548
  @app.get("/state")
549
+ async def get_state(): return engine.state()
550
+
551
+ @app.post("/train_offline")
552
+ async def t_offline(d: dict = {}):
553
+ ep = int(d.get('epochs', 5))
 
 
 
 
 
 
 
 
 
 
 
 
 
554
  threading.Thread(target=engine.train_offline, args=(ep,), daemon=True).start()
555
  return {"ok": True}
556
 
557
+ @app.post("/train_visual")
558
+ async def t_visual(d: dict = {}):
559
+ engine.start_visual(); return {"ok": True}
560
+
561
  @app.post("/infer")
562
+ async def infer(d: dict = {}):
563
+ engine.start_infer(n=d.get('n')); return {"ok": True}
 
 
 
 
 
 
564
 
565
+ @app.post("/set_n")
566
+ async def set_n(d: dict):
567
+ engine.running = False
568
+ n = max(1, min(8, int(d.get('n', 1))))
569
+ engine.mesh = Mesh(n=n)
570
+ engine.log(f"Mesh rebuilt: n={n} rows={len(engine.mesh.W)} springs={len(engine.mesh.K)}")
571
+ return {"ok": True}
 
 
 
 
 
 
 
 
572
 
573
  @app.post("/halt")
574
+ async def halt(): engine.running=False; return {"ok":True}
575
+
576
+ @app.post("/reset")
577
+ async def reset():
578
+ engine.running=False
579
+ engine.mesh=Mesh(n=engine.mesh.n)
580
+ engine.log("Springs re-initialised."); return {"ok":True}
581
 
582
  if __name__ == "__main__":
583
  import uvicorn