Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
|
@@ -8,10 +8,10 @@ app = FastAPI()
|
|
| 8 |
app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"])
|
| 9 |
|
| 10 |
DIM = 8
|
| 11 |
-
LR = 0.
|
| 12 |
EWC_LAMBDA = 0.8
|
| 13 |
FISHER_DECAY = 0.95
|
| 14 |
-
DAMPING = 0.
|
| 15 |
DT = 0.2
|
| 16 |
|
| 17 |
# --- AUTO DATA GENERATOR ---
|
|
@@ -78,13 +78,13 @@ class PristineMesh:
|
|
| 78 |
p1, p2 = self.nodes[n1]['pos'], self.nodes[n2]['pos']
|
| 79 |
if math.hypot(p1[0]-p2[0], p1[1]-p2[1]) < 1.05:
|
| 80 |
key = tuple(sorted([n1, n2]))
|
| 81 |
-
self.springs[key] = random.uniform(0.
|
| 82 |
self.fisher[key] = 0.0
|
| 83 |
self.anchor_k[key] = self.springs[key]
|
| 84 |
|
| 85 |
-
self.c_nodes = [n for n in self.nodes if self.nodes[n]['kind'] == 'C']
|
| 86 |
-
self.a_nodes = [n for n in self.nodes if self.nodes[n]['kind'] == 'A']
|
| 87 |
-
self.b_nodes = [n for n in self.nodes if self.nodes[n]['kind'] == 'B']
|
| 88 |
|
| 89 |
def set_inputs(self, a_vec, b_vec):
|
| 90 |
for i, nid in enumerate(self.a_nodes): self.nodes[nid]['x'] = a_vec[i]
|
|
@@ -94,7 +94,7 @@ class PristineMesh:
|
|
| 94 |
data['x'] = 0.0
|
| 95 |
data['vel'] = 0.0
|
| 96 |
|
| 97 |
-
def settle(self, steps=
|
| 98 |
for _ in range(steps):
|
| 99 |
forces = {n: 0.0 for n in self.nodes}
|
| 100 |
for (u, v), K in self.springs.items():
|
|
@@ -104,9 +104,12 @@ class PristineMesh:
|
|
| 104 |
|
| 105 |
for nid, data in self.nodes.items():
|
| 106 |
if not data['anchored']:
|
| 107 |
-
f = forces[nid] - (0.05 * data['x'])
|
| 108 |
data['vel'] = data['vel'] * DAMPING + f * DT
|
| 109 |
data['x'] += data['vel'] * DT
|
|
|
|
|
|
|
|
|
|
| 110 |
|
| 111 |
def get_predictions(self):
|
| 112 |
return [self.nodes[n]['x'] for n in self.c_nodes]
|
|
@@ -116,25 +119,31 @@ class PristineMesh:
|
|
| 116 |
for i, nid in enumerate(self.c_nodes):
|
| 117 |
errors[nid] = self.nodes[nid]['x'] - target_vec[i]
|
| 118 |
|
| 119 |
-
|
|
|
|
| 120 |
next_err = dict(errors)
|
| 121 |
for (u, v), K in self.springs.items():
|
| 122 |
-
|
| 123 |
-
next_err[
|
|
|
|
|
|
|
| 124 |
errors = next_err
|
| 125 |
|
| 126 |
for key in self.springs:
|
| 127 |
u, v = key
|
|
|
|
| 128 |
err_gradient = (errors[u] - errors[v]) * (self.nodes[u]['x'] - self.nodes[v]['x'])
|
| 129 |
|
| 130 |
if mode == 'train':
|
| 131 |
-
|
|
|
|
| 132 |
self.fisher[key] = FISHER_DECAY * self.fisher[key] + (1 - FISHER_DECAY) * (err_gradient ** 2)
|
| 133 |
elif mode == 'infer':
|
| 134 |
penalty = EWC_LAMBDA * self.fisher[key] * (self.springs[key] - self.anchor_k[key])
|
| 135 |
-
self.springs[key]
|
| 136 |
|
| 137 |
-
|
|
|
|
| 138 |
|
| 139 |
def save_anchors(self):
|
| 140 |
self.anchor_k = dict(self.springs)
|
|
@@ -168,11 +177,14 @@ class Engine:
|
|
| 168 |
self.current_type = sample['type']
|
| 169 |
|
| 170 |
self.mesh.set_inputs(sample['a'], sample['b'])
|
| 171 |
-
self.mesh.settle(steps=
|
| 172 |
preds = self.mesh.get_predictions()
|
| 173 |
|
| 174 |
if sample['type'] != 'manual':
|
|
|
|
| 175 |
err = float(np.mean(np.abs(np.array(preds) - np.array(sample['c']))))
|
|
|
|
|
|
|
| 176 |
self.current_err = err
|
| 177 |
self.error_hist.append(err)
|
| 178 |
if len(self.error_hist) > 100: self.error_hist.pop(0)
|
|
@@ -180,10 +192,9 @@ class Engine:
|
|
| 180 |
if self.mode == 'infer':
|
| 181 |
self.test_results.append({'type': self.current_type, 'err': err})
|
| 182 |
|
| 183 |
-
# Apply LMS + EWC
|
| 184 |
self.mesh.lms_update(sample['c'], mode=self.mode)
|
| 185 |
else:
|
| 186 |
-
# For manual mode, just show prediction without training
|
| 187 |
self.current_err = 0.0
|
| 188 |
|
| 189 |
self.iter += 1
|
|
@@ -199,11 +210,16 @@ class Engine:
|
|
| 199 |
errs = []
|
| 200 |
for sample in self.train_data:
|
| 201 |
self.mesh.set_inputs(sample['a'], sample['b'])
|
| 202 |
-
self.mesh.settle(
|
| 203 |
self.mesh.lms_update(sample['c'], mode='train')
|
| 204 |
preds = self.mesh.get_predictions()
|
| 205 |
-
|
| 206 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 207 |
self.mesh.save_anchors()
|
| 208 |
self.add_log("✓ Training Complete. EWC Anchors saved.")
|
| 209 |
self.mode = 'idle'
|
|
@@ -228,7 +244,7 @@ except Exception as e:
|
|
| 228 |
def loop():
|
| 229 |
while True:
|
| 230 |
if engine.running: engine.run_step()
|
| 231 |
-
time.sleep(0.
|
| 232 |
threading.Thread(target=loop, daemon=True).start()
|
| 233 |
|
| 234 |
@app.get("/", response_class=HTMLResponse)
|
|
|
|
| 8 |
app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"])
|
| 9 |
|
| 10 |
DIM = 8
|
| 11 |
+
LR = 0.02 # Slower, highly stable learning rate
|
| 12 |
EWC_LAMBDA = 0.8
|
| 13 |
FISHER_DECAY = 0.95
|
| 14 |
+
DAMPING = 0.6 # Increased damping for visual stability
|
| 15 |
DT = 0.2
|
| 16 |
|
| 17 |
# --- AUTO DATA GENERATOR ---
|
|
|
|
| 78 |
p1, p2 = self.nodes[n1]['pos'], self.nodes[n2]['pos']
|
| 79 |
if math.hypot(p1[0]-p2[0], p1[1]-p2[1]) < 1.05:
|
| 80 |
key = tuple(sorted([n1, n2]))
|
| 81 |
+
self.springs[key] = random.uniform(0.1, 0.4)
|
| 82 |
self.fisher[key] = 0.0
|
| 83 |
self.anchor_k[key] = self.springs[key]
|
| 84 |
|
| 85 |
+
self.c_nodes = sorted([n for n in self.nodes if self.nodes[n]['kind'] == 'C'], key=lambda k: self.nodes[k]['col'])
|
| 86 |
+
self.a_nodes = sorted([n for n in self.nodes if self.nodes[n]['kind'] == 'A'], key=lambda k: self.nodes[k]['col'])
|
| 87 |
+
self.b_nodes = sorted([n for n in self.nodes if self.nodes[n]['kind'] == 'B'], key=lambda k: self.nodes[k]['col'])
|
| 88 |
|
| 89 |
def set_inputs(self, a_vec, b_vec):
|
| 90 |
for i, nid in enumerate(self.a_nodes): self.nodes[nid]['x'] = a_vec[i]
|
|
|
|
| 94 |
data['x'] = 0.0
|
| 95 |
data['vel'] = 0.0
|
| 96 |
|
| 97 |
+
def settle(self, steps=30):
|
| 98 |
for _ in range(steps):
|
| 99 |
forces = {n: 0.0 for n in self.nodes}
|
| 100 |
for (u, v), K in self.springs.items():
|
|
|
|
| 104 |
|
| 105 |
for nid, data in self.nodes.items():
|
| 106 |
if not data['anchored']:
|
| 107 |
+
f = forces[nid] - (0.05 * data['x']) # Weak grounding
|
| 108 |
data['vel'] = data['vel'] * DAMPING + f * DT
|
| 109 |
data['x'] += data['vel'] * DT
|
| 110 |
+
|
| 111 |
+
# MAGICAL FIX: Bounding prevents mathematical explosion
|
| 112 |
+
data['x'] = max(-2.0, min(2.0, data['x']))
|
| 113 |
|
| 114 |
def get_predictions(self):
|
| 115 |
return [self.nodes[n]['x'] for n in self.c_nodes]
|
|
|
|
| 119 |
for i, nid in enumerate(self.c_nodes):
|
| 120 |
errors[nid] = self.nodes[nid]['x'] - target_vec[i]
|
| 121 |
|
| 122 |
+
# Deep Error Diffusion
|
| 123 |
+
for _ in range(5):
|
| 124 |
next_err = dict(errors)
|
| 125 |
for (u, v), K in self.springs.items():
|
| 126 |
+
weight = min(abs(K) * 0.1, 0.4)
|
| 127 |
+
next_err[u] += weight * errors[v]
|
| 128 |
+
next_err[v] += weight * errors[u]
|
| 129 |
+
for n in errors: next_err[n] *= 0.85 # Decay to prevent error explosion
|
| 130 |
errors = next_err
|
| 131 |
|
| 132 |
for key in self.springs:
|
| 133 |
u, v = key
|
| 134 |
+
# True gradient formulation
|
| 135 |
err_gradient = (errors[u] - errors[v]) * (self.nodes[u]['x'] - self.nodes[v]['x'])
|
| 136 |
|
| 137 |
if mode == 'train':
|
| 138 |
+
# FIXED: MUST BE += TO MINIMIZE ERROR!
|
| 139 |
+
self.springs[key] += LR * err_gradient
|
| 140 |
self.fisher[key] = FISHER_DECAY * self.fisher[key] + (1 - FISHER_DECAY) * (err_gradient ** 2)
|
| 141 |
elif mode == 'infer':
|
| 142 |
penalty = EWC_LAMBDA * self.fisher[key] * (self.springs[key] - self.anchor_k[key])
|
| 143 |
+
self.springs[key] += (LR * 0.5 * err_gradient) - penalty
|
| 144 |
|
| 145 |
+
# Stable K bounds
|
| 146 |
+
self.springs[key] = max(-1.0, min(3.0, self.springs[key]))
|
| 147 |
|
| 148 |
def save_anchors(self):
|
| 149 |
self.anchor_k = dict(self.springs)
|
|
|
|
| 177 |
self.current_type = sample['type']
|
| 178 |
|
| 179 |
self.mesh.set_inputs(sample['a'], sample['b'])
|
| 180 |
+
self.mesh.settle(steps=25)
|
| 181 |
preds = self.mesh.get_predictions()
|
| 182 |
|
| 183 |
if sample['type'] != 'manual':
|
| 184 |
+
# Clean Mean Absolute Error
|
| 185 |
err = float(np.mean(np.abs(np.array(preds) - np.array(sample['c']))))
|
| 186 |
+
if math.isnan(err): err = 1.0 # Safety catch
|
| 187 |
+
|
| 188 |
self.current_err = err
|
| 189 |
self.error_hist.append(err)
|
| 190 |
if len(self.error_hist) > 100: self.error_hist.pop(0)
|
|
|
|
| 192 |
if self.mode == 'infer':
|
| 193 |
self.test_results.append({'type': self.current_type, 'err': err})
|
| 194 |
|
| 195 |
+
# Apply correct LMS + EWC
|
| 196 |
self.mesh.lms_update(sample['c'], mode=self.mode)
|
| 197 |
else:
|
|
|
|
| 198 |
self.current_err = 0.0
|
| 199 |
|
| 200 |
self.iter += 1
|
|
|
|
| 210 |
errs = []
|
| 211 |
for sample in self.train_data:
|
| 212 |
self.mesh.set_inputs(sample['a'], sample['b'])
|
| 213 |
+
self.mesh.settle(20)
|
| 214 |
self.mesh.lms_update(sample['c'], mode='train')
|
| 215 |
preds = self.mesh.get_predictions()
|
| 216 |
+
e = np.mean(np.abs(np.array(preds) - np.array(sample['c'])))
|
| 217 |
+
if not math.isnan(e): errs.append(e)
|
| 218 |
+
|
| 219 |
+
avg_e = np.mean(errs) if errs else 0.0
|
| 220 |
+
self.add_log(f"Ep {ep+1} | Avg Err: {avg_e:.4f}")
|
| 221 |
+
print(f"Epoch {ep+1} | Avg Err: {avg_e:.4f}")
|
| 222 |
+
|
| 223 |
self.mesh.save_anchors()
|
| 224 |
self.add_log("✓ Training Complete. EWC Anchors saved.")
|
| 225 |
self.mode = 'idle'
|
|
|
|
| 244 |
def loop():
|
| 245 |
while True:
|
| 246 |
if engine.running: engine.run_step()
|
| 247 |
+
time.sleep(0.06)
|
| 248 |
threading.Thread(target=loop, daemon=True).start()
|
| 249 |
|
| 250 |
@app.get("/", response_class=HTMLResponse)
|