| import torch |
| import torch.nn as nn |
| from typing import Optional, Any, cast |
| from ..utils.sleep import EpisodicBuffer, SleepReplayScheduler |
|
|
| class NeuroscienceEnhancer(nn.Module): |
| """ |
| Neuroscience-inspired enhancement layer for SEM V6. |
| |
| Integrates five biological learning mechanisms: |
| |
| 1. **STDP (Spike-Timing Dependent Plasticity)**: Temporal causality learning |
| where synaptic strength changes based on relative timing of pre- and |
| post-synaptic spikes (Bi & Poo, 1998). |
| |
| 2. **Sleep Consolidation**: Offline memory replay during designated "sleep" |
| phases to strengthen important patterns (Wilson & McNaughton, 1994). |
| |
| 3. **Neuromodulator Dynamics**: Three neuromodulators (ACh, NE, 5-HT) |
| dynamically adjust learning rates based on context: |
| - ACh (Acetylcholine): Enhances plasticity during learning |
| - NE (Norepinephrine): Increases exploration during arousal |
| - 5-HT (Serotonin): Stabilizes weights during consolidation |
| (Sara, 2009; Hasselmo, 2006) |
| |
| 4. **Lateral Inhibition**: Winner-take-all competition where highly active |
| neurons suppress neighbors to enforce sparsity. |
| |
| 5. **Predictive Coding**: Error-driven learning where prediction errors |
| propagate backward to update representations (Rao & Ballard, 1999). |
| |
| These mechanisms enhance Module C (ChebyKAN propagator) with biological |
| realism while maintaining the frozen architecture constraint. |
| |
| References: |
| - Bi, G., & Poo, M. (1998). Synaptic modifications in cultured |
| hippocampal neurons: dependence on spike timing, synaptic strength, |
| and postsynaptic cell type. Journal of Neuroscience, 18(24), 10464-10472. |
| - Wilson, M. A., & McNaughton, B. L. (1994). Reactivation of hippocampal |
| ensemble memories during sleep. Science, 265(5172), 676-679. |
| - Rao, R. P., & Ballard, D. H. (1999). Predictive coding in the visual |
| cortex: a functional interpretation of some extra-classical |
| receptive-field effects. Nature Neuroscience, 2(1), 79-87. |
| - Sara, S. J. (2009). The locus coeruleus and noradrenergic modulation |
| of cognition. Nature Reviews Neuroscience, 10(3), 211-223. |
| - Hasselmo, M. E. (2006). The role of acetylcholine in learning and |
| memory. Current Opinion in Neurobiology, 16(6), 710-715. |
| """ |
|
|
| def __init__( |
| self, |
| manifold_dim: int = 16384, |
| sparsity: float = 0.05, |
| device: str = "cuda", |
| enable_stdp: bool = True, |
| enable_sleep: bool = True, |
| enable_neuromodulation: bool = True, |
| enable_lateral_inhibition: bool = True, |
| enable_predictive_coding: bool = True, |
| awake_steps: int = 1000, |
| sleep_steps: int = 100, |
| sleep_replay_batch: int = 32, |
| buffer_max_size: int = 1000, |
| ): |
| """ |
| Initialize NeuroscienceEnhancer. |
| |
| Args: |
| manifold_dim: Dimensionality of the hypergraph manifold (default: 16384) |
| sparsity: Target sparsity for lateral inhibition (default: 0.05, i.e., 5%) |
| device: Device for tensor operations ('cuda' or 'cpu') |
| enable_stdp: Enable STDP learning mechanism |
| enable_sleep: Enable sleep consolidation system |
| enable_neuromodulation: Enable neuromodulator dynamics (ACh, NE, 5-HT) |
| enable_lateral_inhibition: Enable lateral inhibition (k-WTA) |
| enable_predictive_coding: Enable predictive coding error computation |
| awake_steps: Number of training steps per awake phase (default: 1000) |
| sleep_steps: Number of replay steps per sleep phase (default: 100) |
| sleep_replay_batch: Batch size for sleep replay (default: 32) |
| buffer_max_size: Maximum size of episodic replay buffer (default: 1000) |
| """ |
| super().__init__() |
|
|
| |
| assert torch.cuda.is_available(), "GPU required for Module E (per CLAUDE.md)" |
|
|
| self.manifold_dim = manifold_dim |
| self.sparsity = sparsity |
| self.k = int(manifold_dim * sparsity) |
| self.device = torch.device(device) |
|
|
| |
| self.enable_stdp = enable_stdp |
| self.enable_sleep = enable_sleep |
| self.enable_neuromodulation = enable_neuromodulation |
| self.enable_lateral_inhibition = enable_lateral_inhibition |
| self.enable_predictive_coding = enable_predictive_coding |
|
|
| |
| self.sleep_scheduler: Optional[SleepReplayScheduler] |
| self.episodic_buffer: Optional[EpisodicBuffer] |
|
|
| if self.enable_sleep: |
| self.sleep_scheduler = SleepReplayScheduler( |
| awake_steps=awake_steps, |
| sleep_steps=sleep_steps, |
| sleep_replay_batch=sleep_replay_batch |
| ) |
| |
| self.episodic_buffer = EpisodicBuffer( |
| max_size=buffer_max_size, |
| device=str(self.device) |
| ) |
| else: |
| self.sleep_scheduler = None |
| self.episodic_buffer = None |
|
|
| |
| |
| if self.enable_neuromodulation: |
| self.ach = nn.Parameter(torch.tensor(1.0, device=self.device)) |
| self.ne = nn.Parameter(torch.tensor(0.5, device=self.device)) |
| self.serotonin = nn.Parameter(torch.tensor(0.3, device=self.device)) |
|
|
| |
| |
| |
|
|
| def set_sleep_mode(self, is_sleeping: bool) -> None: |
| """ |
| Set sleep/wake mode for the enhancer. |
| |
| Args: |
| is_sleeping: True for sleep mode (offline consolidation), |
| False for awake mode (online learning) |
| |
| Note: |
| When using the sleep scheduler, prefer using step() for automatic |
| sleep/wake transitions instead of manually setting sleep mode. |
| """ |
| |
| if self.enable_sleep and self.sleep_scheduler is not None: |
| |
| self.sleep_scheduler.awake = not is_sleeping |
|
|
| def step(self) -> None: |
| """ |
| Advance one step in sleep/wake cycle. |
| |
| Automatically transitions between awake and sleep modes based on |
| the configured scheduler. Should be called once per training step. |
| |
| Example: |
| >>> enhancer = NeuroscienceEnhancer(enable_sleep=True) |
| >>> for step in range(10000): |
| ... if enhancer.is_awake(): |
| ... # Online training |
| ... loss = train_step(data) |
| ... enhancer.add_episode(episode) |
| ... else: |
| ... # Sleep consolidation |
| ... replay_batch = enhancer.sample_episodes(batch_size=32) |
| ... consolidate(replay_batch) |
| ... enhancer.step() |
| """ |
| if self.enable_sleep and self.sleep_scheduler is not None: |
| self.sleep_scheduler.step() |
|
|
| def is_awake(self) -> bool: |
| """ |
| Check if currently in awake (online learning) mode. |
| |
| Returns: |
| True if awake, False if sleeping |
| """ |
| if self.enable_sleep and self.sleep_scheduler is not None: |
| return self.sleep_scheduler.is_awake() |
| return True |
|
|
| def is_sleeping(self) -> bool: |
| """ |
| Check if currently in sleep (offline consolidation) mode. |
| |
| Returns: |
| True if sleeping, False if awake |
| """ |
| if self.enable_sleep and self.sleep_scheduler is not None: |
| return self.sleep_scheduler.is_sleeping() |
| return False |
|
|
| def add_episode(self, episode: dict[str, Any]) -> None: |
| """ |
| Add episode to replay buffer during awake phase. |
| |
| Args: |
| episode: Episode dictionary containing 'sdr', 'reward', 'timestamp' |
| |
| Example: |
| >>> episode = { |
| ... 'sdr': torch.randn(16384, device='cuda'), |
| ... 'reward': 1.5, |
| ... 'timestamp': 100.0 |
| ... } |
| >>> enhancer.add_episode(episode) |
| """ |
| if self.enable_sleep and self.episodic_buffer is not None: |
| self.episodic_buffer.add(episode) |
|
|
| def sample_episodes(self, batch_size: int, **kwargs: Any) -> list[dict[str, Any]]: |
| """ |
| Sample episodes from replay buffer for sleep consolidation. |
| |
| Args: |
| batch_size: Number of episodes to sample |
| **kwargs: Additional arguments passed to buffer.sample() |
| (e.g., prioritize=True, reverse_temporal=True) |
| |
| Returns: |
| List of episode dictionaries |
| |
| Example: |
| >>> # Sample for reverse temporal replay during sleep |
| >>> batch = enhancer.sample_episodes( |
| ... batch_size=32, |
| ... reverse_temporal=True |
| ... ) |
| """ |
| if self.enable_sleep and self.episodic_buffer is not None: |
| return self.episodic_buffer.sample(batch_size, **kwargs) |
| return [] |
|
|
| def get_phase_progress(self) -> float: |
| """ |
| Get progress through current sleep/wake phase. |
| |
| Returns: |
| Progress as fraction [0, 1] (0.0 = phase start, 1.0 = phase end) |
| """ |
| if self.enable_sleep and self.sleep_scheduler is not None: |
| return self.sleep_scheduler.get_phase_progress() |
| return 0.0 |
|
|
| def get_neuromodulator_states(self) -> tuple[float, float, float]: |
| """ |
| Get current neuromodulator levels. |
| |
| Returns: |
| Tuple of (ACh, NE, Serotonin) levels |
| """ |
| if self.enable_neuromodulation: |
| return ( |
| self.ach.item(), |
| self.ne.item(), |
| self.serotonin.item() |
| ) |
| else: |
| return (1.0, 0.0, 0.0) |
|
|
| def compute_effective_learning_rate(self, base_lr: float) -> float: |
| """ |
| Compute effective learning rate modulated by neuromodulators. |
| |
| Formula (per spec): |
| lr_effective = base_lr * ach * (1 + ne) * (1 - 0.5 * serotonin) |
| |
| Args: |
| base_lr: Base learning rate from optimizer |
| |
| Returns: |
| Effective learning rate after neuromodulator modulation |
| """ |
| if not self.enable_neuromodulation: |
| return base_lr |
|
|
| |
| ach_clamped = torch.clamp(self.ach, 0.0, 2.0) |
| ne_clamped = torch.clamp(self.ne, 0.0, 2.0) |
| serotonin_clamped = torch.clamp(self.serotonin, 0.0, 2.0) |
|
|
| lr_effective = base_lr * ach_clamped * (1 + ne_clamped) * (1 - 0.5 * serotonin_clamped) |
|
|
| return cast(float, lr_effective.item()) |
|
|
| def compute_predictive_coding_error( |
| self, |
| prediction: torch.Tensor, |
| target: torch.Tensor, |
| return_magnitude: bool = False, |
| ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: |
| """ |
| Compute predictive coding error for error-driven learning. |
| |
| Implements Rao & Ballard (1999) predictive coding framework where |
| prediction errors drive learning through hierarchical error propagation. |
| Error neurons compute the difference between top-down predictions and |
| bottom-up sensory input, and these errors are used to update |
| representations at each level of the hierarchy. |
| |
| Mathematical formulation: |
| error = target - prediction |
| error_magnitude = ||error||_2 (L2 norm) |
| |
| In predictive coding, weights are updated proportional to error magnitude: |
| Δw ∝ error_magnitude * gradient |
| |
| This creates a Bayesian inference framework where: |
| - Higher-level predictions influence lower-level representations |
| - Prediction errors are minimized through gradient descent |
| - Hierarchical structure emerges naturally |
| |
| Args: |
| prediction: Model's prediction (batch, manifold_dim) |
| This represents the top-down prediction from higher |
| hierarchical levels |
| target: Ground truth target (batch, manifold_dim) |
| This represents the bottom-up sensory input or |
| desired output from lower hierarchical levels |
| return_magnitude: If True, also return L2 norm of error |
| (useful for monitoring convergence) |
| |
| Returns: |
| Tuple of: |
| - error: Prediction error tensor (batch, manifold_dim) |
| Sign indicates direction of error (target > pred: +, target < pred: -) |
| - error_magnitude: L2 norm of error per sample (batch,) |
| Only returned if return_magnitude=True, else None |
| |
| Example: |
| >>> enhancer = NeuroscienceEnhancer(manifold_dim=16384, device='cuda') |
| >>> prediction = torch.randn(32, 16384, device='cuda') |
| >>> target = torch.randn(32, 16384, device='cuda') |
| >>> error, magnitude = enhancer.compute_predictive_coding_error( |
| ... prediction, target, return_magnitude=True |
| ... ) |
| >>> # Use error for weight updates: Δw ∝ error |
| >>> # Monitor magnitude to verify error reduction over training |
| |
| Note: |
| In the hierarchical predictive coding framework: |
| - Level N+1 predicts activity at Level N |
| - Error at Level N = actual(N) - predicted(N) |
| - This error is used to: |
| 1. Update Level N+1's predictions (top-down) |
| 2. Update Level N's representations (bottom-up) |
| - Iterative minimization of prediction error across hierarchy |
| implements Bayesian inference |
| |
| Reference: |
| Rao, R. P., & Ballard, D. H. (1999). Predictive coding in the visual |
| cortex: a functional interpretation of some extra-classical |
| receptive-field effects. Nature Neuroscience, 2(1), 79-87. |
| """ |
| |
| |
| error = target - prediction |
|
|
| |
| error_magnitude = None |
| if return_magnitude: |
| |
| |
| |
| error_magnitude = torch.norm(error, p=2, dim=1) |
|
|
| return error, error_magnitude |
|
|
| def apply_lateral_inhibition(self, activations: torch.Tensor) -> torch.Tensor: |
| """ |
| Apply k-Winners-Take-All lateral inhibition (vectorized). |
| |
| Keeps only the top-k activations, zeros out the rest to enforce sparsity. |
| |
| Args: |
| activations: Input activations (batch, manifold_dim) |
| |
| Returns: |
| Sparse activations with exactly k active neurons per sample |
| """ |
| if not self.enable_lateral_inhibition: |
| return activations |
|
|
| |
| _, top_k_indices = torch.topk(activations, self.k, dim=-1) |
|
|
| |
| sparse_activations = torch.zeros_like(activations) |
| sparse_activations.scatter_(-1, top_k_indices, 1.0) |
|
|
| return sparse_activations |
|
|
| def forward( |
| self, |
| u: torch.Tensor, |
| target: Optional[torch.Tensor] = None |
| ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: |
| """ |
| Forward pass through neuroscience enhancement layer. |
| |
| Args: |
| u: Input state from Module C propagator (batch, manifold_dim) |
| target: Optional target for predictive coding error computation |
| |
| Returns: |
| Tuple of: |
| - Enhanced state after neuroscience mechanisms |
| - Prediction error (if target provided and predictive coding enabled) |
| """ |
| enhanced_u = u |
| prediction_error = None |
|
|
| |
| if self.enable_lateral_inhibition and not self.is_sleeping(): |
| enhanced_u = self.apply_lateral_inhibition(enhanced_u) |
|
|
| |
| if self.enable_predictive_coding and target is not None: |
| prediction_error, _ = self.compute_predictive_coding_error( |
| prediction=enhanced_u, |
| target=target, |
| return_magnitude=False |
| ) |
|
|
| |
| |
|
|
| return enhanced_u, prediction_error |
|
|