Spaces:
Paused
Paused
Commit ·
c557815
1
Parent(s): c9edbe0
Upgrade to CTM v2.0 - Full PyTorch implementation with online training
Browse files- app.py +586 -186
- app_v1_backup.py +464 -0
- requirements.txt +20 -1
- requirements_v1.txt +2 -0
- utils/dendrite_losses.py +228 -0
app.py
CHANGED
|
@@ -1,85 +1,250 @@
|
|
| 1 |
"""
|
| 2 |
-
CTM Nervous System Server -
|
| 3 |
-
=============================================================
|
| 4 |
-
|
| 5 |
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
|
|
|
|
|
|
|
|
|
| 13 |
|
| 14 |
Based on: arXiv:2505.05522 (Continuous Thought Machines - Sakana AI)
|
|
|
|
| 15 |
"""
|
| 16 |
|
| 17 |
import gradio as gr
|
| 18 |
import numpy as np
|
| 19 |
import json
|
| 20 |
-
from typing import List, Dict, Any, Optional
|
| 21 |
import os
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 22 |
|
| 23 |
# ============================================================================
|
| 24 |
-
#
|
| 25 |
# ============================================================================
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 26 |
|
| 27 |
-
|
|
|
|
|
|
|
|
|
|
| 28 |
"""
|
| 29 |
-
|
| 30 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 31 |
"""
|
| 32 |
|
| 33 |
-
def __init__(self,
|
| 34 |
-
self.
|
| 35 |
-
self.
|
| 36 |
-
self.
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
self.activated_state = np.random.randn(d_model) * 0.1
|
| 41 |
-
|
| 42 |
-
# NLM weights (simplified: one weight matrix per "neuron group")
|
| 43 |
-
self.nlm_weights = np.random.randn(16, memory_length) * 0.1 # 16 groups for 16 dendrites
|
| 44 |
-
|
| 45 |
-
def compute_sync_matrix(self, z: np.ndarray) -> np.ndarray:
|
| 46 |
-
"""S^t = Z · Z^T (normalized)"""
|
| 47 |
-
z_norm = z / (np.linalg.norm(z) + 1e-8)
|
| 48 |
-
S = np.outer(z_norm, z_norm)
|
| 49 |
-
return S
|
| 50 |
-
|
| 51 |
-
def compute_certainty(self, predictions: np.ndarray) -> float:
|
| 52 |
-
"""Certainty = 1 - normalized entropy"""
|
| 53 |
-
probs = np.abs(predictions) / (np.sum(np.abs(predictions)) + 1e-8)
|
| 54 |
-
probs = np.clip(probs, 1e-10, 1.0)
|
| 55 |
-
entropy = -np.sum(probs * np.log(probs))
|
| 56 |
-
max_entropy = np.log(len(probs))
|
| 57 |
-
normalized_entropy = entropy / (max_entropy + 1e-8)
|
| 58 |
-
return float(1.0 - normalized_entropy)
|
| 59 |
-
|
| 60 |
-
def process_ticks(self, input_features: np.ndarray, n_ticks: Optional[int] = None) -> Dict:
|
| 61 |
-
"""Run T internal ticks and return sync matrix + certainty"""
|
| 62 |
-
n_ticks = n_ticks or self.n_ticks
|
| 63 |
-
|
| 64 |
-
# Ensure input is right size
|
| 65 |
-
if len(input_features) < self.d_model:
|
| 66 |
-
input_features = np.pad(input_features, (0, self.d_model - len(input_features)))
|
| 67 |
else:
|
| 68 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 69 |
|
| 70 |
certainties = []
|
| 71 |
-
sync_matrices = []
|
| 72 |
|
| 73 |
for t in range(n_ticks):
|
| 74 |
-
#
|
| 75 |
-
combined = np.concatenate([self.activated_state,
|
| 76 |
pre_activation = np.tanh(combined[:self.d_model] * 0.1 + np.random.randn(self.d_model) * 0.01)
|
| 77 |
|
| 78 |
-
# Update trace
|
| 79 |
self.state_trace = np.roll(self.state_trace, -1, axis=1)
|
| 80 |
self.state_trace[:, -1] = pre_activation
|
| 81 |
|
| 82 |
-
#
|
| 83 |
post_activation = np.zeros(self.d_model)
|
| 84 |
group_size = self.d_model // 16
|
| 85 |
for g in range(16):
|
|
@@ -91,50 +256,151 @@ class SimplifiedCTM:
|
|
| 91 |
|
| 92 |
self.activated_state = post_activation
|
| 93 |
|
| 94 |
-
# Compute
|
| 95 |
-
|
| 96 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 97 |
|
| 98 |
-
|
| 99 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 100 |
|
| 101 |
-
#
|
| 102 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 103 |
|
| 104 |
return {
|
| 105 |
-
"
|
| 106 |
-
"
|
| 107 |
-
"
|
| 108 |
-
"
|
| 109 |
-
"
|
| 110 |
-
"best_tick": best_tick,
|
| 111 |
-
"ticks_used": n_ticks
|
| 112 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 113 |
|
| 114 |
-
#
|
| 115 |
-
|
|
|
|
|
|
|
| 116 |
|
| 117 |
# ============================================================================
|
| 118 |
# PHYSICS VALIDATION (from SNN Omega-21)
|
| 119 |
# ============================================================================
|
| 120 |
-
|
| 121 |
def validate_physics(trajectory: List[float], params: Dict) -> Dict:
|
| 122 |
-
"""Validate against 5 physics losses from SNN Omega-21"""
|
| 123 |
trajectory = np.array(trajectory)
|
| 124 |
|
| 125 |
# L_energy: Energy conservation
|
| 126 |
energy = np.sum(trajectory ** 2)
|
| 127 |
-
P_max = params.get("P_max",
|
| 128 |
L_energy = float(max(0, energy - P_max) ** 2)
|
| 129 |
|
| 130 |
# L_thermo: Thermodynamics (dew point check)
|
| 131 |
-
T_dew = params.get("T_dew",
|
| 132 |
-
T_amb = params.get("T_amb",
|
| 133 |
L_thermo = float(max(0, T_dew - T_amb) ** 2)
|
| 134 |
|
| 135 |
# L_causal: Causality (velocity limit)
|
| 136 |
velocity = np.diff(trajectory) if len(trajectory) > 1 else np.array([0])
|
| 137 |
-
v_max = params.get("v_max",
|
| 138 |
L_causal = float(np.sum(np.maximum(0, np.abs(velocity) - v_max) ** 2))
|
| 139 |
|
| 140 |
# L_conserv: Flux conservation
|
|
@@ -165,49 +431,54 @@ def validate_physics(trajectory: List[float], params: Dict) -> Dict:
|
|
| 165 |
|
| 166 |
def sense_snn(snn_json: str) -> str:
|
| 167 |
"""
|
| 168 |
-
/sense_snn - Process 72D SNN input
|
| 169 |
-
|
| 170 |
-
|
|
|
|
| 171 |
"""
|
| 172 |
try:
|
| 173 |
data = json.loads(snn_json)
|
| 174 |
|
| 175 |
-
# Extract 72D vector
|
| 176 |
if "vector_72d" in data:
|
| 177 |
input_vec = np.array(data["vector_72d"])
|
| 178 |
elif "dendrites" in data:
|
| 179 |
-
|
| 180 |
-
input_vec = np.array(dendrite_values)
|
| 181 |
else:
|
| 182 |
-
input_vec = np.random.randn(72)
|
| 183 |
|
| 184 |
-
# Pad to
|
| 185 |
-
|
| 186 |
-
|
| 187 |
|
| 188 |
# Process through CTM
|
| 189 |
-
|
|
|
|
| 190 |
|
| 191 |
-
# Detect anomalies (low certainty
|
| 192 |
anomalies = []
|
| 193 |
-
if result["
|
| 194 |
-
anomalies.append("Low overall certainty")
|
| 195 |
|
| 196 |
return json.dumps({
|
| 197 |
"status": "success",
|
| 198 |
-
"coherent_features": result["
|
| 199 |
-
"certainty": result["
|
|
|
|
| 200 |
"anomalies": anomalies,
|
| 201 |
-
"ticks_used": result["ticks_used"]
|
|
|
|
| 202 |
}, indent=2)
|
| 203 |
except Exception as e:
|
| 204 |
return json.dumps({"status": "error", "message": str(e)})
|
| 205 |
|
|
|
|
| 206 |
def reason_hypergraph(context_json: str) -> str:
|
| 207 |
"""
|
| 208 |
-
/reason_hypergraph - Reason about hypergraph, propose edges
|
| 209 |
-
|
| 210 |
-
|
|
|
|
| 211 |
"""
|
| 212 |
try:
|
| 213 |
data = json.loads(context_json)
|
|
@@ -217,38 +488,40 @@ def reason_hypergraph(context_json: str) -> str:
|
|
| 217 |
n_ticks = data.get("ticks", 50)
|
| 218 |
|
| 219 |
# Flatten node features for CTM input
|
| 220 |
-
input_vec = node_features.flatten()
|
| 221 |
|
| 222 |
# Process through CTM with more ticks for reasoning
|
| 223 |
-
result = ctm.
|
| 224 |
|
| 225 |
-
# Extract proposed edges from sync matrix (S_ij > 0.
|
| 226 |
-
sync = np.array(result["best_sync_matrix"])
|
| 227 |
proposed_edges = []
|
| 228 |
-
|
| 229 |
-
|
| 230 |
-
|
| 231 |
-
|
| 232 |
-
|
| 233 |
-
|
| 234 |
-
|
| 235 |
-
|
| 236 |
-
|
| 237 |
-
|
| 238 |
-
|
| 239 |
-
|
| 240 |
-
|
|
|
|
| 241 |
|
| 242 |
return json.dumps({
|
| 243 |
"status": "success",
|
| 244 |
"proposed_edges": proposed_edges,
|
| 245 |
-
"certainty": result["
|
| 246 |
"best_tick": result["best_tick"],
|
| 247 |
-
"ticks_used": result["ticks_used"]
|
|
|
|
| 248 |
}, indent=2)
|
| 249 |
except Exception as e:
|
| 250 |
return json.dumps({"status": "error", "message": str(e)})
|
| 251 |
|
|
|
|
| 252 |
def validate_physics_endpoint(physics_json: str) -> str:
|
| 253 |
"""
|
| 254 |
/validate_physics - Validate trajectory against 5 physics losses
|
|
@@ -265,135 +538,238 @@ def validate_physics_endpoint(physics_json: str) -> str:
|
|
| 265 |
except Exception as e:
|
| 266 |
return json.dumps({"status": "error", "message": str(e)})
|
| 267 |
|
|
|
|
| 268 |
def dream_endpoint(dream_json: str) -> str:
|
| 269 |
"""
|
| 270 |
-
/dream - Offline consolidation with
|
| 271 |
-
|
| 272 |
-
|
| 273 |
"""
|
| 274 |
try:
|
| 275 |
data = json.loads(dream_json)
|
| 276 |
|
| 277 |
snapshot = data.get("hypergraph_snapshot", {})
|
| 278 |
-
n_ticks = data.get("ticks",
|
| 279 |
|
| 280 |
# Extract features from snapshot
|
| 281 |
nodes = snapshot.get("nodes", [])
|
| 282 |
if nodes:
|
| 283 |
-
input_vec = np.array([n.get("features", [0]*16) for n in nodes]).flatten()
|
| 284 |
else:
|
| 285 |
-
input_vec = np.random.randn(
|
| 286 |
-
|
| 287 |
-
# Dream: run CTM with many ticks and no external input after initial
|
| 288 |
-
result = ctm.process_ticks(input_vec, n_ticks=min(n_ticks, 100)) # Cap at 100 for CPU
|
| 289 |
|
| 290 |
-
#
|
| 291 |
-
|
| 292 |
|
| 293 |
-
#
|
| 294 |
new_edges = []
|
| 295 |
-
n = min(len(nodes), sync.shape[0]) if nodes else 16
|
| 296 |
-
for i in range(n):
|
| 297 |
-
for j in range(i+1, n):
|
| 298 |
-
if sync[i, j] > 0.85:
|
| 299 |
-
new_edges.append([i, j, float(sync[i, j])])
|
| 300 |
-
|
| 301 |
-
# Find weak sync pairs (edges to prune)
|
| 302 |
pruned_edges = []
|
| 303 |
-
|
| 304 |
-
|
| 305 |
-
|
| 306 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 307 |
|
| 308 |
return json.dumps({
|
| 309 |
"status": "success",
|
| 310 |
"discovered_patterns": len(new_edges),
|
| 311 |
-
"new_edges": new_edges[:10],
|
| 312 |
-
"pruned_edges": pruned_edges[:10],
|
| 313 |
-
"consolidation_certainty": result["
|
| 314 |
-
"ticks_used": result["ticks_used"]
|
|
|
|
| 315 |
}, indent=2)
|
| 316 |
except Exception as e:
|
| 317 |
return json.dumps({"status": "error", "message": str(e)})
|
| 318 |
|
|
|
|
| 319 |
def calibrate_stdp_endpoint(stdp_json: str) -> str:
|
| 320 |
"""
|
| 321 |
-
/calibrate_stdp - Suggest STDP weight adjustments
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 322 |
"""
|
| 323 |
try:
|
| 324 |
data = json.loads(stdp_json)
|
| 325 |
|
| 326 |
current_weights = np.array(data.get("current_weights", [1.0]*16))
|
| 327 |
-
node_features = np.array(data.get("node_features", [[0]*16]*
|
| 328 |
|
| 329 |
-
#
|
| 330 |
-
input_vec = node_features.flatten()
|
| 331 |
-
result = ctm.process_ticks(input_vec, n_ticks=25)
|
| 332 |
|
| 333 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 334 |
|
| 335 |
-
#
|
| 336 |
-
|
| 337 |
-
|
| 338 |
-
|
| 339 |
-
|
| 340 |
-
avg_sync = np.mean(sync[i, :])
|
| 341 |
-
# Scale current weight by sync
|
| 342 |
-
suggested[i] = current_weights[i] * (0.5 + avg_sync)
|
| 343 |
|
| 344 |
return json.dumps({
|
| 345 |
"status": "success",
|
| 346 |
-
"suggested_weights":
|
| 347 |
-
"weight_changes":
|
| 348 |
-
"confidence":
|
|
|
|
|
|
|
| 349 |
}, indent=2)
|
| 350 |
except Exception as e:
|
| 351 |
return json.dumps({"status": "error", "message": str(e)})
|
| 352 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 353 |
def health_check() -> str:
|
| 354 |
-
"""Health check
|
| 355 |
return json.dumps({
|
| 356 |
"status": "healthy",
|
| 357 |
-
"model": "CTM Nervous System
|
| 358 |
-
"
|
| 359 |
-
"
|
| 360 |
-
"
|
|
|
|
|
|
|
| 361 |
"endpoints": [
|
| 362 |
"/sense_snn",
|
| 363 |
-
"/reason_hypergraph",
|
| 364 |
"/validate_physics",
|
| 365 |
"/dream",
|
| 366 |
-
"/calibrate_stdp"
|
|
|
|
|
|
|
| 367 |
]
|
| 368 |
}, indent=2)
|
| 369 |
|
|
|
|
| 370 |
# ============================================================================
|
| 371 |
# GRADIO INTERFACE
|
| 372 |
# ============================================================================
|
| 373 |
|
| 374 |
-
with gr.Blocks(title="CTM Nervous System") as demo:
|
| 375 |
gr.Markdown("""
|
| 376 |
-
# 🧬 CTM Nervous System
|
| 377 |
-
**Continuous Thought Machine for Hypergraph
|
| 378 |
|
| 379 |
Based on [arXiv:2505.05522](https://arxiv.org/abs/2505.05522) - Sakana AI
|
| 380 |
|
| 381 |
---
|
| 382 |
|
| 383 |
-
##
|
| 384 |
-
- **
|
| 385 |
-
- **
|
| 386 |
-
- **
|
| 387 |
-
- **
|
| 388 |
-
|
|
|
|
| 389 |
""")
|
| 390 |
|
| 391 |
with gr.Tabs():
|
| 392 |
with gr.Tab("🔌 /sense_snn"):
|
| 393 |
-
gr.Markdown("Process 72D SNN input
|
| 394 |
snn_input = gr.Textbox(
|
| 395 |
label="SNN JSON Input",
|
| 396 |
-
value='{"dendrites": {"d1": 0.1, "d2": 0.2, "d3": 0.3}}',
|
| 397 |
lines=5
|
| 398 |
)
|
| 399 |
snn_output = gr.Textbox(label="Output", lines=10)
|
|
@@ -401,7 +777,7 @@ with gr.Blocks(title="CTM Nervous System") as demo:
|
|
| 401 |
snn_btn.click(sense_snn, inputs=snn_input, outputs=snn_output, api_name="sense_snn")
|
| 402 |
|
| 403 |
with gr.Tab("🧠 /reason_hypergraph"):
|
| 404 |
-
gr.Markdown("Reason about hypergraph context")
|
| 405 |
reason_input = gr.Textbox(
|
| 406 |
label="Context JSON",
|
| 407 |
value='{"node_features": [[0.1, 0.2], [0.3, 0.4]], "existing_edges": [], "ticks": 50}',
|
|
@@ -412,7 +788,7 @@ with gr.Blocks(title="CTM Nervous System") as demo:
|
|
| 412 |
reason_btn.click(reason_hypergraph, inputs=reason_input, outputs=reason_output, api_name="reason_hypergraph")
|
| 413 |
|
| 414 |
with gr.Tab("⚡ /validate_physics"):
|
| 415 |
-
gr.Markdown("Validate against 5 physics losses")
|
| 416 |
physics_input = gr.Textbox(
|
| 417 |
label="Physics JSON",
|
| 418 |
value='{"trajectory": [0.1, 0.2, 0.3], "physics_params": {"P_max": 1000}}',
|
|
@@ -423,7 +799,7 @@ with gr.Blocks(title="CTM Nervous System") as demo:
|
|
| 423 |
physics_btn.click(validate_physics_endpoint, inputs=physics_input, outputs=physics_output, api_name="validate_physics")
|
| 424 |
|
| 425 |
with gr.Tab("💤 /dream"):
|
| 426 |
-
gr.Markdown("Offline consolidation")
|
| 427 |
dream_input = gr.Textbox(
|
| 428 |
label="Dream JSON",
|
| 429 |
value='{"hypergraph_snapshot": {"nodes": []}, "ticks": 100}',
|
|
@@ -434,7 +810,7 @@ with gr.Blocks(title="CTM Nervous System") as demo:
|
|
| 434 |
dream_btn.click(dream_endpoint, inputs=dream_input, outputs=dream_output, api_name="dream")
|
| 435 |
|
| 436 |
with gr.Tab("🔧 /calibrate_stdp"):
|
| 437 |
-
gr.Markdown("Calibrate STDP weights")
|
| 438 |
stdp_input = gr.Textbox(
|
| 439 |
label="STDP JSON",
|
| 440 |
value='{"current_weights": [1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1], "node_features": [[0.1, 0.2]]}',
|
|
@@ -444,16 +820,40 @@ with gr.Blocks(title="CTM Nervous System") as demo:
|
|
| 444 |
stdp_btn = gr.Button("Calibrate", variant="primary")
|
| 445 |
stdp_btn.click(calibrate_stdp_endpoint, inputs=stdp_input, outputs=stdp_output, api_name="calibrate_stdp")
|
| 446 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 447 |
with gr.Tab("❤️ Health"):
|
| 448 |
-
health_output = gr.Textbox(label="Health Status", lines=
|
| 449 |
health_btn = gr.Button("Check Health", variant="secondary")
|
| 450 |
health_btn.click(health_check, inputs=None, outputs=health_output, api_name="health_check")
|
| 451 |
|
| 452 |
gr.Markdown("""
|
| 453 |
---
|
| 454 |
-
**Architecture**: CTM as Nervous System → Hypergraph as Thought
|
|
|
|
|
|
|
| 455 |
|
| 456 |
-
**Training**:
|
| 457 |
""")
|
| 458 |
|
| 459 |
if __name__ == "__main__":
|
|
@@ -461,4 +861,4 @@ if __name__ == "__main__":
|
|
| 461 |
server_name="0.0.0.0",
|
| 462 |
server_port=7860,
|
| 463 |
show_error=True
|
| 464 |
-
)
|
|
|
|
| 1 |
"""
|
| 2 |
+
CTM Nervous System Server v2.0 - Full PyTorch Implementation
|
| 3 |
+
=============================================================
|
| 4 |
+
Continuous Thought Machine for ART-17 Hypergraph Coherence Generation
|
| 5 |
|
| 6 |
+
PURPOSE (from skills):
|
| 7 |
+
1. REGULACIÓN: Calibrar pesos STDP de las 16 dendritas
|
| 8 |
+
2. COHERENCIA: Generar hipergrafos deterministas
|
| 9 |
+
3. RAZONAMIENTO: Motor de inferencia activa (internal ticks)
|
| 10 |
+
4. SINCRONIZACIÓN: Representación via Neural Synchronization
|
| 11 |
+
|
| 12 |
+
TRAINING STRATEGY:
|
| 13 |
+
- Progressive online learning with use
|
| 14 |
+
- Integrates with Brain server (Qwen + VL-JEPA) for semantic grounding
|
| 15 |
+
- Automatic checkpoint saving
|
| 16 |
|
| 17 |
Based on: arXiv:2505.05522 (Continuous Thought Machines - Sakana AI)
|
| 18 |
+
Adapted for: ART-17 Dendrite Regulation & Hypergraph Generation
|
| 19 |
"""
|
| 20 |
|
| 21 |
import gradio as gr
|
| 22 |
import numpy as np
|
| 23 |
import json
|
|
|
|
| 24 |
import os
|
| 25 |
+
from typing import List, Dict, Any, Optional
|
| 26 |
+
from datetime import datetime
|
| 27 |
+
|
| 28 |
+
# ============================================================================
|
| 29 |
+
# PYTORCH IMPORTS WITH FALLBACK
|
| 30 |
+
# ============================================================================
|
| 31 |
+
try:
|
| 32 |
+
import torch
|
| 33 |
+
import torch.nn as nn
|
| 34 |
+
import torch.nn.functional as F
|
| 35 |
+
TORCH_AVAILABLE = True
|
| 36 |
+
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
| 37 |
+
print(f"🔧 PyTorch available. Device: {DEVICE}")
|
| 38 |
+
except ImportError:
|
| 39 |
+
TORCH_AVAILABLE = False
|
| 40 |
+
DEVICE = "cpu"
|
| 41 |
+
print("⚠️ PyTorch not available. Using simplified NumPy fallback.")
|
| 42 |
+
|
| 43 |
+
# ============================================================================
|
| 44 |
+
# FULL CTM IMPORT (with fallback to simplified)
|
| 45 |
+
# ============================================================================
|
| 46 |
+
if TORCH_AVAILABLE:
|
| 47 |
+
try:
|
| 48 |
+
from models.ctm import ContinuousThoughtMachine
|
| 49 |
+
from models.modules import SynapseUNET, SuperLinear
|
| 50 |
+
from utils.losses import image_classification_loss
|
| 51 |
+
CTM_FULL = True
|
| 52 |
+
print("✅ Full CTM model loaded from models/ctm.py")
|
| 53 |
+
except ImportError as e:
|
| 54 |
+
CTM_FULL = False
|
| 55 |
+
print(f"⚠️ Could not import full CTM: {e}. Using simplified.")
|
| 56 |
+
else:
|
| 57 |
+
CTM_FULL = False
|
| 58 |
|
| 59 |
# ============================================================================
|
| 60 |
+
# CONFIGURATION FOR ART-17 INTEGRATION
|
| 61 |
# ============================================================================
|
| 62 |
+
CONFIG = {
|
| 63 |
+
# CTM Architecture (matching ART-17)
|
| 64 |
+
"iterations": 50, # T internal ticks
|
| 65 |
+
"d_model": 256, # Latent dimension
|
| 66 |
+
"d_input": 72, # Input from SNN (72D)
|
| 67 |
+
"memory_length": 16, # History length (16 dendrites)
|
| 68 |
+
"n_synch_out": 32, # Output sync neurons
|
| 69 |
+
"n_synch_action": 16, # Action sync neurons
|
| 70 |
+
"out_dims": 16, # Output: 16 dendrite adjustments
|
| 71 |
+
|
| 72 |
+
# Training
|
| 73 |
+
"learning_rate": 1e-4,
|
| 74 |
+
"weight_decay": 1e-5,
|
| 75 |
+
"checkpoint_dir": "checkpoints",
|
| 76 |
+
"auto_save_every": 100, # Save every N forward passes
|
| 77 |
+
|
| 78 |
+
# Integration
|
| 79 |
+
"brain_server_url": "https://elliotasdasdasfasas-brain.hf.space",
|
| 80 |
+
|
| 81 |
+
# Physics validation
|
| 82 |
+
"physics_thresholds": {
|
| 83 |
+
"P_max": 1000.0,
|
| 84 |
+
"v_max": 100.0,
|
| 85 |
+
"T_dew": 15.0,
|
| 86 |
+
"T_amb": 25.0
|
| 87 |
+
}
|
| 88 |
+
}
|
| 89 |
|
| 90 |
+
# ============================================================================
|
| 91 |
+
# FULL CTM WRAPPER FOR ART-17
|
| 92 |
+
# ============================================================================
|
| 93 |
+
class CTM_ART17:
|
| 94 |
"""
|
| 95 |
+
Full Continuous Thought Machine adapted for ART-17.
|
| 96 |
+
|
| 97 |
+
Key mechanisms from paper:
|
| 98 |
+
1. NLMs (Neuron-Level Models) - Each neuron processes its own history
|
| 99 |
+
2. Neural Synchronization - Representation is S = Z·Z^T
|
| 100 |
+
3. Adaptive Compute - Can halt early when confident
|
| 101 |
+
|
| 102 |
+
Purpose in ART-17:
|
| 103 |
+
- Regulate 16 dendrite STDP weights
|
| 104 |
+
- Generate coherent hypergraph edges
|
| 105 |
+
- Serve as "nervous system" for the whole system
|
| 106 |
"""
|
| 107 |
|
| 108 |
+
def __init__(self, config: dict):
|
| 109 |
+
self.config = config
|
| 110 |
+
self.forward_count = 0
|
| 111 |
+
self.training_samples = []
|
| 112 |
+
|
| 113 |
+
if CTM_FULL and TORCH_AVAILABLE:
|
| 114 |
+
self._init_full_ctm()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 115 |
else:
|
| 116 |
+
self._init_simplified_ctm()
|
| 117 |
+
|
| 118 |
+
def _init_full_ctm(self):
|
| 119 |
+
"""Initialize full PyTorch CTM model."""
|
| 120 |
+
self.model = ContinuousThoughtMachine(
|
| 121 |
+
iterations=self.config["iterations"],
|
| 122 |
+
d_model=self.config["d_model"],
|
| 123 |
+
d_input=self.config["d_input"],
|
| 124 |
+
heads=4,
|
| 125 |
+
n_synch_out=self.config["n_synch_out"],
|
| 126 |
+
n_synch_action=self.config["n_synch_action"],
|
| 127 |
+
synapse_depth=2,
|
| 128 |
+
memory_length=self.config["memory_length"],
|
| 129 |
+
deep_nlms=True,
|
| 130 |
+
memory_hidden_dims=32,
|
| 131 |
+
do_layernorm_nlm=False,
|
| 132 |
+
backbone_type='none',
|
| 133 |
+
positional_embedding_type='none',
|
| 134 |
+
out_dims=self.config["out_dims"],
|
| 135 |
+
prediction_reshaper=[self.config["out_dims"]],
|
| 136 |
+
dropout=0.1,
|
| 137 |
+
neuron_select_type='random-pairing'
|
| 138 |
+
).to(DEVICE)
|
| 139 |
+
|
| 140 |
+
# Dummy forward to initialize lazy modules
|
| 141 |
+
with torch.no_grad():
|
| 142 |
+
dummy = torch.randn(1, self.config["d_input"], device=DEVICE)
|
| 143 |
+
dummy = dummy.unsqueeze(-1).unsqueeze(-1) # [1, 72, 1, 1]
|
| 144 |
+
try:
|
| 145 |
+
_ = self.model(dummy)
|
| 146 |
+
except Exception as e:
|
| 147 |
+
print(f"⚠️ Lazy init failed: {e}")
|
| 148 |
+
|
| 149 |
+
self.model.eval()
|
| 150 |
+
self.optimizer = torch.optim.AdamW(
|
| 151 |
+
self.model.parameters(),
|
| 152 |
+
lr=self.config["learning_rate"],
|
| 153 |
+
weight_decay=self.config["weight_decay"]
|
| 154 |
+
)
|
| 155 |
+
|
| 156 |
+
self.is_full = True
|
| 157 |
+
param_count = sum(p.numel() for p in self.model.parameters())
|
| 158 |
+
print(f"✅ Full CTM initialized: {param_count:,} parameters")
|
| 159 |
+
|
| 160 |
+
# Try to load existing checkpoint
|
| 161 |
+
self._load_checkpoint()
|
| 162 |
+
|
| 163 |
+
def _init_simplified_ctm(self):
|
| 164 |
+
"""Initialize simplified NumPy CTM (fallback)."""
|
| 165 |
+
self.d_model = self.config["d_model"]
|
| 166 |
+
self.memory_length = self.config["memory_length"]
|
| 167 |
+
self.n_ticks = self.config["iterations"]
|
| 168 |
+
|
| 169 |
+
# State traces
|
| 170 |
+
self.state_trace = np.zeros((self.d_model, self.memory_length))
|
| 171 |
+
self.activated_state = np.random.randn(self.d_model) * 0.1
|
| 172 |
+
|
| 173 |
+
# NLM weights (simplified: 16 groups for 16 dendrites)
|
| 174 |
+
self.nlm_weights = np.random.randn(16, self.memory_length) * 0.1
|
| 175 |
+
|
| 176 |
+
self.is_full = False
|
| 177 |
+
print("✅ Simplified CTM initialized (NumPy fallback)")
|
| 178 |
+
|
| 179 |
+
def forward(self, input_72d: np.ndarray, n_ticks: Optional[int] = None) -> Dict:
|
| 180 |
+
"""
|
| 181 |
+
Process input through CTM.
|
| 182 |
+
|
| 183 |
+
Args:
|
| 184 |
+
input_72d: 72D input from SNN
|
| 185 |
+
n_ticks: Override number of internal ticks
|
| 186 |
+
|
| 187 |
+
Returns:
|
| 188 |
+
Dict with predictions, certainty, sync matrix
|
| 189 |
+
"""
|
| 190 |
+
n_ticks = n_ticks or self.config["iterations"]
|
| 191 |
+
self.forward_count += 1
|
| 192 |
+
|
| 193 |
+
if self.is_full:
|
| 194 |
+
return self._forward_full(input_72d, n_ticks)
|
| 195 |
+
else:
|
| 196 |
+
return self._forward_simplified(input_72d, n_ticks)
|
| 197 |
+
|
| 198 |
+
def _forward_full(self, input_72d: np.ndarray, n_ticks: int) -> Dict:
|
| 199 |
+
"""Forward pass with full PyTorch CTM."""
|
| 200 |
+
# Prepare tensor
|
| 201 |
+
x = torch.tensor(input_72d, dtype=torch.float32, device=DEVICE)
|
| 202 |
+
if len(x.shape) == 1:
|
| 203 |
+
x = x.unsqueeze(0) # Add batch dim
|
| 204 |
+
x = x.unsqueeze(-1).unsqueeze(-1) # [B, 72, 1, 1]
|
| 205 |
+
|
| 206 |
+
with torch.no_grad():
|
| 207 |
+
predictions, certainties, sync_out = self.model(x)
|
| 208 |
+
|
| 209 |
+
# Extract results
|
| 210 |
+
final_pred = predictions[:, :, -1].cpu().numpy()[0] # Last tick [16]
|
| 211 |
+
final_cert = certainties[:, 1, -1].cpu().numpy()[0] # 1-entropy
|
| 212 |
+
|
| 213 |
+
# Find tick with highest certainty
|
| 214 |
+
best_tick_idx = certainties[:, 1, :].argmax(dim=-1)[0].item()
|
| 215 |
+
best_pred = predictions[:, :, best_tick_idx].cpu().numpy()[0]
|
| 216 |
+
|
| 217 |
+
# Sync matrix for hypergraph edge proposals
|
| 218 |
+
sync_matrix = sync_out.cpu().numpy()[0] if sync_out is not None else None
|
| 219 |
+
|
| 220 |
+
return {
|
| 221 |
+
"predictions": final_pred.tolist(),
|
| 222 |
+
"best_predictions": best_pred.tolist(),
|
| 223 |
+
"certainty": float(final_cert),
|
| 224 |
+
"best_tick": int(best_tick_idx),
|
| 225 |
+
"ticks_used": n_ticks,
|
| 226 |
+
"sync_matrix": sync_matrix.tolist() if sync_matrix is not None else None,
|
| 227 |
+
"model": "ContinuousThoughtMachine (Full PyTorch)"
|
| 228 |
+
}
|
| 229 |
+
|
| 230 |
+
def _forward_simplified(self, input_72d: np.ndarray, n_ticks: int) -> Dict:
|
| 231 |
+
"""Forward pass with simplified NumPy CTM."""
|
| 232 |
+
# Pad to d_model
|
| 233 |
+
input_256 = np.zeros(self.d_model)
|
| 234 |
+
input_256[:min(len(input_72d), self.d_model)] = input_72d[:self.d_model]
|
| 235 |
|
| 236 |
certainties = []
|
|
|
|
| 237 |
|
| 238 |
for t in range(n_ticks):
|
| 239 |
+
# Synapse update (simplified)
|
| 240 |
+
combined = np.concatenate([self.activated_state, input_256[:self.d_model//2]])
|
| 241 |
pre_activation = np.tanh(combined[:self.d_model] * 0.1 + np.random.randn(self.d_model) * 0.01)
|
| 242 |
|
| 243 |
+
# Update trace (memory)
|
| 244 |
self.state_trace = np.roll(self.state_trace, -1, axis=1)
|
| 245 |
self.state_trace[:, -1] = pre_activation
|
| 246 |
|
| 247 |
+
# NLM processing (simplified: 16 groups)
|
| 248 |
post_activation = np.zeros(self.d_model)
|
| 249 |
group_size = self.d_model // 16
|
| 250 |
for g in range(16):
|
|
|
|
| 256 |
|
| 257 |
self.activated_state = post_activation
|
| 258 |
|
| 259 |
+
# Compute certainty
|
| 260 |
+
probs = np.abs(self.activated_state) / (np.sum(np.abs(self.activated_state)) + 1e-8)
|
| 261 |
+
probs = np.clip(probs, 1e-10, 1.0)
|
| 262 |
+
entropy = -np.sum(probs * np.log(probs))
|
| 263 |
+
max_entropy = np.log(len(probs))
|
| 264 |
+
certainties.append(float(1.0 - entropy / (max_entropy + 1e-8)))
|
| 265 |
+
|
| 266 |
+
# Best tick
|
| 267 |
+
best_tick_idx = int(np.argmax(certainties))
|
| 268 |
+
|
| 269 |
+
# Sync matrix
|
| 270 |
+
z_norm = self.activated_state / (np.linalg.norm(self.activated_state) + 1e-8)
|
| 271 |
+
sync_matrix = np.outer(z_norm, z_norm)
|
| 272 |
+
|
| 273 |
+
# Predictions (first 16 elements of activated state)
|
| 274 |
+
predictions = self.activated_state[:16].tolist()
|
| 275 |
+
|
| 276 |
+
return {
|
| 277 |
+
"predictions": predictions,
|
| 278 |
+
"best_predictions": predictions,
|
| 279 |
+
"certainty": certainties[-1],
|
| 280 |
+
"best_tick": best_tick_idx,
|
| 281 |
+
"ticks_used": n_ticks,
|
| 282 |
+
"sync_matrix": sync_matrix[:16, :16].tolist(),
|
| 283 |
+
"model": "SimplifiedCTM (NumPy fallback)"
|
| 284 |
+
}
|
| 285 |
+
|
| 286 |
+
def train_step(self, input_72d: np.ndarray, target_16d: np.ndarray,
|
| 287 |
+
physics_loss: float = 0.0) -> Dict:
|
| 288 |
+
"""
|
| 289 |
+
Online training step.
|
| 290 |
+
|
| 291 |
+
Args:
|
| 292 |
+
input_72d: Input from SNN
|
| 293 |
+
target_16d: Target dendrite adjustments (ground truth)
|
| 294 |
+
physics_loss: Current physics loss for weighting
|
| 295 |
|
| 296 |
+
Returns:
|
| 297 |
+
Dict with loss and gradient info
|
| 298 |
+
"""
|
| 299 |
+
if not self.is_full or not TORCH_AVAILABLE:
|
| 300 |
+
return {"status": "skip", "reason": "Training requires full PyTorch CTM"}
|
| 301 |
+
|
| 302 |
+
self.model.train()
|
| 303 |
+
|
| 304 |
+
# Prepare tensors
|
| 305 |
+
x = torch.tensor(input_72d, dtype=torch.float32, device=DEVICE)
|
| 306 |
+
x = x.unsqueeze(0).unsqueeze(-1).unsqueeze(-1) # [1, 72, 1, 1]
|
| 307 |
+
y = torch.tensor(target_16d, dtype=torch.float32, device=DEVICE).unsqueeze(0)
|
| 308 |
+
|
| 309 |
+
# Forward
|
| 310 |
+
predictions, certainties, _ = self.model(x)
|
| 311 |
|
| 312 |
+
# Loss: dendrite_regulation_loss
|
| 313 |
+
# predictions: [B, 16, T], y: [B, 16]
|
| 314 |
+
y_exp = y.unsqueeze(-1).expand(-1, -1, predictions.size(-1)) # [B, 16, T]
|
| 315 |
+
mse_per_tick = F.mse_loss(predictions, y_exp, reduction='none').mean(dim=1) # [B, T]
|
| 316 |
+
|
| 317 |
+
# Select best tick (min loss) and most certain tick
|
| 318 |
+
loss_min_idx = mse_per_tick.argmin(dim=1) # [B]
|
| 319 |
+
loss_cert_idx = certainties[:, 1, :].argmax(dim=1) # [B]
|
| 320 |
+
|
| 321 |
+
batch_idx = torch.arange(predictions.size(0), device=DEVICE)
|
| 322 |
+
loss_min = mse_per_tick[batch_idx, loss_min_idx].mean()
|
| 323 |
+
loss_cert = mse_per_tick[batch_idx, loss_cert_idx].mean()
|
| 324 |
+
|
| 325 |
+
# Combined loss with physics penalty
|
| 326 |
+
mse_loss = (loss_min + loss_cert) / 2
|
| 327 |
+
physics_penalty = physics_loss * 0.1
|
| 328 |
+
total_loss = mse_loss + physics_penalty
|
| 329 |
+
|
| 330 |
+
# Backward
|
| 331 |
+
self.optimizer.zero_grad()
|
| 332 |
+
total_loss.backward()
|
| 333 |
+
torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
|
| 334 |
+
self.optimizer.step()
|
| 335 |
+
|
| 336 |
+
self.model.eval()
|
| 337 |
+
|
| 338 |
+
# Auto-save checkpoint
|
| 339 |
+
if self.forward_count % self.config["auto_save_every"] == 0:
|
| 340 |
+
self._save_checkpoint()
|
| 341 |
|
| 342 |
return {
|
| 343 |
+
"status": "trained",
|
| 344 |
+
"loss": float(total_loss.item()),
|
| 345 |
+
"mse_loss": float(mse_loss.item()),
|
| 346 |
+
"physics_penalty": float(physics_penalty),
|
| 347 |
+
"best_tick": int(loss_cert_idx[0].item())
|
|
|
|
|
|
|
| 348 |
}
|
| 349 |
+
|
| 350 |
+
def _save_checkpoint(self):
|
| 351 |
+
"""Save model checkpoint."""
|
| 352 |
+
if not self.is_full:
|
| 353 |
+
return
|
| 354 |
+
|
| 355 |
+
os.makedirs(self.config["checkpoint_dir"], exist_ok=True)
|
| 356 |
+
path = os.path.join(self.config["checkpoint_dir"], "ctm_art17_latest.pt")
|
| 357 |
+
|
| 358 |
+
torch.save({
|
| 359 |
+
"model_state_dict": self.model.state_dict(),
|
| 360 |
+
"optimizer_state_dict": self.optimizer.state_dict(),
|
| 361 |
+
"forward_count": self.forward_count,
|
| 362 |
+
"timestamp": datetime.now().isoformat()
|
| 363 |
+
}, path)
|
| 364 |
+
print(f"💾 Checkpoint saved: {path}")
|
| 365 |
+
|
| 366 |
+
def _load_checkpoint(self):
|
| 367 |
+
"""Load model checkpoint if exists."""
|
| 368 |
+
path = os.path.join(self.config["checkpoint_dir"], "ctm_art17_latest.pt")
|
| 369 |
+
if os.path.exists(path):
|
| 370 |
+
try:
|
| 371 |
+
checkpoint = torch.load(path, map_location=DEVICE)
|
| 372 |
+
self.model.load_state_dict(checkpoint["model_state_dict"])
|
| 373 |
+
self.optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
|
| 374 |
+
self.forward_count = checkpoint.get("forward_count", 0)
|
| 375 |
+
print(f"✅ Checkpoint loaded: {path}")
|
| 376 |
+
except Exception as e:
|
| 377 |
+
print(f"⚠️ Could not load checkpoint: {e}")
|
| 378 |
|
| 379 |
+
# ============================================================================
|
| 380 |
+
# GLOBAL CTM INSTANCE
|
| 381 |
+
# ============================================================================
|
| 382 |
+
ctm = CTM_ART17(CONFIG)
|
| 383 |
|
| 384 |
# ============================================================================
|
| 385 |
# PHYSICS VALIDATION (from SNN Omega-21)
|
| 386 |
# ============================================================================
|
|
|
|
| 387 |
def validate_physics(trajectory: List[float], params: Dict) -> Dict:
|
| 388 |
+
"""Validate against 5 physics losses from SNN Omega-21."""
|
| 389 |
trajectory = np.array(trajectory)
|
| 390 |
|
| 391 |
# L_energy: Energy conservation
|
| 392 |
energy = np.sum(trajectory ** 2)
|
| 393 |
+
P_max = params.get("P_max", CONFIG["physics_thresholds"]["P_max"])
|
| 394 |
L_energy = float(max(0, energy - P_max) ** 2)
|
| 395 |
|
| 396 |
# L_thermo: Thermodynamics (dew point check)
|
| 397 |
+
T_dew = params.get("T_dew", CONFIG["physics_thresholds"]["T_dew"])
|
| 398 |
+
T_amb = params.get("T_amb", CONFIG["physics_thresholds"]["T_amb"])
|
| 399 |
L_thermo = float(max(0, T_dew - T_amb) ** 2)
|
| 400 |
|
| 401 |
# L_causal: Causality (velocity limit)
|
| 402 |
velocity = np.diff(trajectory) if len(trajectory) > 1 else np.array([0])
|
| 403 |
+
v_max = params.get("v_max", CONFIG["physics_thresholds"]["v_max"])
|
| 404 |
L_causal = float(np.sum(np.maximum(0, np.abs(velocity) - v_max) ** 2))
|
| 405 |
|
| 406 |
# L_conserv: Flux conservation
|
|
|
|
| 431 |
|
| 432 |
def sense_snn(snn_json: str) -> str:
|
| 433 |
"""
|
| 434 |
+
/sense_snn - Process 72D SNN input through CTM
|
| 435 |
+
|
| 436 |
+
Input: JSON with dendrite values or 72D vector
|
| 437 |
+
Output: Coherent features, certainty, sync matrix
|
| 438 |
"""
|
| 439 |
try:
|
| 440 |
data = json.loads(snn_json)
|
| 441 |
|
| 442 |
+
# Extract 72D vector
|
| 443 |
if "vector_72d" in data:
|
| 444 |
input_vec = np.array(data["vector_72d"])
|
| 445 |
elif "dendrites" in data:
|
| 446 |
+
input_vec = np.array(list(data["dendrites"].values()))
|
|
|
|
| 447 |
else:
|
| 448 |
+
input_vec = np.random.randn(72)
|
| 449 |
|
| 450 |
+
# Pad to 72D if needed
|
| 451 |
+
if len(input_vec) < 72:
|
| 452 |
+
input_vec = np.pad(input_vec, (0, 72 - len(input_vec)))
|
| 453 |
|
| 454 |
# Process through CTM
|
| 455 |
+
n_ticks = data.get("ticks", 25)
|
| 456 |
+
result = ctm.forward(input_vec[:72], n_ticks)
|
| 457 |
|
| 458 |
+
# Detect anomalies (low certainty)
|
| 459 |
anomalies = []
|
| 460 |
+
if result["certainty"] < 0.5:
|
| 461 |
+
anomalies.append("Low overall certainty - consider retraining")
|
| 462 |
|
| 463 |
return json.dumps({
|
| 464 |
"status": "success",
|
| 465 |
+
"coherent_features": result["predictions"],
|
| 466 |
+
"certainty": result["certainty"],
|
| 467 |
+
"best_tick": result["best_tick"],
|
| 468 |
"anomalies": anomalies,
|
| 469 |
+
"ticks_used": result["ticks_used"],
|
| 470 |
+
"model": result["model"]
|
| 471 |
}, indent=2)
|
| 472 |
except Exception as e:
|
| 473 |
return json.dumps({"status": "error", "message": str(e)})
|
| 474 |
|
| 475 |
+
|
| 476 |
def reason_hypergraph(context_json: str) -> str:
|
| 477 |
"""
|
| 478 |
+
/reason_hypergraph - Reason about hypergraph context, propose edges
|
| 479 |
+
|
| 480 |
+
Uses CTM synchronization matrix to find strongly correlated node pairs.
|
| 481 |
+
These become proposed hyperedges.
|
| 482 |
"""
|
| 483 |
try:
|
| 484 |
data = json.loads(context_json)
|
|
|
|
| 488 |
n_ticks = data.get("ticks", 50)
|
| 489 |
|
| 490 |
# Flatten node features for CTM input
|
| 491 |
+
input_vec = node_features.flatten()[:72]
|
| 492 |
|
| 493 |
# Process through CTM with more ticks for reasoning
|
| 494 |
+
result = ctm.forward(input_vec, n_ticks)
|
| 495 |
|
| 496 |
+
# Extract proposed edges from sync matrix (S_ij > 0.7)
|
|
|
|
| 497 |
proposed_edges = []
|
| 498 |
+
if result["sync_matrix"] is not None:
|
| 499 |
+
sync = np.array(result["sync_matrix"])
|
| 500 |
+
n_nodes = min(len(node_features), sync.shape[0])
|
| 501 |
+
|
| 502 |
+
for i in range(n_nodes):
|
| 503 |
+
for j in range(i+1, n_nodes):
|
| 504 |
+
sync_ij = sync[i, j]
|
| 505 |
+
if sync_ij > 0.7: # Threshold for edge proposal
|
| 506 |
+
edge_exists = any(
|
| 507 |
+
(e[0] == i and e[1] == j) or (e[0] == j and e[1] == i)
|
| 508 |
+
for e in existing_edges
|
| 509 |
+
)
|
| 510 |
+
if not edge_exists:
|
| 511 |
+
proposed_edges.append([i, j, float(sync_ij)])
|
| 512 |
|
| 513 |
return json.dumps({
|
| 514 |
"status": "success",
|
| 515 |
"proposed_edges": proposed_edges,
|
| 516 |
+
"certainty": result["certainty"],
|
| 517 |
"best_tick": result["best_tick"],
|
| 518 |
+
"ticks_used": result["ticks_used"],
|
| 519 |
+
"model": result["model"]
|
| 520 |
}, indent=2)
|
| 521 |
except Exception as e:
|
| 522 |
return json.dumps({"status": "error", "message": str(e)})
|
| 523 |
|
| 524 |
+
|
| 525 |
def validate_physics_endpoint(physics_json: str) -> str:
|
| 526 |
"""
|
| 527 |
/validate_physics - Validate trajectory against 5 physics losses
|
|
|
|
| 538 |
except Exception as e:
|
| 539 |
return json.dumps({"status": "error", "message": str(e)})
|
| 540 |
|
| 541 |
+
|
| 542 |
def dream_endpoint(dream_json: str) -> str:
|
| 543 |
"""
|
| 544 |
+
/dream - Offline consolidation with many ticks
|
| 545 |
+
|
| 546 |
+
Discovers patterns, proposes new edges, identifies edges to prune.
|
| 547 |
"""
|
| 548 |
try:
|
| 549 |
data = json.loads(dream_json)
|
| 550 |
|
| 551 |
snapshot = data.get("hypergraph_snapshot", {})
|
| 552 |
+
n_ticks = min(data.get("ticks", 100), 100) # Cap at 100 for CPU
|
| 553 |
|
| 554 |
# Extract features from snapshot
|
| 555 |
nodes = snapshot.get("nodes", [])
|
| 556 |
if nodes:
|
| 557 |
+
input_vec = np.array([n.get("features", [0]*16) for n in nodes]).flatten()[:72]
|
| 558 |
else:
|
| 559 |
+
input_vec = np.random.randn(72)
|
|
|
|
|
|
|
|
|
|
| 560 |
|
| 561 |
+
# Dream: run CTM with many ticks
|
| 562 |
+
result = ctm.forward(input_vec, n_ticks)
|
| 563 |
|
| 564 |
+
# Analyze sync for patterns
|
| 565 |
new_edges = []
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 566 |
pruned_edges = []
|
| 567 |
+
|
| 568 |
+
if result["sync_matrix"] is not None:
|
| 569 |
+
sync = np.array(result["sync_matrix"])
|
| 570 |
+
n = min(len(nodes), sync.shape[0]) if nodes else 16
|
| 571 |
+
|
| 572 |
+
for i in range(n):
|
| 573 |
+
for j in range(i+1, n):
|
| 574 |
+
if sync[i, j] > 0.85:
|
| 575 |
+
new_edges.append([i, j, float(sync[i, j])])
|
| 576 |
+
elif sync[i, j] < 0.1:
|
| 577 |
+
pruned_edges.append([i, j])
|
| 578 |
|
| 579 |
return json.dumps({
|
| 580 |
"status": "success",
|
| 581 |
"discovered_patterns": len(new_edges),
|
| 582 |
+
"new_edges": new_edges[:10],
|
| 583 |
+
"pruned_edges": pruned_edges[:10],
|
| 584 |
+
"consolidation_certainty": result["certainty"],
|
| 585 |
+
"ticks_used": result["ticks_used"],
|
| 586 |
+
"model": result["model"]
|
| 587 |
}, indent=2)
|
| 588 |
except Exception as e:
|
| 589 |
return json.dumps({"status": "error", "message": str(e)})
|
| 590 |
|
| 591 |
+
|
| 592 |
def calibrate_stdp_endpoint(stdp_json: str) -> str:
|
| 593 |
"""
|
| 594 |
+
/calibrate_stdp - Suggest STDP weight adjustments
|
| 595 |
+
|
| 596 |
+
This is the CORE regulatory function:
|
| 597 |
+
- Receives current 16 dendrite weights
|
| 598 |
+
- Processes through CTM to get sync patterns
|
| 599 |
+
- Returns suggested weight adjustments
|
| 600 |
"""
|
| 601 |
try:
|
| 602 |
data = json.loads(stdp_json)
|
| 603 |
|
| 604 |
current_weights = np.array(data.get("current_weights", [1.0]*16))
|
| 605 |
+
node_features = np.array(data.get("node_features", [[0]*16]*4))
|
| 606 |
|
| 607 |
+
# Flatten features for CTM input
|
| 608 |
+
input_vec = node_features.flatten()[:72]
|
|
|
|
| 609 |
|
| 610 |
+
# Process through CTM
|
| 611 |
+
result = ctm.forward(input_vec, n_ticks=25)
|
| 612 |
+
|
| 613 |
+
# Use predictions as weight adjustments
|
| 614 |
+
predictions = np.array(result["best_predictions"])
|
| 615 |
|
| 616 |
+
# Scale based on certainty
|
| 617 |
+
confidence = result["certainty"]
|
| 618 |
+
weight_changes = (predictions - 0.5) * confidence * 0.1
|
| 619 |
+
|
| 620 |
+
new_weights = current_weights + weight_changes
|
|
|
|
|
|
|
|
|
|
| 621 |
|
| 622 |
return json.dumps({
|
| 623 |
"status": "success",
|
| 624 |
+
"suggested_weights": new_weights.tolist(),
|
| 625 |
+
"weight_changes": weight_changes.tolist(),
|
| 626 |
+
"confidence": confidence,
|
| 627 |
+
"best_tick": result["best_tick"],
|
| 628 |
+
"model": result["model"]
|
| 629 |
}, indent=2)
|
| 630 |
except Exception as e:
|
| 631 |
return json.dumps({"status": "error", "message": str(e)})
|
| 632 |
|
| 633 |
+
|
| 634 |
+
def regulate_endpoint(regulate_json: str) -> str:
|
| 635 |
+
"""
|
| 636 |
+
/regulate - Full feedback loop for ART-17 regulation (NEW)
|
| 637 |
+
|
| 638 |
+
Combines all signals to provide comprehensive regulation:
|
| 639 |
+
- Dendrite state
|
| 640 |
+
- Latent representation
|
| 641 |
+
- Physics loss
|
| 642 |
+
- Anomaly score
|
| 643 |
+
|
| 644 |
+
Returns action recommendation with confidence.
|
| 645 |
+
"""
|
| 646 |
+
try:
|
| 647 |
+
data = json.loads(regulate_json)
|
| 648 |
+
|
| 649 |
+
# Inputs from local system
|
| 650 |
+
dendrites = np.array(data.get("dendrites", [0.0]*16))
|
| 651 |
+
latent_256 = np.array(data.get("latent_256", [0.0]*256))
|
| 652 |
+
physics_loss = data.get("physics_loss", 0.0)
|
| 653 |
+
anomaly_score = data.get("anomaly_score", 0.0)
|
| 654 |
+
|
| 655 |
+
# Combine into 72D input
|
| 656 |
+
input_72 = np.concatenate([
|
| 657 |
+
dendrites, # 16D
|
| 658 |
+
latent_256[:56] # 56D from latent
|
| 659 |
+
])
|
| 660 |
+
|
| 661 |
+
# Process through CTM
|
| 662 |
+
result = ctm.forward(input_72, n_ticks=50)
|
| 663 |
+
|
| 664 |
+
# Compute regulation signals
|
| 665 |
+
predictions = np.array(result["best_predictions"])
|
| 666 |
+
certainty = result["certainty"]
|
| 667 |
+
|
| 668 |
+
# Urgency based on physics and anomaly
|
| 669 |
+
urgency = min(1.0, physics_loss + anomaly_score)
|
| 670 |
+
regulation_strength = urgency * certainty
|
| 671 |
+
|
| 672 |
+
# Weight adjustments
|
| 673 |
+
dendrite_deltas = predictions * regulation_strength * 0.05
|
| 674 |
+
|
| 675 |
+
# Determine if intervention needed
|
| 676 |
+
needs_intervention = urgency > 0.5 or certainty < 0.3
|
| 677 |
+
|
| 678 |
+
return json.dumps({
|
| 679 |
+
"status": "success",
|
| 680 |
+
"dendrite_deltas": dendrite_deltas.tolist(),
|
| 681 |
+
"regulation_strength": float(regulation_strength),
|
| 682 |
+
"confidence": certainty,
|
| 683 |
+
"urgency": float(urgency),
|
| 684 |
+
"needs_intervention": needs_intervention,
|
| 685 |
+
"recommended_action": "ADJUST" if needs_intervention else "MAINTAIN",
|
| 686 |
+
"best_tick": result["best_tick"],
|
| 687 |
+
"model": result["model"]
|
| 688 |
+
}, indent=2)
|
| 689 |
+
except Exception as e:
|
| 690 |
+
return json.dumps({"status": "error", "message": str(e)})
|
| 691 |
+
|
| 692 |
+
|
| 693 |
+
def train_online_endpoint(train_json: str) -> str:
|
| 694 |
+
"""
|
| 695 |
+
/train_online - Progressive online training (NEW)
|
| 696 |
+
|
| 697 |
+
Allows the local system to train the CTM with experience.
|
| 698 |
+
Sends input-output pairs and receives training feedback.
|
| 699 |
+
"""
|
| 700 |
+
try:
|
| 701 |
+
data = json.loads(train_json)
|
| 702 |
+
|
| 703 |
+
input_72d = np.array(data.get("input_72d", [0.0]*72))
|
| 704 |
+
target_16d = np.array(data.get("target_16d", [0.0]*16))
|
| 705 |
+
physics_loss = data.get("physics_loss", 0.0)
|
| 706 |
+
|
| 707 |
+
# Perform training step
|
| 708 |
+
result = ctm.train_step(input_72d, target_16d, physics_loss)
|
| 709 |
+
|
| 710 |
+
return json.dumps({
|
| 711 |
+
"status": result["status"],
|
| 712 |
+
"loss": result.get("loss"),
|
| 713 |
+
"mse_loss": result.get("mse_loss"),
|
| 714 |
+
"physics_penalty": result.get("physics_penalty"),
|
| 715 |
+
"best_tick": result.get("best_tick"),
|
| 716 |
+
"forward_count": ctm.forward_count,
|
| 717 |
+
"message": "Training step completed" if result["status"] == "trained" else result.get("reason")
|
| 718 |
+
}, indent=2)
|
| 719 |
+
except Exception as e:
|
| 720 |
+
return json.dumps({"status": "error", "message": str(e)})
|
| 721 |
+
|
| 722 |
+
|
| 723 |
def health_check() -> str:
|
| 724 |
+
"""Health check with model info."""
|
| 725 |
return json.dumps({
|
| 726 |
"status": "healthy",
|
| 727 |
+
"model": f"CTM Nervous System v2.0 ({'Full PyTorch' if ctm.is_full else 'NumPy Fallback'})",
|
| 728 |
+
"device": DEVICE,
|
| 729 |
+
"d_model": CONFIG["d_model"],
|
| 730 |
+
"iterations": CONFIG["iterations"],
|
| 731 |
+
"memory_length": CONFIG["memory_length"],
|
| 732 |
+
"forward_count": ctm.forward_count,
|
| 733 |
"endpoints": [
|
| 734 |
"/sense_snn",
|
| 735 |
+
"/reason_hypergraph",
|
| 736 |
"/validate_physics",
|
| 737 |
"/dream",
|
| 738 |
+
"/calibrate_stdp",
|
| 739 |
+
"/regulate", # NEW
|
| 740 |
+
"/train_online" # NEW
|
| 741 |
]
|
| 742 |
}, indent=2)
|
| 743 |
|
| 744 |
+
|
| 745 |
# ============================================================================
|
| 746 |
# GRADIO INTERFACE
|
| 747 |
# ============================================================================
|
| 748 |
|
| 749 |
+
with gr.Blocks(title="CTM Nervous System v2.0", theme=gr.themes.Soft()) as demo:
|
| 750 |
gr.Markdown("""
|
| 751 |
+
# 🧬 CTM Nervous System v2.0
|
| 752 |
+
**Continuous Thought Machine for ART-17 Hypergraph Coherence**
|
| 753 |
|
| 754 |
Based on [arXiv:2505.05522](https://arxiv.org/abs/2505.05522) - Sakana AI
|
| 755 |
|
| 756 |
---
|
| 757 |
|
| 758 |
+
## Key Innovations
|
| 759 |
+
- **NLMs (Neuron-Level Models)**: Each neuron processes its own history
|
| 760 |
+
- **Neural Synchronization**: Representation via S = Z·Z^T
|
| 761 |
+
- **Adaptive Compute**: Halts when confident
|
| 762 |
+
- **Online Training**: Progressive learning with use
|
| 763 |
+
|
| 764 |
+
---
|
| 765 |
""")
|
| 766 |
|
| 767 |
with gr.Tabs():
|
| 768 |
with gr.Tab("🔌 /sense_snn"):
|
| 769 |
+
gr.Markdown("Process 72D SNN input through CTM")
|
| 770 |
snn_input = gr.Textbox(
|
| 771 |
label="SNN JSON Input",
|
| 772 |
+
value='{"dendrites": {"d1": 0.1, "d2": 0.2, "d3": 0.3}, "ticks": 25}',
|
| 773 |
lines=5
|
| 774 |
)
|
| 775 |
snn_output = gr.Textbox(label="Output", lines=10)
|
|
|
|
| 777 |
snn_btn.click(sense_snn, inputs=snn_input, outputs=snn_output, api_name="sense_snn")
|
| 778 |
|
| 779 |
with gr.Tab("🧠 /reason_hypergraph"):
|
| 780 |
+
gr.Markdown("Reason about hypergraph context, propose edges")
|
| 781 |
reason_input = gr.Textbox(
|
| 782 |
label="Context JSON",
|
| 783 |
value='{"node_features": [[0.1, 0.2], [0.3, 0.4]], "existing_edges": [], "ticks": 50}',
|
|
|
|
| 788 |
reason_btn.click(reason_hypergraph, inputs=reason_input, outputs=reason_output, api_name="reason_hypergraph")
|
| 789 |
|
| 790 |
with gr.Tab("⚡ /validate_physics"):
|
| 791 |
+
gr.Markdown("Validate trajectory against 5 physics losses")
|
| 792 |
physics_input = gr.Textbox(
|
| 793 |
label="Physics JSON",
|
| 794 |
value='{"trajectory": [0.1, 0.2, 0.3], "physics_params": {"P_max": 1000}}',
|
|
|
|
| 799 |
physics_btn.click(validate_physics_endpoint, inputs=physics_input, outputs=physics_output, api_name="validate_physics")
|
| 800 |
|
| 801 |
with gr.Tab("💤 /dream"):
|
| 802 |
+
gr.Markdown("Offline consolidation - discover patterns")
|
| 803 |
dream_input = gr.Textbox(
|
| 804 |
label="Dream JSON",
|
| 805 |
value='{"hypergraph_snapshot": {"nodes": []}, "ticks": 100}',
|
|
|
|
| 810 |
dream_btn.click(dream_endpoint, inputs=dream_input, outputs=dream_output, api_name="dream")
|
| 811 |
|
| 812 |
with gr.Tab("🔧 /calibrate_stdp"):
|
| 813 |
+
gr.Markdown("Calibrate STDP weights (Core regulatory function)")
|
| 814 |
stdp_input = gr.Textbox(
|
| 815 |
label="STDP JSON",
|
| 816 |
value='{"current_weights": [1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1], "node_features": [[0.1, 0.2]]}',
|
|
|
|
| 820 |
stdp_btn = gr.Button("Calibrate", variant="primary")
|
| 821 |
stdp_btn.click(calibrate_stdp_endpoint, inputs=stdp_input, outputs=stdp_output, api_name="calibrate_stdp")
|
| 822 |
|
| 823 |
+
with gr.Tab("🎯 /regulate [NEW]"):
|
| 824 |
+
gr.Markdown("Full feedback loop for ART-17 regulation")
|
| 825 |
+
regulate_input = gr.Textbox(
|
| 826 |
+
label="Regulate JSON",
|
| 827 |
+
value='{"dendrites": [0.5]*16, "latent_256": [0.1]*256, "physics_loss": 0.01, "anomaly_score": 0.05}',
|
| 828 |
+
lines=5
|
| 829 |
+
)
|
| 830 |
+
regulate_output = gr.Textbox(label="Output", lines=10)
|
| 831 |
+
regulate_btn = gr.Button("Regulate", variant="primary")
|
| 832 |
+
regulate_btn.click(regulate_endpoint, inputs=regulate_input, outputs=regulate_output, api_name="regulate")
|
| 833 |
+
|
| 834 |
+
with gr.Tab("📚 /train_online [NEW]"):
|
| 835 |
+
gr.Markdown("Progressive online training with experience")
|
| 836 |
+
train_input = gr.Textbox(
|
| 837 |
+
label="Training JSON",
|
| 838 |
+
value='{"input_72d": [0.1]*72, "target_16d": [0.5]*16, "physics_loss": 0.01}',
|
| 839 |
+
lines=5
|
| 840 |
+
)
|
| 841 |
+
train_output = gr.Textbox(label="Output", lines=10)
|
| 842 |
+
train_btn = gr.Button("Train Step", variant="primary")
|
| 843 |
+
train_btn.click(train_online_endpoint, inputs=train_input, outputs=train_output, api_name="train_online")
|
| 844 |
+
|
| 845 |
with gr.Tab("❤️ Health"):
|
| 846 |
+
health_output = gr.Textbox(label="Health Status", lines=15)
|
| 847 |
health_btn = gr.Button("Check Health", variant="secondary")
|
| 848 |
health_btn.click(health_check, inputs=None, outputs=health_output, api_name="health_check")
|
| 849 |
|
| 850 |
gr.Markdown("""
|
| 851 |
---
|
| 852 |
+
**Architecture**: CTM as Nervous System → Hypergraph as Coherent Thought
|
| 853 |
+
|
| 854 |
+
**Integration**: Local ART-17 ↔ CTM (regulation) ↔ Brain Server (semantics)
|
| 855 |
|
| 856 |
+
**Training**: Progressive online learning + Physics-Informed Loss
|
| 857 |
""")
|
| 858 |
|
| 859 |
if __name__ == "__main__":
|
|
|
|
| 861 |
server_name="0.0.0.0",
|
| 862 |
server_port=7860,
|
| 863 |
show_error=True
|
| 864 |
+
)
|
app_v1_backup.py
ADDED
|
@@ -0,0 +1,464 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
CTM Nervous System Server - Continuous Thought Machine for Hypergraph Maintenance
|
| 3 |
+
===================================================================================
|
| 4 |
+
Implementation of the definitive proposal: CTM as Nervous System for ART-17 Hypergraph.
|
| 5 |
+
|
| 6 |
+
Endpoints:
|
| 7 |
+
- /sense_snn: Process 72D SNN input with NLM-style processing
|
| 8 |
+
- /reason_hypergraph: Reason about hypergraph context, propose edges
|
| 9 |
+
- /validate_physics: Validate proposals against 5 physics losses
|
| 10 |
+
- /dream: Offline consolidation with T=500+ ticks
|
| 11 |
+
- /calibrate_stdp: Suggest STDP weight adjustments from sync matrix
|
| 12 |
+
- /health: Health check endpoint
|
| 13 |
+
|
| 14 |
+
Based on: arXiv:2505.05522 (Continuous Thought Machines - Sakana AI)
|
| 15 |
+
"""
|
| 16 |
+
|
| 17 |
+
import gradio as gr
|
| 18 |
+
import numpy as np
|
| 19 |
+
import json
|
| 20 |
+
from typing import List, Dict, Any, Optional
|
| 21 |
+
import os
|
| 22 |
+
|
| 23 |
+
# ============================================================================
|
| 24 |
+
# SIMPLIFIED CTM SIMULATION (CPU-only for Hugging Face free tier)
|
| 25 |
+
# ============================================================================
|
| 26 |
+
|
| 27 |
+
class SimplifiedCTM:
|
| 28 |
+
"""
|
| 29 |
+
Simplified CTM for CPU-only environment.
|
| 30 |
+
Simulates the key mechanisms without full PyTorch model.
|
| 31 |
+
"""
|
| 32 |
+
|
| 33 |
+
def __init__(self, d_model: int = 256, memory_length: int = 16, n_ticks: int = 50):
|
| 34 |
+
self.d_model = d_model
|
| 35 |
+
self.memory_length = memory_length
|
| 36 |
+
self.n_ticks = n_ticks
|
| 37 |
+
|
| 38 |
+
# Initialize state
|
| 39 |
+
self.state_trace = np.zeros((d_model, memory_length))
|
| 40 |
+
self.activated_state = np.random.randn(d_model) * 0.1
|
| 41 |
+
|
| 42 |
+
# NLM weights (simplified: one weight matrix per "neuron group")
|
| 43 |
+
self.nlm_weights = np.random.randn(16, memory_length) * 0.1 # 16 groups for 16 dendrites
|
| 44 |
+
|
| 45 |
+
def compute_sync_matrix(self, z: np.ndarray) -> np.ndarray:
|
| 46 |
+
"""S^t = Z · Z^T (normalized)"""
|
| 47 |
+
z_norm = z / (np.linalg.norm(z) + 1e-8)
|
| 48 |
+
S = np.outer(z_norm, z_norm)
|
| 49 |
+
return S
|
| 50 |
+
|
| 51 |
+
def compute_certainty(self, predictions: np.ndarray) -> float:
|
| 52 |
+
"""Certainty = 1 - normalized entropy"""
|
| 53 |
+
probs = np.abs(predictions) / (np.sum(np.abs(predictions)) + 1e-8)
|
| 54 |
+
probs = np.clip(probs, 1e-10, 1.0)
|
| 55 |
+
entropy = -np.sum(probs * np.log(probs))
|
| 56 |
+
max_entropy = np.log(len(probs))
|
| 57 |
+
normalized_entropy = entropy / (max_entropy + 1e-8)
|
| 58 |
+
return float(1.0 - normalized_entropy)
|
| 59 |
+
|
| 60 |
+
def process_ticks(self, input_features: np.ndarray, n_ticks: Optional[int] = None) -> Dict:
|
| 61 |
+
"""Run T internal ticks and return sync matrix + certainty"""
|
| 62 |
+
n_ticks = n_ticks or self.n_ticks
|
| 63 |
+
|
| 64 |
+
# Ensure input is right size
|
| 65 |
+
if len(input_features) < self.d_model:
|
| 66 |
+
input_features = np.pad(input_features, (0, self.d_model - len(input_features)))
|
| 67 |
+
else:
|
| 68 |
+
input_features = input_features[:self.d_model]
|
| 69 |
+
|
| 70 |
+
certainties = []
|
| 71 |
+
sync_matrices = []
|
| 72 |
+
|
| 73 |
+
for t in range(n_ticks):
|
| 74 |
+
# Simulate synapse update
|
| 75 |
+
combined = np.concatenate([self.activated_state, input_features[:self.d_model//2]])
|
| 76 |
+
pre_activation = np.tanh(combined[:self.d_model] * 0.1 + np.random.randn(self.d_model) * 0.01)
|
| 77 |
+
|
| 78 |
+
# Update trace
|
| 79 |
+
self.state_trace = np.roll(self.state_trace, -1, axis=1)
|
| 80 |
+
self.state_trace[:, -1] = pre_activation
|
| 81 |
+
|
| 82 |
+
# Simulate NLM (simplified)
|
| 83 |
+
post_activation = np.zeros(self.d_model)
|
| 84 |
+
group_size = self.d_model // 16
|
| 85 |
+
for g in range(16):
|
| 86 |
+
start = g * group_size
|
| 87 |
+
end = start + group_size
|
| 88 |
+
group_trace = self.state_trace[start:end, :]
|
| 89 |
+
group_output = np.mean(group_trace @ self.nlm_weights[g])
|
| 90 |
+
post_activation[start:end] = np.tanh(group_output)
|
| 91 |
+
|
| 92 |
+
self.activated_state = post_activation
|
| 93 |
+
|
| 94 |
+
# Compute sync and certainty
|
| 95 |
+
sync = self.compute_sync_matrix(self.activated_state)
|
| 96 |
+
cert = self.compute_certainty(self.activated_state)
|
| 97 |
+
|
| 98 |
+
sync_matrices.append(sync)
|
| 99 |
+
certainties.append(cert)
|
| 100 |
+
|
| 101 |
+
# Find best ticks (min-loss proxy: max certainty)
|
| 102 |
+
best_tick = int(np.argmax(certainties))
|
| 103 |
+
|
| 104 |
+
return {
|
| 105 |
+
"final_sync_matrix": sync_matrices[-1].tolist(),
|
| 106 |
+
"best_sync_matrix": sync_matrices[best_tick].tolist(),
|
| 107 |
+
"certainties": certainties,
|
| 108 |
+
"final_certainty": float(certainties[-1]),
|
| 109 |
+
"max_certainty": float(max(certainties)),
|
| 110 |
+
"best_tick": best_tick,
|
| 111 |
+
"ticks_used": n_ticks
|
| 112 |
+
}
|
| 113 |
+
|
| 114 |
+
# Global CTM instance
|
| 115 |
+
ctm = SimplifiedCTM(d_model=256, memory_length=16, n_ticks=50)
|
| 116 |
+
|
| 117 |
+
# ============================================================================
|
| 118 |
+
# PHYSICS VALIDATION (from SNN Omega-21)
|
| 119 |
+
# ============================================================================
|
| 120 |
+
|
| 121 |
+
def validate_physics(trajectory: List[float], params: Dict) -> Dict:
|
| 122 |
+
"""Validate against 5 physics losses from SNN Omega-21"""
|
| 123 |
+
trajectory = np.array(trajectory)
|
| 124 |
+
|
| 125 |
+
# L_energy: Energy conservation
|
| 126 |
+
energy = np.sum(trajectory ** 2)
|
| 127 |
+
P_max = params.get("P_max", 1000.0)
|
| 128 |
+
L_energy = float(max(0, energy - P_max) ** 2)
|
| 129 |
+
|
| 130 |
+
# L_thermo: Thermodynamics (dew point check)
|
| 131 |
+
T_dew = params.get("T_dew", 15.0)
|
| 132 |
+
T_amb = params.get("T_amb", 25.0)
|
| 133 |
+
L_thermo = float(max(0, T_dew - T_amb) ** 2)
|
| 134 |
+
|
| 135 |
+
# L_causal: Causality (velocity limit)
|
| 136 |
+
velocity = np.diff(trajectory) if len(trajectory) > 1 else np.array([0])
|
| 137 |
+
v_max = params.get("v_max", 100.0)
|
| 138 |
+
L_causal = float(np.sum(np.maximum(0, np.abs(velocity) - v_max) ** 2))
|
| 139 |
+
|
| 140 |
+
# L_conserv: Flux conservation
|
| 141 |
+
flux_in = params.get("flux_in", 1.0)
|
| 142 |
+
flux_out = params.get("flux_out", 1.0)
|
| 143 |
+
L_conserv = float((flux_in - flux_out) ** 2)
|
| 144 |
+
|
| 145 |
+
# L_entropy: 2nd Law (entropy must increase)
|
| 146 |
+
entropy_change = params.get("entropy_change", 0.1)
|
| 147 |
+
L_entropy = float(max(0, -entropy_change) ** 2)
|
| 148 |
+
|
| 149 |
+
# Total physics loss
|
| 150 |
+
L_total = L_energy + L_thermo + L_causal + L_conserv + L_entropy
|
| 151 |
+
|
| 152 |
+
return {
|
| 153 |
+
"valid": L_total < 0.01,
|
| 154 |
+
"L_energy": L_energy,
|
| 155 |
+
"L_thermo": L_thermo,
|
| 156 |
+
"L_causal": L_causal,
|
| 157 |
+
"L_conserv": L_conserv,
|
| 158 |
+
"L_entropy": L_entropy,
|
| 159 |
+
"L_total": L_total
|
| 160 |
+
}
|
| 161 |
+
|
| 162 |
+
# ============================================================================
|
| 163 |
+
# ENDPOINT FUNCTIONS
|
| 164 |
+
# ============================================================================
|
| 165 |
+
|
| 166 |
+
def sense_snn(snn_json: str) -> str:
|
| 167 |
+
"""
|
| 168 |
+
/sense_snn - Process 72D SNN input
|
| 169 |
+
Input: JSON with dendrite values
|
| 170 |
+
Output: Coherent features + anomalies
|
| 171 |
+
"""
|
| 172 |
+
try:
|
| 173 |
+
data = json.loads(snn_json)
|
| 174 |
+
|
| 175 |
+
# Extract 72D vector (or create from dendrites)
|
| 176 |
+
if "vector_72d" in data:
|
| 177 |
+
input_vec = np.array(data["vector_72d"])
|
| 178 |
+
elif "dendrites" in data:
|
| 179 |
+
dendrite_values = list(data["dendrites"].values())
|
| 180 |
+
input_vec = np.array(dendrite_values)
|
| 181 |
+
else:
|
| 182 |
+
input_vec = np.random.randn(72) # Fallback
|
| 183 |
+
|
| 184 |
+
# Pad to 256D
|
| 185 |
+
input_256 = np.zeros(256)
|
| 186 |
+
input_256[:min(len(input_vec), 256)] = input_vec[:min(len(input_vec), 256)]
|
| 187 |
+
|
| 188 |
+
# Process through CTM
|
| 189 |
+
result = ctm.process_ticks(input_256, n_ticks=25)
|
| 190 |
+
|
| 191 |
+
# Detect anomalies (low certainty regions)
|
| 192 |
+
anomalies = []
|
| 193 |
+
if result["final_certainty"] < 0.5:
|
| 194 |
+
anomalies.append("Low overall certainty")
|
| 195 |
+
|
| 196 |
+
return json.dumps({
|
| 197 |
+
"status": "success",
|
| 198 |
+
"coherent_features": result["final_sync_matrix"][:16][:16], # 16x16 subset
|
| 199 |
+
"certainty": result["final_certainty"],
|
| 200 |
+
"anomalies": anomalies,
|
| 201 |
+
"ticks_used": result["ticks_used"]
|
| 202 |
+
}, indent=2)
|
| 203 |
+
except Exception as e:
|
| 204 |
+
return json.dumps({"status": "error", "message": str(e)})
|
| 205 |
+
|
| 206 |
+
def reason_hypergraph(context_json: str) -> str:
|
| 207 |
+
"""
|
| 208 |
+
/reason_hypergraph - Reason about hypergraph, propose edges
|
| 209 |
+
Input: Node features + existing edges
|
| 210 |
+
Output: Proposed new edges + certainty
|
| 211 |
+
"""
|
| 212 |
+
try:
|
| 213 |
+
data = json.loads(context_json)
|
| 214 |
+
|
| 215 |
+
node_features = np.array(data.get("node_features", [[0]*16]*8))
|
| 216 |
+
existing_edges = data.get("existing_edges", [])
|
| 217 |
+
n_ticks = data.get("ticks", 50)
|
| 218 |
+
|
| 219 |
+
# Flatten node features for CTM input
|
| 220 |
+
input_vec = node_features.flatten()
|
| 221 |
+
|
| 222 |
+
# Process through CTM with more ticks for reasoning
|
| 223 |
+
result = ctm.process_ticks(input_vec, n_ticks=n_ticks)
|
| 224 |
+
|
| 225 |
+
# Extract proposed edges from sync matrix (S_ij > 0.8)
|
| 226 |
+
sync = np.array(result["best_sync_matrix"])
|
| 227 |
+
proposed_edges = []
|
| 228 |
+
|
| 229 |
+
n_nodes = min(len(node_features), sync.shape[0])
|
| 230 |
+
for i in range(n_nodes):
|
| 231 |
+
for j in range(i+1, n_nodes):
|
| 232 |
+
sync_ij = sync[i, j]
|
| 233 |
+
if sync_ij > 0.8:
|
| 234 |
+
# Check if edge already exists
|
| 235 |
+
edge_exists = any(
|
| 236 |
+
(e[0] == i and e[1] == j) or (e[0] == j and e[1] == i)
|
| 237 |
+
for e in existing_edges
|
| 238 |
+
)
|
| 239 |
+
if not edge_exists:
|
| 240 |
+
proposed_edges.append([i, j, float(sync_ij)])
|
| 241 |
+
|
| 242 |
+
return json.dumps({
|
| 243 |
+
"status": "success",
|
| 244 |
+
"proposed_edges": proposed_edges,
|
| 245 |
+
"certainty": result["max_certainty"],
|
| 246 |
+
"best_tick": result["best_tick"],
|
| 247 |
+
"ticks_used": result["ticks_used"]
|
| 248 |
+
}, indent=2)
|
| 249 |
+
except Exception as e:
|
| 250 |
+
return json.dumps({"status": "error", "message": str(e)})
|
| 251 |
+
|
| 252 |
+
def validate_physics_endpoint(physics_json: str) -> str:
|
| 253 |
+
"""
|
| 254 |
+
/validate_physics - Validate trajectory against 5 physics losses
|
| 255 |
+
"""
|
| 256 |
+
try:
|
| 257 |
+
data = json.loads(physics_json)
|
| 258 |
+
trajectory = data.get("trajectory", [0.0])
|
| 259 |
+
params = data.get("physics_params", {})
|
| 260 |
+
|
| 261 |
+
result = validate_physics(trajectory, params)
|
| 262 |
+
result["status"] = "success"
|
| 263 |
+
|
| 264 |
+
return json.dumps(result, indent=2)
|
| 265 |
+
except Exception as e:
|
| 266 |
+
return json.dumps({"status": "error", "message": str(e)})
|
| 267 |
+
|
| 268 |
+
def dream_endpoint(dream_json: str) -> str:
|
| 269 |
+
"""
|
| 270 |
+
/dream - Offline consolidation with T=500+ ticks
|
| 271 |
+
Input: Hypergraph snapshot
|
| 272 |
+
Output: Discovered patterns + new edges
|
| 273 |
+
"""
|
| 274 |
+
try:
|
| 275 |
+
data = json.loads(dream_json)
|
| 276 |
+
|
| 277 |
+
snapshot = data.get("hypergraph_snapshot", {})
|
| 278 |
+
n_ticks = data.get("ticks", 500)
|
| 279 |
+
|
| 280 |
+
# Extract features from snapshot
|
| 281 |
+
nodes = snapshot.get("nodes", [])
|
| 282 |
+
if nodes:
|
| 283 |
+
input_vec = np.array([n.get("features", [0]*16) for n in nodes]).flatten()
|
| 284 |
+
else:
|
| 285 |
+
input_vec = np.random.randn(256) # Random dream if no nodes
|
| 286 |
+
|
| 287 |
+
# Dream: run CTM with many ticks and no external input after initial
|
| 288 |
+
result = ctm.process_ticks(input_vec, n_ticks=min(n_ticks, 100)) # Cap at 100 for CPU
|
| 289 |
+
|
| 290 |
+
# Analyze sync evolution to find patterns
|
| 291 |
+
sync = np.array(result["final_sync_matrix"])
|
| 292 |
+
|
| 293 |
+
# Find strong sync pairs (new edges)
|
| 294 |
+
new_edges = []
|
| 295 |
+
n = min(len(nodes), sync.shape[0]) if nodes else 16
|
| 296 |
+
for i in range(n):
|
| 297 |
+
for j in range(i+1, n):
|
| 298 |
+
if sync[i, j] > 0.85:
|
| 299 |
+
new_edges.append([i, j, float(sync[i, j])])
|
| 300 |
+
|
| 301 |
+
# Find weak sync pairs (edges to prune)
|
| 302 |
+
pruned_edges = []
|
| 303 |
+
for i in range(n):
|
| 304 |
+
for j in range(i+1, n):
|
| 305 |
+
if sync[i, j] < 0.1:
|
| 306 |
+
pruned_edges.append([i, j])
|
| 307 |
+
|
| 308 |
+
return json.dumps({
|
| 309 |
+
"status": "success",
|
| 310 |
+
"discovered_patterns": len(new_edges),
|
| 311 |
+
"new_edges": new_edges[:10], # Top 10
|
| 312 |
+
"pruned_edges": pruned_edges[:10], # Top 10
|
| 313 |
+
"consolidation_certainty": result["max_certainty"],
|
| 314 |
+
"ticks_used": result["ticks_used"]
|
| 315 |
+
}, indent=2)
|
| 316 |
+
except Exception as e:
|
| 317 |
+
return json.dumps({"status": "error", "message": str(e)})
|
| 318 |
+
|
| 319 |
+
def calibrate_stdp_endpoint(stdp_json: str) -> str:
|
| 320 |
+
"""
|
| 321 |
+
/calibrate_stdp - Suggest STDP weight adjustments from sync
|
| 322 |
+
"""
|
| 323 |
+
try:
|
| 324 |
+
data = json.loads(stdp_json)
|
| 325 |
+
|
| 326 |
+
current_weights = np.array(data.get("current_weights", [1.0]*16))
|
| 327 |
+
node_features = np.array(data.get("node_features", [[0]*16]*8))
|
| 328 |
+
|
| 329 |
+
# Process to get sync matrix
|
| 330 |
+
input_vec = node_features.flatten()
|
| 331 |
+
result = ctm.process_ticks(input_vec, n_ticks=25)
|
| 332 |
+
|
| 333 |
+
sync = np.array(result["final_sync_matrix"])
|
| 334 |
+
|
| 335 |
+
# Suggest weight adjustments based on sync patterns
|
| 336 |
+
# Uses diagonal of sync (self-similarity) to scale weights
|
| 337 |
+
suggested = np.zeros(16)
|
| 338 |
+
for i in range(16):
|
| 339 |
+
# Average sync of neuron i with others
|
| 340 |
+
avg_sync = np.mean(sync[i, :])
|
| 341 |
+
# Scale current weight by sync
|
| 342 |
+
suggested[i] = current_weights[i] * (0.5 + avg_sync)
|
| 343 |
+
|
| 344 |
+
return json.dumps({
|
| 345 |
+
"status": "success",
|
| 346 |
+
"suggested_weights": suggested.tolist(),
|
| 347 |
+
"weight_changes": (suggested - current_weights).tolist(),
|
| 348 |
+
"confidence": result["final_certainty"]
|
| 349 |
+
}, indent=2)
|
| 350 |
+
except Exception as e:
|
| 351 |
+
return json.dumps({"status": "error", "message": str(e)})
|
| 352 |
+
|
| 353 |
+
def health_check() -> str:
|
| 354 |
+
"""Health check for the CTM server"""
|
| 355 |
+
return json.dumps({
|
| 356 |
+
"status": "healthy",
|
| 357 |
+
"model": "CTM Nervous System v1.0",
|
| 358 |
+
"d_model": ctm.d_model,
|
| 359 |
+
"memory_length": ctm.memory_length,
|
| 360 |
+
"default_ticks": ctm.n_ticks,
|
| 361 |
+
"endpoints": [
|
| 362 |
+
"/sense_snn",
|
| 363 |
+
"/reason_hypergraph",
|
| 364 |
+
"/validate_physics",
|
| 365 |
+
"/dream",
|
| 366 |
+
"/calibrate_stdp"
|
| 367 |
+
]
|
| 368 |
+
}, indent=2)
|
| 369 |
+
|
| 370 |
+
# ============================================================================
|
| 371 |
+
# GRADIO INTERFACE
|
| 372 |
+
# ============================================================================
|
| 373 |
+
|
| 374 |
+
with gr.Blocks(title="CTM Nervous System") as demo:
|
| 375 |
+
gr.Markdown("""
|
| 376 |
+
# 🧬 CTM Nervous System
|
| 377 |
+
**Continuous Thought Machine for Hypergraph Maintenance**
|
| 378 |
+
|
| 379 |
+
Based on [arXiv:2505.05522](https://arxiv.org/abs/2505.05522) - Sakana AI
|
| 380 |
+
|
| 381 |
+
---
|
| 382 |
+
|
| 383 |
+
## Endpoints
|
| 384 |
+
- **/sense_snn**: Process 72D SNN input
|
| 385 |
+
- **/reason_hypergraph**: Reason about context, propose edges
|
| 386 |
+
- **/validate_physics**: Validate against 5 physics losses
|
| 387 |
+
- **/dream**: Offline consolidation (T=500+)
|
| 388 |
+
- **/calibrate_stdp**: Suggest STDP weight adjustments
|
| 389 |
+
""")
|
| 390 |
+
|
| 391 |
+
with gr.Tabs():
|
| 392 |
+
with gr.Tab("🔌 /sense_snn"):
|
| 393 |
+
gr.Markdown("Process 72D SNN input vector")
|
| 394 |
+
snn_input = gr.Textbox(
|
| 395 |
+
label="SNN JSON Input",
|
| 396 |
+
value='{"dendrites": {"d1": 0.1, "d2": 0.2, "d3": 0.3}}',
|
| 397 |
+
lines=5
|
| 398 |
+
)
|
| 399 |
+
snn_output = gr.Textbox(label="Output", lines=10)
|
| 400 |
+
snn_btn = gr.Button("Process", variant="primary")
|
| 401 |
+
snn_btn.click(sense_snn, inputs=snn_input, outputs=snn_output, api_name="sense_snn")
|
| 402 |
+
|
| 403 |
+
with gr.Tab("🧠 /reason_hypergraph"):
|
| 404 |
+
gr.Markdown("Reason about hypergraph context")
|
| 405 |
+
reason_input = gr.Textbox(
|
| 406 |
+
label="Context JSON",
|
| 407 |
+
value='{"node_features": [[0.1, 0.2], [0.3, 0.4]], "existing_edges": [], "ticks": 50}',
|
| 408 |
+
lines=5
|
| 409 |
+
)
|
| 410 |
+
reason_output = gr.Textbox(label="Output", lines=10)
|
| 411 |
+
reason_btn = gr.Button("Reason", variant="primary")
|
| 412 |
+
reason_btn.click(reason_hypergraph, inputs=reason_input, outputs=reason_output, api_name="reason_hypergraph")
|
| 413 |
+
|
| 414 |
+
with gr.Tab("⚡ /validate_physics"):
|
| 415 |
+
gr.Markdown("Validate against 5 physics losses")
|
| 416 |
+
physics_input = gr.Textbox(
|
| 417 |
+
label="Physics JSON",
|
| 418 |
+
value='{"trajectory": [0.1, 0.2, 0.3], "physics_params": {"P_max": 1000}}',
|
| 419 |
+
lines=5
|
| 420 |
+
)
|
| 421 |
+
physics_output = gr.Textbox(label="Output", lines=10)
|
| 422 |
+
physics_btn = gr.Button("Validate", variant="primary")
|
| 423 |
+
physics_btn.click(validate_physics_endpoint, inputs=physics_input, outputs=physics_output, api_name="validate_physics")
|
| 424 |
+
|
| 425 |
+
with gr.Tab("💤 /dream"):
|
| 426 |
+
gr.Markdown("Offline consolidation")
|
| 427 |
+
dream_input = gr.Textbox(
|
| 428 |
+
label="Dream JSON",
|
| 429 |
+
value='{"hypergraph_snapshot": {"nodes": []}, "ticks": 100}',
|
| 430 |
+
lines=5
|
| 431 |
+
)
|
| 432 |
+
dream_output = gr.Textbox(label="Output", lines=10)
|
| 433 |
+
dream_btn = gr.Button("Dream", variant="primary")
|
| 434 |
+
dream_btn.click(dream_endpoint, inputs=dream_input, outputs=dream_output, api_name="dream")
|
| 435 |
+
|
| 436 |
+
with gr.Tab("🔧 /calibrate_stdp"):
|
| 437 |
+
gr.Markdown("Calibrate STDP weights")
|
| 438 |
+
stdp_input = gr.Textbox(
|
| 439 |
+
label="STDP JSON",
|
| 440 |
+
value='{"current_weights": [1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1], "node_features": [[0.1, 0.2]]}',
|
| 441 |
+
lines=5
|
| 442 |
+
)
|
| 443 |
+
stdp_output = gr.Textbox(label="Output", lines=10)
|
| 444 |
+
stdp_btn = gr.Button("Calibrate", variant="primary")
|
| 445 |
+
stdp_btn.click(calibrate_stdp_endpoint, inputs=stdp_input, outputs=stdp_output, api_name="calibrate_stdp")
|
| 446 |
+
|
| 447 |
+
with gr.Tab("❤️ Health"):
|
| 448 |
+
health_output = gr.Textbox(label="Health Status", lines=10)
|
| 449 |
+
health_btn = gr.Button("Check Health", variant="secondary")
|
| 450 |
+
health_btn.click(health_check, inputs=None, outputs=health_output, api_name="health_check")
|
| 451 |
+
|
| 452 |
+
gr.Markdown("""
|
| 453 |
+
---
|
| 454 |
+
**Architecture**: CTM as Nervous System → Hypergraph as Thought
|
| 455 |
+
|
| 456 |
+
**Training**: Min-Loss + Max-Certainty + Physics Regularization
|
| 457 |
+
""")
|
| 458 |
+
|
| 459 |
+
if __name__ == "__main__":
|
| 460 |
+
demo.launch(
|
| 461 |
+
server_name="0.0.0.0",
|
| 462 |
+
server_port=7860,
|
| 463 |
+
show_error=True
|
| 464 |
+
)
|
requirements.txt
CHANGED
|
@@ -1,2 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
gradio>=5.0.0
|
| 2 |
-
numpy
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Requirements for CTM Nervous System v2.0 (Full PyTorch)
|
| 2 |
+
# =========================================================
|
| 3 |
+
# Upgraded to use full ContinuousThoughtMachine implementation
|
| 4 |
+
|
| 5 |
+
# Core Framework
|
| 6 |
gradio>=5.0.0
|
| 7 |
+
numpy
|
| 8 |
+
|
| 9 |
+
# PyTorch (CPU-only for HuggingFace free tier)
|
| 10 |
+
# Use torch-cpu to minimize memory footprint
|
| 11 |
+
--extra-index-url https://download.pytorch.org/whl/cpu
|
| 12 |
+
torch>=2.0.0
|
| 13 |
+
|
| 14 |
+
# For checkpoint saving
|
| 15 |
+
safetensors
|
| 16 |
+
|
| 17 |
+
# For potential model weights download
|
| 18 |
+
huggingface_hub
|
| 19 |
+
|
| 20 |
+
# HTTP client for Brain server integration
|
| 21 |
+
requests
|
requirements_v1.txt
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
gradio>=5.0.0
|
| 2 |
+
numpy
|
utils/dendrite_losses.py
ADDED
|
@@ -0,0 +1,228 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Dendrite Regulation Loss Function for CTM
|
| 3 |
+
==========================================
|
| 4 |
+
Adapted from Sakana AI CTM losses for ART-17 dendrite regulation task.
|
| 5 |
+
|
| 6 |
+
This loss combines:
|
| 7 |
+
1. MSE between predicted and target dendrite adjustments
|
| 8 |
+
2. Physics loss penalty (from SNN Omega-21)
|
| 9 |
+
3. Certainty-weighted tick selection (CTM innovation)
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
import torch
|
| 13 |
+
import torch.nn as nn
|
| 14 |
+
import torch.nn.functional as F
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def dendrite_regulation_loss(
|
| 18 |
+
predictions: torch.Tensor, # (B, 16, T) - predictions per tick
|
| 19 |
+
certainties: torch.Tensor, # (B, 2, T) - [entropy, 1-entropy]
|
| 20 |
+
targets: torch.Tensor, # (B, 16) - target deltas
|
| 21 |
+
physics_loss: torch.Tensor, # (B,) - from SNN physics validation
|
| 22 |
+
use_most_certain: bool = True,
|
| 23 |
+
physics_weight: float = 0.1
|
| 24 |
+
) -> tuple:
|
| 25 |
+
"""
|
| 26 |
+
CTM Loss adapted for ART-17 dendrite regulation.
|
| 27 |
+
|
| 28 |
+
Combines:
|
| 29 |
+
1. MSE between prediction and target at min-loss tick
|
| 30 |
+
2. MSE between prediction and target at max-certainty tick
|
| 31 |
+
3. Physics loss penalty
|
| 32 |
+
|
| 33 |
+
This dual-tick selection is a key CTM innovation that allows
|
| 34 |
+
the model to "think" until confident.
|
| 35 |
+
|
| 36 |
+
Args:
|
| 37 |
+
predictions: Model predictions of shape (B, 16, T)
|
| 38 |
+
- B: batch size
|
| 39 |
+
- 16: dendrite dimensions
|
| 40 |
+
- T: number of internal ticks
|
| 41 |
+
certainties: Certainty values of shape (B, 2, T)
|
| 42 |
+
- Second dim: [normalized_entropy, 1-normalized_entropy]
|
| 43 |
+
targets: Ground truth deltas of shape (B, 16)
|
| 44 |
+
physics_loss: Physics validation loss of shape (B,)
|
| 45 |
+
use_most_certain: If True, also consider max-certainty tick
|
| 46 |
+
physics_weight: Weight for physics loss penalty
|
| 47 |
+
|
| 48 |
+
Returns:
|
| 49 |
+
total_loss: Combined loss value
|
| 50 |
+
loss_index: Indices of selected ticks (B,)
|
| 51 |
+
"""
|
| 52 |
+
B, D, T = predictions.shape # B, 16, 50
|
| 53 |
+
|
| 54 |
+
# Expand targets to all ticks: (B, 16) -> (B, 16, T)
|
| 55 |
+
targets_exp = targets.unsqueeze(-1).expand(-1, -1, T)
|
| 56 |
+
|
| 57 |
+
# MSE loss per tick: (B, 16, T) -> (B, T) after mean over dendrites
|
| 58 |
+
mse_per_tick = F.mse_loss(predictions, targets_exp, reduction='none')
|
| 59 |
+
mse_per_tick = mse_per_tick.mean(dim=1) # Average over 16 dendrites
|
| 60 |
+
|
| 61 |
+
# Tick with minimum loss
|
| 62 |
+
loss_index_1 = mse_per_tick.argmin(dim=1) # (B,)
|
| 63 |
+
|
| 64 |
+
# Tick with maximum certainty (1 - normalized_entropy)
|
| 65 |
+
loss_index_2 = certainties[:, 1, :].argmax(dim=1) # (B,)
|
| 66 |
+
|
| 67 |
+
if not use_most_certain:
|
| 68 |
+
# Fall back to final tick
|
| 69 |
+
loss_index_2 = torch.full_like(loss_index_2, T - 1)
|
| 70 |
+
|
| 71 |
+
# Select losses at chosen ticks
|
| 72 |
+
batch_idx = torch.arange(B, device=predictions.device)
|
| 73 |
+
loss_min = mse_per_tick[batch_idx, loss_index_1].mean()
|
| 74 |
+
loss_certain = mse_per_tick[batch_idx, loss_index_2].mean()
|
| 75 |
+
|
| 76 |
+
# Combined MSE loss (average of min-loss and max-certainty)
|
| 77 |
+
mse_loss = (loss_min + loss_certain) / 2
|
| 78 |
+
|
| 79 |
+
# Physics penalty
|
| 80 |
+
physics_penalty = physics_loss.mean() * physics_weight
|
| 81 |
+
|
| 82 |
+
# Total loss
|
| 83 |
+
total_loss = mse_loss + physics_penalty
|
| 84 |
+
|
| 85 |
+
return total_loss, loss_index_2
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
def hypergraph_edge_loss(
|
| 89 |
+
sync_matrix: torch.Tensor, # (B, N, N) - synchronization matrix
|
| 90 |
+
edge_labels: torch.Tensor, # (B, N, N) - binary edge labels
|
| 91 |
+
certainties: torch.Tensor, # (B, 2, T)
|
| 92 |
+
use_most_certain: bool = True
|
| 93 |
+
) -> tuple:
|
| 94 |
+
"""
|
| 95 |
+
Loss for predicting hypergraph edges from synchronization matrix.
|
| 96 |
+
|
| 97 |
+
The CTM's sync matrix S = Z·Z^T captures neural synchronization,
|
| 98 |
+
which we use to predict edges in the hypergraph.
|
| 99 |
+
|
| 100 |
+
Args:
|
| 101 |
+
sync_matrix: Predicted sync values (B, N, N)
|
| 102 |
+
edge_labels: Ground truth edges (B, N, N), binary
|
| 103 |
+
certainties: Certainty values (B, 2, T)
|
| 104 |
+
use_most_certain: Whether to weight by certainty
|
| 105 |
+
|
| 106 |
+
Returns:
|
| 107 |
+
loss: BCE loss for edge prediction
|
| 108 |
+
certainty: Average certainty at best tick
|
| 109 |
+
"""
|
| 110 |
+
# Binary cross entropy for edge prediction
|
| 111 |
+
bce_loss = F.binary_cross_entropy_with_logits(
|
| 112 |
+
sync_matrix,
|
| 113 |
+
edge_labels.float(),
|
| 114 |
+
reduction='mean'
|
| 115 |
+
)
|
| 116 |
+
|
| 117 |
+
# Weight by certainty if enabled
|
| 118 |
+
if use_most_certain and certainties is not None:
|
| 119 |
+
max_certainty = certainties[:, 1, :].max(dim=-1)[0].mean()
|
| 120 |
+
# Higher certainty -> lower loss weight (model is confident)
|
| 121 |
+
certainty_weight = 2.0 - max_certainty
|
| 122 |
+
loss = bce_loss * certainty_weight
|
| 123 |
+
else:
|
| 124 |
+
loss = bce_loss
|
| 125 |
+
max_certainty = torch.tensor(0.5)
|
| 126 |
+
|
| 127 |
+
return loss, max_certainty
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
def combined_art17_loss(
|
| 131 |
+
predictions: torch.Tensor,
|
| 132 |
+
certainties: torch.Tensor,
|
| 133 |
+
sync_matrix: torch.Tensor,
|
| 134 |
+
dendrite_targets: torch.Tensor,
|
| 135 |
+
edge_labels: torch.Tensor,
|
| 136 |
+
physics_loss: torch.Tensor,
|
| 137 |
+
dendrite_weight: float = 1.0,
|
| 138 |
+
edge_weight: float = 0.5,
|
| 139 |
+
physics_weight: float = 0.1
|
| 140 |
+
) -> dict:
|
| 141 |
+
"""
|
| 142 |
+
Combined loss for ART-17 CTM training.
|
| 143 |
+
|
| 144 |
+
Combines:
|
| 145 |
+
1. Dendrite regulation loss (primary task)
|
| 146 |
+
2. Hypergraph edge prediction loss (auxiliary)
|
| 147 |
+
3. Physics constraint penalty (regularization)
|
| 148 |
+
|
| 149 |
+
Args:
|
| 150 |
+
predictions: Dendrite predictions (B, 16, T)
|
| 151 |
+
certainties: Certainty values (B, 2, T)
|
| 152 |
+
sync_matrix: Sync matrix for edge prediction (B, N, N)
|
| 153 |
+
dendrite_targets: Target deltas (B, 16)
|
| 154 |
+
edge_labels: Edge ground truth (B, N, N)
|
| 155 |
+
physics_loss: Physics validation (B,)
|
| 156 |
+
dendrite_weight: Weight for dendrite loss
|
| 157 |
+
edge_weight: Weight for edge loss
|
| 158 |
+
physics_weight: Weight for physics penalty
|
| 159 |
+
|
| 160 |
+
Returns:
|
| 161 |
+
Dict with total_loss and component losses
|
| 162 |
+
"""
|
| 163 |
+
# Dendrite regulation loss
|
| 164 |
+
dend_loss, best_tick = dendrite_regulation_loss(
|
| 165 |
+
predictions, certainties, dendrite_targets,
|
| 166 |
+
physics_loss, physics_weight=0.0 # Physics handled separately
|
| 167 |
+
)
|
| 168 |
+
|
| 169 |
+
# Edge prediction loss
|
| 170 |
+
edge_loss, certainty = hypergraph_edge_loss(
|
| 171 |
+
sync_matrix, edge_labels, certainties
|
| 172 |
+
)
|
| 173 |
+
|
| 174 |
+
# Physics penalty
|
| 175 |
+
phys_penalty = physics_loss.mean() * physics_weight
|
| 176 |
+
|
| 177 |
+
# Total weighted loss
|
| 178 |
+
total_loss = (
|
| 179 |
+
dendrite_weight * dend_loss +
|
| 180 |
+
edge_weight * edge_loss +
|
| 181 |
+
phys_penalty
|
| 182 |
+
)
|
| 183 |
+
|
| 184 |
+
return {
|
| 185 |
+
"total_loss": total_loss,
|
| 186 |
+
"dendrite_loss": dend_loss,
|
| 187 |
+
"edge_loss": edge_loss,
|
| 188 |
+
"physics_penalty": phys_penalty,
|
| 189 |
+
"best_tick": best_tick,
|
| 190 |
+
"certainty": certainty
|
| 191 |
+
}
|
| 192 |
+
|
| 193 |
+
|
| 194 |
+
# Keep original loss functions for compatibility
|
| 195 |
+
def compute_ctc_loss(predictions, targets, blank_label=0):
|
| 196 |
+
"""CTC loss for sequence tasks (original from Sakana)."""
|
| 197 |
+
batch_size, num_classes, prediction_length = predictions.shape
|
| 198 |
+
_, target_length = targets.shape
|
| 199 |
+
|
| 200 |
+
log_probs = F.log_softmax(predictions, dim=1)
|
| 201 |
+
log_probs = log_probs.permute(2, 0, 1)
|
| 202 |
+
|
| 203 |
+
input_lengths = torch.full(size=(batch_size,), fill_value=prediction_length, dtype=torch.long)
|
| 204 |
+
target_lengths = torch.tensor([t.shape[0] for t in targets], dtype=torch.long)
|
| 205 |
+
|
| 206 |
+
ctc_loss = torch.nn.CTCLoss(blank=blank_label, reduction='mean')
|
| 207 |
+
concatenated_targets = torch.cat(list(targets))
|
| 208 |
+
loss = ctc_loss(log_probs, concatenated_targets, input_lengths, target_lengths)
|
| 209 |
+
|
| 210 |
+
return loss
|
| 211 |
+
|
| 212 |
+
|
| 213 |
+
def image_classification_loss(predictions, certainties, targets, use_most_certain=True):
|
| 214 |
+
"""Image classification loss (original from Sakana)."""
|
| 215 |
+
targets_expanded = torch.repeat_interleave(targets.unsqueeze(-1), predictions.size(-1), -1)
|
| 216 |
+
losses = nn.CrossEntropyLoss(reduction='none')(predictions, targets_expanded)
|
| 217 |
+
|
| 218 |
+
loss_index_1 = losses.argmin(dim=1)
|
| 219 |
+
loss_index_2 = certainties[:,1].argmax(-1)
|
| 220 |
+
if not use_most_certain:
|
| 221 |
+
loss_index_2[:] = -1
|
| 222 |
+
|
| 223 |
+
batch_indexer = torch.arange(predictions.size(0), device=predictions.device)
|
| 224 |
+
loss_minimum_ce = losses[batch_indexer, loss_index_1].mean()
|
| 225 |
+
loss_selected = losses[batch_indexer, loss_index_2].mean()
|
| 226 |
+
|
| 227 |
+
loss = (loss_minimum_ce + loss_selected)/2
|
| 228 |
+
return loss, loss_index_2
|