Spaces:
Paused
Paused
| """ | |
| CTM Nervous System Server v2.0 - Full PyTorch Implementation | |
| ============================================================= | |
| Continuous Thought Machine for ART-17 Hypergraph Coherence Generation | |
| PURPOSE (from skills): | |
| 1. REGULACIÓN: Calibrar pesos STDP de las 16 dendritas | |
| 2. COHERENCIA: Generar hipergrafos deterministas | |
| 3. RAZONAMIENTO: Motor de inferencia activa (internal ticks) | |
| 4. SINCRONIZACIÓN: Representación via Neural Synchronization | |
| TRAINING STRATEGY: | |
| - Progressive online learning with use | |
| - Integrates with Brain server (Qwen + VL-JEPA) for semantic grounding | |
| - Automatic checkpoint saving | |
| Based on: arXiv:2505.05522 (Continuous Thought Machines - Sakana AI) | |
| Adapted for: ART-17 Dendrite Regulation & Hypergraph Generation | |
| """ | |
| import gradio as gr | |
| import numpy as np | |
| import json | |
| import os | |
| from typing import List, Dict, Any, Optional | |
| from datetime import datetime | |
| from utils.bunker_client import BunkerClient | |
| # ============================================================================ | |
| # PYTORCH IMPORTS WITH FALLBACK | |
| # ============================================================================ | |
| try: | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| TORCH_AVAILABLE = True | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| print(f"🔧 PyTorch available. Device: {DEVICE}") | |
| except ImportError: | |
| TORCH_AVAILABLE = False | |
| DEVICE = "cpu" | |
| print("⚠️ PyTorch not available. Using simplified NumPy fallback.") | |
| # ============================================================================ | |
| # FULL CTM IMPORT (with fallback to simplified) | |
| # ============================================================================ | |
| if TORCH_AVAILABLE: | |
| try: | |
| from models.ctm import ContinuousThoughtMachine | |
| from models.modules import SynapseUNET, SuperLinear | |
| from utils.losses import image_classification_loss | |
| CTM_FULL = True | |
| print("✅ Full CTM model loaded from models/ctm.py") | |
| except ImportError as e: | |
| CTM_FULL = False | |
| print(f"⚠️ Could not import full CTM: {e}. Using simplified.") | |
| else: | |
| CTM_FULL = False | |
| # ============================================================================ | |
| # CONFIGURATION FOR ART-17 INTEGRATION (v3.0) | |
| # ============================================================================ | |
| CONFIG = { | |
| # CTM Architecture (matching ART-17) | |
| "iterations": 50, # T internal ticks (max) | |
| "d_model": 256, # Latent dimension | |
| "d_input": 72, # Input from SNN (72D) | |
| "memory_length": 16, # History length (16 dendrites) | |
| "n_synch_out": 32, # Output sync neurons | |
| "n_synch_action": 16, # Action sync neurons | |
| "out_dims": 16, # Output: 16 dendrite adjustments | |
| # v3.0 Improvements | |
| "adaptive_halting": True, # Enable early stopping | |
| "certainty_threshold": 0.85, # Halt if certainty > threshold | |
| "sync_decay_alpha": 0.9, # S_new = α*S_old + (1-α)*S_current | |
| "use_backbone": True, # Use Backbone72D transformation | |
| # Training | |
| "learning_rate": 1e-4, | |
| "weight_decay": 1e-5, | |
| "checkpoint_dir": "checkpoints", | |
| "auto_save_every": 100, # Save every N forward passes | |
| # Integration | |
| "brain_server_url": "https://elliotasdasdasfasas-brain.hf.space", | |
| # Physics validation | |
| "physics_thresholds": { | |
| "P_max": 1000.0, | |
| "v_max": 100.0, | |
| "T_dew": 15.0, | |
| "T_amb": 25.0 | |
| } | |
| } | |
| # ============================================================================ | |
| # BACKBONE 72D (v3.0 - Transform input before CTM) | |
| # ============================================================================ | |
| class Backbone72D(nn.Module if TORCH_AVAILABLE else object): | |
| """ | |
| Transform 72D SNN input to d_model dimensions. | |
| Paper insight: Raw input needs proper embedding for CTM to work well. | |
| """ | |
| def __init__(self, d_input=72, d_model=256): | |
| if not TORCH_AVAILABLE: | |
| return | |
| super().__init__() | |
| self.net = nn.Sequential( | |
| nn.Linear(d_input, 128), | |
| nn.LayerNorm(128), | |
| nn.GELU(), | |
| nn.Linear(128, d_model), | |
| nn.LayerNorm(d_model) | |
| ) | |
| def forward(self, x): | |
| # x: [B, 72] | |
| return self.net(x) # [B, 256] | |
| # ============================================================================ | |
| # FULL CTM WRAPPER FOR ART-17 | |
| # ============================================================================ | |
| class CTM_ART17: | |
| """ | |
| Full Continuous Thought Machine adapted for ART-17. | |
| Key mechanisms from paper: | |
| 1. NLMs (Neuron-Level Models) - Each neuron processes its own history | |
| 2. Neural Synchronization - Representation is S = Z·Z^T | |
| 3. Adaptive Compute - Can halt early when confident | |
| Purpose in ART-17: | |
| - Regulate 16 dendrite STDP weights | |
| - Generate coherent hypergraph edges | |
| - Serve as "nervous system" for the whole system | |
| """ | |
| def __init__(self, config: dict): | |
| self.config = config | |
| self.forward_count = 0 | |
| self.training_samples = [] | |
| self.bunker = BunkerClient(buffer_dir=config.get("buffer_dir", "_ctm_buffer")) | |
| if CTM_FULL and TORCH_AVAILABLE: | |
| self._init_full_ctm() | |
| else: | |
| self._init_simplified_ctm() | |
| def _init_full_ctm(self): | |
| """Initialize full PyTorch CTM model.""" | |
| self.model = ContinuousThoughtMachine( | |
| iterations=self.config["iterations"], | |
| d_model=self.config["d_model"], | |
| d_input=self.config["d_input"], | |
| heads=4, | |
| n_synch_out=self.config["n_synch_out"], | |
| n_synch_action=self.config["n_synch_action"], | |
| synapse_depth=2, | |
| memory_length=self.config["memory_length"], | |
| deep_nlms=True, | |
| memory_hidden_dims=32, | |
| do_layernorm_nlm=False, | |
| backbone_type='none', | |
| positional_embedding_type='none', | |
| out_dims=self.config["out_dims"], | |
| prediction_reshaper=[self.config["out_dims"]], | |
| dropout=0.1, | |
| neuron_select_type='random-pairing' | |
| ).to(DEVICE) | |
| # Dummy forward to initialize lazy modules | |
| with torch.no_grad(): | |
| dummy = torch.randn(1, self.config["d_input"], device=DEVICE) | |
| dummy = dummy.unsqueeze(-1).unsqueeze(-1) # [1, 72, 1, 1] | |
| try: | |
| _ = self.model(dummy) | |
| except Exception as e: | |
| print(f"⚠️ Lazy init failed: {e}") | |
| self.model.eval() | |
| self.optimizer = torch.optim.AdamW( | |
| self.model.parameters(), | |
| lr=self.config["learning_rate"], | |
| weight_decay=self.config["weight_decay"] | |
| ) | |
| self.is_full = True | |
| param_count = sum(p.numel() for p in self.model.parameters()) | |
| print(f"✅ Full CTM initialized: {param_count:,} parameters") | |
| # Try to load existing checkpoint | |
| self._load_checkpoint() | |
| def _init_simplified_ctm(self): | |
| """Initialize simplified NumPy CTM (fallback).""" | |
| self.d_model = self.config["d_model"] | |
| self.memory_length = self.config["memory_length"] | |
| self.n_ticks = self.config["iterations"] | |
| # State traces | |
| self.state_trace = np.zeros((self.d_model, self.memory_length)) | |
| self.activated_state = np.random.randn(self.d_model) * 0.1 | |
| # NLM weights (simplified: 16 groups for 16 dendrites) | |
| self.nlm_weights = np.random.randn(16, self.memory_length) * 0.1 | |
| self.is_full = False | |
| print("✅ Simplified CTM initialized (NumPy fallback)") | |
| def forward(self, input_72d: np.ndarray, n_ticks: Optional[int] = None) -> Dict: | |
| """ | |
| Process input through CTM. | |
| Args: | |
| input_72d: 72D input from SNN | |
| n_ticks: Override number of internal ticks | |
| Returns: | |
| Dict with predictions, certainty, sync matrix | |
| """ | |
| n_ticks = n_ticks or self.config["iterations"] | |
| self.forward_count += 1 | |
| if self.is_full: | |
| return self._forward_full(input_72d, n_ticks) | |
| else: | |
| return self._forward_simplified(input_72d, n_ticks) | |
| def _forward_full(self, input_72d: np.ndarray, n_ticks: int) -> Dict: | |
| """Forward pass with full PyTorch CTM.""" | |
| # Prepare tensor | |
| x = torch.tensor(input_72d, dtype=torch.float32, device=DEVICE) | |
| if len(x.shape) == 1: | |
| x = x.unsqueeze(0) # Add batch dim | |
| x = x.unsqueeze(-1).unsqueeze(-1) # [B, 72, 1, 1] | |
| with torch.no_grad(): | |
| predictions, certainties, sync_out = self.model(x) | |
| # Extract results | |
| final_pred = predictions[:, :, -1].cpu().numpy()[0] # Last tick [16] | |
| final_cert = certainties[:, 1, -1].cpu().numpy()[0] # 1-entropy | |
| # Find tick with highest certainty | |
| best_tick_idx = certainties[:, 1, :].argmax(dim=-1)[0].item() | |
| best_pred = predictions[:, :, best_tick_idx].cpu().numpy()[0] | |
| # Sync matrix for hypergraph edge proposals | |
| sync_matrix = sync_out.cpu().numpy()[0] if sync_out is not None else None | |
| return { | |
| "predictions": final_pred.tolist(), | |
| "best_predictions": best_pred.tolist(), | |
| "certainty": float(final_cert), | |
| "best_tick": int(best_tick_idx), | |
| "ticks_used": n_ticks, | |
| "sync_matrix": sync_matrix.tolist() if sync_matrix is not None else None, | |
| "model": "ContinuousThoughtMachine (Full PyTorch)" | |
| } | |
| def _forward_simplified(self, input_72d: np.ndarray, n_ticks: int) -> Dict: | |
| """ | |
| Forward pass with simplified NumPy CTM (v3.0). | |
| v3.0 Features: | |
| 1. Backbone transformation (72D -> 256D) | |
| 2. Sync Decay (S = α*S_prev + (1-α)*S_current) | |
| 3. Adaptive Halting (stop if certainty > threshold) | |
| """ | |
| # v3.0: Backbone transformation (simple linear projection) | |
| if self.config.get("use_backbone", True): | |
| # Learned transformation: 72D -> 256D | |
| input_256 = np.zeros(self.d_model) | |
| # Simple linear projection + normalization (simulates Backbone72D) | |
| projected = np.tanh(input_72d[:72] * np.random.randn(72) * 0.1) if len(input_72d) >= 72 else input_72d | |
| input_256[:min(len(projected), self.d_model)] = projected[:min(len(projected), self.d_model)] | |
| else: | |
| input_256 = np.zeros(self.d_model) | |
| input_256[:min(len(input_72d), self.d_model)] = input_72d[:self.d_model] | |
| # v3.0: Sync Decay initialization | |
| alpha = self.config.get("sync_decay_alpha", 0.9) | |
| sync_matrix_prev = np.zeros((self.d_model, self.d_model)) | |
| # v3.0: Adaptive halting config | |
| adaptive_halting = self.config.get("adaptive_halting", True) | |
| certainty_threshold = self.config.get("certainty_threshold", 0.85) | |
| certainties = [] | |
| all_predictions = [] | |
| ticks_actually_used = 0 | |
| for t in range(n_ticks): | |
| ticks_actually_used = t + 1 | |
| # Synapse update (simplified global mixing) | |
| combined = np.concatenate([self.activated_state, input_256[:self.d_model//2]]) | |
| pre_activation = np.tanh(combined[:self.d_model] * 0.1 + np.random.randn(self.d_model) * 0.01) | |
| # Update trace (memory) | |
| self.state_trace = np.roll(self.state_trace, -1, axis=1) | |
| self.state_trace[:, -1] = pre_activation | |
| # NLM processing (simplified: 16 groups for 16 dendrites) | |
| post_activation = np.zeros(self.d_model) | |
| group_size = self.d_model // 16 | |
| for g in range(16): | |
| start = g * group_size | |
| end = start + group_size | |
| group_trace = self.state_trace[start:end, :] | |
| group_output = np.mean(group_trace @ self.nlm_weights[g]) | |
| post_activation[start:end] = np.tanh(group_output) | |
| self.activated_state = post_activation | |
| # v3.0: Sync Decay - S = α*S_prev + (1-α)*Z·Z^T | |
| z_norm = self.activated_state / (np.linalg.norm(self.activated_state) + 1e-8) | |
| sync_current = np.outer(z_norm, z_norm) | |
| sync_matrix = alpha * sync_matrix_prev + (1 - alpha) * sync_current | |
| sync_matrix_prev = sync_matrix | |
| # Store predictions at this tick | |
| all_predictions.append(self.activated_state[:16].copy()) | |
| # Compute certainty | |
| probs = np.abs(self.activated_state) / (np.sum(np.abs(self.activated_state)) + 1e-8) | |
| probs = np.clip(probs, 1e-10, 1.0) | |
| entropy = -np.sum(probs * np.log(probs)) | |
| max_entropy = np.log(len(probs)) | |
| certainty = float(1.0 - entropy / (max_entropy + 1e-8)) | |
| certainties.append(certainty) | |
| # v3.0: Adaptive Halting - stop early if confident enough | |
| if adaptive_halting and certainty > certainty_threshold: | |
| break | |
| # Best tick selection | |
| best_tick_idx = int(np.argmax(certainties)) | |
| best_predictions = all_predictions[best_tick_idx].tolist() | |
| return { | |
| "predictions": self.activated_state[:16].tolist(), | |
| "best_predictions": best_predictions, | |
| "certainty": certainties[-1], | |
| "best_tick": best_tick_idx, | |
| "ticks_used": ticks_actually_used, # v3.0: Actual ticks, may be < n_ticks | |
| "max_ticks": n_ticks, | |
| "halted_early": ticks_actually_used < n_ticks, # v3.0: Flag | |
| "sync_matrix": sync_matrix[:16, :16].tolist(), | |
| "model": "SimplifiedCTM v3.0 (NumPy + AdaptiveHalt + SyncDecay)" | |
| } | |
| def train_step(self, input_72d: np.ndarray, target_16d: np.ndarray, | |
| physics_loss: float = 0.0) -> Dict: | |
| """ | |
| Online training step. | |
| Args: | |
| input_72d: Input from SNN | |
| target_16d: Target dendrite adjustments (ground truth) | |
| physics_loss: Current physics loss for weighting | |
| Returns: | |
| Dict with loss and gradient info | |
| """ | |
| if not self.is_full or not TORCH_AVAILABLE: | |
| return {"status": "skip", "reason": "Training requires full PyTorch CTM"} | |
| self.model.train() | |
| # Prepare tensors | |
| x = torch.tensor(input_72d, dtype=torch.float32, device=DEVICE) | |
| x = x.unsqueeze(0).unsqueeze(-1).unsqueeze(-1) # [1, 72, 1, 1] | |
| y = torch.tensor(target_16d, dtype=torch.float32, device=DEVICE).unsqueeze(0) | |
| # Forward | |
| predictions, certainties, _ = self.model(x) | |
| # Loss: dendrite_regulation_loss | |
| # predictions: [B, 16, T], y: [B, 16] | |
| y_exp = y.unsqueeze(-1).expand(-1, -1, predictions.size(-1)) # [B, 16, T] | |
| mse_per_tick = F.mse_loss(predictions, y_exp, reduction='none').mean(dim=1) # [B, T] | |
| # Select best tick (min loss) and most certain tick | |
| loss_min_idx = mse_per_tick.argmin(dim=1) # [B] | |
| loss_cert_idx = certainties[:, 1, :].argmax(dim=1) # [B] | |
| batch_idx = torch.arange(predictions.size(0), device=DEVICE) | |
| loss_min = mse_per_tick[batch_idx, loss_min_idx].mean() | |
| loss_cert = mse_per_tick[batch_idx, loss_cert_idx].mean() | |
| # Combined loss with physics penalty | |
| mse_loss = (loss_min + loss_cert) / 2 | |
| physics_penalty = physics_loss * 0.1 | |
| total_loss = mse_loss + physics_penalty | |
| # Backward | |
| self.optimizer.zero_grad() | |
| total_loss.backward() | |
| torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0) | |
| self.optimizer.step() | |
| self.model.eval() | |
| # Auto-save checkpoint | |
| if self.forward_count % self.config["auto_save_every"] == 0: | |
| self._save_checkpoint() | |
| return { | |
| "status": "trained", | |
| "loss": float(total_loss.item()), | |
| "mse_loss": float(mse_loss.item()), | |
| "physics_penalty": float(physics_penalty), | |
| "best_tick": int(loss_cert_idx[0].item()) | |
| } | |
| def _save_checkpoint(self): | |
| """Save model checkpoint.""" | |
| if not self.is_full: | |
| return | |
| os.makedirs(self.config["checkpoint_dir"], exist_ok=True) | |
| path = os.path.join(self.config["checkpoint_dir"], "ctm_art17_latest.pt") | |
| torch.save({ | |
| "model_state_dict": self.model.state_dict(), | |
| "optimizer_state_dict": self.optimizer.state_dict(), | |
| "forward_count": self.forward_count, | |
| "timestamp": datetime.now().isoformat() | |
| }, path) | |
| print(f"💾 Checkpoint saved: {path}") | |
| # Upload to Bunker (Async/Fail-Safe) | |
| self.bunker.save_file(path, remote_folder="ctm_backups") | |
| def _load_checkpoint(self): | |
| """Load model checkpoint if exists.""" | |
| path = os.path.join(self.config["checkpoint_dir"], "ctm_art17_latest.pt") | |
| if os.path.exists(path): | |
| try: | |
| checkpoint = torch.load(path, map_location=DEVICE) | |
| self.model.load_state_dict(checkpoint["model_state_dict"]) | |
| self.optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) | |
| self.forward_count = checkpoint.get("forward_count", 0) | |
| print(f"✅ Checkpoint loaded: {path}") | |
| except Exception as e: | |
| print(f"⚠️ Could not load checkpoint: {e}") | |
| # ============================================================================ | |
| # GLOBAL CTM INSTANCE | |
| # ============================================================================ | |
| ctm = CTM_ART17(CONFIG) | |
| # ============================================================================ | |
| # PHYSICS VALIDATION (from SNN Omega-21) | |
| # ============================================================================ | |
| def validate_physics(trajectory: List[float], params: Dict) -> Dict: | |
| """Validate against 5 physics losses from SNN Omega-21.""" | |
| trajectory = np.array(trajectory) | |
| # L_energy: Energy conservation | |
| energy = np.sum(trajectory ** 2) | |
| P_max = params.get("P_max", CONFIG["physics_thresholds"]["P_max"]) | |
| L_energy = float(max(0, energy - P_max) ** 2) | |
| # L_thermo: Thermodynamics (dew point check) | |
| T_dew = params.get("T_dew", CONFIG["physics_thresholds"]["T_dew"]) | |
| T_amb = params.get("T_amb", CONFIG["physics_thresholds"]["T_amb"]) | |
| L_thermo = float(max(0, T_dew - T_amb) ** 2) | |
| # L_causal: Causality (velocity limit) | |
| velocity = np.diff(trajectory) if len(trajectory) > 1 else np.array([0]) | |
| v_max = params.get("v_max", CONFIG["physics_thresholds"]["v_max"]) | |
| L_causal = float(np.sum(np.maximum(0, np.abs(velocity) - v_max) ** 2)) | |
| # L_conserv: Flux conservation | |
| flux_in = params.get("flux_in", 1.0) | |
| flux_out = params.get("flux_out", 1.0) | |
| L_conserv = float((flux_in - flux_out) ** 2) | |
| # L_entropy: 2nd Law (entropy must increase) | |
| entropy_change = params.get("entropy_change", 0.1) | |
| L_entropy = float(max(0, -entropy_change) ** 2) | |
| # Total physics loss | |
| L_total = L_energy + L_thermo + L_causal + L_conserv + L_entropy | |
| return { | |
| "valid": L_total < 0.01, | |
| "L_energy": L_energy, | |
| "L_thermo": L_thermo, | |
| "L_causal": L_causal, | |
| "L_conserv": L_conserv, | |
| "L_entropy": L_entropy, | |
| "L_total": L_total | |
| } | |
| # ============================================================================ | |
| # ENDPOINT FUNCTIONS | |
| # ============================================================================ | |
| def sense_snn(snn_json: str) -> str: | |
| """ | |
| /sense_snn - Process 72D SNN input through CTM | |
| Input: JSON with dendrite values or 72D vector | |
| Output: Coherent features, certainty, sync matrix | |
| """ | |
| try: | |
| data = json.loads(snn_json) | |
| # Extract 72D vector | |
| if "vector_72d" in data: | |
| input_vec = np.array(data["vector_72d"]) | |
| elif "dendrites" in data: | |
| input_vec = np.array(list(data["dendrites"].values())) | |
| else: | |
| input_vec = np.random.randn(72) | |
| # Pad to 72D if needed | |
| if len(input_vec) < 72: | |
| input_vec = np.pad(input_vec, (0, 72 - len(input_vec))) | |
| # Process through CTM | |
| n_ticks = data.get("ticks", 25) | |
| result = ctm.forward(input_vec[:72], n_ticks) | |
| # Detect anomalies (low certainty) | |
| anomalies = [] | |
| if result["certainty"] < 0.5: | |
| anomalies.append("Low overall certainty - consider retraining") | |
| return json.dumps({ | |
| "status": "success", | |
| "coherent_features": result["predictions"], | |
| "certainty": result["certainty"], | |
| "best_tick": result["best_tick"], | |
| "anomalies": anomalies, | |
| "ticks_used": result["ticks_used"], | |
| "model": result["model"] | |
| }, indent=2) | |
| except Exception as e: | |
| return json.dumps({"status": "error", "message": str(e)}) | |
| def reason_hypergraph(context_json: str) -> str: | |
| """ | |
| /reason_hypergraph - Reason about hypergraph context, propose edges | |
| Uses CTM synchronization matrix to find strongly correlated node pairs. | |
| These become proposed hyperedges. | |
| """ | |
| try: | |
| data = json.loads(context_json) | |
| node_features = np.array(data.get("node_features", [[0]*16]*8)) | |
| existing_edges = data.get("existing_edges", []) | |
| n_ticks = data.get("ticks", 50) | |
| # Flatten node features for CTM input and pad to 72D | |
| flattened = node_features.flatten() | |
| input_vec = np.zeros(72) | |
| input_vec[:min(len(flattened), 72)] = flattened[:min(len(flattened), 72)] | |
| # Process through CTM with more ticks for reasoning | |
| result = ctm.forward(input_vec, n_ticks) | |
| # Extract proposed edges from sync matrix (S_ij > 0.7) | |
| proposed_edges = [] | |
| if result["sync_matrix"] is not None: | |
| sync = np.array(result["sync_matrix"]) | |
| # Ensure sync is 2D | |
| if len(sync.shape) == 1: | |
| # 1D array - skip edge extraction | |
| pass | |
| elif len(sync.shape) >= 2: | |
| n_nodes = min(len(node_features), sync.shape[0]) | |
| for i in range(n_nodes): | |
| for j in range(i+1, n_nodes): | |
| if j < sync.shape[1]: # Check bounds | |
| sync_ij = sync[i, j] | |
| if sync_ij > 0.7: # Threshold for edge proposal | |
| edge_exists = any( | |
| (e[0] == i and e[1] == j) or (e[0] == j and e[1] == i) | |
| for e in existing_edges | |
| ) | |
| if not edge_exists: | |
| proposed_edges.append([i, j, float(sync_ij)]) | |
| return json.dumps({ | |
| "status": "success", | |
| "proposed_edges": proposed_edges, | |
| "certainty": result["certainty"], | |
| "best_tick": result["best_tick"], | |
| "ticks_used": result["ticks_used"], | |
| "model": result["model"] | |
| }, indent=2) | |
| except Exception as e: | |
| return json.dumps({"status": "error", "message": str(e)}) | |
| def validate_physics_endpoint(physics_json: str) -> str: | |
| """ | |
| /validate_physics - Validate trajectory against 5 physics losses | |
| """ | |
| try: | |
| data = json.loads(physics_json) | |
| trajectory = data.get("trajectory", [0.0]) | |
| params = data.get("physics_params", {}) | |
| result = validate_physics(trajectory, params) | |
| result["status"] = "success" | |
| return json.dumps(result, indent=2) | |
| except Exception as e: | |
| return json.dumps({"status": "error", "message": str(e)}) | |
| def dream_endpoint(dream_json: str) -> str: | |
| """ | |
| /dream - Offline consolidation with many ticks | |
| Discovers patterns, proposes new edges, identifies edges to prune. | |
| """ | |
| try: | |
| data = json.loads(dream_json) | |
| snapshot = data.get("hypergraph_snapshot", {}) | |
| n_ticks = min(data.get("ticks", 100), 100) # Cap at 100 for CPU | |
| # Extract features from snapshot | |
| nodes = snapshot.get("nodes", []) | |
| if nodes: | |
| input_vec = np.array([n.get("features", [0]*16) for n in nodes]).flatten()[:72] | |
| else: | |
| input_vec = np.random.randn(72) | |
| # Dream: run CTM with many ticks | |
| result = ctm.forward(input_vec, n_ticks) | |
| # Analyze sync for patterns | |
| new_edges = [] | |
| pruned_edges = [] | |
| if result["sync_matrix"] is not None: | |
| sync = np.array(result["sync_matrix"]) | |
| n = min(len(nodes), sync.shape[0]) if nodes else 16 | |
| for i in range(n): | |
| for j in range(i+1, n): | |
| if sync[i, j] > 0.85: | |
| new_edges.append([i, j, float(sync[i, j])]) | |
| elif sync[i, j] < 0.1: | |
| pruned_edges.append([i, j]) | |
| return json.dumps({ | |
| "status": "success", | |
| "discovered_patterns": len(new_edges), | |
| "new_edges": new_edges[:10], | |
| "pruned_edges": pruned_edges[:10], | |
| "consolidation_certainty": result["certainty"], | |
| "ticks_used": result["ticks_used"], | |
| "model": result["model"] | |
| }, indent=2) | |
| except Exception as e: | |
| return json.dumps({"status": "error", "message": str(e)}) | |
| def calibrate_stdp_endpoint(stdp_json: str) -> str: | |
| """ | |
| /calibrate_stdp - Suggest STDP weight adjustments | |
| This is the CORE regulatory function: | |
| - Receives current 16 dendrite weights | |
| - Processes through CTM to get sync patterns | |
| - Returns suggested weight adjustments | |
| """ | |
| try: | |
| data = json.loads(stdp_json) | |
| current_weights = np.array(data.get("current_weights", [1.0]*16)) | |
| node_features = np.array(data.get("node_features", [[0]*16]*4)) | |
| # Flatten features for CTM input | |
| input_vec = node_features.flatten()[:72] | |
| # Process through CTM | |
| result = ctm.forward(input_vec, n_ticks=25) | |
| # Use predictions as weight adjustments | |
| predictions = np.array(result["best_predictions"]) | |
| # Scale based on certainty | |
| confidence = result["certainty"] | |
| weight_changes = (predictions - 0.5) * confidence * 0.1 | |
| new_weights = current_weights + weight_changes | |
| return json.dumps({ | |
| "status": "success", | |
| "suggested_weights": new_weights.tolist(), | |
| "weight_changes": weight_changes.tolist(), | |
| "confidence": confidence, | |
| "best_tick": result["best_tick"], | |
| "model": result["model"] | |
| }, indent=2) | |
| except Exception as e: | |
| return json.dumps({"status": "error", "message": str(e)}) | |
| def regulate_endpoint(regulate_json: str) -> str: | |
| """ | |
| /regulate - Full feedback loop for ART-17 regulation (NEW) | |
| Combines all signals to provide comprehensive regulation: | |
| - Dendrite state | |
| - Latent representation | |
| - Physics loss | |
| - Anomaly score | |
| Returns action recommendation with confidence. | |
| """ | |
| try: | |
| data = json.loads(regulate_json) | |
| # Inputs from local system | |
| dendrites = np.array(data.get("dendrites", [0.0]*16)) | |
| latent_256 = np.array(data.get("latent_256", [0.0]*256)) | |
| physics_loss = data.get("physics_loss", 0.0) | |
| anomaly_score = data.get("anomaly_score", 0.0) | |
| # Combine into 72D input | |
| input_72 = np.concatenate([ | |
| dendrites, # 16D | |
| latent_256[:56] # 56D from latent | |
| ]) | |
| # Process through CTM | |
| result = ctm.forward(input_72, n_ticks=50) | |
| # Compute regulation signals | |
| predictions = np.array(result["best_predictions"]) | |
| certainty = result["certainty"] | |
| # Urgency based on physics and anomaly | |
| urgency = min(1.0, physics_loss + anomaly_score) | |
| regulation_strength = urgency * certainty | |
| # Weight adjustments | |
| dendrite_deltas = predictions * regulation_strength * 0.05 | |
| # Determine if intervention needed | |
| needs_intervention = urgency > 0.5 or certainty < 0.3 | |
| return json.dumps({ | |
| "status": "success", | |
| "dendrite_deltas": dendrite_deltas.tolist(), | |
| "regulation_strength": float(regulation_strength), | |
| "confidence": certainty, | |
| "urgency": float(urgency), | |
| "needs_intervention": needs_intervention, | |
| "recommended_action": "ADJUST" if needs_intervention else "MAINTAIN", | |
| "best_tick": result["best_tick"], | |
| "model": result["model"] | |
| }, indent=2) | |
| except Exception as e: | |
| return json.dumps({"status": "error", "message": str(e)}) | |
| def train_online_endpoint(train_json: str) -> str: | |
| """ | |
| /train_online - Progressive online training (NEW) | |
| Allows the local system to train the CTM with experience. | |
| Sends input-output pairs and receives training feedback. | |
| """ | |
| try: | |
| data = json.loads(train_json) | |
| input_72d = np.array(data.get("input_72d", [0.0]*72)) | |
| target_16d = np.array(data.get("target_16d", [0.0]*16)) | |
| physics_loss = data.get("physics_loss", 0.0) | |
| # Perform training step | |
| result = ctm.train_step(input_72d, target_16d, physics_loss) | |
| return json.dumps({ | |
| "status": result["status"], | |
| "loss": result.get("loss"), | |
| "mse_loss": result.get("mse_loss"), | |
| "physics_penalty": result.get("physics_penalty"), | |
| "best_tick": result.get("best_tick"), | |
| "forward_count": ctm.forward_count, | |
| "message": "Training step completed" if result["status"] == "trained" else result.get("reason") | |
| }, indent=2) | |
| except Exception as e: | |
| return json.dumps({"status": "error", "message": str(e)}) | |
| def health_check() -> str: | |
| """Health check with model info.""" | |
| return json.dumps({ | |
| "status": "healthy", | |
| "model": f"CTM Nervous System v2.0 ({'Full PyTorch' if ctm.is_full else 'NumPy Fallback'})", | |
| "device": DEVICE, | |
| "d_model": CONFIG["d_model"], | |
| "iterations": CONFIG["iterations"], | |
| "memory_length": CONFIG["memory_length"], | |
| "forward_count": ctm.forward_count, | |
| "endpoints": [ | |
| "/sense_snn", | |
| "/reason_hypergraph", | |
| "/validate_physics", | |
| "/dream", | |
| "/calibrate_stdp", | |
| "/regulate", # NEW | |
| "/train_online" # NEW | |
| ] | |
| }, indent=2) | |
| # ============================================================================ | |
| # GRADIO INTERFACE | |
| # ============================================================================ | |
| with gr.Blocks(title="CTM Nervous System v2.0", theme=gr.themes.Soft()) as demo: | |
| gr.Markdown(""" | |
| # 🧬 CTM Nervous System v2.0 | |
| **Continuous Thought Machine for ART-17 Hypergraph Coherence** | |
| Based on [arXiv:2505.05522](https://arxiv.org/abs/2505.05522) - Sakana AI | |
| --- | |
| ## Key Innovations | |
| - **NLMs (Neuron-Level Models)**: Each neuron processes its own history | |
| - **Neural Synchronization**: Representation via S = Z·Z^T | |
| - **Adaptive Compute**: Halts when confident | |
| - **Online Training**: Progressive learning with use | |
| --- | |
| """) | |
| with gr.Tabs(): | |
| with gr.Tab("🔌 /sense_snn"): | |
| gr.Markdown("Process 72D SNN input through CTM") | |
| snn_input = gr.Textbox( | |
| label="SNN JSON Input", | |
| value='{"dendrites": {"d1": 0.1, "d2": 0.2, "d3": 0.3}, "ticks": 25}', | |
| lines=5 | |
| ) | |
| snn_output = gr.Textbox(label="Output", lines=10) | |
| snn_btn = gr.Button("Process", variant="primary") | |
| snn_btn.click(sense_snn, inputs=snn_input, outputs=snn_output, api_name="sense_snn") | |
| with gr.Tab("🧠 /reason_hypergraph"): | |
| gr.Markdown("Reason about hypergraph context, propose edges") | |
| reason_input = gr.Textbox( | |
| label="Context JSON", | |
| value='{"node_features": [[0.1, 0.2], [0.3, 0.4]], "existing_edges": [], "ticks": 50}', | |
| lines=5 | |
| ) | |
| reason_output = gr.Textbox(label="Output", lines=10) | |
| reason_btn = gr.Button("Reason", variant="primary") | |
| reason_btn.click(reason_hypergraph, inputs=reason_input, outputs=reason_output, api_name="reason_hypergraph") | |
| with gr.Tab("⚡ /validate_physics"): | |
| gr.Markdown("Validate trajectory against 5 physics losses") | |
| physics_input = gr.Textbox( | |
| label="Physics JSON", | |
| value='{"trajectory": [0.1, 0.2, 0.3], "physics_params": {"P_max": 1000}}', | |
| lines=5 | |
| ) | |
| physics_output = gr.Textbox(label="Output", lines=10) | |
| physics_btn = gr.Button("Validate", variant="primary") | |
| physics_btn.click(validate_physics_endpoint, inputs=physics_input, outputs=physics_output, api_name="validate_physics") | |
| with gr.Tab("💤 /dream"): | |
| gr.Markdown("Offline consolidation - discover patterns") | |
| dream_input = gr.Textbox( | |
| label="Dream JSON", | |
| value='{"hypergraph_snapshot": {"nodes": []}, "ticks": 100}', | |
| lines=5 | |
| ) | |
| dream_output = gr.Textbox(label="Output", lines=10) | |
| dream_btn = gr.Button("Dream", variant="primary") | |
| dream_btn.click(dream_endpoint, inputs=dream_input, outputs=dream_output, api_name="dream") | |
| with gr.Tab("🔧 /calibrate_stdp"): | |
| gr.Markdown("Calibrate STDP weights (Core regulatory function)") | |
| stdp_input = gr.Textbox( | |
| label="STDP JSON", | |
| value='{"current_weights": [1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1], "node_features": [[0.1, 0.2]]}', | |
| lines=5 | |
| ) | |
| stdp_output = gr.Textbox(label="Output", lines=10) | |
| stdp_btn = gr.Button("Calibrate", variant="primary") | |
| stdp_btn.click(calibrate_stdp_endpoint, inputs=stdp_input, outputs=stdp_output, api_name="calibrate_stdp") | |
| with gr.Tab("🎯 /regulate [NEW]"): | |
| gr.Markdown("Full feedback loop for ART-17 regulation") | |
| regulate_input = gr.Textbox( | |
| label="Regulate JSON", | |
| value='{"dendrites": [0.5]*16, "latent_256": [0.1]*256, "physics_loss": 0.01, "anomaly_score": 0.05}', | |
| lines=5 | |
| ) | |
| regulate_output = gr.Textbox(label="Output", lines=10) | |
| regulate_btn = gr.Button("Regulate", variant="primary") | |
| regulate_btn.click(regulate_endpoint, inputs=regulate_input, outputs=regulate_output, api_name="regulate") | |
| with gr.Tab("📚 /train_online [NEW]"): | |
| gr.Markdown("Progressive online training with experience") | |
| train_input = gr.Textbox( | |
| label="Training JSON", | |
| value='{"input_72d": [0.1]*72, "target_16d": [0.5]*16, "physics_loss": 0.01}', | |
| lines=5 | |
| ) | |
| train_output = gr.Textbox(label="Output", lines=10) | |
| train_btn = gr.Button("Train Step", variant="primary") | |
| train_btn.click(train_online_endpoint, inputs=train_input, outputs=train_output, api_name="train_online") | |
| with gr.Tab("❤️ Health"): | |
| health_output = gr.Textbox(label="Health Status", lines=15) | |
| health_btn = gr.Button("Check Health", variant="secondary") | |
| health_btn.click(health_check, inputs=None, outputs=health_output, api_name="health_check") | |
| gr.Markdown(""" | |
| --- | |
| **Architecture**: CTM as Nervous System → Hypergraph as Coherent Thought | |
| **Integration**: Local ART-17 ↔ CTM (regulation) ↔ Brain Server (semantics) | |
| **Training**: Progressive online learning + Physics-Informed Loss | |
| """) | |
| if __name__ == "__main__": | |
| demo.launch( | |
| server_name="0.0.0.0", | |
| server_port=7860, | |
| show_error=True | |
| ) | |