Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
|
@@ -67,7 +67,7 @@ class PristineMesh:
|
|
| 67 |
for c in range(w):
|
| 68 |
nid = f"{kind}_r{r}_c{c}"
|
| 69 |
self.nodes[nid] = {
|
| 70 |
-
'x': 0.0, 'vel': 0.0, 'kind': kind,
|
| 71 |
'pos': (x_offset + c, y), 'anchored': kind in ['A', 'B']
|
| 72 |
}
|
| 73 |
|
|
@@ -151,34 +151,49 @@ class Engine:
|
|
| 151 |
self.test_data = []
|
| 152 |
self.error_hist = []
|
| 153 |
self.current_err = 0.0
|
|
|
|
|
|
|
| 154 |
|
| 155 |
def add_log(self, msg):
|
| 156 |
self.logs.insert(0, f"[{self.iter:05d}] {msg}")
|
| 157 |
-
if len(self.logs) >
|
| 158 |
|
| 159 |
def run_step(self):
|
| 160 |
if not self.queue:
|
| 161 |
self.running = False
|
|
|
|
| 162 |
return
|
| 163 |
|
| 164 |
sample = self.queue.popleft()
|
|
|
|
|
|
|
| 165 |
self.mesh.set_inputs(sample['a'], sample['b'])
|
| 166 |
self.mesh.settle(steps=30)
|
| 167 |
preds = self.mesh.get_predictions()
|
| 168 |
|
| 169 |
-
|
| 170 |
-
|
| 171 |
-
|
| 172 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 173 |
|
| 174 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 175 |
|
| 176 |
self.iter += 1
|
| 177 |
-
if self.iter %
|
| 178 |
-
self.add_log(f"[{
|
| 179 |
|
| 180 |
def train_offline(self, epochs):
|
| 181 |
-
self.
|
|
|
|
|
|
|
| 182 |
for ep in range(epochs):
|
| 183 |
random.shuffle(self.train_data)
|
| 184 |
errs = []
|
|
@@ -188,9 +203,19 @@ class Engine:
|
|
| 188 |
self.mesh.lms_update(sample['c'], mode='train')
|
| 189 |
preds = self.mesh.get_predictions()
|
| 190 |
errs.append(np.mean(np.abs(np.array(preds) - np.array(sample['c']))))
|
| 191 |
-
self.add_log(f"
|
| 192 |
self.mesh.save_anchors()
|
| 193 |
self.add_log("✓ Training Complete. EWC Anchors saved.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 194 |
|
| 195 |
engine = Engine()
|
| 196 |
try:
|
|
@@ -203,7 +228,7 @@ except Exception as e:
|
|
| 203 |
def loop():
|
| 204 |
while True:
|
| 205 |
if engine.running: engine.run_step()
|
| 206 |
-
time.sleep(0.
|
| 207 |
threading.Thread(target=loop, daemon=True).start()
|
| 208 |
|
| 209 |
@app.get("/", response_class=HTMLResponse)
|
|
@@ -213,27 +238,50 @@ async def ui(): return FileResponse("index.html")
|
|
| 213 |
async def state():
|
| 214 |
return {
|
| 215 |
'nodes': engine.mesh.nodes,
|
| 216 |
-
# FIXED: Serialize tuple keys safely for JSON by joining with a pipe '|'
|
| 217 |
'springs': {f"{u}|{v}": k for (u, v), k in engine.mesh.springs.items()},
|
| 218 |
'error': engine.current_err,
|
| 219 |
'hist': engine.error_hist,
|
| 220 |
'mode': engine.mode,
|
| 221 |
'running': engine.running,
|
| 222 |
-
'logs': engine.logs
|
|
|
|
|
|
|
|
|
|
|
|
|
| 223 |
}
|
| 224 |
|
| 225 |
@app.post("/train")
|
| 226 |
-
async def train():
|
| 227 |
-
|
|
|
|
| 228 |
return {"ok": True}
|
| 229 |
|
| 230 |
@app.post("/infer")
|
| 231 |
-
async def infer():
|
|
|
|
| 232 |
engine.mode = 'infer'
|
| 233 |
-
engine.
|
|
|
|
|
|
|
| 234 |
engine.running = True
|
| 235 |
return {"ok": True}
|
| 236 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 237 |
@app.post("/halt")
|
| 238 |
async def halt(): engine.running = False; return {"ok": True}
|
| 239 |
|
|
|
|
| 67 |
for c in range(w):
|
| 68 |
nid = f"{kind}_r{r}_c{c}"
|
| 69 |
self.nodes[nid] = {
|
| 70 |
+
'x': 0.0, 'vel': 0.0, 'kind': kind, 'row': r, 'col': c,
|
| 71 |
'pos': (x_offset + c, y), 'anchored': kind in ['A', 'B']
|
| 72 |
}
|
| 73 |
|
|
|
|
| 151 |
self.test_data = []
|
| 152 |
self.error_hist = []
|
| 153 |
self.current_err = 0.0
|
| 154 |
+
self.current_type = '—'
|
| 155 |
+
self.test_results = []
|
| 156 |
|
| 157 |
def add_log(self, msg):
|
| 158 |
self.logs.insert(0, f"[{self.iter:05d}] {msg}")
|
| 159 |
+
if len(self.logs) > 40: self.logs.pop()
|
| 160 |
|
| 161 |
def run_step(self):
|
| 162 |
if not self.queue:
|
| 163 |
self.running = False
|
| 164 |
+
self.add_log("Queue empty. Standing by.")
|
| 165 |
return
|
| 166 |
|
| 167 |
sample = self.queue.popleft()
|
| 168 |
+
self.current_type = sample['type']
|
| 169 |
+
|
| 170 |
self.mesh.set_inputs(sample['a'], sample['b'])
|
| 171 |
self.mesh.settle(steps=30)
|
| 172 |
preds = self.mesh.get_predictions()
|
| 173 |
|
| 174 |
+
if sample['type'] != 'manual':
|
| 175 |
+
err = float(np.mean(np.abs(np.array(preds) - np.array(sample['c']))))
|
| 176 |
+
self.current_err = err
|
| 177 |
+
self.error_hist.append(err)
|
| 178 |
+
if len(self.error_hist) > 100: self.error_hist.pop(0)
|
| 179 |
+
|
| 180 |
+
if self.mode == 'infer':
|
| 181 |
+
self.test_results.append({'type': self.current_type, 'err': err})
|
| 182 |
|
| 183 |
+
# Apply LMS + EWC
|
| 184 |
+
self.mesh.lms_update(sample['c'], mode=self.mode)
|
| 185 |
+
else:
|
| 186 |
+
# For manual mode, just show prediction without training
|
| 187 |
+
self.current_err = 0.0
|
| 188 |
|
| 189 |
self.iter += 1
|
| 190 |
+
if self.iter % 5 == 0 or sample['type'] == 'manual':
|
| 191 |
+
self.add_log(f"[{self.current_type}] err: {self.current_err:.4f}")
|
| 192 |
|
| 193 |
def train_offline(self, epochs):
|
| 194 |
+
self.running = False
|
| 195 |
+
self.mode = 'train'
|
| 196 |
+
self.add_log(f"⚡ Offline Training: {epochs} epochs...")
|
| 197 |
for ep in range(epochs):
|
| 198 |
random.shuffle(self.train_data)
|
| 199 |
errs = []
|
|
|
|
| 203 |
self.mesh.lms_update(sample['c'], mode='train')
|
| 204 |
preds = self.mesh.get_predictions()
|
| 205 |
errs.append(np.mean(np.abs(np.array(preds) - np.array(sample['c']))))
|
| 206 |
+
self.add_log(f"Ep {ep+1} | Avg Err: {np.mean(errs):.4f}")
|
| 207 |
self.mesh.save_anchors()
|
| 208 |
self.add_log("✓ Training Complete. EWC Anchors saved.")
|
| 209 |
+
self.mode = 'idle'
|
| 210 |
+
|
| 211 |
+
def get_accuracy_summary(self):
|
| 212 |
+
acc = {}
|
| 213 |
+
for r in self.test_results:
|
| 214 |
+
t = r['type']
|
| 215 |
+
if t not in acc: acc[t] = {'n': 0, 'sum_e': 0.0}
|
| 216 |
+
acc[t]['n'] += 1
|
| 217 |
+
acc[t]['sum_e'] += r['err']
|
| 218 |
+
return {t: {'n': v['n'], 'avg_err': round(v['sum_e']/v['n'], 4)} for t, v in acc.items()}
|
| 219 |
|
| 220 |
engine = Engine()
|
| 221 |
try:
|
|
|
|
| 228 |
def loop():
|
| 229 |
while True:
|
| 230 |
if engine.running: engine.run_step()
|
| 231 |
+
time.sleep(0.08) # Slowed down slightly so UI visualizes it nicely
|
| 232 |
threading.Thread(target=loop, daemon=True).start()
|
| 233 |
|
| 234 |
@app.get("/", response_class=HTMLResponse)
|
|
|
|
| 238 |
async def state():
|
| 239 |
return {
|
| 240 |
'nodes': engine.mesh.nodes,
|
|
|
|
| 241 |
'springs': {f"{u}|{v}": k for (u, v), k in engine.mesh.springs.items()},
|
| 242 |
'error': engine.current_err,
|
| 243 |
'hist': engine.error_hist,
|
| 244 |
'mode': engine.mode,
|
| 245 |
'running': engine.running,
|
| 246 |
+
'logs': engine.logs,
|
| 247 |
+
'current_type': engine.current_type,
|
| 248 |
+
'queue_size': len(engine.queue),
|
| 249 |
+
'type_acc': engine.get_accuracy_summary(),
|
| 250 |
+
'dim': DIM
|
| 251 |
}
|
| 252 |
|
| 253 |
@app.post("/train")
|
| 254 |
+
async def train(data: dict):
|
| 255 |
+
ep = int(data.get('epochs', 5))
|
| 256 |
+
threading.Thread(target=engine.train_offline, args=(ep,), daemon=True).start()
|
| 257 |
return {"ok": True}
|
| 258 |
|
| 259 |
@app.post("/infer")
|
| 260 |
+
async def infer(data: dict):
|
| 261 |
+
n = int(data.get('n', 200))
|
| 262 |
engine.mode = 'infer'
|
| 263 |
+
engine.test_results = []
|
| 264 |
+
engine.queue.clear()
|
| 265 |
+
engine.queue.extend(engine.test_data[:n])
|
| 266 |
engine.running = True
|
| 267 |
return {"ok": True}
|
| 268 |
|
| 269 |
+
@app.post("/manual")
|
| 270 |
+
async def manual(data: dict):
|
| 271 |
+
try:
|
| 272 |
+
a_vec = [float(x.strip()) for x in data.get('a', '').split(',')]
|
| 273 |
+
b_vec = [float(x.strip()) for x in data.get('b', '').split(',')]
|
| 274 |
+
if len(a_vec) != DIM or len(b_vec) != DIM:
|
| 275 |
+
return {"ok": False, "error": f"Vectors must be exactly length {DIM}"}
|
| 276 |
+
|
| 277 |
+
engine.mode = 'manual'
|
| 278 |
+
engine.queue.clear()
|
| 279 |
+
engine.queue.append({'a': a_vec, 'b': b_vec, 'c': [0]*DIM, 'type': 'manual'})
|
| 280 |
+
engine.running = True
|
| 281 |
+
return {"ok": True}
|
| 282 |
+
except Exception as e:
|
| 283 |
+
return {"ok": False, "error": str(e)}
|
| 284 |
+
|
| 285 |
@app.post("/halt")
|
| 286 |
async def halt(): engine.running = False; return {"ok": True}
|
| 287 |
|