everydaytok commited on
Commit
32634ef
·
verified ·
1 Parent(s): 8b56059

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +66 -18
app.py CHANGED
@@ -67,7 +67,7 @@ class PristineMesh:
67
  for c in range(w):
68
  nid = f"{kind}_r{r}_c{c}"
69
  self.nodes[nid] = {
70
- 'x': 0.0, 'vel': 0.0, 'kind': kind,
71
  'pos': (x_offset + c, y), 'anchored': kind in ['A', 'B']
72
  }
73
 
@@ -151,34 +151,49 @@ class Engine:
151
  self.test_data = []
152
  self.error_hist = []
153
  self.current_err = 0.0
 
 
154
 
155
  def add_log(self, msg):
156
  self.logs.insert(0, f"[{self.iter:05d}] {msg}")
157
- if len(self.logs) > 30: self.logs.pop()
158
 
159
  def run_step(self):
160
  if not self.queue:
161
  self.running = False
 
162
  return
163
 
164
  sample = self.queue.popleft()
 
 
165
  self.mesh.set_inputs(sample['a'], sample['b'])
166
  self.mesh.settle(steps=30)
167
  preds = self.mesh.get_predictions()
168
 
169
- err = float(np.mean(np.abs(np.array(preds) - np.array(sample['c']))))
170
- self.current_err = err
171
- self.error_hist.append(err)
172
- if len(self.error_hist) > 100: self.error_hist.pop(0)
 
 
 
 
173
 
174
- self.mesh.lms_update(sample['c'], mode=self.mode)
 
 
 
 
175
 
176
  self.iter += 1
177
- if self.iter % 10 == 0:
178
- self.add_log(f"[{sample['type']}] err: {err:.4f}")
179
 
180
  def train_offline(self, epochs):
181
- self.add_log("⚡ Offline Training Started...")
 
 
182
  for ep in range(epochs):
183
  random.shuffle(self.train_data)
184
  errs = []
@@ -188,9 +203,19 @@ class Engine:
188
  self.mesh.lms_update(sample['c'], mode='train')
189
  preds = self.mesh.get_predictions()
190
  errs.append(np.mean(np.abs(np.array(preds) - np.array(sample['c']))))
191
- self.add_log(f"Epoch {ep+1} | Avg Err: {np.mean(errs):.4f}")
192
  self.mesh.save_anchors()
193
  self.add_log("✓ Training Complete. EWC Anchors saved.")
 
 
 
 
 
 
 
 
 
 
194
 
195
  engine = Engine()
196
  try:
@@ -203,7 +228,7 @@ except Exception as e:
203
  def loop():
204
  while True:
205
  if engine.running: engine.run_step()
206
- time.sleep(0.05)
207
  threading.Thread(target=loop, daemon=True).start()
208
 
209
  @app.get("/", response_class=HTMLResponse)
@@ -213,27 +238,50 @@ async def ui(): return FileResponse("index.html")
213
  async def state():
214
  return {
215
  'nodes': engine.mesh.nodes,
216
- # FIXED: Serialize tuple keys safely for JSON by joining with a pipe '|'
217
  'springs': {f"{u}|{v}": k for (u, v), k in engine.mesh.springs.items()},
218
  'error': engine.current_err,
219
  'hist': engine.error_hist,
220
  'mode': engine.mode,
221
  'running': engine.running,
222
- 'logs': engine.logs
 
 
 
 
223
  }
224
 
225
  @app.post("/train")
226
- async def train():
227
- threading.Thread(target=engine.train_offline, args=(5,), daemon=True).start()
 
228
  return {"ok": True}
229
 
230
  @app.post("/infer")
231
- async def infer():
 
232
  engine.mode = 'infer'
233
- engine.queue.extend(engine.test_data)
 
 
234
  engine.running = True
235
  return {"ok": True}
236
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
237
  @app.post("/halt")
238
  async def halt(): engine.running = False; return {"ok": True}
239
 
 
67
  for c in range(w):
68
  nid = f"{kind}_r{r}_c{c}"
69
  self.nodes[nid] = {
70
+ 'x': 0.0, 'vel': 0.0, 'kind': kind, 'row': r, 'col': c,
71
  'pos': (x_offset + c, y), 'anchored': kind in ['A', 'B']
72
  }
73
 
 
151
  self.test_data = []
152
  self.error_hist = []
153
  self.current_err = 0.0
154
+ self.current_type = '—'
155
+ self.test_results = []
156
 
157
  def add_log(self, msg):
158
  self.logs.insert(0, f"[{self.iter:05d}] {msg}")
159
+ if len(self.logs) > 40: self.logs.pop()
160
 
161
  def run_step(self):
162
  if not self.queue:
163
  self.running = False
164
+ self.add_log("Queue empty. Standing by.")
165
  return
166
 
167
  sample = self.queue.popleft()
168
+ self.current_type = sample['type']
169
+
170
  self.mesh.set_inputs(sample['a'], sample['b'])
171
  self.mesh.settle(steps=30)
172
  preds = self.mesh.get_predictions()
173
 
174
+ if sample['type'] != 'manual':
175
+ err = float(np.mean(np.abs(np.array(preds) - np.array(sample['c']))))
176
+ self.current_err = err
177
+ self.error_hist.append(err)
178
+ if len(self.error_hist) > 100: self.error_hist.pop(0)
179
+
180
+ if self.mode == 'infer':
181
+ self.test_results.append({'type': self.current_type, 'err': err})
182
 
183
+ # Apply LMS + EWC
184
+ self.mesh.lms_update(sample['c'], mode=self.mode)
185
+ else:
186
+ # For manual mode, just show prediction without training
187
+ self.current_err = 0.0
188
 
189
  self.iter += 1
190
+ if self.iter % 5 == 0 or sample['type'] == 'manual':
191
+ self.add_log(f"[{self.current_type}] err: {self.current_err:.4f}")
192
 
193
  def train_offline(self, epochs):
194
+ self.running = False
195
+ self.mode = 'train'
196
+ self.add_log(f"⚡ Offline Training: {epochs} epochs...")
197
  for ep in range(epochs):
198
  random.shuffle(self.train_data)
199
  errs = []
 
203
  self.mesh.lms_update(sample['c'], mode='train')
204
  preds = self.mesh.get_predictions()
205
  errs.append(np.mean(np.abs(np.array(preds) - np.array(sample['c']))))
206
+ self.add_log(f"Ep {ep+1} | Avg Err: {np.mean(errs):.4f}")
207
  self.mesh.save_anchors()
208
  self.add_log("✓ Training Complete. EWC Anchors saved.")
209
+ self.mode = 'idle'
210
+
211
+ def get_accuracy_summary(self):
212
+ acc = {}
213
+ for r in self.test_results:
214
+ t = r['type']
215
+ if t not in acc: acc[t] = {'n': 0, 'sum_e': 0.0}
216
+ acc[t]['n'] += 1
217
+ acc[t]['sum_e'] += r['err']
218
+ return {t: {'n': v['n'], 'avg_err': round(v['sum_e']/v['n'], 4)} for t, v in acc.items()}
219
 
220
  engine = Engine()
221
  try:
 
228
  def loop():
229
  while True:
230
  if engine.running: engine.run_step()
231
+ time.sleep(0.08) # Slowed down slightly so UI visualizes it nicely
232
  threading.Thread(target=loop, daemon=True).start()
233
 
234
  @app.get("/", response_class=HTMLResponse)
 
238
  async def state():
239
  return {
240
  'nodes': engine.mesh.nodes,
 
241
  'springs': {f"{u}|{v}": k for (u, v), k in engine.mesh.springs.items()},
242
  'error': engine.current_err,
243
  'hist': engine.error_hist,
244
  'mode': engine.mode,
245
  'running': engine.running,
246
+ 'logs': engine.logs,
247
+ 'current_type': engine.current_type,
248
+ 'queue_size': len(engine.queue),
249
+ 'type_acc': engine.get_accuracy_summary(),
250
+ 'dim': DIM
251
  }
252
 
253
  @app.post("/train")
254
+ async def train(data: dict):
255
+ ep = int(data.get('epochs', 5))
256
+ threading.Thread(target=engine.train_offline, args=(ep,), daemon=True).start()
257
  return {"ok": True}
258
 
259
  @app.post("/infer")
260
+ async def infer(data: dict):
261
+ n = int(data.get('n', 200))
262
  engine.mode = 'infer'
263
+ engine.test_results = []
264
+ engine.queue.clear()
265
+ engine.queue.extend(engine.test_data[:n])
266
  engine.running = True
267
  return {"ok": True}
268
 
269
+ @app.post("/manual")
270
+ async def manual(data: dict):
271
+ try:
272
+ a_vec = [float(x.strip()) for x in data.get('a', '').split(',')]
273
+ b_vec = [float(x.strip()) for x in data.get('b', '').split(',')]
274
+ if len(a_vec) != DIM or len(b_vec) != DIM:
275
+ return {"ok": False, "error": f"Vectors must be exactly length {DIM}"}
276
+
277
+ engine.mode = 'manual'
278
+ engine.queue.clear()
279
+ engine.queue.append({'a': a_vec, 'b': b_vec, 'c': [0]*DIM, 'type': 'manual'})
280
+ engine.running = True
281
+ return {"ok": True}
282
+ except Exception as e:
283
+ return {"ok": False, "error": str(e)}
284
+
285
  @app.post("/halt")
286
  async def halt(): engine.running = False; return {"ok": True}
287