everydaytok commited on
Commit
2baf456
Β·
verified Β·
1 Parent(s): afe086f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +186 -256
app.py CHANGED
@@ -1,27 +1,12 @@
1
  """
2
- main.py β€” Elastic Mesh Engine + FastAPI server.
3
-
4
- Architecture:
5
- Bilateral hourglass: A (top) ─[U1..Un]─┐
6
- C (center waist)
7
- B (bot) ─[L1..Ln]β”€β”˜
8
-
9
- Each node : x, vel ∈ ℝ^DIM
10
- Each spring: K ∈ ℝ^(DIMΓ—DIM) β€” full linear map per edge
11
-
12
- Forward (additive):
13
- x_Ui = K(A,Ui) @ x_A
14
- x_Li = K(B,Li) @ x_B
15
- x_C = Ξ£ K(Ui,C) @ x_Ui + Ξ£ K(Li,C) @ x_Li
16
-
17
- Training:
18
- C anchored at target β†’ K matrices update via matrix LMS
19
- one-shot zero-residual for linear problems
20
-
21
- Inference:
22
- C free β†’ elastic dynamics settle to equilibrium
23
- EWC regularisation protects weights from catastrophic forgetting
24
- Fisher diagonal accumulates during training
25
  """
26
 
27
  import numpy as np
@@ -35,45 +20,51 @@ app.add_middleware(CORSMiddleware,
35
  allow_origins=["*"], allow_methods=["*"], allow_headers=["*"])
36
 
37
  # ── CONSTANTS ─────────────────────────────────────────────────────────────────
38
- DIM = 32 # embedding dimension (scale to 768 for LLM integration)
39
- FWD_K = 1.5 # forward spring stiffness for elastic display
40
- BACK_A = 0.40 # backward tension (C pulls on hidden nodes)
41
- DAMPING = 0.58 # velocity retention per display micro-step
42
- DT = 0.10 # display physics time-step
43
- MICRO = 4 # display micro-steps per server tick
44
- CONV_THRESH = 0.08 # β€–errorβ€– < this β†’ sample converged
45
- MAX_STEPS = 400 # hard cap per sample (prevents infinite loops)
46
- EWC_LAMBDA = 0.6 # EWC penalty strength
47
- FISHER_DECAY= 0.97 # EMA decay for Fisher accumulation
48
-
49
-
50
- class MeshEngine:
 
 
51
  """
52
- Elastic hourglass mesh with matrix spring stiffness.
53
-
54
- The mesh learns to produce C = equilibrium(A, B) such that C lies in the
55
- feasibility space satisfying A-constraints while respecting B-objectives.
56
- This is not computed β€” it is converged to.
57
  """
 
 
 
58
 
59
- def __init__(self, dim: int = DIM, n_upper: int = 3, n_lower: int = 3):
 
 
60
  self.dim = dim
61
  self.n_upper = n_upper
62
  self.n_lower = n_lower
63
- self.mode = 'idle' # 'training' | 'inference' | 'idle'
64
  self.running = False
65
  self.batch_queue = collections.deque()
66
  self.logs = []
67
  self.iteration = 0
68
- self.step_count = 0 # steps on current sample
69
  self.error_norm = 0.0
70
  self.pred_norm = 0.0
71
  self.history = []
72
  self.train_data = []
73
  self.test_data = []
74
- self.c_target = None # ground-truth C for current sample (inference)
75
  self.current_type = 'unknown'
76
- self.test_errors = [] # list of {type, err, rel} β€” inference results
77
  self._init_mesh()
78
 
79
  # ── TOPOLOGY ──────────────────────────────────────────────────────────────
@@ -96,18 +87,14 @@ class MeshEngine:
96
  self.layers = self._layers()
97
  d = self.dim
98
 
99
- # Nodes β€” each carries a d-vector position and velocity
100
  self.nodes = {
101
- nid: {
102
- 'x': np.zeros(d),
103
- 'vel': np.zeros(d),
104
- 'anchored': nid in ('A', 'B'),
105
- }
106
  for layer in self.layers for nid in layer
107
  }
108
 
109
- # Spring matrices β€” K ∈ ℝ^(dΓ—d) per edge, Xavier init
110
- scale = np.sqrt(2.0 / (d + d))
111
  self.K = {}
112
  for i in range(1, self.n_upper + 1):
113
  uid = f'U{i}'
@@ -118,8 +105,7 @@ class MeshEngine:
118
  self.K[('B', lid)] = np.random.normal(0, scale, (d, d))
119
  self.K[(lid, 'C')] = np.random.normal(0, scale, (d, d))
120
 
121
- # EWC: Fisher diagonal (per element of each K matrix)
122
- self.fisher = {k: np.zeros((d, d)) for k in self.K}
123
  self.K_anchor = {k: v.copy() for k, v in self.K.items()}
124
 
125
  # ── PROBLEM SETUP ─────────────────────────────────────────────────────────
@@ -131,7 +117,6 @@ class MeshEngine:
131
  self.current_type = ptype
132
  self.step_count = 0
133
 
134
- # Reset free nodes for fresh elastic oscillation
135
  for layer in self.layers[1:4]:
136
  for nid in layer:
137
  if nid != 'C':
@@ -140,35 +125,36 @@ class MeshEngine:
140
 
141
  c = self.nodes['C']
142
  c['vel'] = np.zeros(d)
143
-
144
  if self.mode == 'training' and c_target is not None:
145
  c['x'] = np.asarray(c_target, dtype=float)[:d]
146
  c['anchored'] = True
147
  self.c_target = c['x'].copy()
148
  else:
149
- # Inference: C is free; store target only for accuracy measurement
150
  c['anchored'] = False
151
  c['x'] = np.zeros(d)
152
  self.c_target = (np.asarray(c_target, dtype=float)[:d]
153
  if c_target is not None else None)
154
 
155
- # ── FEEDFORWARD ───────────────────────────────────────────────────────────
156
 
157
  def _forward(self):
158
  """
159
- Exact feedforward pass (used for learning).
160
- Returns (C_pred, hidden_activations).
 
161
  """
162
  xa, xb = self.nodes['A']['x'], self.nodes['B']['x']
163
  hid = {}
164
 
165
  for i in range(1, self.n_upper + 1):
166
- uid = f'U{i}'
167
- hid[uid] = self.K[('A', uid)] @ xa # ℝ^d
 
168
 
169
  for i in range(1, self.n_lower + 1):
170
- lid = f'L{i}'
171
- hid[lid] = self.K[('B', lid)] @ xb # ℝ^d
 
172
 
173
  pred = np.zeros(self.dim)
174
  for i in range(1, self.n_upper + 1):
@@ -176,25 +162,19 @@ class MeshEngine:
176
  for i in range(1, self.n_lower + 1):
177
  pred += self.K[(f'L{i}', 'C')] @ hid[f'L{i}']
178
 
 
 
179
  return pred, hid
180
 
181
  # ── ELASTIC DISPLAY PHYSICS ───────────────────────────────────────────────
182
 
183
- def _elastic_step(self, n_steps: int = MICRO):
184
- """
185
- Damped-oscillator spring dynamics for visualisation.
186
-
187
- Forward springs pull hidden nodes toward their feedforward rest positions.
188
- Backward tension (BACK_A) lets anchored-C's position propagate upstream β€”
189
- the mesh physically feels the error as strain before any K update.
190
- """
191
  xa, xb = self.nodes['A']['x'], self.nodes['B']['x']
192
-
193
  for _ in range(n_steps):
194
  for i in range(1, self.n_upper + 1):
195
- uid = f'U{i}'
196
- n = self.nodes[uid]
197
- rest = self.K[('A', uid)] @ xa
198
  f = FWD_K * (rest - n['x'])
199
  f += BACK_A * (self.K[(uid, 'C')].T @
200
  (self.nodes['C']['x'] - self.K[(uid, 'C')] @ n['x']))
@@ -202,9 +182,9 @@ class MeshEngine:
202
  n['x'] += n['vel'] * DT
203
 
204
  for i in range(1, self.n_lower + 1):
205
- lid = f'L{i}'
206
- n = self.nodes[lid]
207
- rest = self.K[('B', lid)] @ xb
208
  f = FWD_K * (rest - n['x'])
209
  f += BACK_A * (self.K[(lid, 'C')].T @
210
  (self.nodes['C']['x'] - self.K[(lid, 'C')] @ n['x']))
@@ -218,130 +198,101 @@ class MeshEngine:
218
  rest += self.K[(f'U{i}', 'C')] @ self.nodes[f'U{i}']['x']
219
  for i in range(1, self.n_lower + 1):
220
  rest += self.K[(f'L{i}', 'C')] @ self.nodes[f'L{i}']['x']
221
- f = FWD_K * (rest - c['x'])
 
222
  c['vel'] = c['vel'] * DAMPING + f * DT
223
  c['x'] += c['vel'] * DT
224
 
225
  # ── MATRIX LMS UPDATE ─────────────────────────────────────────────────────
226
 
227
- def _lms_update(self, error: np.ndarray, hid: dict, ewc: bool = False):
228
  """
229
- Matrix LMS with joint optimal step.
230
-
231
- For the output layer (X β†’ C):
232
- grad_K = outer(error, h_X) ∈ ℝ^(dΓ—d)
233
- joint_denom = Ξ£_edges β€–h_Xβ€–Β² (one normaliser for all output-layer edges)
234
- K(X,C) -= grad_K / joint_denom
235
-
236
- This drives β€–errorβ€– β†’ 0 in one step for linear systems (provable).
237
-
238
- For the hidden layer (A/B β†’ U/L):
239
- delta propagates back through K(X,C):
240
- Ξ΄_U = K(U,C)α΅€ @ error
241
- grad_K = outer(Ξ΄_U, x_A)
242
- K(A,U) -= grad_K / β€–x_Aβ€–Β²
243
 
244
- EWC mode: step size reduced by (1 + λ·F) per element, protecting
245
- dimensions with high Fisher importance from past training.
 
 
246
  """
247
  eps = 1e-8
248
  xa = self.nodes['A']['x']
249
  xb = self.nodes['B']['x']
250
 
251
- # ── Output-layer joint update ──────────────────────────────────────
252
  joint_denom = eps
253
  for i in range(1, self.n_upper + 1):
254
  joint_denom += float(np.dot(hid[f'U{i}'], hid[f'U{i}']))
255
  for i in range(1, self.n_lower + 1):
256
  joint_denom += float(np.dot(hid[f'L{i}'], hid[f'L{i}']))
257
 
 
258
  for i in range(1, self.n_upper + 1):
259
  uid = f'U{i}'
260
  key = (uid, 'C')
261
- grad = np.outer(error, hid[uid])
262
- if ewc:
263
- denom = joint_denom * (1.0 + EWC_LAMBDA * self.fisher[key])
264
- else:
265
- denom = joint_denom
266
- self.K[key] -= grad / denom
267
- np.clip(self.K[key], -8.0, 8.0, out=self.K[key])
268
 
269
  for i in range(1, self.n_lower + 1):
270
  lid = f'L{i}'
271
  key = (lid, 'C')
272
- grad = np.outer(error, hid[lid])
273
- if ewc:
274
- denom = joint_denom * (1.0 + EWC_LAMBDA * self.fisher[key])
275
- else:
276
- denom = joint_denom
277
- self.K[key] -= grad / denom
278
- np.clip(self.K[key], -8.0, 8.0, out=self.K[key])
279
 
280
- # ── Hidden-layer update (backprop) ────────────────────────────────
281
- xa_denom = float(np.dot(xa, xa)) + eps
282
- xb_denom = float(np.dot(xb, xb)) + eps
 
 
283
 
284
  for i in range(1, self.n_upper + 1):
285
- uid = f'U{i}'
286
- key = ('A', uid)
287
- delta = self.K[(uid, 'C')].T @ error # back-propagated error ∈ ℝ^d
288
- grad = np.outer(delta, xa)
289
- if ewc:
290
- denom = xa_denom * (1.0 + EWC_LAMBDA * self.fisher[key])
291
- else:
292
- denom = xa_denom
293
- self.K[key] -= grad / denom
294
- np.clip(self.K[key], -8.0, 8.0, out=self.K[key])
295
 
296
  for i in range(1, self.n_lower + 1):
297
- lid = f'L{i}'
298
- key = ('B', lid)
299
  delta = self.K[(lid, 'C')].T @ error
300
- grad = np.outer(delta, xb)
301
- if ewc:
302
- denom = xb_denom * (1.0 + EWC_LAMBDA * self.fisher[key])
303
- else:
304
- denom = xb_denom
305
- self.K[key] -= grad / denom
306
- np.clip(self.K[key], -8.0, 8.0, out=self.K[key])
307
 
308
- # ── FISHER ACCUMULATION (EWC) ─────────────────────────────────────────────
309
 
310
- def _update_fisher(self, error: np.ndarray, hid: dict):
311
- """
312
- Accumulate Fisher diagonal via EMA of squared gradient elements.
313
- High Fisher β†’ this weight dimension was important for past problems.
314
- """
315
  xa = self.nodes['A']['x']
316
  xb = self.nodes['B']['x']
317
-
318
  for i in range(1, self.n_upper + 1):
319
  uid = f'U{i}'
320
- g_uc = np.outer(error, hid[uid]) ** 2
321
- g_au = np.outer(self.K[(uid, 'C')].T @ error, xa) ** 2
322
- self.fisher[(uid, 'C')] = (FISHER_DECAY * self.fisher[(uid, 'C')] +
323
- (1 - FISHER_DECAY) * g_uc)
324
- self.fisher[('A', uid)] = (FISHER_DECAY * self.fisher[('A', uid)] +
325
- (1 - FISHER_DECAY) * g_au)
326
-
327
  for i in range(1, self.n_lower + 1):
328
  lid = f'L{i}'
329
- g_lc = np.outer(error, hid[lid]) ** 2
330
- g_bl = np.outer(self.K[(lid, 'C')].T @ error, xb) ** 2
331
- self.fisher[(lid, 'C')] = (FISHER_DECAY * self.fisher[(lid, 'C')] +
332
- (1 - FISHER_DECAY) * g_lc)
333
- self.fisher[('B', lid)] = (FISHER_DECAY * self.fisher[('B', lid)] +
334
- (1 - FISHER_DECAY) * g_bl)
335
 
336
  # ── PHYSICS STEP ──────────────────────────────────────────────────────────
337
 
338
- def physics_step(self) -> bool:
339
- """One server tick: elastic display + LMS update."""
340
  self._elastic_step(MICRO)
341
-
342
- pred, hid = self._forward()
343
- self.pred_norm = float(np.linalg.norm(pred))
344
- self.step_count += 1
345
 
346
  c = self.nodes['C']
347
  if c['anchored']:
@@ -350,8 +301,7 @@ class MeshEngine:
350
  else:
351
  c['x'] = pred.copy()
352
  error = (pred - self.c_target
353
- if self.c_target is not None
354
- else np.zeros(self.dim))
355
  self.error_norm = float(np.linalg.norm(error))
356
 
357
  self.history.append(round(self.error_norm, 5))
@@ -362,60 +312,57 @@ class MeshEngine:
362
  timeout = self.step_count >= MAX_STEPS
363
 
364
  if converged or timeout:
365
- tag = 'βœ“' if converged else '⚠'
366
- self.add_log(f"{tag} [{self.current_type}] "
367
- f"err={self.error_norm:.4f} it={self.step_count}")
 
 
368
  if self.mode == 'inference' and self.c_target is not None:
369
  ct_norm = float(np.linalg.norm(self.c_target)) + 1e-8
370
  self.test_errors.append({
371
- 'type': self.current_type,
372
- 'abs': round(self.error_norm, 5),
373
- 'rel': round(self.error_norm / ct_norm, 5),
374
- 'ok': converged,
 
 
375
  })
376
  self._update_fisher(error, hid)
377
  return self._next_or_stop()
378
 
379
  if c['anchored']:
380
- # Training: update K to reduce error
381
  self._lms_update(error, hid, ewc=False)
382
  elif self.mode == 'inference':
383
- # Inference: EWC-regularised online adaptation
384
  self._lms_update(error, hid, ewc=True)
385
 
386
  self.iteration += 1
387
  return True
388
 
389
- def _next_or_stop(self) -> bool:
390
  if self.batch_queue:
391
  p = self.batch_queue.popleft()
392
- self.set_problem(p['A'], p['B'], p.get('C'), p.get('type', 'unknown'))
393
  return True
394
  self.running = False
395
  self.add_log("β—Ό Queue empty.")
396
  return False
397
 
398
- # ── FAST OFFLINE TRAINING ─────────────────────────────────────────────────
399
 
400
- def train_offline(self, epochs: int = 5):
401
- """
402
- Run full training at CPU speed (no sleep, no display physics).
403
- Called in a background thread from /train_offline endpoint.
404
- """
405
  self.running = False
406
  self.mode = 'training'
407
- self.add_log(f"⚑ Offline training: {epochs} epoch(s)…")
408
 
409
  for ep in range(1, epochs + 1):
410
  random.shuffle(self.train_data)
411
- total_err = 0.0
412
- converged = 0
413
 
414
  for sample in self.train_data:
415
- d = self.dim
416
- xa = np.asarray(sample['A'], dtype=float)[:d]
417
- xb = np.asarray(sample['B'], dtype=float)[:d]
418
- ct = np.asarray(sample['C'], dtype=float)[:d]
419
  self.nodes['A']['x'] = xa
420
  self.nodes['B']['x'] = xb
421
  self.nodes['C']['x'] = ct
@@ -437,9 +384,8 @@ class MeshEngine:
437
  self.add_log(f" Ep {ep}/{epochs}: avgβ€–eβ€–={avg:.4f} conv={pct:.1f}%")
438
  print(f" Ep {ep}/{epochs}: avgβ€–eβ€–={avg:.4f} converged={pct:.1f}%")
439
 
440
- # Save anchor weights for EWC
441
  self.K_anchor = {k: v.copy() for k, v in self.K.items()}
442
- self.add_log("βœ“ Offline training complete. EWC anchors saved.")
443
  self.mode = 'idle'
444
 
445
  # ── DATA LOADING ──────────────────────────────────────────────────────────
@@ -447,8 +393,11 @@ class MeshEngine:
447
  def load_data(self, train='data/train.json', test='data/test.json'):
448
  with open(train) as f: self.train_data = json.load(f)
449
  with open(test) as f: self.test_data = json.load(f)
450
- self.add_log(f"Data loaded: {len(self.train_data)} train / "
451
- f"{len(self.test_data)} test")
 
 
 
452
 
453
  # ── QUEUE HELPERS ──────────────────────────��──────────────────────────────
454
 
@@ -456,17 +405,16 @@ class MeshEngine:
456
  data = random.sample(self.train_data,
457
  min(n or len(self.train_data), len(self.train_data)))
458
  self._fill_queue(data, anchor_c=True)
459
- self.mode = 'training'
460
- self.running = True
461
  self.add_log(f"β–Ά Visual training: {len(data)} samples")
462
 
463
  def start_inference(self, n=None):
464
  data = self.test_data[:n] if n else self.test_data
465
  self.test_errors = []
466
  self._fill_queue(data, anchor_c=False)
467
- self.mode = 'inference'
468
- self.running = True
469
- self.add_log(f"β–Ά Inference: {len(data)} samples")
470
 
471
  def _fill_queue(self, data, anchor_c):
472
  self.batch_queue.clear()
@@ -479,16 +427,15 @@ class MeshEngine:
479
  if anchor_c:
480
  self.set_problem(p['A'], p['B'], p['C'], p['type'])
481
  else:
482
- # Inference: don't anchor but store target
483
  d = self.dim
484
- self.nodes['A']['x'] = np.asarray(p['A'])[:d]
485
- self.nodes['B']['x'] = np.asarray(p['B'])[:d]
486
- self.nodes['C']['x'] = np.zeros(d)
487
- self.nodes['C']['vel'] = np.zeros(d)
488
  self.nodes['C']['anchored'] = False
489
- self.c_target = np.asarray(p['C'])[:d]
490
- self.current_type = p['type']
491
- self.step_count = 0
492
  for layer in self.layers[1:4]:
493
  for nid in layer:
494
  if nid != 'C':
@@ -502,7 +449,7 @@ class MeshEngine:
502
  if len(self.logs) > 60:
503
  self.logs.pop()
504
 
505
- # ── STATE SERIALISATION ───────────────────────────────────────────────────
506
 
507
  def state_dict(self):
508
  nodes_out = {}
@@ -516,31 +463,34 @@ class MeshEngine:
516
 
517
  springs_out = {}
518
  for (u, v), km in self.K.items():
519
- label = f"{u}β†’{v}"
520
- springs_out[label] = {
521
- 'frob': round(float(np.linalg.norm(km)), 4),
522
  'mean': round(float(np.mean(km)), 4),
523
  'std': round(float(np.std(km)), 4),
524
- 'fish': round(float(np.mean(self.fisher[(u, v)])), 5),
525
  }
526
 
527
- # Per-type inference accuracy
528
  type_acc = {}
529
  for te in self.test_errors:
530
  t = te['type']
531
  if t not in type_acc:
532
- type_acc[t] = {'n': 0, 'n_ok': 0, 'sum_abs': 0.0}
533
- type_acc[t]['n'] += 1
534
- type_acc[t]['n_ok'] += int(te['ok'])
535
- type_acc[t]['sum_abs'] += te['abs']
536
- acc_summary = {
537
- t: {
538
- 'n': v['n'],
539
- 'acc': round(100 * v['n_ok'] / max(v['n'], 1), 1),
540
- 'avg_err': round(v['sum_abs'] / max(v['n'], 1), 4),
 
 
 
 
 
 
541
  }
542
- for t, v in type_acc.items()
543
- }
544
 
545
  return {
546
  'nodes': nodes_out,
@@ -563,6 +513,7 @@ class MeshEngine:
563
  'n_test_done': len(self.test_errors),
564
  'current_type': self.current_type,
565
  'dim': self.dim,
 
566
  }
567
 
568
 
@@ -573,7 +524,7 @@ engine = MeshEngine(dim=DIM, n_upper=3, n_lower=3)
573
  try:
574
  engine.load_data()
575
  except Exception as e:
576
- engine.add_log(f"No data found β€” run: python data_gen.py ({e})")
577
 
578
 
579
  def run_loop():
@@ -585,70 +536,49 @@ def run_loop():
585
  threading.Thread(target=run_loop, daemon=True).start()
586
 
587
 
588
- @app.get("/", response_class=HTMLResponse)
589
- async def get_ui():
590
- return FileResponse("index.html")
591
 
592
  @app.get("/state")
593
- async def get_state():
594
- return engine.state_dict()
595
-
596
- # ── Training controls ─────────────────────────────────────────────────────────
597
 
598
  @app.post("/train_visual")
599
  async def train_visual(data: dict = {}):
600
- """Start visual (slow) training β€” shows elastic dynamics in UI."""
601
  engine.start_training(n=data.get('n'))
602
  return {"ok": True}
603
 
604
  @app.post("/train_offline")
605
  async def train_offline(data: dict = {}):
606
- """Fast offline training in background thread β€” no display."""
607
  epochs = int(data.get('epochs', 5))
608
  threading.Thread(target=engine.train_offline, args=(epochs,), daemon=True).start()
609
  return {"ok": True, "epochs": epochs}
610
 
611
  @app.post("/infer")
612
  async def start_infer(data: dict = {}):
613
- """Run inference on test set, measuring C reconstruction accuracy."""
614
  engine.start_inference(n=data.get('n'))
615
  return {"ok": True}
616
 
617
  @app.post("/reload_data")
618
  async def reload_data():
619
- try:
620
- engine.load_data()
621
- return {"ok": True}
622
- except Exception as e:
623
- return {"ok": False, "error": str(e)}
624
-
625
- # ── Topology controls ────────────────────────────────────────────────────────
626
 
627
  @app.post("/set_layer")
628
  async def set_layer(data: dict):
629
- layer = data.get('layer', '')
630
- delta = int(data.get('delta', 0))
631
  engine.running = False
632
- if layer == 'upper':
633
- engine.n_upper = max(1, min(8, engine.n_upper + delta))
634
- elif layer == 'lower':
635
- engine.n_lower = max(1, min(8, engine.n_lower + delta))
636
  engine._init_mesh()
637
- engine.add_log(f"Topology β†’ U{engine.n_upper} Β· L{engine.n_lower} | springs re-init")
638
  return {"ok": True, "n_upper": engine.n_upper, "n_lower": engine.n_lower}
639
 
640
  @app.post("/halt")
641
- async def halt():
642
- engine.running = False
643
- return {"ok": True}
644
 
645
  @app.post("/reset")
646
- async def reset():
647
- engine.running = False
648
- engine._init_mesh()
649
- engine.add_log("Mesh reset.")
650
- return {"ok": True}
651
-
652
 
653
  if __name__ == "__main__":
654
  import uvicorn
 
1
  """
2
+ main.py β€” Elastic Mesh Engine v3
3
+
4
+ Changes from v2:
5
+ β‘  Layer normalisation after every spring transform β†’ kills weight explosion
6
+ β‘‘ Convergence threshold 0.02 (was 0.08) β†’ genuine precision
7
+ β‘’ DIM = 64 (was 32) β†’ double the space
8
+ β‘£ OOD test: model trained on seen types only,
9
+ test set contains both seen + unseen types
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  """
11
 
12
  import numpy as np
 
20
  allow_origins=["*"], allow_methods=["*"], allow_headers=["*"])
21
 
22
  # ── CONSTANTS ─────────────────────────────────────────────────────────────────
23
+ DIM = 64
24
+ FWD_K = 1.5
25
+ BACK_A = 0.40
26
+ DAMPING = 0.58
27
+ DT = 0.10
28
+ MICRO = 4
29
+ CONV_THRESH = 0.02 # ← tightened from 0.08
30
+ MAX_STEPS = 600 # ← increased to give tighter threshold room
31
+ EWC_LAMBDA = 0.6
32
+ FISHER_DECAY = 0.97
33
+ LN_EPS = 1e-6 # layer norm epsilon
34
+
35
+
36
+ # ── LAYER NORM ────────────────────────────────────────────────────────────────
37
+ def layer_norm(x: np.ndarray) -> np.ndarray:
38
  """
39
+ Zero-mean unit-variance normalisation over the D-vector.
40
+ Applied after every spring transform to prevent the 200Γ— amplification
41
+ seen in v2 (input springs were β€–Kβ€–β‰ˆ200 while output springs were β€–Kβ€–β‰ˆ0.8).
42
+ The mesh can still learn arbitrary directions β€” only the scale is removed.
 
43
  """
44
+ mu = np.mean(x)
45
+ std = np.std(x) + LN_EPS
46
+ return (x - mu) / std
47
 
48
+
49
+ class MeshEngine:
50
+ def __init__(self, dim=DIM, n_upper=3, n_lower=3):
51
  self.dim = dim
52
  self.n_upper = n_upper
53
  self.n_lower = n_lower
54
+ self.mode = 'idle'
55
  self.running = False
56
  self.batch_queue = collections.deque()
57
  self.logs = []
58
  self.iteration = 0
59
+ self.step_count = 0
60
  self.error_norm = 0.0
61
  self.pred_norm = 0.0
62
  self.history = []
63
  self.train_data = []
64
  self.test_data = []
65
+ self.c_target = None
66
  self.current_type = 'unknown'
67
+ self.test_errors = []
68
  self._init_mesh()
69
 
70
  # ── TOPOLOGY ──────────────────────────────────────────────────────────────
 
87
  self.layers = self._layers()
88
  d = self.dim
89
 
 
90
  self.nodes = {
91
+ nid: {'x': np.zeros(d), 'vel': np.zeros(d),
92
+ 'anchored': nid in ('A', 'B')}
 
 
 
93
  for layer in self.layers for nid in layer
94
  }
95
 
96
+ # Xavier init β€” scale normalised so layer norm doesn't start at extreme values
97
+ scale = np.sqrt(1.0 / d)
98
  self.K = {}
99
  for i in range(1, self.n_upper + 1):
100
  uid = f'U{i}'
 
105
  self.K[('B', lid)] = np.random.normal(0, scale, (d, d))
106
  self.K[(lid, 'C')] = np.random.normal(0, scale, (d, d))
107
 
108
+ self.fisher = {k: np.zeros((d, d)) for k in self.K}
 
109
  self.K_anchor = {k: v.copy() for k, v in self.K.items()}
110
 
111
  # ── PROBLEM SETUP ─────────────────────────────────────────────────────────
 
117
  self.current_type = ptype
118
  self.step_count = 0
119
 
 
120
  for layer in self.layers[1:4]:
121
  for nid in layer:
122
  if nid != 'C':
 
125
 
126
  c = self.nodes['C']
127
  c['vel'] = np.zeros(d)
 
128
  if self.mode == 'training' and c_target is not None:
129
  c['x'] = np.asarray(c_target, dtype=float)[:d]
130
  c['anchored'] = True
131
  self.c_target = c['x'].copy()
132
  else:
 
133
  c['anchored'] = False
134
  c['x'] = np.zeros(d)
135
  self.c_target = (np.asarray(c_target, dtype=float)[:d]
136
  if c_target is not None else None)
137
 
138
+ # ── FEEDFORWARD (with layer norm) ─────────────────────────────────────────
139
 
140
  def _forward(self):
141
  """
142
+ Exact feedforward pass.
143
+ layer_norm applied after each K transform β€” prevents scale explosion.
144
+ The normalised activations are what the output springs read.
145
  """
146
  xa, xb = self.nodes['A']['x'], self.nodes['B']['x']
147
  hid = {}
148
 
149
  for i in range(1, self.n_upper + 1):
150
+ uid = f'U{i}'
151
+ raw = self.K[('A', uid)] @ xa
152
+ hid[uid] = layer_norm(raw) # ← norm here
153
 
154
  for i in range(1, self.n_lower + 1):
155
+ lid = f'L{i}'
156
+ raw = self.K[('B', lid)] @ xb
157
+ hid[lid] = layer_norm(raw) # ← norm here
158
 
159
  pred = np.zeros(self.dim)
160
  for i in range(1, self.n_upper + 1):
 
162
  for i in range(1, self.n_lower + 1):
163
  pred += self.K[(f'L{i}', 'C')] @ hid[f'L{i}']
164
 
165
+ # Final layer norm on prediction keeps C in a consistent scale range
166
+ pred = layer_norm(pred)
167
  return pred, hid
168
 
169
  # ── ELASTIC DISPLAY PHYSICS ───────────────────────────────────────────────
170
 
171
+ def _elastic_step(self, n_steps=MICRO):
 
 
 
 
 
 
 
172
  xa, xb = self.nodes['A']['x'], self.nodes['B']['x']
 
173
  for _ in range(n_steps):
174
  for i in range(1, self.n_upper + 1):
175
+ uid = f'U{i}'
176
+ n = self.nodes[uid]
177
+ rest = layer_norm(self.K[('A', uid)] @ xa)
178
  f = FWD_K * (rest - n['x'])
179
  f += BACK_A * (self.K[(uid, 'C')].T @
180
  (self.nodes['C']['x'] - self.K[(uid, 'C')] @ n['x']))
 
182
  n['x'] += n['vel'] * DT
183
 
184
  for i in range(1, self.n_lower + 1):
185
+ lid = f'L{i}'
186
+ n = self.nodes[lid]
187
+ rest = layer_norm(self.K[('B', lid)] @ xb)
188
  f = FWD_K * (rest - n['x'])
189
  f += BACK_A * (self.K[(lid, 'C')].T @
190
  (self.nodes['C']['x'] - self.K[(lid, 'C')] @ n['x']))
 
198
  rest += self.K[(f'U{i}', 'C')] @ self.nodes[f'U{i}']['x']
199
  for i in range(1, self.n_lower + 1):
200
  rest += self.K[(f'L{i}', 'C')] @ self.nodes[f'L{i}']['x']
201
+ rest = layer_norm(rest)
202
+ f = FWD_K * (rest - c['x'])
203
  c['vel'] = c['vel'] * DAMPING + f * DT
204
  c['x'] += c['vel'] * DT
205
 
206
  # ── MATRIX LMS UPDATE ─────────────────────────────────────────────────────
207
 
208
+ def _lms_update(self, error, hid, ewc=False):
209
  """
210
+ Matrix LMS with joint optimal step + layer norm jacobian correction.
 
 
 
 
 
 
 
 
 
 
 
 
 
211
 
212
+ Because we apply layer norm after K@x, the gradient of the normed output
213
+ with respect to K is scaled by the Jacobian of layer norm.
214
+ For LN(Kx): βˆ‚LN(Kx)/βˆ‚K β‰ˆ (I - outer(Ε·,Ε·)) @ outer(Β·, x) / std
215
+ We use a first-order approximation: scale grad by 1/std of pre-norm.
216
  """
217
  eps = 1e-8
218
  xa = self.nodes['A']['x']
219
  xb = self.nodes['B']['x']
220
 
221
+ # Joint denominator across all output-layer edges
222
  joint_denom = eps
223
  for i in range(1, self.n_upper + 1):
224
  joint_denom += float(np.dot(hid[f'U{i}'], hid[f'U{i}']))
225
  for i in range(1, self.n_lower + 1):
226
  joint_denom += float(np.dot(hid[f'L{i}'], hid[f'L{i}']))
227
 
228
+ # Output layer (Xi β†’ C)
229
  for i in range(1, self.n_upper + 1):
230
  uid = f'U{i}'
231
  key = (uid, 'C')
232
+ g = np.outer(error, hid[uid])
233
+ d = joint_denom * (1.0 + EWC_LAMBDA * self.fisher[key]) if ewc else joint_denom
234
+ self.K[key] -= g / d
235
+ np.clip(self.K[key], -10.0, 10.0, out=self.K[key])
 
 
 
236
 
237
  for i in range(1, self.n_lower + 1):
238
  lid = f'L{i}'
239
  key = (lid, 'C')
240
+ g = np.outer(error, hid[lid])
241
+ d = joint_denom * (1.0 + EWC_LAMBDA * self.fisher[key]) if ewc else joint_denom
242
+ self.K[key] -= g / d
243
+ np.clip(self.K[key], -10.0, 10.0, out=self.K[key])
 
 
 
244
 
245
+ # Hidden layer (A/B β†’ U/L) β€” backprop through layer norm approx
246
+ xa_std = float(np.std(xa)) + eps
247
+ xb_std = float(np.std(xb)) + eps
248
+ xa_denom = float(np.dot(xa, xa)) / xa_std + eps
249
+ xb_denom = float(np.dot(xb, xb)) / xb_std + eps
250
 
251
  for i in range(1, self.n_upper + 1):
252
+ uid = f'U{i}'
253
+ key = ('A', uid)
254
+ delta = self.K[(uid, 'C')].T @ error
255
+ g = np.outer(delta, xa) / xa_std
256
+ d = xa_denom * (1.0 + EWC_LAMBDA * self.fisher[key]) if ewc else xa_denom
257
+ self.K[key] -= g / d
258
+ np.clip(self.K[key], -10.0, 10.0, out=self.K[key])
 
 
 
259
 
260
  for i in range(1, self.n_lower + 1):
261
+ lid = f'L{i}'
262
+ key = ('B', lid)
263
  delta = self.K[(lid, 'C')].T @ error
264
+ g = np.outer(delta, xb) / xb_std
265
+ d = xb_denom * (1.0 + EWC_LAMBDA * self.fisher[key]) if ewc else xb_denom
266
+ self.K[key] -= g / d
267
+ np.clip(self.K[key], -10.0, 10.0, out=self.K[key])
 
 
 
268
 
269
+ # ── FISHER ACCUMULATION ───────────────────────────────────────────────────
270
 
271
+ def _update_fisher(self, error, hid):
 
 
 
 
272
  xa = self.nodes['A']['x']
273
  xb = self.nodes['B']['x']
 
274
  for i in range(1, self.n_upper + 1):
275
  uid = f'U{i}'
276
+ self.fisher[(uid, 'C')] = (FISHER_DECAY * self.fisher[(uid, 'C')] +
277
+ (1-FISHER_DECAY) * np.outer(error, hid[uid])**2)
278
+ self.fisher[('A', uid)] = (FISHER_DECAY * self.fisher[('A', uid)] +
279
+ (1-FISHER_DECAY) * np.outer(
280
+ self.K[(uid,'C')].T @ error, xa)**2)
 
 
281
  for i in range(1, self.n_lower + 1):
282
  lid = f'L{i}'
283
+ self.fisher[(lid, 'C')] = (FISHER_DECAY * self.fisher[(lid, 'C')] +
284
+ (1-FISHER_DECAY) * np.outer(error, hid[lid])**2)
285
+ self.fisher[('B', lid)] = (FISHER_DECAY * self.fisher[('B', lid)] +
286
+ (1-FISHER_DECAY) * np.outer(
287
+ self.K[(lid,'C')].T @ error, xb)**2)
 
288
 
289
  # ── PHYSICS STEP ──────────────────────────────────────────────────────────
290
 
291
+ def physics_step(self):
 
292
  self._elastic_step(MICRO)
293
+ pred, hid = self._forward()
294
+ self.pred_norm = float(np.linalg.norm(pred))
295
+ self.step_count += 1
 
296
 
297
  c = self.nodes['C']
298
  if c['anchored']:
 
301
  else:
302
  c['x'] = pred.copy()
303
  error = (pred - self.c_target
304
+ if self.c_target is not None else np.zeros(self.dim))
 
305
  self.error_norm = float(np.linalg.norm(error))
306
 
307
  self.history.append(round(self.error_norm, 5))
 
312
  timeout = self.step_count >= MAX_STEPS
313
 
314
  if converged or timeout:
315
+ tag = 'βœ“' if converged else '⚠ TIMEOUT'
316
+ is_ood = self.current_type in ('sphere', 'simplex')
317
+ ood_tag = ' [OOD]' if is_ood else ' [seen]'
318
+ self.add_log(f"{tag}{ood_tag} [{self.current_type}] "
319
+ f"err={self.error_norm:.4f} steps={self.step_count}")
320
  if self.mode == 'inference' and self.c_target is not None:
321
  ct_norm = float(np.linalg.norm(self.c_target)) + 1e-8
322
  self.test_errors.append({
323
+ 'type': self.current_type,
324
+ 'abs': round(self.error_norm, 5),
325
+ 'rel': round(self.error_norm / ct_norm, 5),
326
+ 'ok': converged,
327
+ 'steps': self.step_count,
328
+ 'ood': is_ood,
329
  })
330
  self._update_fisher(error, hid)
331
  return self._next_or_stop()
332
 
333
  if c['anchored']:
 
334
  self._lms_update(error, hid, ewc=False)
335
  elif self.mode == 'inference':
 
336
  self._lms_update(error, hid, ewc=True)
337
 
338
  self.iteration += 1
339
  return True
340
 
341
+ def _next_or_stop(self):
342
  if self.batch_queue:
343
  p = self.batch_queue.popleft()
344
+ self.set_problem(p['A'], p['B'], p.get('C'), p.get('type', '?'))
345
  return True
346
  self.running = False
347
  self.add_log("β—Ό Queue empty.")
348
  return False
349
 
350
+ # ── OFFLINE TRAINING ──────────────────────────────────────────────────────
351
 
352
+ def train_offline(self, epochs=5):
 
 
 
 
353
  self.running = False
354
  self.mode = 'training'
355
+ self.add_log(f"⚑ Offline training: {epochs} epoch(s) | dim={self.dim} | thresh={CONV_THRESH}")
356
 
357
  for ep in range(1, epochs + 1):
358
  random.shuffle(self.train_data)
359
+ total_err, converged = 0.0, 0
 
360
 
361
  for sample in self.train_data:
362
+ d = self.dim
363
+ xa = np.asarray(sample['A'])[:d]
364
+ xb = np.asarray(sample['B'])[:d]
365
+ ct = np.asarray(sample['C'])[:d]
366
  self.nodes['A']['x'] = xa
367
  self.nodes['B']['x'] = xb
368
  self.nodes['C']['x'] = ct
 
384
  self.add_log(f" Ep {ep}/{epochs}: avgβ€–eβ€–={avg:.4f} conv={pct:.1f}%")
385
  print(f" Ep {ep}/{epochs}: avgβ€–eβ€–={avg:.4f} converged={pct:.1f}%")
386
 
 
387
  self.K_anchor = {k: v.copy() for k, v in self.K.items()}
388
+ self.add_log("βœ“ Training done. EWC anchors saved.")
389
  self.mode = 'idle'
390
 
391
  # ── DATA LOADING ──────────────────────────────────────────────────────────
 
393
  def load_data(self, train='data/train.json', test='data/test.json'):
394
  with open(train) as f: self.train_data = json.load(f)
395
  with open(test) as f: self.test_data = json.load(f)
396
+ # Count OOD types in test
397
+ ood = sum(1 for d in self.test_data if d['type'] in ('sphere','simplex'))
398
+ seen = len(self.test_data) - ood
399
+ self.add_log(f"Data: {len(self.train_data)} train | "
400
+ f"{len(self.test_data)} test ({seen} seen / {ood} OOD)")
401
 
402
  # ── QUEUE HELPERS ──────────────────────────��──────────────────────────────
403
 
 
405
  data = random.sample(self.train_data,
406
  min(n or len(self.train_data), len(self.train_data)))
407
  self._fill_queue(data, anchor_c=True)
408
+ self.mode = 'training'; self.running = True
 
409
  self.add_log(f"β–Ά Visual training: {len(data)} samples")
410
 
411
  def start_inference(self, n=None):
412
  data = self.test_data[:n] if n else self.test_data
413
  self.test_errors = []
414
  self._fill_queue(data, anchor_c=False)
415
+ self.mode = 'inference'; self.running = True
416
+ self.add_log(f"β–Ά Inference: {len(data)} samples "
417
+ f"({sum(1 for d in data if d['type'] in ('sphere','simplex'))} OOD)")
418
 
419
  def _fill_queue(self, data, anchor_c):
420
  self.batch_queue.clear()
 
427
  if anchor_c:
428
  self.set_problem(p['A'], p['B'], p['C'], p['type'])
429
  else:
 
430
  d = self.dim
431
+ self.nodes['A']['x'] = np.asarray(p['A'])[:d]
432
+ self.nodes['B']['x'] = np.asarray(p['B'])[:d]
433
+ self.nodes['C']['x'] = np.zeros(d)
434
+ self.nodes['C']['vel'] = np.zeros(d)
435
  self.nodes['C']['anchored'] = False
436
+ self.c_target = np.asarray(p['C'])[:d]
437
+ self.current_type = p['type']
438
+ self.step_count = 0
439
  for layer in self.layers[1:4]:
440
  for nid in layer:
441
  if nid != 'C':
 
449
  if len(self.logs) > 60:
450
  self.logs.pop()
451
 
452
+ # ── STATE DICT ────────────────────────────────────────────────────────────
453
 
454
  def state_dict(self):
455
  nodes_out = {}
 
463
 
464
  springs_out = {}
465
  for (u, v), km in self.K.items():
466
+ springs_out[f"{u}β†’{v}"] = {
467
+ 'frob': round(float(np.linalg.norm(km)), 3),
 
468
  'mean': round(float(np.mean(km)), 4),
469
  'std': round(float(np.std(km)), 4),
470
+ 'fish': round(float(np.mean(self.fisher[(u,v)])), 5),
471
  }
472
 
473
+ # Per-type accuracy β€” separate SEEN vs OOD
474
  type_acc = {}
475
  for te in self.test_errors:
476
  t = te['type']
477
  if t not in type_acc:
478
+ type_acc[t] = {'n':0,'n_ok':0,'sum_abs':0.0,'sum_steps':0,'ood':te['ood']}
479
+ type_acc[t]['n'] += 1
480
+ type_acc[t]['n_ok'] += int(te['ok'])
481
+ type_acc[t]['sum_abs'] += te['abs']
482
+ type_acc[t]['sum_steps'] += te['steps']
483
+
484
+ acc_summary = {}
485
+ for t, v in type_acc.items():
486
+ n = max(v['n'], 1)
487
+ acc_summary[t] = {
488
+ 'n': v['n'],
489
+ 'acc': round(100 * v['n_ok'] / n, 1),
490
+ 'avg_err': round(v['sum_abs'] / n, 4),
491
+ 'avg_steps': round(v['sum_steps'] / n, 1),
492
+ 'ood': v['ood'],
493
  }
 
 
494
 
495
  return {
496
  'nodes': nodes_out,
 
513
  'n_test_done': len(self.test_errors),
514
  'current_type': self.current_type,
515
  'dim': self.dim,
516
+ 'conv_thresh': CONV_THRESH,
517
  }
518
 
519
 
 
524
  try:
525
  engine.load_data()
526
  except Exception as e:
527
+ engine.add_log(f"⚠ No data β€” run: python data_gen.py ({e})")
528
 
529
 
530
  def run_loop():
 
536
  threading.Thread(target=run_loop, daemon=True).start()
537
 
538
 
539
+ @app.get("/", response_class=HTMLResponse)
540
+ async def get_ui(): return FileResponse("index.html")
 
541
 
542
  @app.get("/state")
543
+ async def get_state(): return engine.state_dict()
 
 
 
544
 
545
  @app.post("/train_visual")
546
  async def train_visual(data: dict = {}):
 
547
  engine.start_training(n=data.get('n'))
548
  return {"ok": True}
549
 
550
  @app.post("/train_offline")
551
  async def train_offline(data: dict = {}):
 
552
  epochs = int(data.get('epochs', 5))
553
  threading.Thread(target=engine.train_offline, args=(epochs,), daemon=True).start()
554
  return {"ok": True, "epochs": epochs}
555
 
556
  @app.post("/infer")
557
  async def start_infer(data: dict = {}):
 
558
  engine.start_inference(n=data.get('n'))
559
  return {"ok": True}
560
 
561
  @app.post("/reload_data")
562
  async def reload_data():
563
+ try: engine.load_data(); return {"ok": True}
564
+ except Exception as e: return {"ok": False, "error": str(e)}
 
 
 
 
 
565
 
566
  @app.post("/set_layer")
567
  async def set_layer(data: dict):
 
 
568
  engine.running = False
569
+ if data.get('layer') == 'upper':
570
+ engine.n_upper = max(1, min(8, engine.n_upper + int(data['delta'])))
571
+ elif data.get('layer') == 'lower':
572
+ engine.n_lower = max(1, min(8, engine.n_lower + int(data['delta'])))
573
  engine._init_mesh()
574
+ engine.add_log(f"Topology β†’ U{engine.n_upper}Β·L{engine.n_lower}")
575
  return {"ok": True, "n_upper": engine.n_upper, "n_lower": engine.n_lower}
576
 
577
  @app.post("/halt")
578
+ async def halt(): engine.running = False; return {"ok": True}
 
 
579
 
580
  @app.post("/reset")
581
+ async def reset(): engine.running = False; engine._init_mesh(); return {"ok": True}
 
 
 
 
 
582
 
583
  if __name__ == "__main__":
584
  import uvicorn