everydaytok commited on
Commit
ef6311d
·
verified ·
1 Parent(s): 2803122

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +37 -21
app.py CHANGED
@@ -8,10 +8,10 @@ app = FastAPI()
8
  app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"])
9
 
10
  DIM = 8
11
- LR = 0.05
12
  EWC_LAMBDA = 0.8
13
  FISHER_DECAY = 0.95
14
- DAMPING = 0.5
15
  DT = 0.2
16
 
17
  # --- AUTO DATA GENERATOR ---
@@ -78,13 +78,13 @@ class PristineMesh:
78
  p1, p2 = self.nodes[n1]['pos'], self.nodes[n2]['pos']
79
  if math.hypot(p1[0]-p2[0], p1[1]-p2[1]) < 1.05:
80
  key = tuple(sorted([n1, n2]))
81
- self.springs[key] = random.uniform(0.4, 0.6)
82
  self.fisher[key] = 0.0
83
  self.anchor_k[key] = self.springs[key]
84
 
85
- self.c_nodes = [n for n in self.nodes if self.nodes[n]['kind'] == 'C']
86
- self.a_nodes = [n for n in self.nodes if self.nodes[n]['kind'] == 'A']
87
- self.b_nodes = [n for n in self.nodes if self.nodes[n]['kind'] == 'B']
88
 
89
  def set_inputs(self, a_vec, b_vec):
90
  for i, nid in enumerate(self.a_nodes): self.nodes[nid]['x'] = a_vec[i]
@@ -94,7 +94,7 @@ class PristineMesh:
94
  data['x'] = 0.0
95
  data['vel'] = 0.0
96
 
97
- def settle(self, steps=40):
98
  for _ in range(steps):
99
  forces = {n: 0.0 for n in self.nodes}
100
  for (u, v), K in self.springs.items():
@@ -104,9 +104,12 @@ class PristineMesh:
104
 
105
  for nid, data in self.nodes.items():
106
  if not data['anchored']:
107
- f = forces[nid] - (0.05 * data['x'])
108
  data['vel'] = data['vel'] * DAMPING + f * DT
109
  data['x'] += data['vel'] * DT
 
 
 
110
 
111
  def get_predictions(self):
112
  return [self.nodes[n]['x'] for n in self.c_nodes]
@@ -116,25 +119,31 @@ class PristineMesh:
116
  for i, nid in enumerate(self.c_nodes):
117
  errors[nid] = self.nodes[nid]['x'] - target_vec[i]
118
 
119
- for _ in range(3):
 
120
  next_err = dict(errors)
121
  for (u, v), K in self.springs.items():
122
- next_err[u] += K * errors[v] * 0.1
123
- next_err[v] += K * errors[u] * 0.1
 
 
124
  errors = next_err
125
 
126
  for key in self.springs:
127
  u, v = key
 
128
  err_gradient = (errors[u] - errors[v]) * (self.nodes[u]['x'] - self.nodes[v]['x'])
129
 
130
  if mode == 'train':
131
- self.springs[key] -= LR * err_gradient
 
132
  self.fisher[key] = FISHER_DECAY * self.fisher[key] + (1 - FISHER_DECAY) * (err_gradient ** 2)
133
  elif mode == 'infer':
134
  penalty = EWC_LAMBDA * self.fisher[key] * (self.springs[key] - self.anchor_k[key])
135
- self.springs[key] -= (LR * 0.5 * err_gradient) + penalty
136
 
137
- self.springs[key] = max(-4.0, min(8.0, self.springs[key]))
 
138
 
139
  def save_anchors(self):
140
  self.anchor_k = dict(self.springs)
@@ -168,11 +177,14 @@ class Engine:
168
  self.current_type = sample['type']
169
 
170
  self.mesh.set_inputs(sample['a'], sample['b'])
171
- self.mesh.settle(steps=30)
172
  preds = self.mesh.get_predictions()
173
 
174
  if sample['type'] != 'manual':
 
175
  err = float(np.mean(np.abs(np.array(preds) - np.array(sample['c']))))
 
 
176
  self.current_err = err
177
  self.error_hist.append(err)
178
  if len(self.error_hist) > 100: self.error_hist.pop(0)
@@ -180,10 +192,9 @@ class Engine:
180
  if self.mode == 'infer':
181
  self.test_results.append({'type': self.current_type, 'err': err})
182
 
183
- # Apply LMS + EWC
184
  self.mesh.lms_update(sample['c'], mode=self.mode)
185
  else:
186
- # For manual mode, just show prediction without training
187
  self.current_err = 0.0
188
 
189
  self.iter += 1
@@ -199,11 +210,16 @@ class Engine:
199
  errs = []
200
  for sample in self.train_data:
201
  self.mesh.set_inputs(sample['a'], sample['b'])
202
- self.mesh.settle(15)
203
  self.mesh.lms_update(sample['c'], mode='train')
204
  preds = self.mesh.get_predictions()
205
- errs.append(np.mean(np.abs(np.array(preds) - np.array(sample['c']))))
206
- self.add_log(f"Ep {ep+1} | Avg Err: {np.mean(errs):.4f}")
 
 
 
 
 
207
  self.mesh.save_anchors()
208
  self.add_log("✓ Training Complete. EWC Anchors saved.")
209
  self.mode = 'idle'
@@ -228,7 +244,7 @@ except Exception as e:
228
  def loop():
229
  while True:
230
  if engine.running: engine.run_step()
231
- time.sleep(0.08) # Slowed down slightly so UI visualizes it nicely
232
  threading.Thread(target=loop, daemon=True).start()
233
 
234
  @app.get("/", response_class=HTMLResponse)
 
8
  app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"])
9
 
10
  DIM = 8
11
+ LR = 0.02 # Slower, highly stable learning rate
12
  EWC_LAMBDA = 0.8
13
  FISHER_DECAY = 0.95
14
+ DAMPING = 0.6 # Increased damping for visual stability
15
  DT = 0.2
16
 
17
  # --- AUTO DATA GENERATOR ---
 
78
  p1, p2 = self.nodes[n1]['pos'], self.nodes[n2]['pos']
79
  if math.hypot(p1[0]-p2[0], p1[1]-p2[1]) < 1.05:
80
  key = tuple(sorted([n1, n2]))
81
+ self.springs[key] = random.uniform(0.1, 0.4)
82
  self.fisher[key] = 0.0
83
  self.anchor_k[key] = self.springs[key]
84
 
85
+ self.c_nodes = sorted([n for n in self.nodes if self.nodes[n]['kind'] == 'C'], key=lambda k: self.nodes[k]['col'])
86
+ self.a_nodes = sorted([n for n in self.nodes if self.nodes[n]['kind'] == 'A'], key=lambda k: self.nodes[k]['col'])
87
+ self.b_nodes = sorted([n for n in self.nodes if self.nodes[n]['kind'] == 'B'], key=lambda k: self.nodes[k]['col'])
88
 
89
  def set_inputs(self, a_vec, b_vec):
90
  for i, nid in enumerate(self.a_nodes): self.nodes[nid]['x'] = a_vec[i]
 
94
  data['x'] = 0.0
95
  data['vel'] = 0.0
96
 
97
+ def settle(self, steps=30):
98
  for _ in range(steps):
99
  forces = {n: 0.0 for n in self.nodes}
100
  for (u, v), K in self.springs.items():
 
104
 
105
  for nid, data in self.nodes.items():
106
  if not data['anchored']:
107
+ f = forces[nid] - (0.05 * data['x']) # Weak grounding
108
  data['vel'] = data['vel'] * DAMPING + f * DT
109
  data['x'] += data['vel'] * DT
110
+
111
+ # MAGICAL FIX: Bounding prevents mathematical explosion
112
+ data['x'] = max(-2.0, min(2.0, data['x']))
113
 
114
  def get_predictions(self):
115
  return [self.nodes[n]['x'] for n in self.c_nodes]
 
119
  for i, nid in enumerate(self.c_nodes):
120
  errors[nid] = self.nodes[nid]['x'] - target_vec[i]
121
 
122
+ # Deep Error Diffusion
123
+ for _ in range(5):
124
  next_err = dict(errors)
125
  for (u, v), K in self.springs.items():
126
+ weight = min(abs(K) * 0.1, 0.4)
127
+ next_err[u] += weight * errors[v]
128
+ next_err[v] += weight * errors[u]
129
+ for n in errors: next_err[n] *= 0.85 # Decay to prevent error explosion
130
  errors = next_err
131
 
132
  for key in self.springs:
133
  u, v = key
134
+ # True gradient formulation
135
  err_gradient = (errors[u] - errors[v]) * (self.nodes[u]['x'] - self.nodes[v]['x'])
136
 
137
  if mode == 'train':
138
+ # FIXED: MUST BE += TO MINIMIZE ERROR!
139
+ self.springs[key] += LR * err_gradient
140
  self.fisher[key] = FISHER_DECAY * self.fisher[key] + (1 - FISHER_DECAY) * (err_gradient ** 2)
141
  elif mode == 'infer':
142
  penalty = EWC_LAMBDA * self.fisher[key] * (self.springs[key] - self.anchor_k[key])
143
+ self.springs[key] += (LR * 0.5 * err_gradient) - penalty
144
 
145
+ # Stable K bounds
146
+ self.springs[key] = max(-1.0, min(3.0, self.springs[key]))
147
 
148
  def save_anchors(self):
149
  self.anchor_k = dict(self.springs)
 
177
  self.current_type = sample['type']
178
 
179
  self.mesh.set_inputs(sample['a'], sample['b'])
180
+ self.mesh.settle(steps=25)
181
  preds = self.mesh.get_predictions()
182
 
183
  if sample['type'] != 'manual':
184
+ # Clean Mean Absolute Error
185
  err = float(np.mean(np.abs(np.array(preds) - np.array(sample['c']))))
186
+ if math.isnan(err): err = 1.0 # Safety catch
187
+
188
  self.current_err = err
189
  self.error_hist.append(err)
190
  if len(self.error_hist) > 100: self.error_hist.pop(0)
 
192
  if self.mode == 'infer':
193
  self.test_results.append({'type': self.current_type, 'err': err})
194
 
195
+ # Apply correct LMS + EWC
196
  self.mesh.lms_update(sample['c'], mode=self.mode)
197
  else:
 
198
  self.current_err = 0.0
199
 
200
  self.iter += 1
 
210
  errs = []
211
  for sample in self.train_data:
212
  self.mesh.set_inputs(sample['a'], sample['b'])
213
+ self.mesh.settle(20)
214
  self.mesh.lms_update(sample['c'], mode='train')
215
  preds = self.mesh.get_predictions()
216
+ e = np.mean(np.abs(np.array(preds) - np.array(sample['c'])))
217
+ if not math.isnan(e): errs.append(e)
218
+
219
+ avg_e = np.mean(errs) if errs else 0.0
220
+ self.add_log(f"Ep {ep+1} | Avg Err: {avg_e:.4f}")
221
+ print(f"Epoch {ep+1} | Avg Err: {avg_e:.4f}")
222
+
223
  self.mesh.save_anchors()
224
  self.add_log("✓ Training Complete. EWC Anchors saved.")
225
  self.mode = 'idle'
 
244
  def loop():
245
  while True:
246
  if engine.running: engine.run_step()
247
+ time.sleep(0.06)
248
  threading.Thread(target=loop, daemon=True).start()
249
 
250
  @app.get("/", response_class=HTMLResponse)