everydaytok commited on
Commit
56e8396
Β·
verified Β·
1 Parent(s): 8869a74

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +161 -159
app.py CHANGED
@@ -23,7 +23,7 @@ class SimEngine:
23
  self.n_upper = 3
24
  self.n_lower = 3
25
  self.back_alpha = 0.45
26
- self.cross_connect = False # ← new
27
  self.running = False
28
  self.batch_queue = collections.deque()
29
  self.logs = []
@@ -31,6 +31,7 @@ class SimEngine:
31
  self.current_error = 0.0
32
  self.current_prediction = 0.0
33
  self.history = []
 
34
  self._init_mesh()
35
 
36
  # ── TOPOLOGY ──────────────────────────────────────────────────────────────
@@ -63,7 +64,6 @@ class SimEngine:
63
  self.nodes[f'B{d}']['x'] = 3.0
64
 
65
  self.springs = {}
66
- # Vertical springs (always)
67
  for d in range(1, n+1):
68
  for j in range(1, self.n_upper+1):
69
  uid = f'U{d}_{j}'
@@ -74,54 +74,89 @@ class SimEngine:
74
  self.springs[(f'B{d}', lid)] = round(random.uniform(0.85, 1.15), 4)
75
  self.springs[(lid, f'C{d}')] = round(random.uniform(0.85, 1.15), 4)
76
 
77
- # Lateral springs (only when cross_connect=True)
78
- if self.cross_connect and n > 1:
79
- self._add_lateral_springs()
 
80
 
81
- # ── LATERAL SPRING HELPERS ────────────────────────────────────────────────
82
 
83
- def _lateral_keys(self):
84
- """All lateral spring keys for current topology."""
85
- keys = []
86
- n = self.n_inputs
 
 
 
 
 
 
87
  if n < 2:
88
- return keys
89
  for d in range(1, n):
90
- for j in range(1, self.n_upper+1):
91
- keys.append((f'U{d}_{j}', f'U{d+1}_{j}'))
92
- for j in range(1, self.n_lower+1):
93
- keys.append((f'L{d}_{j}', f'L{d+1}_{j}'))
94
- return keys
95
 
96
- def _add_lateral_springs(self):
97
- """Add lateral springs without disturbing existing vertical springs."""
98
- for key in self._lateral_keys():
99
- if key not in self.springs:
100
- self.springs[key] = round(random.uniform(0.85, 1.15), 4)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
101
 
102
- def _remove_lateral_springs(self):
103
- """Remove lateral springs, preserving vertical springs."""
104
- lateral = set(self._lateral_keys())
105
- for key in lateral:
106
- self.springs.pop(key, None)
 
 
 
 
107
 
108
  def toggle_cross_connect(self):
109
  """
110
- Toggle cross_connect ON/OFF.
111
- ON β†’ adds lateral springs between same-row hidden nodes of adjacent dims.
112
- OFF β†’ removes lateral springs.
113
- Existing vertical spring values are never touched.
 
114
  """
115
  self.cross_connect = not self.cross_connect
 
 
 
 
116
  if self.cross_connect:
117
- self._add_lateral_springs()
118
  self.add_log(
119
- f"Cross-connect ON β€” {len(self._lateral_keys())} lateral springs added"
 
 
120
  )
121
  else:
122
- n_removed = len([k for k in self._lateral_keys() if k in self.springs])
123
- self._remove_lateral_springs()
124
- self.add_log(f"Cross-connect OFF β€” {n_removed} lateral springs removed")
125
 
126
  # ── LOGGING ───────────────────────────────────────────────────────────────
127
 
@@ -141,8 +176,7 @@ class SimEngine:
141
  def _to_vec(self, val, n):
142
  if isinstance(val, (list, tuple)):
143
  v = [float(x) for x in val]
144
- if len(v) >= n:
145
- return v[:n]
146
  return v + [v[-1]] * (n - len(v))
147
  return [float(val)] * n
148
 
@@ -186,98 +220,76 @@ class SimEngine:
186
  c['anchored'] = False
187
  c['x'] = 0.0
188
 
189
- # ── ELASTIC DISPLAY PHYSICS ───────────────────────────────────────────────
190
- # Refactored to accumulate all forces first then integrate.
191
- # This lets lateral forces slot in cleanly alongside vertical forces.
 
 
192
 
193
  def _elastic_step(self, n_steps):
194
  alpha = self.back_alpha
195
  n = self.n_inputs
196
 
197
  for _ in range(n_steps):
198
- forces = {}
199
- for nid, nd in self.nodes.items():
200
- if not nd['anchored']:
201
- forces[nid] = 0.0
202
 
203
- # ── Vertical forces (per dimension) ───────────────────────────
204
  for d in range(1, n+1):
205
  A_val = self.nodes[f'A{d}']['x']
206
  B_val = self.nodes[f'B{d}']['x']
207
  C_val = self.nodes[f'C{d}']['x']
208
 
209
  for j in range(1, self.n_upper+1):
210
- uid = f'U{d}_{j}'
211
- rest = self.springs[(f'A{d}', uid)] * A_val
212
- f = FWD_K * (rest - self.nodes[uid]['x'])
 
213
  if alpha > 0:
214
- kuc = self.springs[(uid, f'C{d}')]
215
  f += alpha * kuc * (C_val - self.nodes[uid]['x'])
216
- forces[uid] = forces.get(uid, 0.0) + f
 
217
 
218
  for j in range(1, self.n_lower+1):
219
- lid = f'L{d}_{j}'
220
- rest = self.springs[(f'B{d}', lid)] * B_val
221
- f = FWD_K * (rest - self.nodes[lid]['x'])
 
222
  if alpha > 0:
223
- klc = self.springs[(lid, f'C{d}')]
224
  f += alpha * klc * (C_val - self.nodes[lid]['x'])
225
- forces[lid] = forces.get(lid, 0.0) + f
 
226
 
227
  c = self.nodes[f'C{d}']
228
  if not c['anchored']:
229
  rest_c = (
230
- sum(self.springs[(f'U{d}_{j}', f'C{d}')] * self.nodes[f'U{d}_{j}']['x']
 
231
  for j in range(1, self.n_upper+1)) +
232
- sum(self.springs[(f'L{d}_{j}', f'C{d}')] * self.nodes[f'L{d}_{j}']['x']
 
233
  for j in range(1, self.n_lower+1))
234
  )
235
- forces[f'C{d}'] = forces.get(f'C{d}', 0.0) + FWD_K * (rest_c - c['x'])
236
-
237
- # ── Lateral forces (cross_connect only) ───────────────────────
238
- # Each lateral spring K(u1, u2) pulls u1 toward u2 and vice-versa.
239
- # This is standard Hooke: F = K*(x_other - x_self).
240
- # Effect: adjacent-dimension hidden nodes share mechanical state β€”
241
- # tension in one dimension's hidden layer propagates sideways.
242
- if self.cross_connect and n > 1:
243
- for d in range(1, n):
244
- for j in range(1, self.n_upper+1):
245
- u1, u2 = f'U{d}_{j}', f'U{d+1}_{j}'
246
- key = (u1, u2)
247
- if key in self.springs:
248
- k = self.springs[key]
249
- dx = self.nodes[u2]['x'] - self.nodes[u1]['x']
250
- forces[u1] = forces.get(u1, 0.0) + k * dx
251
- forces[u2] = forces.get(u2, 0.0) - k * dx
252
- for j in range(1, self.n_lower+1):
253
- l1, l2 = f'L{d}_{j}', f'L{d+1}_{j}'
254
- key = (l1, l2)
255
- if key in self.springs:
256
- k = self.springs[key]
257
- dx = self.nodes[l2]['x'] - self.nodes[l1]['x']
258
- forces[l1] = forces.get(l1, 0.0) + k * dx
259
- forces[l2] = forces.get(l2, 0.0) - k * dx
260
-
261
- # ── Integrate ─────────────────────────────────────────────────
262
  max_v = 0.0
263
  for nid, f in forces.items():
264
  nd = self.nodes[nid]
265
  nd['vel'] = nd['vel'] * DAMPING + f * DT
266
  nd['x'] += nd['vel'] * DT
267
  max_v = max(max_v, abs(nd['vel']))
268
-
269
  if max_v < SETTLE:
270
  break
271
 
272
  # ── FEEDFORWARD ───────────────────────────────────────────────────────────
 
 
 
 
273
 
274
  def _feedforward(self):
275
- """
276
- When cross_connect=False: exact K*input (analytic, same as before).
277
- When cross_connect=True: hidden values read from elastic-settled
278
- node positions β€” they already encode lateral
279
- cross-dimensional mixing.
280
- """
281
  n = self.n_inputs
282
  preds = []
283
  ff = {}
@@ -287,29 +299,31 @@ class SimEngine:
287
  B_val = self.nodes[f'B{d}']['x']
288
 
289
  for j in range(1, self.n_upper+1):
290
- uid = f'U{d}_{j}'
291
- ff[uid] = (self.nodes[uid]['x'] if self.cross_connect
292
- else self.springs[(f'A{d}', uid)] * A_val)
 
293
 
294
  for j in range(1, self.n_lower+1):
295
- lid = f'L{d}_{j}'
296
- ff[lid] = (self.nodes[lid]['x'] if self.cross_connect
297
- else self.springs[(f'B{d}', lid)] * B_val)
 
298
 
299
  if self.architecture == 'multiplicative':
300
  nm = max(self.n_upper, self.n_lower)
301
  pred = 0.0
302
  for i in range(nm):
303
- uid = f'U{d}_{(i % self.n_upper)+1}'
304
- lid = f'L{d}_{(i % self.n_lower)+1}'
305
- ku = self.springs[(uid, f'C{d}')]
306
- kl = self.springs[(lid, f'C{d}')]
307
- pred += ku * ff[uid] * kl * ff[lid]
308
  else:
309
  pred = (
310
- sum(self.springs[(f'U{d}_{j}', f'C{d}')] * ff[f'U{d}_{j}']
311
  for j in range(1, self.n_upper+1)) +
312
- sum(self.springs[(f'L{d}_{j}', f'C{d}')] * ff[f'L{d}_{j}']
313
  for j in range(1, self.n_lower+1))
314
  )
315
  preds.append(pred)
@@ -317,11 +331,14 @@ class SimEngine:
317
  return preds, ff
318
 
319
  # ── LMS UPDATE ────────────────────────────────────────────────────────────
 
 
 
 
320
 
321
  def _lms_update(self, errors, ff):
322
  n = self.n_inputs
323
 
324
- # ── Vertical springs (same as before) ─────────────────────────────
325
  for d in range(1, n+1):
326
  err = errors[d-1]
327
  A_val = self.nodes[f'A{d}']['x']
@@ -330,57 +347,40 @@ class SimEngine:
330
 
331
  if self.architecture == 'additive':
332
  for j in range(1, self.n_upper+1):
333
- uid = f'U{d}_{j}'
334
- grads[(f'A{d}', uid)] = self.springs[(uid, f'C{d}')] * A_val
335
- grads[(uid, f'C{d}')] = self.springs[(f'A{d}', uid)] * A_val
 
 
 
336
  for j in range(1, self.n_lower+1):
337
- lid = f'L{d}_{j}'
338
- grads[(f'B{d}', lid)] = self.springs[(lid, f'C{d}')] * B_val
339
- grads[(lid, f'C{d}')] = self.springs[(f'B{d}', lid)] * B_val
 
 
 
340
  else:
341
  nm = max(self.n_upper, self.n_lower)
342
  for i in range(nm):
343
- uid = f'U{d}_{(i % self.n_upper)+1}'
344
- lid = f'L{d}_{(i % self.n_lower)+1}'
345
- ku = self.springs[(uid, f'C{d}')]
346
- kl = self.springs[(lid, f'C{d}')]
347
- Uv = ff[uid]; Lv = ff[lid]
348
- grads[(f'A{d}', uid)] = ku * A_val * kl * Lv
349
- grads[(f'B{d}', lid)] = kl * B_val * ku * Uv
350
- grads[(uid, f'C{d}')] = Uv * kl * Lv
351
- grads[(lid, f'C{d}')] = Lv * ku * Uv
 
 
352
 
353
  norm_sq = sum(g*g for g in grads.values()) + 1e-10
354
  mu = err / norm_sq
355
  for key, g in grads.items():
356
- self.springs[key] -= mu * g
357
- self.springs[key] = max(-30.0, min(30.0, self.springs[key]))
358
-
359
- # ── Lateral spring update (cross_connect only) ─────────────────────
360
- # Uses a cross-error Hebbian rule:
361
- # Ξ”K(u1,u2) ∝ -(e_d * x_u2 + e_{d+1} * x_u1)
362
- # Interpretation: tighten lateral coupling when both neighbouring dims
363
- # would benefit from sharing information; loosen when they conflict.
364
- if self.cross_connect and n > 1:
365
- lr_lat = 0.005
366
- for d in range(1, n):
367
- e1 = errors[d-1]; e2 = errors[d]
368
- for j in range(1, self.n_upper+1):
369
- key = (f'U{d}_{j}', f'U{d+1}_{j}')
370
- if key in self.springs:
371
- x1 = ff.get(f'U{d}_{j}', 0.0)
372
- x2 = ff.get(f'U{d+1}_{j}', 0.0)
373
- grad = e1 * x2 + e2 * x1
374
- self.springs[key] -= lr_lat * grad / (grad**2 + 1e-10)**0.5
375
- self.springs[key] = max(-30.0, min(30.0, self.springs[key]))
376
- for j in range(1, self.n_lower+1):
377
- key = (f'L{d}_{j}', f'L{d+1}_{j}')
378
- if key in self.springs:
379
- x1 = ff.get(f'L{d}_{j}', 0.0)
380
- x2 = ff.get(f'L{d+1}_{j}', 0.0)
381
- grad = e1 * x2 + e2 * x1
382
- self.springs[key] -= lr_lat * grad / (grad**2 + 1e-10)**0.5
383
- self.springs[key] = max(-30.0, min(30.0, self.springs[key]))
384
 
385
  # ── PHYSICS STEP ──────────────────────────────────────────────────────────
386
 
@@ -440,7 +440,8 @@ class SimEngine:
440
 
441
  def generate_batch(self, count=30):
442
  self.batch_queue.clear()
443
- n = self.n_inputs
 
444
  for _ in range(count):
445
  a_vec = [round(random.uniform(1.0, 10.0), 2) for _ in range(n)]
446
  b_vec = [round(random.uniform(1.0, 10.0), 2) for _ in range(n)]
@@ -449,9 +450,10 @@ class SimEngine:
449
  p = self.batch_queue.popleft()
450
  self.set_problem(p['a'], p['b'], p.get('c'))
451
  self.running = True
452
- cx = 'X' if self.cross_connect else 'Β·'
453
  self.add_log(
454
- f"β–Ά {count} | {self.dataset_type} | D={n} U{self.n_upper}Β·L{self.n_lower} [{cx}]"
 
455
  )
456
 
457
 
@@ -475,9 +477,8 @@ async def get_ui():
475
  @app.get("/state")
476
  async def get_state():
477
  springs_out = {f"{u}β†’{v}": round(k, 5) for (u, v), k in engine.springs.items()}
478
- n = engine.n_inputs
479
- n_lat = len([k for k in engine.springs if
480
- k[0][0] in ('U','L') and k[1][0] in ('U','L') and k[0][0] == k[1][0]])
481
  return {
482
  'nodes': engine.nodes,
483
  'springs': springs_out,
@@ -497,7 +498,7 @@ async def get_state():
497
  'n_lower': engine.n_lower,
498
  'back_alpha': engine.back_alpha,
499
  'cross_connect': engine.cross_connect,
500
- 'n_lateral': n_lat,
501
  'queue_size': len(engine.batch_queue),
502
  }
503
 
@@ -512,13 +513,12 @@ async def set_mode(data: dict):
512
 
513
  @app.post("/toggle_cross")
514
  async def toggle_cross():
515
- engine.running = False
516
  engine.toggle_cross_connect()
517
  return {
518
  "ok": True,
519
  "cross_connect": engine.cross_connect,
520
  "n_springs": len(engine.springs),
521
- "n_lateral": len(engine._lateral_keys()),
522
  }
523
 
524
 
@@ -547,9 +547,10 @@ async def config(data: dict):
547
  engine.running = False
548
  engine._init_mesh()
549
  engine.logs = []
 
550
  engine.add_log(
551
  f"Mesh rebuilt: D={new_ni} U{new_nu}Β·L{new_nl} "
552
- f"cross={'ON' if engine.cross_connect else 'OFF'}"
553
  )
554
  else:
555
  engine.add_log(
@@ -569,9 +570,10 @@ async def set_layer(data: dict):
569
  elif layer == 'lower': engine.n_lower = max(1, min(16, engine.n_lower + delta))
570
  engine.running = False
571
  engine._init_mesh()
 
572
  engine.add_log(
573
  f"Topology β†’ D={engine.n_inputs} U{engine.n_upper}Β·L{engine.n_lower} "
574
- f"cross={'ON' if engine.cross_connect else 'OFF'}"
575
  )
576
  return {
577
  "ok": True,
 
23
  self.n_upper = 3
24
  self.n_lower = 3
25
  self.back_alpha = 0.45
26
+ self.cross_connect = False
27
  self.running = False
28
  self.batch_queue = collections.deque()
29
  self.logs = []
 
31
  self.current_error = 0.0
32
  self.current_prediction = 0.0
33
  self.history = []
34
+ self.merge_map = {}
35
  self._init_mesh()
36
 
37
  # ── TOPOLOGY ──────────────────────────────────────────────────────────────
 
64
  self.nodes[f'B{d}']['x'] = 3.0
65
 
66
  self.springs = {}
 
67
  for d in range(1, n+1):
68
  for j in range(1, self.n_upper+1):
69
  uid = f'U{d}_{j}'
 
74
  self.springs[(f'B{d}', lid)] = round(random.uniform(0.85, 1.15), 4)
75
  self.springs[(lid, f'C{d}')] = round(random.uniform(0.85, 1.15), 4)
76
 
77
+ # Structural merge: co-located nodes become shared vertices
78
+ self.merge_map = self._compute_merge_map() if self.cross_connect else {}
79
+ if self.merge_map:
80
+ self._apply_merge()
81
 
82
+ # ── NODE MERGE ────────────────────────────────────────────────────────────
83
 
84
+ def _compute_merge_map(self):
85
+ """
86
+ The rightmost upper hidden node of dim d and the leftmost upper
87
+ hidden node of dim d+1 are visually co-located β†’ one shared vertex.
88
+ Same for lower. Only applies when β‰₯2 hidden nodes per side so the
89
+ boundary node is distinct from the centre node.
90
+ Returns {removed_id: canonical_id}.
91
+ """
92
+ mm = {}
93
+ n = self.n_inputs
94
  if n < 2:
95
+ return mm
96
  for d in range(1, n):
97
+ if self.n_upper >= 2:
98
+ mm[f'U{d+1}_1'] = f'U{d}_{self.n_upper}'
99
+ if self.n_lower >= 2:
100
+ mm[f'L{d+1}_1'] = f'L{d}_{self.n_lower}'
101
+ return mm
102
 
103
+ def _apply_merge(self):
104
+ """
105
+ Retarget all spring keys through merge_map and remove duplicate nodes.
106
+ e.g. (A2, U2_1) β†’ (A2, U1_3). If two remapped keys collide
107
+ (should not happen with this rule) their constants are averaged.
108
+ """
109
+ mm = self.merge_map
110
+ new_springs = {}
111
+ for (u, v), k in self.springs.items():
112
+ key = (mm.get(u, u), mm.get(v, v))
113
+ if key in new_springs:
114
+ new_springs[key] = (new_springs[key] + k) / 2.0
115
+ else:
116
+ new_springs[key] = k
117
+ self.springs = new_springs
118
+
119
+ removed = set(mm.keys())
120
+ for rid in removed:
121
+ self.nodes.pop(rid, None)
122
+ self.layers = [
123
+ [nid for nid in layer if nid not in removed]
124
+ for layer in self.layers
125
+ ]
126
+
127
+ # ── MERGE HELPERS ─────────────────────────────────────────────────────────
128
 
129
+ def _resolve(self, nid):
130
+ """Resolve a node ID to its canonical (possibly merged) ID."""
131
+ return self.merge_map.get(nid, nid)
132
+
133
+ def _spring(self, u, v):
134
+ """Spring constant lookup with automatic merge-map resolution."""
135
+ return self.springs[(self._resolve(u), self._resolve(v))]
136
+
137
+ # ── CROSS CONNECT TOGGLE ────────────────────────────────���─────────────────
138
 
139
  def toggle_cross_connect(self):
140
  """
141
+ Toggle structural node merging ON/OFF.
142
+ ON β†’ overlapping boundary hidden nodes become one shared vertex
143
+ with springs to both neighbouring inputs/outputs.
144
+ OFF β†’ fully independent parallel hourglasses (original behaviour).
145
+ Rebuilds the mesh (topology change); spring values reset.
146
  """
147
  self.cross_connect = not self.cross_connect
148
+ self.running = False
149
+ self._init_mesh()
150
+ self.logs = []
151
+ ns = len(self.merge_map)
152
  if self.cross_connect:
 
153
  self.add_log(
154
+ f"Cross-connect ON β€” {ns} shared "
155
+ f"{'vertex' if ns == 1 else 'vertices'} "
156
+ f"(structural merge, no extra springs)"
157
  )
158
  else:
159
+ self.add_log("Cross-connect OFF β€” independent parallel hourglasses")
 
 
160
 
161
  # ── LOGGING ───────────────────────────────────────────────────────────────
162
 
 
176
  def _to_vec(self, val, n):
177
  if isinstance(val, (list, tuple)):
178
  v = [float(x) for x in val]
179
+ if len(v) >= n: return v[:n]
 
180
  return v + [v[-1]] * (n - len(v))
181
  return [float(val)] * n
182
 
 
220
  c['anchored'] = False
221
  c['x'] = 0.0
222
 
223
+ # ── ELASTIC STEP ──────────────────────────────────────────────────────────
224
+ # Forces are accumulated first, then integrated β€” clean slot for any
225
+ # future force contributions. Merge-aware: all node lookups resolve
226
+ # through merge_map so shared vertices accumulate forces from every
227
+ # dimension that owns them.
228
 
229
  def _elastic_step(self, n_steps):
230
  alpha = self.back_alpha
231
  n = self.n_inputs
232
 
233
  for _ in range(n_steps):
234
+ forces = {nid: 0.0 for nid, nd in self.nodes.items()
235
+ if not nd['anchored']}
 
 
236
 
 
237
  for d in range(1, n+1):
238
  A_val = self.nodes[f'A{d}']['x']
239
  B_val = self.nodes[f'B{d}']['x']
240
  C_val = self.nodes[f'C{d}']['x']
241
 
242
  for j in range(1, self.n_upper+1):
243
+ uid_raw = f'U{d}_{j}'
244
+ uid = self._resolve(uid_raw)
245
+ ak = self._spring(f'A{d}', uid_raw)
246
+ f = FWD_K * (ak * A_val - self.nodes[uid]['x'])
247
  if alpha > 0:
248
+ kuc = self._spring(uid_raw, f'C{d}')
249
  f += alpha * kuc * (C_val - self.nodes[uid]['x'])
250
+ if uid in forces:
251
+ forces[uid] += f
252
 
253
  for j in range(1, self.n_lower+1):
254
+ lid_raw = f'L{d}_{j}'
255
+ lid = self._resolve(lid_raw)
256
+ bk = self._spring(f'B{d}', lid_raw)
257
+ f = FWD_K * (bk * B_val - self.nodes[lid]['x'])
258
  if alpha > 0:
259
+ klc = self._spring(lid_raw, f'C{d}')
260
  f += alpha * klc * (C_val - self.nodes[lid]['x'])
261
+ if lid in forces:
262
+ forces[lid] += f
263
 
264
  c = self.nodes[f'C{d}']
265
  if not c['anchored']:
266
  rest_c = (
267
+ sum(self._spring(f'U{d}_{j}', f'C{d}') *
268
+ self.nodes[self._resolve(f'U{d}_{j}')]['x']
269
  for j in range(1, self.n_upper+1)) +
270
+ sum(self._spring(f'L{d}_{j}', f'C{d}') *
271
+ self.nodes[self._resolve(f'L{d}_{j}')]['x']
272
  for j in range(1, self.n_lower+1))
273
  )
274
+ forces[f'C{d}'] = forces.get(f'C{d}', 0.0) + \
275
+ FWD_K * (rest_c - c['x'])
276
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
277
  max_v = 0.0
278
  for nid, f in forces.items():
279
  nd = self.nodes[nid]
280
  nd['vel'] = nd['vel'] * DAMPING + f * DT
281
  nd['x'] += nd['vel'] * DT
282
  max_v = max(max_v, abs(nd['vel']))
 
283
  if max_v < SETTLE:
284
  break
285
 
286
  # ── FEEDFORWARD ───────────────────────────────────────────────────────────
287
+ # cross_connect=False β†’ analytic K*input (original behaviour)
288
+ # cross_connect=True β†’ settled node positions used; shared vertices
289
+ # already encode cross-dimensional mixing from
290
+ # receiving forces from both neighbouring inputs.
291
 
292
  def _feedforward(self):
 
 
 
 
 
 
293
  n = self.n_inputs
294
  preds = []
295
  ff = {}
 
299
  B_val = self.nodes[f'B{d}']['x']
300
 
301
  for j in range(1, self.n_upper+1):
302
+ uid_raw = f'U{d}_{j}'
303
+ uid = self._resolve(uid_raw)
304
+ ff[uid_raw] = (self.nodes[uid]['x'] if self.cross_connect
305
+ else self._spring(f'A{d}', uid_raw) * A_val)
306
 
307
  for j in range(1, self.n_lower+1):
308
+ lid_raw = f'L{d}_{j}'
309
+ lid = self._resolve(lid_raw)
310
+ ff[lid_raw] = (self.nodes[lid]['x'] if self.cross_connect
311
+ else self._spring(f'B{d}', lid_raw) * B_val)
312
 
313
  if self.architecture == 'multiplicative':
314
  nm = max(self.n_upper, self.n_lower)
315
  pred = 0.0
316
  for i in range(nm):
317
+ uid_raw = f'U{d}_{(i % self.n_upper)+1}'
318
+ lid_raw = f'L{d}_{(i % self.n_lower)+1}'
319
+ ku = self._spring(uid_raw, f'C{d}')
320
+ kl = self._spring(lid_raw, f'C{d}')
321
+ pred += ku * ff[uid_raw] * kl * ff[lid_raw]
322
  else:
323
  pred = (
324
+ sum(self._spring(f'U{d}_{j}', f'C{d}') * ff[f'U{d}_{j}']
325
  for j in range(1, self.n_upper+1)) +
326
+ sum(self._spring(f'L{d}_{j}', f'C{d}') * ff[f'L{d}_{j}']
327
  for j in range(1, self.n_lower+1))
328
  )
329
  preds.append(pred)
 
331
  return preds, ff
332
 
333
  # ── LMS UPDATE ────────────────────────────────────────────────────────────
334
+ # Each dimension's error drives its own spring gradients independently.
335
+ # Shared vertices have springs from both dimensions updated separately
336
+ # β€” the merge means those canonical spring keys already exist in
337
+ # self.springs, so the updates land correctly with no special casing.
338
 
339
  def _lms_update(self, errors, ff):
340
  n = self.n_inputs
341
 
 
342
  for d in range(1, n+1):
343
  err = errors[d-1]
344
  A_val = self.nodes[f'A{d}']['x']
 
347
 
348
  if self.architecture == 'additive':
349
  for j in range(1, self.n_upper+1):
350
+ uid_raw = f'U{d}_{j}'
351
+ uid = self._resolve(uid_raw)
352
+ ak_key = (f'A{d}', uid)
353
+ uc_key = (uid, f'C{d}')
354
+ grads[ak_key] = self._spring(uid_raw, f'C{d}') * A_val
355
+ grads[uc_key] = self._spring(f'A{d}', uid_raw) * A_val
356
  for j in range(1, self.n_lower+1):
357
+ lid_raw = f'L{d}_{j}'
358
+ lid = self._resolve(lid_raw)
359
+ bk_key = (f'B{d}', lid)
360
+ lc_key = (lid, f'C{d}')
361
+ grads[bk_key] = self._spring(lid_raw, f'C{d}') * B_val
362
+ grads[lc_key] = self._spring(f'B{d}', lid_raw) * B_val
363
  else:
364
  nm = max(self.n_upper, self.n_lower)
365
  for i in range(nm):
366
+ uid_raw = f'U{d}_{(i % self.n_upper)+1}'
367
+ lid_raw = f'L{d}_{(i % self.n_lower)+1}'
368
+ uid = self._resolve(uid_raw)
369
+ lid = self._resolve(lid_raw)
370
+ ku = self._spring(uid_raw, f'C{d}')
371
+ kl = self._spring(lid_raw, f'C{d}')
372
+ Uv = ff[uid_raw]; Lv = ff[lid_raw]
373
+ grads[(f'A{d}', uid)] = ku * A_val * kl * Lv
374
+ grads[(f'B{d}', lid)] = kl * B_val * ku * Uv
375
+ grads[(uid, f'C{d}')] = Uv * kl * Lv
376
+ grads[(lid, f'C{d}')] = Lv * ku * Uv
377
 
378
  norm_sq = sum(g*g for g in grads.values()) + 1e-10
379
  mu = err / norm_sq
380
  for key, g in grads.items():
381
+ if key in self.springs:
382
+ self.springs[key] -= mu * g
383
+ self.springs[key] = max(-30.0, min(30.0, self.springs[key]))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
384
 
385
  # ── PHYSICS STEP ──────────────────────────────────────────────────────────
386
 
 
440
 
441
  def generate_batch(self, count=30):
442
  self.batch_queue.clear()
443
+ n = self.n_inputs
444
+ ns = len(self.merge_map)
445
  for _ in range(count):
446
  a_vec = [round(random.uniform(1.0, 10.0), 2) for _ in range(n)]
447
  b_vec = [round(random.uniform(1.0, 10.0), 2) for _ in range(n)]
 
450
  p = self.batch_queue.popleft()
451
  self.set_problem(p['a'], p['b'], p.get('c'))
452
  self.running = True
453
+ tag = f'M{ns}' if self.cross_connect and ns else 'Β·'
454
  self.add_log(
455
+ f"β–Ά {count} | {self.dataset_type} | "
456
+ f"D={n} U{self.n_upper}Β·L{self.n_lower} [{tag}]"
457
  )
458
 
459
 
 
477
  @app.get("/state")
478
  async def get_state():
479
  springs_out = {f"{u}β†’{v}": round(k, 5) for (u, v), k in engine.springs.items()}
480
+ n = engine.n_inputs
481
+ ns = len(engine.merge_map)
 
482
  return {
483
  'nodes': engine.nodes,
484
  'springs': springs_out,
 
498
  'n_lower': engine.n_lower,
499
  'back_alpha': engine.back_alpha,
500
  'cross_connect': engine.cross_connect,
501
+ 'n_shared': ns,
502
  'queue_size': len(engine.batch_queue),
503
  }
504
 
 
513
 
514
  @app.post("/toggle_cross")
515
  async def toggle_cross():
 
516
  engine.toggle_cross_connect()
517
  return {
518
  "ok": True,
519
  "cross_connect": engine.cross_connect,
520
  "n_springs": len(engine.springs),
521
+ "n_shared": len(engine.merge_map),
522
  }
523
 
524
 
 
547
  engine.running = False
548
  engine._init_mesh()
549
  engine.logs = []
550
+ ns = len(engine.merge_map)
551
  engine.add_log(
552
  f"Mesh rebuilt: D={new_ni} U{new_nu}Β·L{new_nl} "
553
+ f"cross={'ON ('+str(ns)+' shared)' if engine.cross_connect else 'OFF'}"
554
  )
555
  else:
556
  engine.add_log(
 
570
  elif layer == 'lower': engine.n_lower = max(1, min(16, engine.n_lower + delta))
571
  engine.running = False
572
  engine._init_mesh()
573
+ ns = len(engine.merge_map)
574
  engine.add_log(
575
  f"Topology β†’ D={engine.n_inputs} U{engine.n_upper}Β·L{engine.n_lower} "
576
+ f"cross={'ON ('+str(ns)+' shared)' if engine.cross_connect else 'OFF'}"
577
  )
578
  return {
579
  "ok": True,