import torch import torch.nn as nn from typing import Optional, Any, cast from ..utils.sleep import EpisodicBuffer, SleepReplayScheduler class NeuroscienceEnhancer(nn.Module): """ Neuroscience-inspired enhancement layer for SEM V6. Integrates five biological learning mechanisms: 1. **STDP (Spike-Timing Dependent Plasticity)**: Temporal causality learning where synaptic strength changes based on relative timing of pre- and post-synaptic spikes (Bi & Poo, 1998). 2. **Sleep Consolidation**: Offline memory replay during designated "sleep" phases to strengthen important patterns (Wilson & McNaughton, 1994). 3. **Neuromodulator Dynamics**: Three neuromodulators (ACh, NE, 5-HT) dynamically adjust learning rates based on context: - ACh (Acetylcholine): Enhances plasticity during learning - NE (Norepinephrine): Increases exploration during arousal - 5-HT (Serotonin): Stabilizes weights during consolidation (Sara, 2009; Hasselmo, 2006) 4. **Lateral Inhibition**: Winner-take-all competition where highly active neurons suppress neighbors to enforce sparsity. 5. **Predictive Coding**: Error-driven learning where prediction errors propagate backward to update representations (Rao & Ballard, 1999). These mechanisms enhance Module C (ChebyKAN propagator) with biological realism while maintaining the frozen architecture constraint. References: - Bi, G., & Poo, M. (1998). Synaptic modifications in cultured hippocampal neurons: dependence on spike timing, synaptic strength, and postsynaptic cell type. Journal of Neuroscience, 18(24), 10464-10472. - Wilson, M. A., & McNaughton, B. L. (1994). Reactivation of hippocampal ensemble memories during sleep. Science, 265(5172), 676-679. - Rao, R. P., & Ballard, D. H. (1999). Predictive coding in the visual cortex: a functional interpretation of some extra-classical receptive-field effects. Nature Neuroscience, 2(1), 79-87. - Sara, S. J. (2009). The locus coeruleus and noradrenergic modulation of cognition. Nature Reviews Neuroscience, 10(3), 211-223. - Hasselmo, M. E. (2006). The role of acetylcholine in learning and memory. Current Opinion in Neurobiology, 16(6), 710-715. """ def __init__( self, manifold_dim: int = 16384, sparsity: float = 0.05, device: str = "cuda", enable_stdp: bool = True, enable_sleep: bool = True, enable_neuromodulation: bool = True, enable_lateral_inhibition: bool = True, enable_predictive_coding: bool = True, awake_steps: int = 1000, sleep_steps: int = 100, sleep_replay_batch: int = 32, buffer_max_size: int = 1000, ): """ Initialize NeuroscienceEnhancer. Args: manifold_dim: Dimensionality of the hypergraph manifold (default: 16384) sparsity: Target sparsity for lateral inhibition (default: 0.05, i.e., 5%) device: Device for tensor operations ('cuda' or 'cpu') enable_stdp: Enable STDP learning mechanism enable_sleep: Enable sleep consolidation system enable_neuromodulation: Enable neuromodulator dynamics (ACh, NE, 5-HT) enable_lateral_inhibition: Enable lateral inhibition (k-WTA) enable_predictive_coding: Enable predictive coding error computation awake_steps: Number of training steps per awake phase (default: 1000) sleep_steps: Number of replay steps per sleep phase (default: 100) sleep_replay_batch: Batch size for sleep replay (default: 32) buffer_max_size: Maximum size of episodic replay buffer (default: 1000) """ super().__init__() # GPU requirement check (per CLAUDE.md) assert torch.cuda.is_available(), "GPU required for Module E (per CLAUDE.md)" self.manifold_dim = manifold_dim self.sparsity = sparsity self.k = int(manifold_dim * sparsity) # Number of winners for k-WTA self.device = torch.device(device) # Feature flags self.enable_stdp = enable_stdp self.enable_sleep = enable_sleep self.enable_neuromodulation = enable_neuromodulation self.enable_lateral_inhibition = enable_lateral_inhibition self.enable_predictive_coding = enable_predictive_coding # Sleep/wake cycle scheduler (subtask-2-3) self.sleep_scheduler: Optional[SleepReplayScheduler] self.episodic_buffer: Optional[EpisodicBuffer] if self.enable_sleep: self.sleep_scheduler = SleepReplayScheduler( awake_steps=awake_steps, sleep_steps=sleep_steps, sleep_replay_batch=sleep_replay_batch ) # Episodic replay buffer (subtask-2-1) self.episodic_buffer = EpisodicBuffer( max_size=buffer_max_size, device=str(self.device) ) else: self.sleep_scheduler = None self.episodic_buffer = None # Neuromodulator levels (learnable parameters) # Initialized to biologically plausible baseline values if self.enable_neuromodulation: self.ach = nn.Parameter(torch.tensor(1.0, device=self.device)) # Acetylcholine self.ne = nn.Parameter(torch.tensor(0.5, device=self.device)) # Norepinephrine self.serotonin = nn.Parameter(torch.tensor(0.3, device=self.device)) # Serotonin # Placeholder for future components (to be implemented in subsequent subtasks) # - STDP learner (subtask-1-2) - to be integrated # - Predictive coding error module (subtask-3-3) - to be implemented def set_sleep_mode(self, is_sleeping: bool) -> None: """ Set sleep/wake mode for the enhancer. Args: is_sleeping: True for sleep mode (offline consolidation), False for awake mode (online learning) Note: When using the sleep scheduler, prefer using step() for automatic sleep/wake transitions instead of manually setting sleep mode. """ # Manual override (bypasses scheduler if enabled) if self.enable_sleep and self.sleep_scheduler is not None: # Synchronize scheduler state with manual override self.sleep_scheduler.awake = not is_sleeping def step(self) -> None: """ Advance one step in sleep/wake cycle. Automatically transitions between awake and sleep modes based on the configured scheduler. Should be called once per training step. Example: >>> enhancer = NeuroscienceEnhancer(enable_sleep=True) >>> for step in range(10000): ... if enhancer.is_awake(): ... # Online training ... loss = train_step(data) ... enhancer.add_episode(episode) ... else: ... # Sleep consolidation ... replay_batch = enhancer.sample_episodes(batch_size=32) ... consolidate(replay_batch) ... enhancer.step() """ if self.enable_sleep and self.sleep_scheduler is not None: self.sleep_scheduler.step() def is_awake(self) -> bool: """ Check if currently in awake (online learning) mode. Returns: True if awake, False if sleeping """ if self.enable_sleep and self.sleep_scheduler is not None: return self.sleep_scheduler.is_awake() return True # Default to awake if sleep disabled def is_sleeping(self) -> bool: """ Check if currently in sleep (offline consolidation) mode. Returns: True if sleeping, False if awake """ if self.enable_sleep and self.sleep_scheduler is not None: return self.sleep_scheduler.is_sleeping() return False # Default to not sleeping if sleep disabled def add_episode(self, episode: dict[str, Any]) -> None: """ Add episode to replay buffer during awake phase. Args: episode: Episode dictionary containing 'sdr', 'reward', 'timestamp' Example: >>> episode = { ... 'sdr': torch.randn(16384, device='cuda'), ... 'reward': 1.5, ... 'timestamp': 100.0 ... } >>> enhancer.add_episode(episode) """ if self.enable_sleep and self.episodic_buffer is not None: self.episodic_buffer.add(episode) def sample_episodes(self, batch_size: int, **kwargs: Any) -> list[dict[str, Any]]: """ Sample episodes from replay buffer for sleep consolidation. Args: batch_size: Number of episodes to sample **kwargs: Additional arguments passed to buffer.sample() (e.g., prioritize=True, reverse_temporal=True) Returns: List of episode dictionaries Example: >>> # Sample for reverse temporal replay during sleep >>> batch = enhancer.sample_episodes( ... batch_size=32, ... reverse_temporal=True ... ) """ if self.enable_sleep and self.episodic_buffer is not None: return self.episodic_buffer.sample(batch_size, **kwargs) return [] def get_phase_progress(self) -> float: """ Get progress through current sleep/wake phase. Returns: Progress as fraction [0, 1] (0.0 = phase start, 1.0 = phase end) """ if self.enable_sleep and self.sleep_scheduler is not None: return self.sleep_scheduler.get_phase_progress() return 0.0 def get_neuromodulator_states(self) -> tuple[float, float, float]: """ Get current neuromodulator levels. Returns: Tuple of (ACh, NE, Serotonin) levels """ if self.enable_neuromodulation: return ( self.ach.item(), self.ne.item(), self.serotonin.item() ) else: return (1.0, 0.0, 0.0) # Defaults when neuromodulation disabled def compute_effective_learning_rate(self, base_lr: float) -> float: """ Compute effective learning rate modulated by neuromodulators. Formula (per spec): lr_effective = base_lr * ach * (1 + ne) * (1 - 0.5 * serotonin) Args: base_lr: Base learning rate from optimizer Returns: Effective learning rate after neuromodulator modulation """ if not self.enable_neuromodulation: return base_lr # Clamp neuromodulators to safe range [0, 2] to prevent instability ach_clamped = torch.clamp(self.ach, 0.0, 2.0) ne_clamped = torch.clamp(self.ne, 0.0, 2.0) serotonin_clamped = torch.clamp(self.serotonin, 0.0, 2.0) lr_effective = base_lr * ach_clamped * (1 + ne_clamped) * (1 - 0.5 * serotonin_clamped) return cast(float, lr_effective.item()) def compute_predictive_coding_error( self, prediction: torch.Tensor, target: torch.Tensor, return_magnitude: bool = False, ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: """ Compute predictive coding error for error-driven learning. Implements Rao & Ballard (1999) predictive coding framework where prediction errors drive learning through hierarchical error propagation. Error neurons compute the difference between top-down predictions and bottom-up sensory input, and these errors are used to update representations at each level of the hierarchy. Mathematical formulation: error = target - prediction error_magnitude = ||error||_2 (L2 norm) In predictive coding, weights are updated proportional to error magnitude: Δw ∝ error_magnitude * gradient This creates a Bayesian inference framework where: - Higher-level predictions influence lower-level representations - Prediction errors are minimized through gradient descent - Hierarchical structure emerges naturally Args: prediction: Model's prediction (batch, manifold_dim) This represents the top-down prediction from higher hierarchical levels target: Ground truth target (batch, manifold_dim) This represents the bottom-up sensory input or desired output from lower hierarchical levels return_magnitude: If True, also return L2 norm of error (useful for monitoring convergence) Returns: Tuple of: - error: Prediction error tensor (batch, manifold_dim) Sign indicates direction of error (target > pred: +, target < pred: -) - error_magnitude: L2 norm of error per sample (batch,) Only returned if return_magnitude=True, else None Example: >>> enhancer = NeuroscienceEnhancer(manifold_dim=16384, device='cuda') >>> prediction = torch.randn(32, 16384, device='cuda') >>> target = torch.randn(32, 16384, device='cuda') >>> error, magnitude = enhancer.compute_predictive_coding_error( ... prediction, target, return_magnitude=True ... ) >>> # Use error for weight updates: Δw ∝ error >>> # Monitor magnitude to verify error reduction over training Note: In the hierarchical predictive coding framework: - Level N+1 predicts activity at Level N - Error at Level N = actual(N) - predicted(N) - This error is used to: 1. Update Level N+1's predictions (top-down) 2. Update Level N's representations (bottom-up) - Iterative minimization of prediction error across hierarchy implements Bayesian inference Reference: Rao, R. P., & Ballard, D. H. (1999). Predictive coding in the visual cortex: a functional interpretation of some extra-classical receptive-field effects. Nature Neuroscience, 2(1), 79-87. """ # Compute raw prediction error (target - prediction) # This represents the surprise signal that drives learning error = target - prediction # Optionally compute error magnitude for monitoring convergence error_magnitude = None if return_magnitude: # L2 norm per sample: ||error||_2 # Used to verify that error decreases over training iterations # (acceptance criterion from spec) error_magnitude = torch.norm(error, p=2, dim=1) return error, error_magnitude def apply_lateral_inhibition(self, activations: torch.Tensor) -> torch.Tensor: """ Apply k-Winners-Take-All lateral inhibition (vectorized). Keeps only the top-k activations, zeros out the rest to enforce sparsity. Args: activations: Input activations (batch, manifold_dim) Returns: Sparse activations with exactly k active neurons per sample """ if not self.enable_lateral_inhibition: return activations # Find top-k indices for each sample in batch _, top_k_indices = torch.topk(activations, self.k, dim=-1) # Vectorized: set top-k positions to 1 (no Python loop) sparse_activations = torch.zeros_like(activations) sparse_activations.scatter_(-1, top_k_indices, 1.0) return sparse_activations def forward( self, u: torch.Tensor, target: Optional[torch.Tensor] = None ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: """ Forward pass through neuroscience enhancement layer. Args: u: Input state from Module C propagator (batch, manifold_dim) target: Optional target for predictive coding error computation Returns: Tuple of: - Enhanced state after neuroscience mechanisms - Prediction error (if target provided and predictive coding enabled) """ enhanced_u = u prediction_error = None # Apply lateral inhibition (if enabled and awake) if self.enable_lateral_inhibition and not self.is_sleeping(): enhanced_u = self.apply_lateral_inhibition(enhanced_u) # Compute predictive coding error (if enabled and target provided) if self.enable_predictive_coding and target is not None: prediction_error, _ = self.compute_predictive_coding_error( prediction=enhanced_u, target=target, return_magnitude=False ) # Note: STDP updates and sleep consolidation are handled externally # via callbacks and training loop orchestration (to be implemented) return enhanced_u, prediction_error