gansta / app.py
Elliotasdasdasfasas's picture
Deploy CTM Codebase bypass FUSE 503
ed89628
"""
CTM Nervous System Server v2.0 - Full PyTorch Implementation
=============================================================
Continuous Thought Machine for ART-17 Hypergraph Coherence Generation
PURPOSE (from skills):
1. REGULACIÓN: Calibrar pesos STDP de las 16 dendritas
2. COHERENCIA: Generar hipergrafos deterministas
3. RAZONAMIENTO: Motor de inferencia activa (internal ticks)
4. SINCRONIZACIÓN: Representación via Neural Synchronization
TRAINING STRATEGY:
- Progressive online learning with use
- Integrates with Brain server (Qwen + VL-JEPA) for semantic grounding
- Automatic checkpoint saving
Based on: arXiv:2505.05522 (Continuous Thought Machines - Sakana AI)
Adapted for: ART-17 Dendrite Regulation & Hypergraph Generation
"""
import gradio as gr
import numpy as np
import json
import os
from typing import List, Dict, Any, Optional
from datetime import datetime
from utils.bunker_client import BunkerClient
# ============================================================================
# PYTORCH IMPORTS WITH FALLBACK
# ============================================================================
try:
import torch
import torch.nn as nn
import torch.nn.functional as F
TORCH_AVAILABLE = True
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(f"🔧 PyTorch available. Device: {DEVICE}")
except ImportError:
TORCH_AVAILABLE = False
DEVICE = "cpu"
print("⚠️ PyTorch not available. Using simplified NumPy fallback.")
# ============================================================================
# FULL CTM IMPORT (with fallback to simplified)
# ============================================================================
if TORCH_AVAILABLE:
try:
from models.ctm import ContinuousThoughtMachine
from models.modules import SynapseUNET, SuperLinear
from utils.losses import image_classification_loss
CTM_FULL = True
print("✅ Full CTM model loaded from models/ctm.py")
except ImportError as e:
CTM_FULL = False
print(f"⚠️ Could not import full CTM: {e}. Using simplified.")
else:
CTM_FULL = False
# ============================================================================
# CONFIGURATION FOR ART-17 INTEGRATION (v3.0)
# ============================================================================
CONFIG = {
# CTM Architecture (matching ART-17)
"iterations": 50, # T internal ticks (max)
"d_model": 256, # Latent dimension
"d_input": 72, # Input from SNN (72D)
"memory_length": 16, # History length (16 dendrites)
"n_synch_out": 32, # Output sync neurons
"n_synch_action": 16, # Action sync neurons
"out_dims": 16, # Output: 16 dendrite adjustments
# v3.0 Improvements
"adaptive_halting": True, # Enable early stopping
"certainty_threshold": 0.85, # Halt if certainty > threshold
"sync_decay_alpha": 0.9, # S_new = α*S_old + (1-α)*S_current
"use_backbone": True, # Use Backbone72D transformation
# Training
"learning_rate": 1e-4,
"weight_decay": 1e-5,
"checkpoint_dir": "checkpoints",
"auto_save_every": 100, # Save every N forward passes
# Integration
"brain_server_url": "https://elliotasdasdasfasas-brain.hf.space",
# Physics validation
"physics_thresholds": {
"P_max": 1000.0,
"v_max": 100.0,
"T_dew": 15.0,
"T_amb": 25.0
}
}
# ============================================================================
# BACKBONE 72D (v3.0 - Transform input before CTM)
# ============================================================================
class Backbone72D(nn.Module if TORCH_AVAILABLE else object):
"""
Transform 72D SNN input to d_model dimensions.
Paper insight: Raw input needs proper embedding for CTM to work well.
"""
def __init__(self, d_input=72, d_model=256):
if not TORCH_AVAILABLE:
return
super().__init__()
self.net = nn.Sequential(
nn.Linear(d_input, 128),
nn.LayerNorm(128),
nn.GELU(),
nn.Linear(128, d_model),
nn.LayerNorm(d_model)
)
def forward(self, x):
# x: [B, 72]
return self.net(x) # [B, 256]
# ============================================================================
# FULL CTM WRAPPER FOR ART-17
# ============================================================================
class CTM_ART17:
"""
Full Continuous Thought Machine adapted for ART-17.
Key mechanisms from paper:
1. NLMs (Neuron-Level Models) - Each neuron processes its own history
2. Neural Synchronization - Representation is S = Z·Z^T
3. Adaptive Compute - Can halt early when confident
Purpose in ART-17:
- Regulate 16 dendrite STDP weights
- Generate coherent hypergraph edges
- Serve as "nervous system" for the whole system
"""
def __init__(self, config: dict):
self.config = config
self.forward_count = 0
self.training_samples = []
self.bunker = BunkerClient(buffer_dir=config.get("buffer_dir", "_ctm_buffer"))
if CTM_FULL and TORCH_AVAILABLE:
self._init_full_ctm()
else:
self._init_simplified_ctm()
def _init_full_ctm(self):
"""Initialize full PyTorch CTM model."""
self.model = ContinuousThoughtMachine(
iterations=self.config["iterations"],
d_model=self.config["d_model"],
d_input=self.config["d_input"],
heads=4,
n_synch_out=self.config["n_synch_out"],
n_synch_action=self.config["n_synch_action"],
synapse_depth=2,
memory_length=self.config["memory_length"],
deep_nlms=True,
memory_hidden_dims=32,
do_layernorm_nlm=False,
backbone_type='none',
positional_embedding_type='none',
out_dims=self.config["out_dims"],
prediction_reshaper=[self.config["out_dims"]],
dropout=0.1,
neuron_select_type='random-pairing'
).to(DEVICE)
# Dummy forward to initialize lazy modules
with torch.no_grad():
dummy = torch.randn(1, self.config["d_input"], device=DEVICE)
dummy = dummy.unsqueeze(-1).unsqueeze(-1) # [1, 72, 1, 1]
try:
_ = self.model(dummy)
except Exception as e:
print(f"⚠️ Lazy init failed: {e}")
self.model.eval()
self.optimizer = torch.optim.AdamW(
self.model.parameters(),
lr=self.config["learning_rate"],
weight_decay=self.config["weight_decay"]
)
self.is_full = True
param_count = sum(p.numel() for p in self.model.parameters())
print(f"✅ Full CTM initialized: {param_count:,} parameters")
# Try to load existing checkpoint
self._load_checkpoint()
def _init_simplified_ctm(self):
"""Initialize simplified NumPy CTM (fallback)."""
self.d_model = self.config["d_model"]
self.memory_length = self.config["memory_length"]
self.n_ticks = self.config["iterations"]
# State traces
self.state_trace = np.zeros((self.d_model, self.memory_length))
self.activated_state = np.random.randn(self.d_model) * 0.1
# NLM weights (simplified: 16 groups for 16 dendrites)
self.nlm_weights = np.random.randn(16, self.memory_length) * 0.1
self.is_full = False
print("✅ Simplified CTM initialized (NumPy fallback)")
def forward(self, input_72d: np.ndarray, n_ticks: Optional[int] = None) -> Dict:
"""
Process input through CTM.
Args:
input_72d: 72D input from SNN
n_ticks: Override number of internal ticks
Returns:
Dict with predictions, certainty, sync matrix
"""
n_ticks = n_ticks or self.config["iterations"]
self.forward_count += 1
if self.is_full:
return self._forward_full(input_72d, n_ticks)
else:
return self._forward_simplified(input_72d, n_ticks)
def _forward_full(self, input_72d: np.ndarray, n_ticks: int) -> Dict:
"""Forward pass with full PyTorch CTM."""
# Prepare tensor
x = torch.tensor(input_72d, dtype=torch.float32, device=DEVICE)
if len(x.shape) == 1:
x = x.unsqueeze(0) # Add batch dim
x = x.unsqueeze(-1).unsqueeze(-1) # [B, 72, 1, 1]
with torch.no_grad():
predictions, certainties, sync_out = self.model(x)
# Extract results
final_pred = predictions[:, :, -1].cpu().numpy()[0] # Last tick [16]
final_cert = certainties[:, 1, -1].cpu().numpy()[0] # 1-entropy
# Find tick with highest certainty
best_tick_idx = certainties[:, 1, :].argmax(dim=-1)[0].item()
best_pred = predictions[:, :, best_tick_idx].cpu().numpy()[0]
# Sync matrix for hypergraph edge proposals
sync_matrix = sync_out.cpu().numpy()[0] if sync_out is not None else None
return {
"predictions": final_pred.tolist(),
"best_predictions": best_pred.tolist(),
"certainty": float(final_cert),
"best_tick": int(best_tick_idx),
"ticks_used": n_ticks,
"sync_matrix": sync_matrix.tolist() if sync_matrix is not None else None,
"model": "ContinuousThoughtMachine (Full PyTorch)"
}
def _forward_simplified(self, input_72d: np.ndarray, n_ticks: int) -> Dict:
"""
Forward pass with simplified NumPy CTM (v3.0).
v3.0 Features:
1. Backbone transformation (72D -> 256D)
2. Sync Decay (S = α*S_prev + (1-α)*S_current)
3. Adaptive Halting (stop if certainty > threshold)
"""
# v3.0: Backbone transformation (simple linear projection)
if self.config.get("use_backbone", True):
# Learned transformation: 72D -> 256D
input_256 = np.zeros(self.d_model)
# Simple linear projection + normalization (simulates Backbone72D)
projected = np.tanh(input_72d[:72] * np.random.randn(72) * 0.1) if len(input_72d) >= 72 else input_72d
input_256[:min(len(projected), self.d_model)] = projected[:min(len(projected), self.d_model)]
else:
input_256 = np.zeros(self.d_model)
input_256[:min(len(input_72d), self.d_model)] = input_72d[:self.d_model]
# v3.0: Sync Decay initialization
alpha = self.config.get("sync_decay_alpha", 0.9)
sync_matrix_prev = np.zeros((self.d_model, self.d_model))
# v3.0: Adaptive halting config
adaptive_halting = self.config.get("adaptive_halting", True)
certainty_threshold = self.config.get("certainty_threshold", 0.85)
certainties = []
all_predictions = []
ticks_actually_used = 0
for t in range(n_ticks):
ticks_actually_used = t + 1
# Synapse update (simplified global mixing)
combined = np.concatenate([self.activated_state, input_256[:self.d_model//2]])
pre_activation = np.tanh(combined[:self.d_model] * 0.1 + np.random.randn(self.d_model) * 0.01)
# Update trace (memory)
self.state_trace = np.roll(self.state_trace, -1, axis=1)
self.state_trace[:, -1] = pre_activation
# NLM processing (simplified: 16 groups for 16 dendrites)
post_activation = np.zeros(self.d_model)
group_size = self.d_model // 16
for g in range(16):
start = g * group_size
end = start + group_size
group_trace = self.state_trace[start:end, :]
group_output = np.mean(group_trace @ self.nlm_weights[g])
post_activation[start:end] = np.tanh(group_output)
self.activated_state = post_activation
# v3.0: Sync Decay - S = α*S_prev + (1-α)*Z·Z^T
z_norm = self.activated_state / (np.linalg.norm(self.activated_state) + 1e-8)
sync_current = np.outer(z_norm, z_norm)
sync_matrix = alpha * sync_matrix_prev + (1 - alpha) * sync_current
sync_matrix_prev = sync_matrix
# Store predictions at this tick
all_predictions.append(self.activated_state[:16].copy())
# Compute certainty
probs = np.abs(self.activated_state) / (np.sum(np.abs(self.activated_state)) + 1e-8)
probs = np.clip(probs, 1e-10, 1.0)
entropy = -np.sum(probs * np.log(probs))
max_entropy = np.log(len(probs))
certainty = float(1.0 - entropy / (max_entropy + 1e-8))
certainties.append(certainty)
# v3.0: Adaptive Halting - stop early if confident enough
if adaptive_halting and certainty > certainty_threshold:
break
# Best tick selection
best_tick_idx = int(np.argmax(certainties))
best_predictions = all_predictions[best_tick_idx].tolist()
return {
"predictions": self.activated_state[:16].tolist(),
"best_predictions": best_predictions,
"certainty": certainties[-1],
"best_tick": best_tick_idx,
"ticks_used": ticks_actually_used, # v3.0: Actual ticks, may be < n_ticks
"max_ticks": n_ticks,
"halted_early": ticks_actually_used < n_ticks, # v3.0: Flag
"sync_matrix": sync_matrix[:16, :16].tolist(),
"model": "SimplifiedCTM v3.0 (NumPy + AdaptiveHalt + SyncDecay)"
}
def train_step(self, input_72d: np.ndarray, target_16d: np.ndarray,
physics_loss: float = 0.0) -> Dict:
"""
Online training step.
Args:
input_72d: Input from SNN
target_16d: Target dendrite adjustments (ground truth)
physics_loss: Current physics loss for weighting
Returns:
Dict with loss and gradient info
"""
if not self.is_full or not TORCH_AVAILABLE:
return {"status": "skip", "reason": "Training requires full PyTorch CTM"}
self.model.train()
# Prepare tensors
x = torch.tensor(input_72d, dtype=torch.float32, device=DEVICE)
x = x.unsqueeze(0).unsqueeze(-1).unsqueeze(-1) # [1, 72, 1, 1]
y = torch.tensor(target_16d, dtype=torch.float32, device=DEVICE).unsqueeze(0)
# Forward
predictions, certainties, _ = self.model(x)
# Loss: dendrite_regulation_loss
# predictions: [B, 16, T], y: [B, 16]
y_exp = y.unsqueeze(-1).expand(-1, -1, predictions.size(-1)) # [B, 16, T]
mse_per_tick = F.mse_loss(predictions, y_exp, reduction='none').mean(dim=1) # [B, T]
# Select best tick (min loss) and most certain tick
loss_min_idx = mse_per_tick.argmin(dim=1) # [B]
loss_cert_idx = certainties[:, 1, :].argmax(dim=1) # [B]
batch_idx = torch.arange(predictions.size(0), device=DEVICE)
loss_min = mse_per_tick[batch_idx, loss_min_idx].mean()
loss_cert = mse_per_tick[batch_idx, loss_cert_idx].mean()
# Combined loss with physics penalty
mse_loss = (loss_min + loss_cert) / 2
physics_penalty = physics_loss * 0.1
total_loss = mse_loss + physics_penalty
# Backward
self.optimizer.zero_grad()
total_loss.backward()
torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
self.optimizer.step()
self.model.eval()
# Auto-save checkpoint
if self.forward_count % self.config["auto_save_every"] == 0:
self._save_checkpoint()
return {
"status": "trained",
"loss": float(total_loss.item()),
"mse_loss": float(mse_loss.item()),
"physics_penalty": float(physics_penalty),
"best_tick": int(loss_cert_idx[0].item())
}
def _save_checkpoint(self):
"""Save model checkpoint."""
if not self.is_full:
return
os.makedirs(self.config["checkpoint_dir"], exist_ok=True)
path = os.path.join(self.config["checkpoint_dir"], "ctm_art17_latest.pt")
torch.save({
"model_state_dict": self.model.state_dict(),
"optimizer_state_dict": self.optimizer.state_dict(),
"forward_count": self.forward_count,
"timestamp": datetime.now().isoformat()
}, path)
print(f"💾 Checkpoint saved: {path}")
# Upload to Bunker (Async/Fail-Safe)
self.bunker.save_file(path, remote_folder="ctm_backups")
def _load_checkpoint(self):
"""Load model checkpoint if exists."""
path = os.path.join(self.config["checkpoint_dir"], "ctm_art17_latest.pt")
if os.path.exists(path):
try:
checkpoint = torch.load(path, map_location=DEVICE)
self.model.load_state_dict(checkpoint["model_state_dict"])
self.optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
self.forward_count = checkpoint.get("forward_count", 0)
print(f"✅ Checkpoint loaded: {path}")
except Exception as e:
print(f"⚠️ Could not load checkpoint: {e}")
# ============================================================================
# GLOBAL CTM INSTANCE
# ============================================================================
ctm = CTM_ART17(CONFIG)
# ============================================================================
# PHYSICS VALIDATION (from SNN Omega-21)
# ============================================================================
def validate_physics(trajectory: List[float], params: Dict) -> Dict:
"""Validate against 5 physics losses from SNN Omega-21."""
trajectory = np.array(trajectory)
# L_energy: Energy conservation
energy = np.sum(trajectory ** 2)
P_max = params.get("P_max", CONFIG["physics_thresholds"]["P_max"])
L_energy = float(max(0, energy - P_max) ** 2)
# L_thermo: Thermodynamics (dew point check)
T_dew = params.get("T_dew", CONFIG["physics_thresholds"]["T_dew"])
T_amb = params.get("T_amb", CONFIG["physics_thresholds"]["T_amb"])
L_thermo = float(max(0, T_dew - T_amb) ** 2)
# L_causal: Causality (velocity limit)
velocity = np.diff(trajectory) if len(trajectory) > 1 else np.array([0])
v_max = params.get("v_max", CONFIG["physics_thresholds"]["v_max"])
L_causal = float(np.sum(np.maximum(0, np.abs(velocity) - v_max) ** 2))
# L_conserv: Flux conservation
flux_in = params.get("flux_in", 1.0)
flux_out = params.get("flux_out", 1.0)
L_conserv = float((flux_in - flux_out) ** 2)
# L_entropy: 2nd Law (entropy must increase)
entropy_change = params.get("entropy_change", 0.1)
L_entropy = float(max(0, -entropy_change) ** 2)
# Total physics loss
L_total = L_energy + L_thermo + L_causal + L_conserv + L_entropy
return {
"valid": L_total < 0.01,
"L_energy": L_energy,
"L_thermo": L_thermo,
"L_causal": L_causal,
"L_conserv": L_conserv,
"L_entropy": L_entropy,
"L_total": L_total
}
# ============================================================================
# ENDPOINT FUNCTIONS
# ============================================================================
def sense_snn(snn_json: str) -> str:
"""
/sense_snn - Process 72D SNN input through CTM
Input: JSON with dendrite values or 72D vector
Output: Coherent features, certainty, sync matrix
"""
try:
data = json.loads(snn_json)
# Extract 72D vector
if "vector_72d" in data:
input_vec = np.array(data["vector_72d"])
elif "dendrites" in data:
input_vec = np.array(list(data["dendrites"].values()))
else:
input_vec = np.random.randn(72)
# Pad to 72D if needed
if len(input_vec) < 72:
input_vec = np.pad(input_vec, (0, 72 - len(input_vec)))
# Process through CTM
n_ticks = data.get("ticks", 25)
result = ctm.forward(input_vec[:72], n_ticks)
# Detect anomalies (low certainty)
anomalies = []
if result["certainty"] < 0.5:
anomalies.append("Low overall certainty - consider retraining")
return json.dumps({
"status": "success",
"coherent_features": result["predictions"],
"certainty": result["certainty"],
"best_tick": result["best_tick"],
"anomalies": anomalies,
"ticks_used": result["ticks_used"],
"model": result["model"]
}, indent=2)
except Exception as e:
return json.dumps({"status": "error", "message": str(e)})
def reason_hypergraph(context_json: str) -> str:
"""
/reason_hypergraph - Reason about hypergraph context, propose edges
Uses CTM synchronization matrix to find strongly correlated node pairs.
These become proposed hyperedges.
"""
try:
data = json.loads(context_json)
node_features = np.array(data.get("node_features", [[0]*16]*8))
existing_edges = data.get("existing_edges", [])
n_ticks = data.get("ticks", 50)
# Flatten node features for CTM input and pad to 72D
flattened = node_features.flatten()
input_vec = np.zeros(72)
input_vec[:min(len(flattened), 72)] = flattened[:min(len(flattened), 72)]
# Process through CTM with more ticks for reasoning
result = ctm.forward(input_vec, n_ticks)
# Extract proposed edges from sync matrix (S_ij > 0.7)
proposed_edges = []
if result["sync_matrix"] is not None:
sync = np.array(result["sync_matrix"])
# Ensure sync is 2D
if len(sync.shape) == 1:
# 1D array - skip edge extraction
pass
elif len(sync.shape) >= 2:
n_nodes = min(len(node_features), sync.shape[0])
for i in range(n_nodes):
for j in range(i+1, n_nodes):
if j < sync.shape[1]: # Check bounds
sync_ij = sync[i, j]
if sync_ij > 0.7: # Threshold for edge proposal
edge_exists = any(
(e[0] == i and e[1] == j) or (e[0] == j and e[1] == i)
for e in existing_edges
)
if not edge_exists:
proposed_edges.append([i, j, float(sync_ij)])
return json.dumps({
"status": "success",
"proposed_edges": proposed_edges,
"certainty": result["certainty"],
"best_tick": result["best_tick"],
"ticks_used": result["ticks_used"],
"model": result["model"]
}, indent=2)
except Exception as e:
return json.dumps({"status": "error", "message": str(e)})
def validate_physics_endpoint(physics_json: str) -> str:
"""
/validate_physics - Validate trajectory against 5 physics losses
"""
try:
data = json.loads(physics_json)
trajectory = data.get("trajectory", [0.0])
params = data.get("physics_params", {})
result = validate_physics(trajectory, params)
result["status"] = "success"
return json.dumps(result, indent=2)
except Exception as e:
return json.dumps({"status": "error", "message": str(e)})
def dream_endpoint(dream_json: str) -> str:
"""
/dream - Offline consolidation with many ticks
Discovers patterns, proposes new edges, identifies edges to prune.
"""
try:
data = json.loads(dream_json)
snapshot = data.get("hypergraph_snapshot", {})
n_ticks = min(data.get("ticks", 100), 100) # Cap at 100 for CPU
# Extract features from snapshot
nodes = snapshot.get("nodes", [])
if nodes:
input_vec = np.array([n.get("features", [0]*16) for n in nodes]).flatten()[:72]
else:
input_vec = np.random.randn(72)
# Dream: run CTM with many ticks
result = ctm.forward(input_vec, n_ticks)
# Analyze sync for patterns
new_edges = []
pruned_edges = []
if result["sync_matrix"] is not None:
sync = np.array(result["sync_matrix"])
n = min(len(nodes), sync.shape[0]) if nodes else 16
for i in range(n):
for j in range(i+1, n):
if sync[i, j] > 0.85:
new_edges.append([i, j, float(sync[i, j])])
elif sync[i, j] < 0.1:
pruned_edges.append([i, j])
return json.dumps({
"status": "success",
"discovered_patterns": len(new_edges),
"new_edges": new_edges[:10],
"pruned_edges": pruned_edges[:10],
"consolidation_certainty": result["certainty"],
"ticks_used": result["ticks_used"],
"model": result["model"]
}, indent=2)
except Exception as e:
return json.dumps({"status": "error", "message": str(e)})
def calibrate_stdp_endpoint(stdp_json: str) -> str:
"""
/calibrate_stdp - Suggest STDP weight adjustments
This is the CORE regulatory function:
- Receives current 16 dendrite weights
- Processes through CTM to get sync patterns
- Returns suggested weight adjustments
"""
try:
data = json.loads(stdp_json)
current_weights = np.array(data.get("current_weights", [1.0]*16))
node_features = np.array(data.get("node_features", [[0]*16]*4))
# Flatten features for CTM input
input_vec = node_features.flatten()[:72]
# Process through CTM
result = ctm.forward(input_vec, n_ticks=25)
# Use predictions as weight adjustments
predictions = np.array(result["best_predictions"])
# Scale based on certainty
confidence = result["certainty"]
weight_changes = (predictions - 0.5) * confidence * 0.1
new_weights = current_weights + weight_changes
return json.dumps({
"status": "success",
"suggested_weights": new_weights.tolist(),
"weight_changes": weight_changes.tolist(),
"confidence": confidence,
"best_tick": result["best_tick"],
"model": result["model"]
}, indent=2)
except Exception as e:
return json.dumps({"status": "error", "message": str(e)})
def regulate_endpoint(regulate_json: str) -> str:
"""
/regulate - Full feedback loop for ART-17 regulation (NEW)
Combines all signals to provide comprehensive regulation:
- Dendrite state
- Latent representation
- Physics loss
- Anomaly score
Returns action recommendation with confidence.
"""
try:
data = json.loads(regulate_json)
# Inputs from local system
dendrites = np.array(data.get("dendrites", [0.0]*16))
latent_256 = np.array(data.get("latent_256", [0.0]*256))
physics_loss = data.get("physics_loss", 0.0)
anomaly_score = data.get("anomaly_score", 0.0)
# Combine into 72D input
input_72 = np.concatenate([
dendrites, # 16D
latent_256[:56] # 56D from latent
])
# Process through CTM
result = ctm.forward(input_72, n_ticks=50)
# Compute regulation signals
predictions = np.array(result["best_predictions"])
certainty = result["certainty"]
# Urgency based on physics and anomaly
urgency = min(1.0, physics_loss + anomaly_score)
regulation_strength = urgency * certainty
# Weight adjustments
dendrite_deltas = predictions * regulation_strength * 0.05
# Determine if intervention needed
needs_intervention = urgency > 0.5 or certainty < 0.3
return json.dumps({
"status": "success",
"dendrite_deltas": dendrite_deltas.tolist(),
"regulation_strength": float(regulation_strength),
"confidence": certainty,
"urgency": float(urgency),
"needs_intervention": needs_intervention,
"recommended_action": "ADJUST" if needs_intervention else "MAINTAIN",
"best_tick": result["best_tick"],
"model": result["model"]
}, indent=2)
except Exception as e:
return json.dumps({"status": "error", "message": str(e)})
def train_online_endpoint(train_json: str) -> str:
"""
/train_online - Progressive online training (NEW)
Allows the local system to train the CTM with experience.
Sends input-output pairs and receives training feedback.
"""
try:
data = json.loads(train_json)
input_72d = np.array(data.get("input_72d", [0.0]*72))
target_16d = np.array(data.get("target_16d", [0.0]*16))
physics_loss = data.get("physics_loss", 0.0)
# Perform training step
result = ctm.train_step(input_72d, target_16d, physics_loss)
return json.dumps({
"status": result["status"],
"loss": result.get("loss"),
"mse_loss": result.get("mse_loss"),
"physics_penalty": result.get("physics_penalty"),
"best_tick": result.get("best_tick"),
"forward_count": ctm.forward_count,
"message": "Training step completed" if result["status"] == "trained" else result.get("reason")
}, indent=2)
except Exception as e:
return json.dumps({"status": "error", "message": str(e)})
def health_check() -> str:
"""Health check with model info."""
return json.dumps({
"status": "healthy",
"model": f"CTM Nervous System v2.0 ({'Full PyTorch' if ctm.is_full else 'NumPy Fallback'})",
"device": DEVICE,
"d_model": CONFIG["d_model"],
"iterations": CONFIG["iterations"],
"memory_length": CONFIG["memory_length"],
"forward_count": ctm.forward_count,
"endpoints": [
"/sense_snn",
"/reason_hypergraph",
"/validate_physics",
"/dream",
"/calibrate_stdp",
"/regulate", # NEW
"/train_online" # NEW
]
}, indent=2)
# ============================================================================
# GRADIO INTERFACE
# ============================================================================
with gr.Blocks(title="CTM Nervous System v2.0", theme=gr.themes.Soft()) as demo:
gr.Markdown("""
# 🧬 CTM Nervous System v2.0
**Continuous Thought Machine for ART-17 Hypergraph Coherence**
Based on [arXiv:2505.05522](https://arxiv.org/abs/2505.05522) - Sakana AI
---
## Key Innovations
- **NLMs (Neuron-Level Models)**: Each neuron processes its own history
- **Neural Synchronization**: Representation via S = Z·Z^T
- **Adaptive Compute**: Halts when confident
- **Online Training**: Progressive learning with use
---
""")
with gr.Tabs():
with gr.Tab("🔌 /sense_snn"):
gr.Markdown("Process 72D SNN input through CTM")
snn_input = gr.Textbox(
label="SNN JSON Input",
value='{"dendrites": {"d1": 0.1, "d2": 0.2, "d3": 0.3}, "ticks": 25}',
lines=5
)
snn_output = gr.Textbox(label="Output", lines=10)
snn_btn = gr.Button("Process", variant="primary")
snn_btn.click(sense_snn, inputs=snn_input, outputs=snn_output, api_name="sense_snn")
with gr.Tab("🧠 /reason_hypergraph"):
gr.Markdown("Reason about hypergraph context, propose edges")
reason_input = gr.Textbox(
label="Context JSON",
value='{"node_features": [[0.1, 0.2], [0.3, 0.4]], "existing_edges": [], "ticks": 50}',
lines=5
)
reason_output = gr.Textbox(label="Output", lines=10)
reason_btn = gr.Button("Reason", variant="primary")
reason_btn.click(reason_hypergraph, inputs=reason_input, outputs=reason_output, api_name="reason_hypergraph")
with gr.Tab("⚡ /validate_physics"):
gr.Markdown("Validate trajectory against 5 physics losses")
physics_input = gr.Textbox(
label="Physics JSON",
value='{"trajectory": [0.1, 0.2, 0.3], "physics_params": {"P_max": 1000}}',
lines=5
)
physics_output = gr.Textbox(label="Output", lines=10)
physics_btn = gr.Button("Validate", variant="primary")
physics_btn.click(validate_physics_endpoint, inputs=physics_input, outputs=physics_output, api_name="validate_physics")
with gr.Tab("💤 /dream"):
gr.Markdown("Offline consolidation - discover patterns")
dream_input = gr.Textbox(
label="Dream JSON",
value='{"hypergraph_snapshot": {"nodes": []}, "ticks": 100}',
lines=5
)
dream_output = gr.Textbox(label="Output", lines=10)
dream_btn = gr.Button("Dream", variant="primary")
dream_btn.click(dream_endpoint, inputs=dream_input, outputs=dream_output, api_name="dream")
with gr.Tab("🔧 /calibrate_stdp"):
gr.Markdown("Calibrate STDP weights (Core regulatory function)")
stdp_input = gr.Textbox(
label="STDP JSON",
value='{"current_weights": [1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1], "node_features": [[0.1, 0.2]]}',
lines=5
)
stdp_output = gr.Textbox(label="Output", lines=10)
stdp_btn = gr.Button("Calibrate", variant="primary")
stdp_btn.click(calibrate_stdp_endpoint, inputs=stdp_input, outputs=stdp_output, api_name="calibrate_stdp")
with gr.Tab("🎯 /regulate [NEW]"):
gr.Markdown("Full feedback loop for ART-17 regulation")
regulate_input = gr.Textbox(
label="Regulate JSON",
value='{"dendrites": [0.5]*16, "latent_256": [0.1]*256, "physics_loss": 0.01, "anomaly_score": 0.05}',
lines=5
)
regulate_output = gr.Textbox(label="Output", lines=10)
regulate_btn = gr.Button("Regulate", variant="primary")
regulate_btn.click(regulate_endpoint, inputs=regulate_input, outputs=regulate_output, api_name="regulate")
with gr.Tab("📚 /train_online [NEW]"):
gr.Markdown("Progressive online training with experience")
train_input = gr.Textbox(
label="Training JSON",
value='{"input_72d": [0.1]*72, "target_16d": [0.5]*16, "physics_loss": 0.01}',
lines=5
)
train_output = gr.Textbox(label="Output", lines=10)
train_btn = gr.Button("Train Step", variant="primary")
train_btn.click(train_online_endpoint, inputs=train_input, outputs=train_output, api_name="train_online")
with gr.Tab("❤️ Health"):
health_output = gr.Textbox(label="Health Status", lines=15)
health_btn = gr.Button("Check Health", variant="secondary")
health_btn.click(health_check, inputs=None, outputs=health_output, api_name="health_check")
gr.Markdown("""
---
**Architecture**: CTM as Nervous System → Hypergraph as Coherent Thought
**Integration**: Local ART-17 ↔ CTM (regulation) ↔ Brain Server (semantics)
**Training**: Progressive online learning + Physics-Informed Loss
""")
if __name__ == "__main__":
demo.launch(
server_name="0.0.0.0",
server_port=7860,
show_error=True
)