Elliotasdasdasfasas commited on
Commit
c557815
·
1 Parent(s): c9edbe0

Upgrade to CTM v2.0 - Full PyTorch implementation with online training

Browse files
Files changed (5) hide show
  1. app.py +586 -186
  2. app_v1_backup.py +464 -0
  3. requirements.txt +20 -1
  4. requirements_v1.txt +2 -0
  5. utils/dendrite_losses.py +228 -0
app.py CHANGED
@@ -1,85 +1,250 @@
1
  """
2
- CTM Nervous System Server - Continuous Thought Machine for Hypergraph Maintenance
3
- ===================================================================================
4
- Implementation of the definitive proposal: CTM as Nervous System for ART-17 Hypergraph.
5
 
6
- Endpoints:
7
- - /sense_snn: Process 72D SNN input with NLM-style processing
8
- - /reason_hypergraph: Reason about hypergraph context, propose edges
9
- - /validate_physics: Validate proposals against 5 physics losses
10
- - /dream: Offline consolidation with T=500+ ticks
11
- - /calibrate_stdp: Suggest STDP weight adjustments from sync matrix
12
- - /health: Health check endpoint
 
 
 
13
 
14
  Based on: arXiv:2505.05522 (Continuous Thought Machines - Sakana AI)
 
15
  """
16
 
17
  import gradio as gr
18
  import numpy as np
19
  import json
20
- from typing import List, Dict, Any, Optional
21
  import os
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
 
23
  # ============================================================================
24
- # SIMPLIFIED CTM SIMULATION (CPU-only for Hugging Face free tier)
25
  # ============================================================================
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
 
27
- class SimplifiedCTM:
 
 
 
28
  """
29
- Simplified CTM for CPU-only environment.
30
- Simulates the key mechanisms without full PyTorch model.
 
 
 
 
 
 
 
 
 
31
  """
32
 
33
- def __init__(self, d_model: int = 256, memory_length: int = 16, n_ticks: int = 50):
34
- self.d_model = d_model
35
- self.memory_length = memory_length
36
- self.n_ticks = n_ticks
37
-
38
- # Initialize state
39
- self.state_trace = np.zeros((d_model, memory_length))
40
- self.activated_state = np.random.randn(d_model) * 0.1
41
-
42
- # NLM weights (simplified: one weight matrix per "neuron group")
43
- self.nlm_weights = np.random.randn(16, memory_length) * 0.1 # 16 groups for 16 dendrites
44
-
45
- def compute_sync_matrix(self, z: np.ndarray) -> np.ndarray:
46
- """S^t = Z · Z^T (normalized)"""
47
- z_norm = z / (np.linalg.norm(z) + 1e-8)
48
- S = np.outer(z_norm, z_norm)
49
- return S
50
-
51
- def compute_certainty(self, predictions: np.ndarray) -> float:
52
- """Certainty = 1 - normalized entropy"""
53
- probs = np.abs(predictions) / (np.sum(np.abs(predictions)) + 1e-8)
54
- probs = np.clip(probs, 1e-10, 1.0)
55
- entropy = -np.sum(probs * np.log(probs))
56
- max_entropy = np.log(len(probs))
57
- normalized_entropy = entropy / (max_entropy + 1e-8)
58
- return float(1.0 - normalized_entropy)
59
-
60
- def process_ticks(self, input_features: np.ndarray, n_ticks: Optional[int] = None) -> Dict:
61
- """Run T internal ticks and return sync matrix + certainty"""
62
- n_ticks = n_ticks or self.n_ticks
63
-
64
- # Ensure input is right size
65
- if len(input_features) < self.d_model:
66
- input_features = np.pad(input_features, (0, self.d_model - len(input_features)))
67
  else:
68
- input_features = input_features[:self.d_model]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
 
70
  certainties = []
71
- sync_matrices = []
72
 
73
  for t in range(n_ticks):
74
- # Simulate synapse update
75
- combined = np.concatenate([self.activated_state, input_features[:self.d_model//2]])
76
  pre_activation = np.tanh(combined[:self.d_model] * 0.1 + np.random.randn(self.d_model) * 0.01)
77
 
78
- # Update trace
79
  self.state_trace = np.roll(self.state_trace, -1, axis=1)
80
  self.state_trace[:, -1] = pre_activation
81
 
82
- # Simulate NLM (simplified)
83
  post_activation = np.zeros(self.d_model)
84
  group_size = self.d_model // 16
85
  for g in range(16):
@@ -91,50 +256,151 @@ class SimplifiedCTM:
91
 
92
  self.activated_state = post_activation
93
 
94
- # Compute sync and certainty
95
- sync = self.compute_sync_matrix(self.activated_state)
96
- cert = self.compute_certainty(self.activated_state)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97
 
98
- sync_matrices.append(sync)
99
- certainties.append(cert)
 
 
 
 
 
 
 
 
 
 
 
 
 
100
 
101
- # Find best ticks (min-loss proxy: max certainty)
102
- best_tick = int(np.argmax(certainties))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
103
 
104
  return {
105
- "final_sync_matrix": sync_matrices[-1].tolist(),
106
- "best_sync_matrix": sync_matrices[best_tick].tolist(),
107
- "certainties": certainties,
108
- "final_certainty": float(certainties[-1]),
109
- "max_certainty": float(max(certainties)),
110
- "best_tick": best_tick,
111
- "ticks_used": n_ticks
112
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
113
 
114
- # Global CTM instance
115
- ctm = SimplifiedCTM(d_model=256, memory_length=16, n_ticks=50)
 
 
116
 
117
  # ============================================================================
118
  # PHYSICS VALIDATION (from SNN Omega-21)
119
  # ============================================================================
120
-
121
  def validate_physics(trajectory: List[float], params: Dict) -> Dict:
122
- """Validate against 5 physics losses from SNN Omega-21"""
123
  trajectory = np.array(trajectory)
124
 
125
  # L_energy: Energy conservation
126
  energy = np.sum(trajectory ** 2)
127
- P_max = params.get("P_max", 1000.0)
128
  L_energy = float(max(0, energy - P_max) ** 2)
129
 
130
  # L_thermo: Thermodynamics (dew point check)
131
- T_dew = params.get("T_dew", 15.0)
132
- T_amb = params.get("T_amb", 25.0)
133
  L_thermo = float(max(0, T_dew - T_amb) ** 2)
134
 
135
  # L_causal: Causality (velocity limit)
136
  velocity = np.diff(trajectory) if len(trajectory) > 1 else np.array([0])
137
- v_max = params.get("v_max", 100.0)
138
  L_causal = float(np.sum(np.maximum(0, np.abs(velocity) - v_max) ** 2))
139
 
140
  # L_conserv: Flux conservation
@@ -165,49 +431,54 @@ def validate_physics(trajectory: List[float], params: Dict) -> Dict:
165
 
166
  def sense_snn(snn_json: str) -> str:
167
  """
168
- /sense_snn - Process 72D SNN input
169
- Input: JSON with dendrite values
170
- Output: Coherent features + anomalies
 
171
  """
172
  try:
173
  data = json.loads(snn_json)
174
 
175
- # Extract 72D vector (or create from dendrites)
176
  if "vector_72d" in data:
177
  input_vec = np.array(data["vector_72d"])
178
  elif "dendrites" in data:
179
- dendrite_values = list(data["dendrites"].values())
180
- input_vec = np.array(dendrite_values)
181
  else:
182
- input_vec = np.random.randn(72) # Fallback
183
 
184
- # Pad to 256D
185
- input_256 = np.zeros(256)
186
- input_256[:min(len(input_vec), 256)] = input_vec[:min(len(input_vec), 256)]
187
 
188
  # Process through CTM
189
- result = ctm.process_ticks(input_256, n_ticks=25)
 
190
 
191
- # Detect anomalies (low certainty regions)
192
  anomalies = []
193
- if result["final_certainty"] < 0.5:
194
- anomalies.append("Low overall certainty")
195
 
196
  return json.dumps({
197
  "status": "success",
198
- "coherent_features": result["final_sync_matrix"][:16][:16], # 16x16 subset
199
- "certainty": result["final_certainty"],
 
200
  "anomalies": anomalies,
201
- "ticks_used": result["ticks_used"]
 
202
  }, indent=2)
203
  except Exception as e:
204
  return json.dumps({"status": "error", "message": str(e)})
205
 
 
206
  def reason_hypergraph(context_json: str) -> str:
207
  """
208
- /reason_hypergraph - Reason about hypergraph, propose edges
209
- Input: Node features + existing edges
210
- Output: Proposed new edges + certainty
 
211
  """
212
  try:
213
  data = json.loads(context_json)
@@ -217,38 +488,40 @@ def reason_hypergraph(context_json: str) -> str:
217
  n_ticks = data.get("ticks", 50)
218
 
219
  # Flatten node features for CTM input
220
- input_vec = node_features.flatten()
221
 
222
  # Process through CTM with more ticks for reasoning
223
- result = ctm.process_ticks(input_vec, n_ticks=n_ticks)
224
 
225
- # Extract proposed edges from sync matrix (S_ij > 0.8)
226
- sync = np.array(result["best_sync_matrix"])
227
  proposed_edges = []
228
-
229
- n_nodes = min(len(node_features), sync.shape[0])
230
- for i in range(n_nodes):
231
- for j in range(i+1, n_nodes):
232
- sync_ij = sync[i, j]
233
- if sync_ij > 0.8:
234
- # Check if edge already exists
235
- edge_exists = any(
236
- (e[0] == i and e[1] == j) or (e[0] == j and e[1] == i)
237
- for e in existing_edges
238
- )
239
- if not edge_exists:
240
- proposed_edges.append([i, j, float(sync_ij)])
 
241
 
242
  return json.dumps({
243
  "status": "success",
244
  "proposed_edges": proposed_edges,
245
- "certainty": result["max_certainty"],
246
  "best_tick": result["best_tick"],
247
- "ticks_used": result["ticks_used"]
 
248
  }, indent=2)
249
  except Exception as e:
250
  return json.dumps({"status": "error", "message": str(e)})
251
 
 
252
  def validate_physics_endpoint(physics_json: str) -> str:
253
  """
254
  /validate_physics - Validate trajectory against 5 physics losses
@@ -265,135 +538,238 @@ def validate_physics_endpoint(physics_json: str) -> str:
265
  except Exception as e:
266
  return json.dumps({"status": "error", "message": str(e)})
267
 
 
268
  def dream_endpoint(dream_json: str) -> str:
269
  """
270
- /dream - Offline consolidation with T=500+ ticks
271
- Input: Hypergraph snapshot
272
- Output: Discovered patterns + new edges
273
  """
274
  try:
275
  data = json.loads(dream_json)
276
 
277
  snapshot = data.get("hypergraph_snapshot", {})
278
- n_ticks = data.get("ticks", 500)
279
 
280
  # Extract features from snapshot
281
  nodes = snapshot.get("nodes", [])
282
  if nodes:
283
- input_vec = np.array([n.get("features", [0]*16) for n in nodes]).flatten()
284
  else:
285
- input_vec = np.random.randn(256) # Random dream if no nodes
286
-
287
- # Dream: run CTM with many ticks and no external input after initial
288
- result = ctm.process_ticks(input_vec, n_ticks=min(n_ticks, 100)) # Cap at 100 for CPU
289
 
290
- # Analyze sync evolution to find patterns
291
- sync = np.array(result["final_sync_matrix"])
292
 
293
- # Find strong sync pairs (new edges)
294
  new_edges = []
295
- n = min(len(nodes), sync.shape[0]) if nodes else 16
296
- for i in range(n):
297
- for j in range(i+1, n):
298
- if sync[i, j] > 0.85:
299
- new_edges.append([i, j, float(sync[i, j])])
300
-
301
- # Find weak sync pairs (edges to prune)
302
  pruned_edges = []
303
- for i in range(n):
304
- for j in range(i+1, n):
305
- if sync[i, j] < 0.1:
306
- pruned_edges.append([i, j])
 
 
 
 
 
 
 
307
 
308
  return json.dumps({
309
  "status": "success",
310
  "discovered_patterns": len(new_edges),
311
- "new_edges": new_edges[:10], # Top 10
312
- "pruned_edges": pruned_edges[:10], # Top 10
313
- "consolidation_certainty": result["max_certainty"],
314
- "ticks_used": result["ticks_used"]
 
315
  }, indent=2)
316
  except Exception as e:
317
  return json.dumps({"status": "error", "message": str(e)})
318
 
 
319
  def calibrate_stdp_endpoint(stdp_json: str) -> str:
320
  """
321
- /calibrate_stdp - Suggest STDP weight adjustments from sync
 
 
 
 
 
322
  """
323
  try:
324
  data = json.loads(stdp_json)
325
 
326
  current_weights = np.array(data.get("current_weights", [1.0]*16))
327
- node_features = np.array(data.get("node_features", [[0]*16]*8))
328
 
329
- # Process to get sync matrix
330
- input_vec = node_features.flatten()
331
- result = ctm.process_ticks(input_vec, n_ticks=25)
332
 
333
- sync = np.array(result["final_sync_matrix"])
 
 
 
 
334
 
335
- # Suggest weight adjustments based on sync patterns
336
- # Uses diagonal of sync (self-similarity) to scale weights
337
- suggested = np.zeros(16)
338
- for i in range(16):
339
- # Average sync of neuron i with others
340
- avg_sync = np.mean(sync[i, :])
341
- # Scale current weight by sync
342
- suggested[i] = current_weights[i] * (0.5 + avg_sync)
343
 
344
  return json.dumps({
345
  "status": "success",
346
- "suggested_weights": suggested.tolist(),
347
- "weight_changes": (suggested - current_weights).tolist(),
348
- "confidence": result["final_certainty"]
 
 
349
  }, indent=2)
350
  except Exception as e:
351
  return json.dumps({"status": "error", "message": str(e)})
352
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
353
  def health_check() -> str:
354
- """Health check for the CTM server"""
355
  return json.dumps({
356
  "status": "healthy",
357
- "model": "CTM Nervous System v1.0",
358
- "d_model": ctm.d_model,
359
- "memory_length": ctm.memory_length,
360
- "default_ticks": ctm.n_ticks,
 
 
361
  "endpoints": [
362
  "/sense_snn",
363
- "/reason_hypergraph",
364
  "/validate_physics",
365
  "/dream",
366
- "/calibrate_stdp"
 
 
367
  ]
368
  }, indent=2)
369
 
 
370
  # ============================================================================
371
  # GRADIO INTERFACE
372
  # ============================================================================
373
 
374
- with gr.Blocks(title="CTM Nervous System") as demo:
375
  gr.Markdown("""
376
- # 🧬 CTM Nervous System
377
- **Continuous Thought Machine for Hypergraph Maintenance**
378
 
379
  Based on [arXiv:2505.05522](https://arxiv.org/abs/2505.05522) - Sakana AI
380
 
381
  ---
382
 
383
- ## Endpoints
384
- - **/sense_snn**: Process 72D SNN input
385
- - **/reason_hypergraph**: Reason about context, propose edges
386
- - **/validate_physics**: Validate against 5 physics losses
387
- - **/dream**: Offline consolidation (T=500+)
388
- - **/calibrate_stdp**: Suggest STDP weight adjustments
 
389
  """)
390
 
391
  with gr.Tabs():
392
  with gr.Tab("🔌 /sense_snn"):
393
- gr.Markdown("Process 72D SNN input vector")
394
  snn_input = gr.Textbox(
395
  label="SNN JSON Input",
396
- value='{"dendrites": {"d1": 0.1, "d2": 0.2, "d3": 0.3}}',
397
  lines=5
398
  )
399
  snn_output = gr.Textbox(label="Output", lines=10)
@@ -401,7 +777,7 @@ with gr.Blocks(title="CTM Nervous System") as demo:
401
  snn_btn.click(sense_snn, inputs=snn_input, outputs=snn_output, api_name="sense_snn")
402
 
403
  with gr.Tab("🧠 /reason_hypergraph"):
404
- gr.Markdown("Reason about hypergraph context")
405
  reason_input = gr.Textbox(
406
  label="Context JSON",
407
  value='{"node_features": [[0.1, 0.2], [0.3, 0.4]], "existing_edges": [], "ticks": 50}',
@@ -412,7 +788,7 @@ with gr.Blocks(title="CTM Nervous System") as demo:
412
  reason_btn.click(reason_hypergraph, inputs=reason_input, outputs=reason_output, api_name="reason_hypergraph")
413
 
414
  with gr.Tab("⚡ /validate_physics"):
415
- gr.Markdown("Validate against 5 physics losses")
416
  physics_input = gr.Textbox(
417
  label="Physics JSON",
418
  value='{"trajectory": [0.1, 0.2, 0.3], "physics_params": {"P_max": 1000}}',
@@ -423,7 +799,7 @@ with gr.Blocks(title="CTM Nervous System") as demo:
423
  physics_btn.click(validate_physics_endpoint, inputs=physics_input, outputs=physics_output, api_name="validate_physics")
424
 
425
  with gr.Tab("💤 /dream"):
426
- gr.Markdown("Offline consolidation")
427
  dream_input = gr.Textbox(
428
  label="Dream JSON",
429
  value='{"hypergraph_snapshot": {"nodes": []}, "ticks": 100}',
@@ -434,7 +810,7 @@ with gr.Blocks(title="CTM Nervous System") as demo:
434
  dream_btn.click(dream_endpoint, inputs=dream_input, outputs=dream_output, api_name="dream")
435
 
436
  with gr.Tab("🔧 /calibrate_stdp"):
437
- gr.Markdown("Calibrate STDP weights")
438
  stdp_input = gr.Textbox(
439
  label="STDP JSON",
440
  value='{"current_weights": [1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1], "node_features": [[0.1, 0.2]]}',
@@ -444,16 +820,40 @@ with gr.Blocks(title="CTM Nervous System") as demo:
444
  stdp_btn = gr.Button("Calibrate", variant="primary")
445
  stdp_btn.click(calibrate_stdp_endpoint, inputs=stdp_input, outputs=stdp_output, api_name="calibrate_stdp")
446
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
447
  with gr.Tab("❤️ Health"):
448
- health_output = gr.Textbox(label="Health Status", lines=10)
449
  health_btn = gr.Button("Check Health", variant="secondary")
450
  health_btn.click(health_check, inputs=None, outputs=health_output, api_name="health_check")
451
 
452
  gr.Markdown("""
453
  ---
454
- **Architecture**: CTM as Nervous System → Hypergraph as Thought
 
 
455
 
456
- **Training**: Min-Loss + Max-Certainty + Physics Regularization
457
  """)
458
 
459
  if __name__ == "__main__":
@@ -461,4 +861,4 @@ if __name__ == "__main__":
461
  server_name="0.0.0.0",
462
  server_port=7860,
463
  show_error=True
464
- )
 
1
  """
2
+ CTM Nervous System Server v2.0 - Full PyTorch Implementation
3
+ =============================================================
4
+ Continuous Thought Machine for ART-17 Hypergraph Coherence Generation
5
 
6
+ PURPOSE (from skills):
7
+ 1. REGULACIÓN: Calibrar pesos STDP de las 16 dendritas
8
+ 2. COHERENCIA: Generar hipergrafos deterministas
9
+ 3. RAZONAMIENTO: Motor de inferencia activa (internal ticks)
10
+ 4. SINCRONIZACIÓN: Representación via Neural Synchronization
11
+
12
+ TRAINING STRATEGY:
13
+ - Progressive online learning with use
14
+ - Integrates with Brain server (Qwen + VL-JEPA) for semantic grounding
15
+ - Automatic checkpoint saving
16
 
17
  Based on: arXiv:2505.05522 (Continuous Thought Machines - Sakana AI)
18
+ Adapted for: ART-17 Dendrite Regulation & Hypergraph Generation
19
  """
20
 
21
  import gradio as gr
22
  import numpy as np
23
  import json
 
24
  import os
25
+ from typing import List, Dict, Any, Optional
26
+ from datetime import datetime
27
+
28
+ # ============================================================================
29
+ # PYTORCH IMPORTS WITH FALLBACK
30
+ # ============================================================================
31
+ try:
32
+ import torch
33
+ import torch.nn as nn
34
+ import torch.nn.functional as F
35
+ TORCH_AVAILABLE = True
36
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
37
+ print(f"🔧 PyTorch available. Device: {DEVICE}")
38
+ except ImportError:
39
+ TORCH_AVAILABLE = False
40
+ DEVICE = "cpu"
41
+ print("⚠️ PyTorch not available. Using simplified NumPy fallback.")
42
+
43
+ # ============================================================================
44
+ # FULL CTM IMPORT (with fallback to simplified)
45
+ # ============================================================================
46
+ if TORCH_AVAILABLE:
47
+ try:
48
+ from models.ctm import ContinuousThoughtMachine
49
+ from models.modules import SynapseUNET, SuperLinear
50
+ from utils.losses import image_classification_loss
51
+ CTM_FULL = True
52
+ print("✅ Full CTM model loaded from models/ctm.py")
53
+ except ImportError as e:
54
+ CTM_FULL = False
55
+ print(f"⚠️ Could not import full CTM: {e}. Using simplified.")
56
+ else:
57
+ CTM_FULL = False
58
 
59
  # ============================================================================
60
+ # CONFIGURATION FOR ART-17 INTEGRATION
61
  # ============================================================================
62
+ CONFIG = {
63
+ # CTM Architecture (matching ART-17)
64
+ "iterations": 50, # T internal ticks
65
+ "d_model": 256, # Latent dimension
66
+ "d_input": 72, # Input from SNN (72D)
67
+ "memory_length": 16, # History length (16 dendrites)
68
+ "n_synch_out": 32, # Output sync neurons
69
+ "n_synch_action": 16, # Action sync neurons
70
+ "out_dims": 16, # Output: 16 dendrite adjustments
71
+
72
+ # Training
73
+ "learning_rate": 1e-4,
74
+ "weight_decay": 1e-5,
75
+ "checkpoint_dir": "checkpoints",
76
+ "auto_save_every": 100, # Save every N forward passes
77
+
78
+ # Integration
79
+ "brain_server_url": "https://elliotasdasdasfasas-brain.hf.space",
80
+
81
+ # Physics validation
82
+ "physics_thresholds": {
83
+ "P_max": 1000.0,
84
+ "v_max": 100.0,
85
+ "T_dew": 15.0,
86
+ "T_amb": 25.0
87
+ }
88
+ }
89
 
90
+ # ============================================================================
91
+ # FULL CTM WRAPPER FOR ART-17
92
+ # ============================================================================
93
+ class CTM_ART17:
94
  """
95
+ Full Continuous Thought Machine adapted for ART-17.
96
+
97
+ Key mechanisms from paper:
98
+ 1. NLMs (Neuron-Level Models) - Each neuron processes its own history
99
+ 2. Neural Synchronization - Representation is S = Z·Z^T
100
+ 3. Adaptive Compute - Can halt early when confident
101
+
102
+ Purpose in ART-17:
103
+ - Regulate 16 dendrite STDP weights
104
+ - Generate coherent hypergraph edges
105
+ - Serve as "nervous system" for the whole system
106
  """
107
 
108
+ def __init__(self, config: dict):
109
+ self.config = config
110
+ self.forward_count = 0
111
+ self.training_samples = []
112
+
113
+ if CTM_FULL and TORCH_AVAILABLE:
114
+ self._init_full_ctm()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
115
  else:
116
+ self._init_simplified_ctm()
117
+
118
+ def _init_full_ctm(self):
119
+ """Initialize full PyTorch CTM model."""
120
+ self.model = ContinuousThoughtMachine(
121
+ iterations=self.config["iterations"],
122
+ d_model=self.config["d_model"],
123
+ d_input=self.config["d_input"],
124
+ heads=4,
125
+ n_synch_out=self.config["n_synch_out"],
126
+ n_synch_action=self.config["n_synch_action"],
127
+ synapse_depth=2,
128
+ memory_length=self.config["memory_length"],
129
+ deep_nlms=True,
130
+ memory_hidden_dims=32,
131
+ do_layernorm_nlm=False,
132
+ backbone_type='none',
133
+ positional_embedding_type='none',
134
+ out_dims=self.config["out_dims"],
135
+ prediction_reshaper=[self.config["out_dims"]],
136
+ dropout=0.1,
137
+ neuron_select_type='random-pairing'
138
+ ).to(DEVICE)
139
+
140
+ # Dummy forward to initialize lazy modules
141
+ with torch.no_grad():
142
+ dummy = torch.randn(1, self.config["d_input"], device=DEVICE)
143
+ dummy = dummy.unsqueeze(-1).unsqueeze(-1) # [1, 72, 1, 1]
144
+ try:
145
+ _ = self.model(dummy)
146
+ except Exception as e:
147
+ print(f"⚠️ Lazy init failed: {e}")
148
+
149
+ self.model.eval()
150
+ self.optimizer = torch.optim.AdamW(
151
+ self.model.parameters(),
152
+ lr=self.config["learning_rate"],
153
+ weight_decay=self.config["weight_decay"]
154
+ )
155
+
156
+ self.is_full = True
157
+ param_count = sum(p.numel() for p in self.model.parameters())
158
+ print(f"✅ Full CTM initialized: {param_count:,} parameters")
159
+
160
+ # Try to load existing checkpoint
161
+ self._load_checkpoint()
162
+
163
+ def _init_simplified_ctm(self):
164
+ """Initialize simplified NumPy CTM (fallback)."""
165
+ self.d_model = self.config["d_model"]
166
+ self.memory_length = self.config["memory_length"]
167
+ self.n_ticks = self.config["iterations"]
168
+
169
+ # State traces
170
+ self.state_trace = np.zeros((self.d_model, self.memory_length))
171
+ self.activated_state = np.random.randn(self.d_model) * 0.1
172
+
173
+ # NLM weights (simplified: 16 groups for 16 dendrites)
174
+ self.nlm_weights = np.random.randn(16, self.memory_length) * 0.1
175
+
176
+ self.is_full = False
177
+ print("✅ Simplified CTM initialized (NumPy fallback)")
178
+
179
+ def forward(self, input_72d: np.ndarray, n_ticks: Optional[int] = None) -> Dict:
180
+ """
181
+ Process input through CTM.
182
+
183
+ Args:
184
+ input_72d: 72D input from SNN
185
+ n_ticks: Override number of internal ticks
186
+
187
+ Returns:
188
+ Dict with predictions, certainty, sync matrix
189
+ """
190
+ n_ticks = n_ticks or self.config["iterations"]
191
+ self.forward_count += 1
192
+
193
+ if self.is_full:
194
+ return self._forward_full(input_72d, n_ticks)
195
+ else:
196
+ return self._forward_simplified(input_72d, n_ticks)
197
+
198
+ def _forward_full(self, input_72d: np.ndarray, n_ticks: int) -> Dict:
199
+ """Forward pass with full PyTorch CTM."""
200
+ # Prepare tensor
201
+ x = torch.tensor(input_72d, dtype=torch.float32, device=DEVICE)
202
+ if len(x.shape) == 1:
203
+ x = x.unsqueeze(0) # Add batch dim
204
+ x = x.unsqueeze(-1).unsqueeze(-1) # [B, 72, 1, 1]
205
+
206
+ with torch.no_grad():
207
+ predictions, certainties, sync_out = self.model(x)
208
+
209
+ # Extract results
210
+ final_pred = predictions[:, :, -1].cpu().numpy()[0] # Last tick [16]
211
+ final_cert = certainties[:, 1, -1].cpu().numpy()[0] # 1-entropy
212
+
213
+ # Find tick with highest certainty
214
+ best_tick_idx = certainties[:, 1, :].argmax(dim=-1)[0].item()
215
+ best_pred = predictions[:, :, best_tick_idx].cpu().numpy()[0]
216
+
217
+ # Sync matrix for hypergraph edge proposals
218
+ sync_matrix = sync_out.cpu().numpy()[0] if sync_out is not None else None
219
+
220
+ return {
221
+ "predictions": final_pred.tolist(),
222
+ "best_predictions": best_pred.tolist(),
223
+ "certainty": float(final_cert),
224
+ "best_tick": int(best_tick_idx),
225
+ "ticks_used": n_ticks,
226
+ "sync_matrix": sync_matrix.tolist() if sync_matrix is not None else None,
227
+ "model": "ContinuousThoughtMachine (Full PyTorch)"
228
+ }
229
+
230
+ def _forward_simplified(self, input_72d: np.ndarray, n_ticks: int) -> Dict:
231
+ """Forward pass with simplified NumPy CTM."""
232
+ # Pad to d_model
233
+ input_256 = np.zeros(self.d_model)
234
+ input_256[:min(len(input_72d), self.d_model)] = input_72d[:self.d_model]
235
 
236
  certainties = []
 
237
 
238
  for t in range(n_ticks):
239
+ # Synapse update (simplified)
240
+ combined = np.concatenate([self.activated_state, input_256[:self.d_model//2]])
241
  pre_activation = np.tanh(combined[:self.d_model] * 0.1 + np.random.randn(self.d_model) * 0.01)
242
 
243
+ # Update trace (memory)
244
  self.state_trace = np.roll(self.state_trace, -1, axis=1)
245
  self.state_trace[:, -1] = pre_activation
246
 
247
+ # NLM processing (simplified: 16 groups)
248
  post_activation = np.zeros(self.d_model)
249
  group_size = self.d_model // 16
250
  for g in range(16):
 
256
 
257
  self.activated_state = post_activation
258
 
259
+ # Compute certainty
260
+ probs = np.abs(self.activated_state) / (np.sum(np.abs(self.activated_state)) + 1e-8)
261
+ probs = np.clip(probs, 1e-10, 1.0)
262
+ entropy = -np.sum(probs * np.log(probs))
263
+ max_entropy = np.log(len(probs))
264
+ certainties.append(float(1.0 - entropy / (max_entropy + 1e-8)))
265
+
266
+ # Best tick
267
+ best_tick_idx = int(np.argmax(certainties))
268
+
269
+ # Sync matrix
270
+ z_norm = self.activated_state / (np.linalg.norm(self.activated_state) + 1e-8)
271
+ sync_matrix = np.outer(z_norm, z_norm)
272
+
273
+ # Predictions (first 16 elements of activated state)
274
+ predictions = self.activated_state[:16].tolist()
275
+
276
+ return {
277
+ "predictions": predictions,
278
+ "best_predictions": predictions,
279
+ "certainty": certainties[-1],
280
+ "best_tick": best_tick_idx,
281
+ "ticks_used": n_ticks,
282
+ "sync_matrix": sync_matrix[:16, :16].tolist(),
283
+ "model": "SimplifiedCTM (NumPy fallback)"
284
+ }
285
+
286
+ def train_step(self, input_72d: np.ndarray, target_16d: np.ndarray,
287
+ physics_loss: float = 0.0) -> Dict:
288
+ """
289
+ Online training step.
290
+
291
+ Args:
292
+ input_72d: Input from SNN
293
+ target_16d: Target dendrite adjustments (ground truth)
294
+ physics_loss: Current physics loss for weighting
295
 
296
+ Returns:
297
+ Dict with loss and gradient info
298
+ """
299
+ if not self.is_full or not TORCH_AVAILABLE:
300
+ return {"status": "skip", "reason": "Training requires full PyTorch CTM"}
301
+
302
+ self.model.train()
303
+
304
+ # Prepare tensors
305
+ x = torch.tensor(input_72d, dtype=torch.float32, device=DEVICE)
306
+ x = x.unsqueeze(0).unsqueeze(-1).unsqueeze(-1) # [1, 72, 1, 1]
307
+ y = torch.tensor(target_16d, dtype=torch.float32, device=DEVICE).unsqueeze(0)
308
+
309
+ # Forward
310
+ predictions, certainties, _ = self.model(x)
311
 
312
+ # Loss: dendrite_regulation_loss
313
+ # predictions: [B, 16, T], y: [B, 16]
314
+ y_exp = y.unsqueeze(-1).expand(-1, -1, predictions.size(-1)) # [B, 16, T]
315
+ mse_per_tick = F.mse_loss(predictions, y_exp, reduction='none').mean(dim=1) # [B, T]
316
+
317
+ # Select best tick (min loss) and most certain tick
318
+ loss_min_idx = mse_per_tick.argmin(dim=1) # [B]
319
+ loss_cert_idx = certainties[:, 1, :].argmax(dim=1) # [B]
320
+
321
+ batch_idx = torch.arange(predictions.size(0), device=DEVICE)
322
+ loss_min = mse_per_tick[batch_idx, loss_min_idx].mean()
323
+ loss_cert = mse_per_tick[batch_idx, loss_cert_idx].mean()
324
+
325
+ # Combined loss with physics penalty
326
+ mse_loss = (loss_min + loss_cert) / 2
327
+ physics_penalty = physics_loss * 0.1
328
+ total_loss = mse_loss + physics_penalty
329
+
330
+ # Backward
331
+ self.optimizer.zero_grad()
332
+ total_loss.backward()
333
+ torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
334
+ self.optimizer.step()
335
+
336
+ self.model.eval()
337
+
338
+ # Auto-save checkpoint
339
+ if self.forward_count % self.config["auto_save_every"] == 0:
340
+ self._save_checkpoint()
341
 
342
  return {
343
+ "status": "trained",
344
+ "loss": float(total_loss.item()),
345
+ "mse_loss": float(mse_loss.item()),
346
+ "physics_penalty": float(physics_penalty),
347
+ "best_tick": int(loss_cert_idx[0].item())
 
 
348
  }
349
+
350
+ def _save_checkpoint(self):
351
+ """Save model checkpoint."""
352
+ if not self.is_full:
353
+ return
354
+
355
+ os.makedirs(self.config["checkpoint_dir"], exist_ok=True)
356
+ path = os.path.join(self.config["checkpoint_dir"], "ctm_art17_latest.pt")
357
+
358
+ torch.save({
359
+ "model_state_dict": self.model.state_dict(),
360
+ "optimizer_state_dict": self.optimizer.state_dict(),
361
+ "forward_count": self.forward_count,
362
+ "timestamp": datetime.now().isoformat()
363
+ }, path)
364
+ print(f"💾 Checkpoint saved: {path}")
365
+
366
+ def _load_checkpoint(self):
367
+ """Load model checkpoint if exists."""
368
+ path = os.path.join(self.config["checkpoint_dir"], "ctm_art17_latest.pt")
369
+ if os.path.exists(path):
370
+ try:
371
+ checkpoint = torch.load(path, map_location=DEVICE)
372
+ self.model.load_state_dict(checkpoint["model_state_dict"])
373
+ self.optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
374
+ self.forward_count = checkpoint.get("forward_count", 0)
375
+ print(f"✅ Checkpoint loaded: {path}")
376
+ except Exception as e:
377
+ print(f"⚠️ Could not load checkpoint: {e}")
378
 
379
+ # ============================================================================
380
+ # GLOBAL CTM INSTANCE
381
+ # ============================================================================
382
+ ctm = CTM_ART17(CONFIG)
383
 
384
  # ============================================================================
385
  # PHYSICS VALIDATION (from SNN Omega-21)
386
  # ============================================================================
 
387
  def validate_physics(trajectory: List[float], params: Dict) -> Dict:
388
+ """Validate against 5 physics losses from SNN Omega-21."""
389
  trajectory = np.array(trajectory)
390
 
391
  # L_energy: Energy conservation
392
  energy = np.sum(trajectory ** 2)
393
+ P_max = params.get("P_max", CONFIG["physics_thresholds"]["P_max"])
394
  L_energy = float(max(0, energy - P_max) ** 2)
395
 
396
  # L_thermo: Thermodynamics (dew point check)
397
+ T_dew = params.get("T_dew", CONFIG["physics_thresholds"]["T_dew"])
398
+ T_amb = params.get("T_amb", CONFIG["physics_thresholds"]["T_amb"])
399
  L_thermo = float(max(0, T_dew - T_amb) ** 2)
400
 
401
  # L_causal: Causality (velocity limit)
402
  velocity = np.diff(trajectory) if len(trajectory) > 1 else np.array([0])
403
+ v_max = params.get("v_max", CONFIG["physics_thresholds"]["v_max"])
404
  L_causal = float(np.sum(np.maximum(0, np.abs(velocity) - v_max) ** 2))
405
 
406
  # L_conserv: Flux conservation
 
431
 
432
  def sense_snn(snn_json: str) -> str:
433
  """
434
+ /sense_snn - Process 72D SNN input through CTM
435
+
436
+ Input: JSON with dendrite values or 72D vector
437
+ Output: Coherent features, certainty, sync matrix
438
  """
439
  try:
440
  data = json.loads(snn_json)
441
 
442
+ # Extract 72D vector
443
  if "vector_72d" in data:
444
  input_vec = np.array(data["vector_72d"])
445
  elif "dendrites" in data:
446
+ input_vec = np.array(list(data["dendrites"].values()))
 
447
  else:
448
+ input_vec = np.random.randn(72)
449
 
450
+ # Pad to 72D if needed
451
+ if len(input_vec) < 72:
452
+ input_vec = np.pad(input_vec, (0, 72 - len(input_vec)))
453
 
454
  # Process through CTM
455
+ n_ticks = data.get("ticks", 25)
456
+ result = ctm.forward(input_vec[:72], n_ticks)
457
 
458
+ # Detect anomalies (low certainty)
459
  anomalies = []
460
+ if result["certainty"] < 0.5:
461
+ anomalies.append("Low overall certainty - consider retraining")
462
 
463
  return json.dumps({
464
  "status": "success",
465
+ "coherent_features": result["predictions"],
466
+ "certainty": result["certainty"],
467
+ "best_tick": result["best_tick"],
468
  "anomalies": anomalies,
469
+ "ticks_used": result["ticks_used"],
470
+ "model": result["model"]
471
  }, indent=2)
472
  except Exception as e:
473
  return json.dumps({"status": "error", "message": str(e)})
474
 
475
+
476
  def reason_hypergraph(context_json: str) -> str:
477
  """
478
+ /reason_hypergraph - Reason about hypergraph context, propose edges
479
+
480
+ Uses CTM synchronization matrix to find strongly correlated node pairs.
481
+ These become proposed hyperedges.
482
  """
483
  try:
484
  data = json.loads(context_json)
 
488
  n_ticks = data.get("ticks", 50)
489
 
490
  # Flatten node features for CTM input
491
+ input_vec = node_features.flatten()[:72]
492
 
493
  # Process through CTM with more ticks for reasoning
494
+ result = ctm.forward(input_vec, n_ticks)
495
 
496
+ # Extract proposed edges from sync matrix (S_ij > 0.7)
 
497
  proposed_edges = []
498
+ if result["sync_matrix"] is not None:
499
+ sync = np.array(result["sync_matrix"])
500
+ n_nodes = min(len(node_features), sync.shape[0])
501
+
502
+ for i in range(n_nodes):
503
+ for j in range(i+1, n_nodes):
504
+ sync_ij = sync[i, j]
505
+ if sync_ij > 0.7: # Threshold for edge proposal
506
+ edge_exists = any(
507
+ (e[0] == i and e[1] == j) or (e[0] == j and e[1] == i)
508
+ for e in existing_edges
509
+ )
510
+ if not edge_exists:
511
+ proposed_edges.append([i, j, float(sync_ij)])
512
 
513
  return json.dumps({
514
  "status": "success",
515
  "proposed_edges": proposed_edges,
516
+ "certainty": result["certainty"],
517
  "best_tick": result["best_tick"],
518
+ "ticks_used": result["ticks_used"],
519
+ "model": result["model"]
520
  }, indent=2)
521
  except Exception as e:
522
  return json.dumps({"status": "error", "message": str(e)})
523
 
524
+
525
  def validate_physics_endpoint(physics_json: str) -> str:
526
  """
527
  /validate_physics - Validate trajectory against 5 physics losses
 
538
  except Exception as e:
539
  return json.dumps({"status": "error", "message": str(e)})
540
 
541
+
542
  def dream_endpoint(dream_json: str) -> str:
543
  """
544
+ /dream - Offline consolidation with many ticks
545
+
546
+ Discovers patterns, proposes new edges, identifies edges to prune.
547
  """
548
  try:
549
  data = json.loads(dream_json)
550
 
551
  snapshot = data.get("hypergraph_snapshot", {})
552
+ n_ticks = min(data.get("ticks", 100), 100) # Cap at 100 for CPU
553
 
554
  # Extract features from snapshot
555
  nodes = snapshot.get("nodes", [])
556
  if nodes:
557
+ input_vec = np.array([n.get("features", [0]*16) for n in nodes]).flatten()[:72]
558
  else:
559
+ input_vec = np.random.randn(72)
 
 
 
560
 
561
+ # Dream: run CTM with many ticks
562
+ result = ctm.forward(input_vec, n_ticks)
563
 
564
+ # Analyze sync for patterns
565
  new_edges = []
 
 
 
 
 
 
 
566
  pruned_edges = []
567
+
568
+ if result["sync_matrix"] is not None:
569
+ sync = np.array(result["sync_matrix"])
570
+ n = min(len(nodes), sync.shape[0]) if nodes else 16
571
+
572
+ for i in range(n):
573
+ for j in range(i+1, n):
574
+ if sync[i, j] > 0.85:
575
+ new_edges.append([i, j, float(sync[i, j])])
576
+ elif sync[i, j] < 0.1:
577
+ pruned_edges.append([i, j])
578
 
579
  return json.dumps({
580
  "status": "success",
581
  "discovered_patterns": len(new_edges),
582
+ "new_edges": new_edges[:10],
583
+ "pruned_edges": pruned_edges[:10],
584
+ "consolidation_certainty": result["certainty"],
585
+ "ticks_used": result["ticks_used"],
586
+ "model": result["model"]
587
  }, indent=2)
588
  except Exception as e:
589
  return json.dumps({"status": "error", "message": str(e)})
590
 
591
+
592
  def calibrate_stdp_endpoint(stdp_json: str) -> str:
593
  """
594
+ /calibrate_stdp - Suggest STDP weight adjustments
595
+
596
+ This is the CORE regulatory function:
597
+ - Receives current 16 dendrite weights
598
+ - Processes through CTM to get sync patterns
599
+ - Returns suggested weight adjustments
600
  """
601
  try:
602
  data = json.loads(stdp_json)
603
 
604
  current_weights = np.array(data.get("current_weights", [1.0]*16))
605
+ node_features = np.array(data.get("node_features", [[0]*16]*4))
606
 
607
+ # Flatten features for CTM input
608
+ input_vec = node_features.flatten()[:72]
 
609
 
610
+ # Process through CTM
611
+ result = ctm.forward(input_vec, n_ticks=25)
612
+
613
+ # Use predictions as weight adjustments
614
+ predictions = np.array(result["best_predictions"])
615
 
616
+ # Scale based on certainty
617
+ confidence = result["certainty"]
618
+ weight_changes = (predictions - 0.5) * confidence * 0.1
619
+
620
+ new_weights = current_weights + weight_changes
 
 
 
621
 
622
  return json.dumps({
623
  "status": "success",
624
+ "suggested_weights": new_weights.tolist(),
625
+ "weight_changes": weight_changes.tolist(),
626
+ "confidence": confidence,
627
+ "best_tick": result["best_tick"],
628
+ "model": result["model"]
629
  }, indent=2)
630
  except Exception as e:
631
  return json.dumps({"status": "error", "message": str(e)})
632
 
633
+
634
+ def regulate_endpoint(regulate_json: str) -> str:
635
+ """
636
+ /regulate - Full feedback loop for ART-17 regulation (NEW)
637
+
638
+ Combines all signals to provide comprehensive regulation:
639
+ - Dendrite state
640
+ - Latent representation
641
+ - Physics loss
642
+ - Anomaly score
643
+
644
+ Returns action recommendation with confidence.
645
+ """
646
+ try:
647
+ data = json.loads(regulate_json)
648
+
649
+ # Inputs from local system
650
+ dendrites = np.array(data.get("dendrites", [0.0]*16))
651
+ latent_256 = np.array(data.get("latent_256", [0.0]*256))
652
+ physics_loss = data.get("physics_loss", 0.0)
653
+ anomaly_score = data.get("anomaly_score", 0.0)
654
+
655
+ # Combine into 72D input
656
+ input_72 = np.concatenate([
657
+ dendrites, # 16D
658
+ latent_256[:56] # 56D from latent
659
+ ])
660
+
661
+ # Process through CTM
662
+ result = ctm.forward(input_72, n_ticks=50)
663
+
664
+ # Compute regulation signals
665
+ predictions = np.array(result["best_predictions"])
666
+ certainty = result["certainty"]
667
+
668
+ # Urgency based on physics and anomaly
669
+ urgency = min(1.0, physics_loss + anomaly_score)
670
+ regulation_strength = urgency * certainty
671
+
672
+ # Weight adjustments
673
+ dendrite_deltas = predictions * regulation_strength * 0.05
674
+
675
+ # Determine if intervention needed
676
+ needs_intervention = urgency > 0.5 or certainty < 0.3
677
+
678
+ return json.dumps({
679
+ "status": "success",
680
+ "dendrite_deltas": dendrite_deltas.tolist(),
681
+ "regulation_strength": float(regulation_strength),
682
+ "confidence": certainty,
683
+ "urgency": float(urgency),
684
+ "needs_intervention": needs_intervention,
685
+ "recommended_action": "ADJUST" if needs_intervention else "MAINTAIN",
686
+ "best_tick": result["best_tick"],
687
+ "model": result["model"]
688
+ }, indent=2)
689
+ except Exception as e:
690
+ return json.dumps({"status": "error", "message": str(e)})
691
+
692
+
693
+ def train_online_endpoint(train_json: str) -> str:
694
+ """
695
+ /train_online - Progressive online training (NEW)
696
+
697
+ Allows the local system to train the CTM with experience.
698
+ Sends input-output pairs and receives training feedback.
699
+ """
700
+ try:
701
+ data = json.loads(train_json)
702
+
703
+ input_72d = np.array(data.get("input_72d", [0.0]*72))
704
+ target_16d = np.array(data.get("target_16d", [0.0]*16))
705
+ physics_loss = data.get("physics_loss", 0.0)
706
+
707
+ # Perform training step
708
+ result = ctm.train_step(input_72d, target_16d, physics_loss)
709
+
710
+ return json.dumps({
711
+ "status": result["status"],
712
+ "loss": result.get("loss"),
713
+ "mse_loss": result.get("mse_loss"),
714
+ "physics_penalty": result.get("physics_penalty"),
715
+ "best_tick": result.get("best_tick"),
716
+ "forward_count": ctm.forward_count,
717
+ "message": "Training step completed" if result["status"] == "trained" else result.get("reason")
718
+ }, indent=2)
719
+ except Exception as e:
720
+ return json.dumps({"status": "error", "message": str(e)})
721
+
722
+
723
  def health_check() -> str:
724
+ """Health check with model info."""
725
  return json.dumps({
726
  "status": "healthy",
727
+ "model": f"CTM Nervous System v2.0 ({'Full PyTorch' if ctm.is_full else 'NumPy Fallback'})",
728
+ "device": DEVICE,
729
+ "d_model": CONFIG["d_model"],
730
+ "iterations": CONFIG["iterations"],
731
+ "memory_length": CONFIG["memory_length"],
732
+ "forward_count": ctm.forward_count,
733
  "endpoints": [
734
  "/sense_snn",
735
+ "/reason_hypergraph",
736
  "/validate_physics",
737
  "/dream",
738
+ "/calibrate_stdp",
739
+ "/regulate", # NEW
740
+ "/train_online" # NEW
741
  ]
742
  }, indent=2)
743
 
744
+
745
  # ============================================================================
746
  # GRADIO INTERFACE
747
  # ============================================================================
748
 
749
+ with gr.Blocks(title="CTM Nervous System v2.0", theme=gr.themes.Soft()) as demo:
750
  gr.Markdown("""
751
+ # 🧬 CTM Nervous System v2.0
752
+ **Continuous Thought Machine for ART-17 Hypergraph Coherence**
753
 
754
  Based on [arXiv:2505.05522](https://arxiv.org/abs/2505.05522) - Sakana AI
755
 
756
  ---
757
 
758
+ ## Key Innovations
759
+ - **NLMs (Neuron-Level Models)**: Each neuron processes its own history
760
+ - **Neural Synchronization**: Representation via S = Z·Z^T
761
+ - **Adaptive Compute**: Halts when confident
762
+ - **Online Training**: Progressive learning with use
763
+
764
+ ---
765
  """)
766
 
767
  with gr.Tabs():
768
  with gr.Tab("🔌 /sense_snn"):
769
+ gr.Markdown("Process 72D SNN input through CTM")
770
  snn_input = gr.Textbox(
771
  label="SNN JSON Input",
772
+ value='{"dendrites": {"d1": 0.1, "d2": 0.2, "d3": 0.3}, "ticks": 25}',
773
  lines=5
774
  )
775
  snn_output = gr.Textbox(label="Output", lines=10)
 
777
  snn_btn.click(sense_snn, inputs=snn_input, outputs=snn_output, api_name="sense_snn")
778
 
779
  with gr.Tab("🧠 /reason_hypergraph"):
780
+ gr.Markdown("Reason about hypergraph context, propose edges")
781
  reason_input = gr.Textbox(
782
  label="Context JSON",
783
  value='{"node_features": [[0.1, 0.2], [0.3, 0.4]], "existing_edges": [], "ticks": 50}',
 
788
  reason_btn.click(reason_hypergraph, inputs=reason_input, outputs=reason_output, api_name="reason_hypergraph")
789
 
790
  with gr.Tab("⚡ /validate_physics"):
791
+ gr.Markdown("Validate trajectory against 5 physics losses")
792
  physics_input = gr.Textbox(
793
  label="Physics JSON",
794
  value='{"trajectory": [0.1, 0.2, 0.3], "physics_params": {"P_max": 1000}}',
 
799
  physics_btn.click(validate_physics_endpoint, inputs=physics_input, outputs=physics_output, api_name="validate_physics")
800
 
801
  with gr.Tab("💤 /dream"):
802
+ gr.Markdown("Offline consolidation - discover patterns")
803
  dream_input = gr.Textbox(
804
  label="Dream JSON",
805
  value='{"hypergraph_snapshot": {"nodes": []}, "ticks": 100}',
 
810
  dream_btn.click(dream_endpoint, inputs=dream_input, outputs=dream_output, api_name="dream")
811
 
812
  with gr.Tab("🔧 /calibrate_stdp"):
813
+ gr.Markdown("Calibrate STDP weights (Core regulatory function)")
814
  stdp_input = gr.Textbox(
815
  label="STDP JSON",
816
  value='{"current_weights": [1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1], "node_features": [[0.1, 0.2]]}',
 
820
  stdp_btn = gr.Button("Calibrate", variant="primary")
821
  stdp_btn.click(calibrate_stdp_endpoint, inputs=stdp_input, outputs=stdp_output, api_name="calibrate_stdp")
822
 
823
+ with gr.Tab("🎯 /regulate [NEW]"):
824
+ gr.Markdown("Full feedback loop for ART-17 regulation")
825
+ regulate_input = gr.Textbox(
826
+ label="Regulate JSON",
827
+ value='{"dendrites": [0.5]*16, "latent_256": [0.1]*256, "physics_loss": 0.01, "anomaly_score": 0.05}',
828
+ lines=5
829
+ )
830
+ regulate_output = gr.Textbox(label="Output", lines=10)
831
+ regulate_btn = gr.Button("Regulate", variant="primary")
832
+ regulate_btn.click(regulate_endpoint, inputs=regulate_input, outputs=regulate_output, api_name="regulate")
833
+
834
+ with gr.Tab("📚 /train_online [NEW]"):
835
+ gr.Markdown("Progressive online training with experience")
836
+ train_input = gr.Textbox(
837
+ label="Training JSON",
838
+ value='{"input_72d": [0.1]*72, "target_16d": [0.5]*16, "physics_loss": 0.01}',
839
+ lines=5
840
+ )
841
+ train_output = gr.Textbox(label="Output", lines=10)
842
+ train_btn = gr.Button("Train Step", variant="primary")
843
+ train_btn.click(train_online_endpoint, inputs=train_input, outputs=train_output, api_name="train_online")
844
+
845
  with gr.Tab("❤️ Health"):
846
+ health_output = gr.Textbox(label="Health Status", lines=15)
847
  health_btn = gr.Button("Check Health", variant="secondary")
848
  health_btn.click(health_check, inputs=None, outputs=health_output, api_name="health_check")
849
 
850
  gr.Markdown("""
851
  ---
852
+ **Architecture**: CTM as Nervous System → Hypergraph as Coherent Thought
853
+
854
+ **Integration**: Local ART-17 ↔ CTM (regulation) ↔ Brain Server (semantics)
855
 
856
+ **Training**: Progressive online learning + Physics-Informed Loss
857
  """)
858
 
859
  if __name__ == "__main__":
 
861
  server_name="0.0.0.0",
862
  server_port=7860,
863
  show_error=True
864
+ )
app_v1_backup.py ADDED
@@ -0,0 +1,464 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ CTM Nervous System Server - Continuous Thought Machine for Hypergraph Maintenance
3
+ ===================================================================================
4
+ Implementation of the definitive proposal: CTM as Nervous System for ART-17 Hypergraph.
5
+
6
+ Endpoints:
7
+ - /sense_snn: Process 72D SNN input with NLM-style processing
8
+ - /reason_hypergraph: Reason about hypergraph context, propose edges
9
+ - /validate_physics: Validate proposals against 5 physics losses
10
+ - /dream: Offline consolidation with T=500+ ticks
11
+ - /calibrate_stdp: Suggest STDP weight adjustments from sync matrix
12
+ - /health: Health check endpoint
13
+
14
+ Based on: arXiv:2505.05522 (Continuous Thought Machines - Sakana AI)
15
+ """
16
+
17
+ import gradio as gr
18
+ import numpy as np
19
+ import json
20
+ from typing import List, Dict, Any, Optional
21
+ import os
22
+
23
+ # ============================================================================
24
+ # SIMPLIFIED CTM SIMULATION (CPU-only for Hugging Face free tier)
25
+ # ============================================================================
26
+
27
+ class SimplifiedCTM:
28
+ """
29
+ Simplified CTM for CPU-only environment.
30
+ Simulates the key mechanisms without full PyTorch model.
31
+ """
32
+
33
+ def __init__(self, d_model: int = 256, memory_length: int = 16, n_ticks: int = 50):
34
+ self.d_model = d_model
35
+ self.memory_length = memory_length
36
+ self.n_ticks = n_ticks
37
+
38
+ # Initialize state
39
+ self.state_trace = np.zeros((d_model, memory_length))
40
+ self.activated_state = np.random.randn(d_model) * 0.1
41
+
42
+ # NLM weights (simplified: one weight matrix per "neuron group")
43
+ self.nlm_weights = np.random.randn(16, memory_length) * 0.1 # 16 groups for 16 dendrites
44
+
45
+ def compute_sync_matrix(self, z: np.ndarray) -> np.ndarray:
46
+ """S^t = Z · Z^T (normalized)"""
47
+ z_norm = z / (np.linalg.norm(z) + 1e-8)
48
+ S = np.outer(z_norm, z_norm)
49
+ return S
50
+
51
+ def compute_certainty(self, predictions: np.ndarray) -> float:
52
+ """Certainty = 1 - normalized entropy"""
53
+ probs = np.abs(predictions) / (np.sum(np.abs(predictions)) + 1e-8)
54
+ probs = np.clip(probs, 1e-10, 1.0)
55
+ entropy = -np.sum(probs * np.log(probs))
56
+ max_entropy = np.log(len(probs))
57
+ normalized_entropy = entropy / (max_entropy + 1e-8)
58
+ return float(1.0 - normalized_entropy)
59
+
60
+ def process_ticks(self, input_features: np.ndarray, n_ticks: Optional[int] = None) -> Dict:
61
+ """Run T internal ticks and return sync matrix + certainty"""
62
+ n_ticks = n_ticks or self.n_ticks
63
+
64
+ # Ensure input is right size
65
+ if len(input_features) < self.d_model:
66
+ input_features = np.pad(input_features, (0, self.d_model - len(input_features)))
67
+ else:
68
+ input_features = input_features[:self.d_model]
69
+
70
+ certainties = []
71
+ sync_matrices = []
72
+
73
+ for t in range(n_ticks):
74
+ # Simulate synapse update
75
+ combined = np.concatenate([self.activated_state, input_features[:self.d_model//2]])
76
+ pre_activation = np.tanh(combined[:self.d_model] * 0.1 + np.random.randn(self.d_model) * 0.01)
77
+
78
+ # Update trace
79
+ self.state_trace = np.roll(self.state_trace, -1, axis=1)
80
+ self.state_trace[:, -1] = pre_activation
81
+
82
+ # Simulate NLM (simplified)
83
+ post_activation = np.zeros(self.d_model)
84
+ group_size = self.d_model // 16
85
+ for g in range(16):
86
+ start = g * group_size
87
+ end = start + group_size
88
+ group_trace = self.state_trace[start:end, :]
89
+ group_output = np.mean(group_trace @ self.nlm_weights[g])
90
+ post_activation[start:end] = np.tanh(group_output)
91
+
92
+ self.activated_state = post_activation
93
+
94
+ # Compute sync and certainty
95
+ sync = self.compute_sync_matrix(self.activated_state)
96
+ cert = self.compute_certainty(self.activated_state)
97
+
98
+ sync_matrices.append(sync)
99
+ certainties.append(cert)
100
+
101
+ # Find best ticks (min-loss proxy: max certainty)
102
+ best_tick = int(np.argmax(certainties))
103
+
104
+ return {
105
+ "final_sync_matrix": sync_matrices[-1].tolist(),
106
+ "best_sync_matrix": sync_matrices[best_tick].tolist(),
107
+ "certainties": certainties,
108
+ "final_certainty": float(certainties[-1]),
109
+ "max_certainty": float(max(certainties)),
110
+ "best_tick": best_tick,
111
+ "ticks_used": n_ticks
112
+ }
113
+
114
+ # Global CTM instance
115
+ ctm = SimplifiedCTM(d_model=256, memory_length=16, n_ticks=50)
116
+
117
+ # ============================================================================
118
+ # PHYSICS VALIDATION (from SNN Omega-21)
119
+ # ============================================================================
120
+
121
+ def validate_physics(trajectory: List[float], params: Dict) -> Dict:
122
+ """Validate against 5 physics losses from SNN Omega-21"""
123
+ trajectory = np.array(trajectory)
124
+
125
+ # L_energy: Energy conservation
126
+ energy = np.sum(trajectory ** 2)
127
+ P_max = params.get("P_max", 1000.0)
128
+ L_energy = float(max(0, energy - P_max) ** 2)
129
+
130
+ # L_thermo: Thermodynamics (dew point check)
131
+ T_dew = params.get("T_dew", 15.0)
132
+ T_amb = params.get("T_amb", 25.0)
133
+ L_thermo = float(max(0, T_dew - T_amb) ** 2)
134
+
135
+ # L_causal: Causality (velocity limit)
136
+ velocity = np.diff(trajectory) if len(trajectory) > 1 else np.array([0])
137
+ v_max = params.get("v_max", 100.0)
138
+ L_causal = float(np.sum(np.maximum(0, np.abs(velocity) - v_max) ** 2))
139
+
140
+ # L_conserv: Flux conservation
141
+ flux_in = params.get("flux_in", 1.0)
142
+ flux_out = params.get("flux_out", 1.0)
143
+ L_conserv = float((flux_in - flux_out) ** 2)
144
+
145
+ # L_entropy: 2nd Law (entropy must increase)
146
+ entropy_change = params.get("entropy_change", 0.1)
147
+ L_entropy = float(max(0, -entropy_change) ** 2)
148
+
149
+ # Total physics loss
150
+ L_total = L_energy + L_thermo + L_causal + L_conserv + L_entropy
151
+
152
+ return {
153
+ "valid": L_total < 0.01,
154
+ "L_energy": L_energy,
155
+ "L_thermo": L_thermo,
156
+ "L_causal": L_causal,
157
+ "L_conserv": L_conserv,
158
+ "L_entropy": L_entropy,
159
+ "L_total": L_total
160
+ }
161
+
162
+ # ============================================================================
163
+ # ENDPOINT FUNCTIONS
164
+ # ============================================================================
165
+
166
+ def sense_snn(snn_json: str) -> str:
167
+ """
168
+ /sense_snn - Process 72D SNN input
169
+ Input: JSON with dendrite values
170
+ Output: Coherent features + anomalies
171
+ """
172
+ try:
173
+ data = json.loads(snn_json)
174
+
175
+ # Extract 72D vector (or create from dendrites)
176
+ if "vector_72d" in data:
177
+ input_vec = np.array(data["vector_72d"])
178
+ elif "dendrites" in data:
179
+ dendrite_values = list(data["dendrites"].values())
180
+ input_vec = np.array(dendrite_values)
181
+ else:
182
+ input_vec = np.random.randn(72) # Fallback
183
+
184
+ # Pad to 256D
185
+ input_256 = np.zeros(256)
186
+ input_256[:min(len(input_vec), 256)] = input_vec[:min(len(input_vec), 256)]
187
+
188
+ # Process through CTM
189
+ result = ctm.process_ticks(input_256, n_ticks=25)
190
+
191
+ # Detect anomalies (low certainty regions)
192
+ anomalies = []
193
+ if result["final_certainty"] < 0.5:
194
+ anomalies.append("Low overall certainty")
195
+
196
+ return json.dumps({
197
+ "status": "success",
198
+ "coherent_features": result["final_sync_matrix"][:16][:16], # 16x16 subset
199
+ "certainty": result["final_certainty"],
200
+ "anomalies": anomalies,
201
+ "ticks_used": result["ticks_used"]
202
+ }, indent=2)
203
+ except Exception as e:
204
+ return json.dumps({"status": "error", "message": str(e)})
205
+
206
+ def reason_hypergraph(context_json: str) -> str:
207
+ """
208
+ /reason_hypergraph - Reason about hypergraph, propose edges
209
+ Input: Node features + existing edges
210
+ Output: Proposed new edges + certainty
211
+ """
212
+ try:
213
+ data = json.loads(context_json)
214
+
215
+ node_features = np.array(data.get("node_features", [[0]*16]*8))
216
+ existing_edges = data.get("existing_edges", [])
217
+ n_ticks = data.get("ticks", 50)
218
+
219
+ # Flatten node features for CTM input
220
+ input_vec = node_features.flatten()
221
+
222
+ # Process through CTM with more ticks for reasoning
223
+ result = ctm.process_ticks(input_vec, n_ticks=n_ticks)
224
+
225
+ # Extract proposed edges from sync matrix (S_ij > 0.8)
226
+ sync = np.array(result["best_sync_matrix"])
227
+ proposed_edges = []
228
+
229
+ n_nodes = min(len(node_features), sync.shape[0])
230
+ for i in range(n_nodes):
231
+ for j in range(i+1, n_nodes):
232
+ sync_ij = sync[i, j]
233
+ if sync_ij > 0.8:
234
+ # Check if edge already exists
235
+ edge_exists = any(
236
+ (e[0] == i and e[1] == j) or (e[0] == j and e[1] == i)
237
+ for e in existing_edges
238
+ )
239
+ if not edge_exists:
240
+ proposed_edges.append([i, j, float(sync_ij)])
241
+
242
+ return json.dumps({
243
+ "status": "success",
244
+ "proposed_edges": proposed_edges,
245
+ "certainty": result["max_certainty"],
246
+ "best_tick": result["best_tick"],
247
+ "ticks_used": result["ticks_used"]
248
+ }, indent=2)
249
+ except Exception as e:
250
+ return json.dumps({"status": "error", "message": str(e)})
251
+
252
+ def validate_physics_endpoint(physics_json: str) -> str:
253
+ """
254
+ /validate_physics - Validate trajectory against 5 physics losses
255
+ """
256
+ try:
257
+ data = json.loads(physics_json)
258
+ trajectory = data.get("trajectory", [0.0])
259
+ params = data.get("physics_params", {})
260
+
261
+ result = validate_physics(trajectory, params)
262
+ result["status"] = "success"
263
+
264
+ return json.dumps(result, indent=2)
265
+ except Exception as e:
266
+ return json.dumps({"status": "error", "message": str(e)})
267
+
268
+ def dream_endpoint(dream_json: str) -> str:
269
+ """
270
+ /dream - Offline consolidation with T=500+ ticks
271
+ Input: Hypergraph snapshot
272
+ Output: Discovered patterns + new edges
273
+ """
274
+ try:
275
+ data = json.loads(dream_json)
276
+
277
+ snapshot = data.get("hypergraph_snapshot", {})
278
+ n_ticks = data.get("ticks", 500)
279
+
280
+ # Extract features from snapshot
281
+ nodes = snapshot.get("nodes", [])
282
+ if nodes:
283
+ input_vec = np.array([n.get("features", [0]*16) for n in nodes]).flatten()
284
+ else:
285
+ input_vec = np.random.randn(256) # Random dream if no nodes
286
+
287
+ # Dream: run CTM with many ticks and no external input after initial
288
+ result = ctm.process_ticks(input_vec, n_ticks=min(n_ticks, 100)) # Cap at 100 for CPU
289
+
290
+ # Analyze sync evolution to find patterns
291
+ sync = np.array(result["final_sync_matrix"])
292
+
293
+ # Find strong sync pairs (new edges)
294
+ new_edges = []
295
+ n = min(len(nodes), sync.shape[0]) if nodes else 16
296
+ for i in range(n):
297
+ for j in range(i+1, n):
298
+ if sync[i, j] > 0.85:
299
+ new_edges.append([i, j, float(sync[i, j])])
300
+
301
+ # Find weak sync pairs (edges to prune)
302
+ pruned_edges = []
303
+ for i in range(n):
304
+ for j in range(i+1, n):
305
+ if sync[i, j] < 0.1:
306
+ pruned_edges.append([i, j])
307
+
308
+ return json.dumps({
309
+ "status": "success",
310
+ "discovered_patterns": len(new_edges),
311
+ "new_edges": new_edges[:10], # Top 10
312
+ "pruned_edges": pruned_edges[:10], # Top 10
313
+ "consolidation_certainty": result["max_certainty"],
314
+ "ticks_used": result["ticks_used"]
315
+ }, indent=2)
316
+ except Exception as e:
317
+ return json.dumps({"status": "error", "message": str(e)})
318
+
319
+ def calibrate_stdp_endpoint(stdp_json: str) -> str:
320
+ """
321
+ /calibrate_stdp - Suggest STDP weight adjustments from sync
322
+ """
323
+ try:
324
+ data = json.loads(stdp_json)
325
+
326
+ current_weights = np.array(data.get("current_weights", [1.0]*16))
327
+ node_features = np.array(data.get("node_features", [[0]*16]*8))
328
+
329
+ # Process to get sync matrix
330
+ input_vec = node_features.flatten()
331
+ result = ctm.process_ticks(input_vec, n_ticks=25)
332
+
333
+ sync = np.array(result["final_sync_matrix"])
334
+
335
+ # Suggest weight adjustments based on sync patterns
336
+ # Uses diagonal of sync (self-similarity) to scale weights
337
+ suggested = np.zeros(16)
338
+ for i in range(16):
339
+ # Average sync of neuron i with others
340
+ avg_sync = np.mean(sync[i, :])
341
+ # Scale current weight by sync
342
+ suggested[i] = current_weights[i] * (0.5 + avg_sync)
343
+
344
+ return json.dumps({
345
+ "status": "success",
346
+ "suggested_weights": suggested.tolist(),
347
+ "weight_changes": (suggested - current_weights).tolist(),
348
+ "confidence": result["final_certainty"]
349
+ }, indent=2)
350
+ except Exception as e:
351
+ return json.dumps({"status": "error", "message": str(e)})
352
+
353
+ def health_check() -> str:
354
+ """Health check for the CTM server"""
355
+ return json.dumps({
356
+ "status": "healthy",
357
+ "model": "CTM Nervous System v1.0",
358
+ "d_model": ctm.d_model,
359
+ "memory_length": ctm.memory_length,
360
+ "default_ticks": ctm.n_ticks,
361
+ "endpoints": [
362
+ "/sense_snn",
363
+ "/reason_hypergraph",
364
+ "/validate_physics",
365
+ "/dream",
366
+ "/calibrate_stdp"
367
+ ]
368
+ }, indent=2)
369
+
370
+ # ============================================================================
371
+ # GRADIO INTERFACE
372
+ # ============================================================================
373
+
374
+ with gr.Blocks(title="CTM Nervous System") as demo:
375
+ gr.Markdown("""
376
+ # 🧬 CTM Nervous System
377
+ **Continuous Thought Machine for Hypergraph Maintenance**
378
+
379
+ Based on [arXiv:2505.05522](https://arxiv.org/abs/2505.05522) - Sakana AI
380
+
381
+ ---
382
+
383
+ ## Endpoints
384
+ - **/sense_snn**: Process 72D SNN input
385
+ - **/reason_hypergraph**: Reason about context, propose edges
386
+ - **/validate_physics**: Validate against 5 physics losses
387
+ - **/dream**: Offline consolidation (T=500+)
388
+ - **/calibrate_stdp**: Suggest STDP weight adjustments
389
+ """)
390
+
391
+ with gr.Tabs():
392
+ with gr.Tab("🔌 /sense_snn"):
393
+ gr.Markdown("Process 72D SNN input vector")
394
+ snn_input = gr.Textbox(
395
+ label="SNN JSON Input",
396
+ value='{"dendrites": {"d1": 0.1, "d2": 0.2, "d3": 0.3}}',
397
+ lines=5
398
+ )
399
+ snn_output = gr.Textbox(label="Output", lines=10)
400
+ snn_btn = gr.Button("Process", variant="primary")
401
+ snn_btn.click(sense_snn, inputs=snn_input, outputs=snn_output, api_name="sense_snn")
402
+
403
+ with gr.Tab("🧠 /reason_hypergraph"):
404
+ gr.Markdown("Reason about hypergraph context")
405
+ reason_input = gr.Textbox(
406
+ label="Context JSON",
407
+ value='{"node_features": [[0.1, 0.2], [0.3, 0.4]], "existing_edges": [], "ticks": 50}',
408
+ lines=5
409
+ )
410
+ reason_output = gr.Textbox(label="Output", lines=10)
411
+ reason_btn = gr.Button("Reason", variant="primary")
412
+ reason_btn.click(reason_hypergraph, inputs=reason_input, outputs=reason_output, api_name="reason_hypergraph")
413
+
414
+ with gr.Tab("⚡ /validate_physics"):
415
+ gr.Markdown("Validate against 5 physics losses")
416
+ physics_input = gr.Textbox(
417
+ label="Physics JSON",
418
+ value='{"trajectory": [0.1, 0.2, 0.3], "physics_params": {"P_max": 1000}}',
419
+ lines=5
420
+ )
421
+ physics_output = gr.Textbox(label="Output", lines=10)
422
+ physics_btn = gr.Button("Validate", variant="primary")
423
+ physics_btn.click(validate_physics_endpoint, inputs=physics_input, outputs=physics_output, api_name="validate_physics")
424
+
425
+ with gr.Tab("💤 /dream"):
426
+ gr.Markdown("Offline consolidation")
427
+ dream_input = gr.Textbox(
428
+ label="Dream JSON",
429
+ value='{"hypergraph_snapshot": {"nodes": []}, "ticks": 100}',
430
+ lines=5
431
+ )
432
+ dream_output = gr.Textbox(label="Output", lines=10)
433
+ dream_btn = gr.Button("Dream", variant="primary")
434
+ dream_btn.click(dream_endpoint, inputs=dream_input, outputs=dream_output, api_name="dream")
435
+
436
+ with gr.Tab("🔧 /calibrate_stdp"):
437
+ gr.Markdown("Calibrate STDP weights")
438
+ stdp_input = gr.Textbox(
439
+ label="STDP JSON",
440
+ value='{"current_weights": [1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1], "node_features": [[0.1, 0.2]]}',
441
+ lines=5
442
+ )
443
+ stdp_output = gr.Textbox(label="Output", lines=10)
444
+ stdp_btn = gr.Button("Calibrate", variant="primary")
445
+ stdp_btn.click(calibrate_stdp_endpoint, inputs=stdp_input, outputs=stdp_output, api_name="calibrate_stdp")
446
+
447
+ with gr.Tab("❤️ Health"):
448
+ health_output = gr.Textbox(label="Health Status", lines=10)
449
+ health_btn = gr.Button("Check Health", variant="secondary")
450
+ health_btn.click(health_check, inputs=None, outputs=health_output, api_name="health_check")
451
+
452
+ gr.Markdown("""
453
+ ---
454
+ **Architecture**: CTM as Nervous System → Hypergraph as Thought
455
+
456
+ **Training**: Min-Loss + Max-Certainty + Physics Regularization
457
+ """)
458
+
459
+ if __name__ == "__main__":
460
+ demo.launch(
461
+ server_name="0.0.0.0",
462
+ server_port=7860,
463
+ show_error=True
464
+ )
requirements.txt CHANGED
@@ -1,2 +1,21 @@
 
 
 
 
 
1
  gradio>=5.0.0
2
- numpy
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Requirements for CTM Nervous System v2.0 (Full PyTorch)
2
+ # =========================================================
3
+ # Upgraded to use full ContinuousThoughtMachine implementation
4
+
5
+ # Core Framework
6
  gradio>=5.0.0
7
+ numpy
8
+
9
+ # PyTorch (CPU-only for HuggingFace free tier)
10
+ # Use torch-cpu to minimize memory footprint
11
+ --extra-index-url https://download.pytorch.org/whl/cpu
12
+ torch>=2.0.0
13
+
14
+ # For checkpoint saving
15
+ safetensors
16
+
17
+ # For potential model weights download
18
+ huggingface_hub
19
+
20
+ # HTTP client for Brain server integration
21
+ requests
requirements_v1.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ gradio>=5.0.0
2
+ numpy
utils/dendrite_losses.py ADDED
@@ -0,0 +1,228 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Dendrite Regulation Loss Function for CTM
3
+ ==========================================
4
+ Adapted from Sakana AI CTM losses for ART-17 dendrite regulation task.
5
+
6
+ This loss combines:
7
+ 1. MSE between predicted and target dendrite adjustments
8
+ 2. Physics loss penalty (from SNN Omega-21)
9
+ 3. Certainty-weighted tick selection (CTM innovation)
10
+ """
11
+
12
+ import torch
13
+ import torch.nn as nn
14
+ import torch.nn.functional as F
15
+
16
+
17
+ def dendrite_regulation_loss(
18
+ predictions: torch.Tensor, # (B, 16, T) - predictions per tick
19
+ certainties: torch.Tensor, # (B, 2, T) - [entropy, 1-entropy]
20
+ targets: torch.Tensor, # (B, 16) - target deltas
21
+ physics_loss: torch.Tensor, # (B,) - from SNN physics validation
22
+ use_most_certain: bool = True,
23
+ physics_weight: float = 0.1
24
+ ) -> tuple:
25
+ """
26
+ CTM Loss adapted for ART-17 dendrite regulation.
27
+
28
+ Combines:
29
+ 1. MSE between prediction and target at min-loss tick
30
+ 2. MSE between prediction and target at max-certainty tick
31
+ 3. Physics loss penalty
32
+
33
+ This dual-tick selection is a key CTM innovation that allows
34
+ the model to "think" until confident.
35
+
36
+ Args:
37
+ predictions: Model predictions of shape (B, 16, T)
38
+ - B: batch size
39
+ - 16: dendrite dimensions
40
+ - T: number of internal ticks
41
+ certainties: Certainty values of shape (B, 2, T)
42
+ - Second dim: [normalized_entropy, 1-normalized_entropy]
43
+ targets: Ground truth deltas of shape (B, 16)
44
+ physics_loss: Physics validation loss of shape (B,)
45
+ use_most_certain: If True, also consider max-certainty tick
46
+ physics_weight: Weight for physics loss penalty
47
+
48
+ Returns:
49
+ total_loss: Combined loss value
50
+ loss_index: Indices of selected ticks (B,)
51
+ """
52
+ B, D, T = predictions.shape # B, 16, 50
53
+
54
+ # Expand targets to all ticks: (B, 16) -> (B, 16, T)
55
+ targets_exp = targets.unsqueeze(-1).expand(-1, -1, T)
56
+
57
+ # MSE loss per tick: (B, 16, T) -> (B, T) after mean over dendrites
58
+ mse_per_tick = F.mse_loss(predictions, targets_exp, reduction='none')
59
+ mse_per_tick = mse_per_tick.mean(dim=1) # Average over 16 dendrites
60
+
61
+ # Tick with minimum loss
62
+ loss_index_1 = mse_per_tick.argmin(dim=1) # (B,)
63
+
64
+ # Tick with maximum certainty (1 - normalized_entropy)
65
+ loss_index_2 = certainties[:, 1, :].argmax(dim=1) # (B,)
66
+
67
+ if not use_most_certain:
68
+ # Fall back to final tick
69
+ loss_index_2 = torch.full_like(loss_index_2, T - 1)
70
+
71
+ # Select losses at chosen ticks
72
+ batch_idx = torch.arange(B, device=predictions.device)
73
+ loss_min = mse_per_tick[batch_idx, loss_index_1].mean()
74
+ loss_certain = mse_per_tick[batch_idx, loss_index_2].mean()
75
+
76
+ # Combined MSE loss (average of min-loss and max-certainty)
77
+ mse_loss = (loss_min + loss_certain) / 2
78
+
79
+ # Physics penalty
80
+ physics_penalty = physics_loss.mean() * physics_weight
81
+
82
+ # Total loss
83
+ total_loss = mse_loss + physics_penalty
84
+
85
+ return total_loss, loss_index_2
86
+
87
+
88
+ def hypergraph_edge_loss(
89
+ sync_matrix: torch.Tensor, # (B, N, N) - synchronization matrix
90
+ edge_labels: torch.Tensor, # (B, N, N) - binary edge labels
91
+ certainties: torch.Tensor, # (B, 2, T)
92
+ use_most_certain: bool = True
93
+ ) -> tuple:
94
+ """
95
+ Loss for predicting hypergraph edges from synchronization matrix.
96
+
97
+ The CTM's sync matrix S = Z·Z^T captures neural synchronization,
98
+ which we use to predict edges in the hypergraph.
99
+
100
+ Args:
101
+ sync_matrix: Predicted sync values (B, N, N)
102
+ edge_labels: Ground truth edges (B, N, N), binary
103
+ certainties: Certainty values (B, 2, T)
104
+ use_most_certain: Whether to weight by certainty
105
+
106
+ Returns:
107
+ loss: BCE loss for edge prediction
108
+ certainty: Average certainty at best tick
109
+ """
110
+ # Binary cross entropy for edge prediction
111
+ bce_loss = F.binary_cross_entropy_with_logits(
112
+ sync_matrix,
113
+ edge_labels.float(),
114
+ reduction='mean'
115
+ )
116
+
117
+ # Weight by certainty if enabled
118
+ if use_most_certain and certainties is not None:
119
+ max_certainty = certainties[:, 1, :].max(dim=-1)[0].mean()
120
+ # Higher certainty -> lower loss weight (model is confident)
121
+ certainty_weight = 2.0 - max_certainty
122
+ loss = bce_loss * certainty_weight
123
+ else:
124
+ loss = bce_loss
125
+ max_certainty = torch.tensor(0.5)
126
+
127
+ return loss, max_certainty
128
+
129
+
130
+ def combined_art17_loss(
131
+ predictions: torch.Tensor,
132
+ certainties: torch.Tensor,
133
+ sync_matrix: torch.Tensor,
134
+ dendrite_targets: torch.Tensor,
135
+ edge_labels: torch.Tensor,
136
+ physics_loss: torch.Tensor,
137
+ dendrite_weight: float = 1.0,
138
+ edge_weight: float = 0.5,
139
+ physics_weight: float = 0.1
140
+ ) -> dict:
141
+ """
142
+ Combined loss for ART-17 CTM training.
143
+
144
+ Combines:
145
+ 1. Dendrite regulation loss (primary task)
146
+ 2. Hypergraph edge prediction loss (auxiliary)
147
+ 3. Physics constraint penalty (regularization)
148
+
149
+ Args:
150
+ predictions: Dendrite predictions (B, 16, T)
151
+ certainties: Certainty values (B, 2, T)
152
+ sync_matrix: Sync matrix for edge prediction (B, N, N)
153
+ dendrite_targets: Target deltas (B, 16)
154
+ edge_labels: Edge ground truth (B, N, N)
155
+ physics_loss: Physics validation (B,)
156
+ dendrite_weight: Weight for dendrite loss
157
+ edge_weight: Weight for edge loss
158
+ physics_weight: Weight for physics penalty
159
+
160
+ Returns:
161
+ Dict with total_loss and component losses
162
+ """
163
+ # Dendrite regulation loss
164
+ dend_loss, best_tick = dendrite_regulation_loss(
165
+ predictions, certainties, dendrite_targets,
166
+ physics_loss, physics_weight=0.0 # Physics handled separately
167
+ )
168
+
169
+ # Edge prediction loss
170
+ edge_loss, certainty = hypergraph_edge_loss(
171
+ sync_matrix, edge_labels, certainties
172
+ )
173
+
174
+ # Physics penalty
175
+ phys_penalty = physics_loss.mean() * physics_weight
176
+
177
+ # Total weighted loss
178
+ total_loss = (
179
+ dendrite_weight * dend_loss +
180
+ edge_weight * edge_loss +
181
+ phys_penalty
182
+ )
183
+
184
+ return {
185
+ "total_loss": total_loss,
186
+ "dendrite_loss": dend_loss,
187
+ "edge_loss": edge_loss,
188
+ "physics_penalty": phys_penalty,
189
+ "best_tick": best_tick,
190
+ "certainty": certainty
191
+ }
192
+
193
+
194
+ # Keep original loss functions for compatibility
195
+ def compute_ctc_loss(predictions, targets, blank_label=0):
196
+ """CTC loss for sequence tasks (original from Sakana)."""
197
+ batch_size, num_classes, prediction_length = predictions.shape
198
+ _, target_length = targets.shape
199
+
200
+ log_probs = F.log_softmax(predictions, dim=1)
201
+ log_probs = log_probs.permute(2, 0, 1)
202
+
203
+ input_lengths = torch.full(size=(batch_size,), fill_value=prediction_length, dtype=torch.long)
204
+ target_lengths = torch.tensor([t.shape[0] for t in targets], dtype=torch.long)
205
+
206
+ ctc_loss = torch.nn.CTCLoss(blank=blank_label, reduction='mean')
207
+ concatenated_targets = torch.cat(list(targets))
208
+ loss = ctc_loss(log_probs, concatenated_targets, input_lengths, target_lengths)
209
+
210
+ return loss
211
+
212
+
213
+ def image_classification_loss(predictions, certainties, targets, use_most_certain=True):
214
+ """Image classification loss (original from Sakana)."""
215
+ targets_expanded = torch.repeat_interleave(targets.unsqueeze(-1), predictions.size(-1), -1)
216
+ losses = nn.CrossEntropyLoss(reduction='none')(predictions, targets_expanded)
217
+
218
+ loss_index_1 = losses.argmin(dim=1)
219
+ loss_index_2 = certainties[:,1].argmax(-1)
220
+ if not use_most_certain:
221
+ loss_index_2[:] = -1
222
+
223
+ batch_indexer = torch.arange(predictions.size(0), device=predictions.device)
224
+ loss_minimum_ce = losses[batch_indexer, loss_index_1].mean()
225
+ loss_selected = losses[batch_indexer, loss_index_2].mean()
226
+
227
+ loss = (loss_minimum_ce + loss_selected)/2
228
+ return loss, loss_index_2