sem-v6-training / src /sem_v6 /modules /module_e.py
icarus112's picture
Upload folder using huggingface_hub
518db7a verified
import torch
import torch.nn as nn
from typing import Optional, Any, cast
from ..utils.sleep import EpisodicBuffer, SleepReplayScheduler
class NeuroscienceEnhancer(nn.Module):
"""
Neuroscience-inspired enhancement layer for SEM V6.
Integrates five biological learning mechanisms:
1. **STDP (Spike-Timing Dependent Plasticity)**: Temporal causality learning
where synaptic strength changes based on relative timing of pre- and
post-synaptic spikes (Bi & Poo, 1998).
2. **Sleep Consolidation**: Offline memory replay during designated "sleep"
phases to strengthen important patterns (Wilson & McNaughton, 1994).
3. **Neuromodulator Dynamics**: Three neuromodulators (ACh, NE, 5-HT)
dynamically adjust learning rates based on context:
- ACh (Acetylcholine): Enhances plasticity during learning
- NE (Norepinephrine): Increases exploration during arousal
- 5-HT (Serotonin): Stabilizes weights during consolidation
(Sara, 2009; Hasselmo, 2006)
4. **Lateral Inhibition**: Winner-take-all competition where highly active
neurons suppress neighbors to enforce sparsity.
5. **Predictive Coding**: Error-driven learning where prediction errors
propagate backward to update representations (Rao & Ballard, 1999).
These mechanisms enhance Module C (ChebyKAN propagator) with biological
realism while maintaining the frozen architecture constraint.
References:
- Bi, G., & Poo, M. (1998). Synaptic modifications in cultured
hippocampal neurons: dependence on spike timing, synaptic strength,
and postsynaptic cell type. Journal of Neuroscience, 18(24), 10464-10472.
- Wilson, M. A., & McNaughton, B. L. (1994). Reactivation of hippocampal
ensemble memories during sleep. Science, 265(5172), 676-679.
- Rao, R. P., & Ballard, D. H. (1999). Predictive coding in the visual
cortex: a functional interpretation of some extra-classical
receptive-field effects. Nature Neuroscience, 2(1), 79-87.
- Sara, S. J. (2009). The locus coeruleus and noradrenergic modulation
of cognition. Nature Reviews Neuroscience, 10(3), 211-223.
- Hasselmo, M. E. (2006). The role of acetylcholine in learning and
memory. Current Opinion in Neurobiology, 16(6), 710-715.
"""
def __init__(
self,
manifold_dim: int = 16384,
sparsity: float = 0.05,
device: str = "cuda",
enable_stdp: bool = True,
enable_sleep: bool = True,
enable_neuromodulation: bool = True,
enable_lateral_inhibition: bool = True,
enable_predictive_coding: bool = True,
awake_steps: int = 1000,
sleep_steps: int = 100,
sleep_replay_batch: int = 32,
buffer_max_size: int = 1000,
):
"""
Initialize NeuroscienceEnhancer.
Args:
manifold_dim: Dimensionality of the hypergraph manifold (default: 16384)
sparsity: Target sparsity for lateral inhibition (default: 0.05, i.e., 5%)
device: Device for tensor operations ('cuda' or 'cpu')
enable_stdp: Enable STDP learning mechanism
enable_sleep: Enable sleep consolidation system
enable_neuromodulation: Enable neuromodulator dynamics (ACh, NE, 5-HT)
enable_lateral_inhibition: Enable lateral inhibition (k-WTA)
enable_predictive_coding: Enable predictive coding error computation
awake_steps: Number of training steps per awake phase (default: 1000)
sleep_steps: Number of replay steps per sleep phase (default: 100)
sleep_replay_batch: Batch size for sleep replay (default: 32)
buffer_max_size: Maximum size of episodic replay buffer (default: 1000)
"""
super().__init__()
# GPU requirement check (per CLAUDE.md)
assert torch.cuda.is_available(), "GPU required for Module E (per CLAUDE.md)"
self.manifold_dim = manifold_dim
self.sparsity = sparsity
self.k = int(manifold_dim * sparsity) # Number of winners for k-WTA
self.device = torch.device(device)
# Feature flags
self.enable_stdp = enable_stdp
self.enable_sleep = enable_sleep
self.enable_neuromodulation = enable_neuromodulation
self.enable_lateral_inhibition = enable_lateral_inhibition
self.enable_predictive_coding = enable_predictive_coding
# Sleep/wake cycle scheduler (subtask-2-3)
self.sleep_scheduler: Optional[SleepReplayScheduler]
self.episodic_buffer: Optional[EpisodicBuffer]
if self.enable_sleep:
self.sleep_scheduler = SleepReplayScheduler(
awake_steps=awake_steps,
sleep_steps=sleep_steps,
sleep_replay_batch=sleep_replay_batch
)
# Episodic replay buffer (subtask-2-1)
self.episodic_buffer = EpisodicBuffer(
max_size=buffer_max_size,
device=str(self.device)
)
else:
self.sleep_scheduler = None
self.episodic_buffer = None
# Neuromodulator levels (learnable parameters)
# Initialized to biologically plausible baseline values
if self.enable_neuromodulation:
self.ach = nn.Parameter(torch.tensor(1.0, device=self.device)) # Acetylcholine
self.ne = nn.Parameter(torch.tensor(0.5, device=self.device)) # Norepinephrine
self.serotonin = nn.Parameter(torch.tensor(0.3, device=self.device)) # Serotonin
# Placeholder for future components (to be implemented in subsequent subtasks)
# - STDP learner (subtask-1-2) - to be integrated
# - Predictive coding error module (subtask-3-3) - to be implemented
def set_sleep_mode(self, is_sleeping: bool) -> None:
"""
Set sleep/wake mode for the enhancer.
Args:
is_sleeping: True for sleep mode (offline consolidation),
False for awake mode (online learning)
Note:
When using the sleep scheduler, prefer using step() for automatic
sleep/wake transitions instead of manually setting sleep mode.
"""
# Manual override (bypasses scheduler if enabled)
if self.enable_sleep and self.sleep_scheduler is not None:
# Synchronize scheduler state with manual override
self.sleep_scheduler.awake = not is_sleeping
def step(self) -> None:
"""
Advance one step in sleep/wake cycle.
Automatically transitions between awake and sleep modes based on
the configured scheduler. Should be called once per training step.
Example:
>>> enhancer = NeuroscienceEnhancer(enable_sleep=True)
>>> for step in range(10000):
... if enhancer.is_awake():
... # Online training
... loss = train_step(data)
... enhancer.add_episode(episode)
... else:
... # Sleep consolidation
... replay_batch = enhancer.sample_episodes(batch_size=32)
... consolidate(replay_batch)
... enhancer.step()
"""
if self.enable_sleep and self.sleep_scheduler is not None:
self.sleep_scheduler.step()
def is_awake(self) -> bool:
"""
Check if currently in awake (online learning) mode.
Returns:
True if awake, False if sleeping
"""
if self.enable_sleep and self.sleep_scheduler is not None:
return self.sleep_scheduler.is_awake()
return True # Default to awake if sleep disabled
def is_sleeping(self) -> bool:
"""
Check if currently in sleep (offline consolidation) mode.
Returns:
True if sleeping, False if awake
"""
if self.enable_sleep and self.sleep_scheduler is not None:
return self.sleep_scheduler.is_sleeping()
return False # Default to not sleeping if sleep disabled
def add_episode(self, episode: dict[str, Any]) -> None:
"""
Add episode to replay buffer during awake phase.
Args:
episode: Episode dictionary containing 'sdr', 'reward', 'timestamp'
Example:
>>> episode = {
... 'sdr': torch.randn(16384, device='cuda'),
... 'reward': 1.5,
... 'timestamp': 100.0
... }
>>> enhancer.add_episode(episode)
"""
if self.enable_sleep and self.episodic_buffer is not None:
self.episodic_buffer.add(episode)
def sample_episodes(self, batch_size: int, **kwargs: Any) -> list[dict[str, Any]]:
"""
Sample episodes from replay buffer for sleep consolidation.
Args:
batch_size: Number of episodes to sample
**kwargs: Additional arguments passed to buffer.sample()
(e.g., prioritize=True, reverse_temporal=True)
Returns:
List of episode dictionaries
Example:
>>> # Sample for reverse temporal replay during sleep
>>> batch = enhancer.sample_episodes(
... batch_size=32,
... reverse_temporal=True
... )
"""
if self.enable_sleep and self.episodic_buffer is not None:
return self.episodic_buffer.sample(batch_size, **kwargs)
return []
def get_phase_progress(self) -> float:
"""
Get progress through current sleep/wake phase.
Returns:
Progress as fraction [0, 1] (0.0 = phase start, 1.0 = phase end)
"""
if self.enable_sleep and self.sleep_scheduler is not None:
return self.sleep_scheduler.get_phase_progress()
return 0.0
def get_neuromodulator_states(self) -> tuple[float, float, float]:
"""
Get current neuromodulator levels.
Returns:
Tuple of (ACh, NE, Serotonin) levels
"""
if self.enable_neuromodulation:
return (
self.ach.item(),
self.ne.item(),
self.serotonin.item()
)
else:
return (1.0, 0.0, 0.0) # Defaults when neuromodulation disabled
def compute_effective_learning_rate(self, base_lr: float) -> float:
"""
Compute effective learning rate modulated by neuromodulators.
Formula (per spec):
lr_effective = base_lr * ach * (1 + ne) * (1 - 0.5 * serotonin)
Args:
base_lr: Base learning rate from optimizer
Returns:
Effective learning rate after neuromodulator modulation
"""
if not self.enable_neuromodulation:
return base_lr
# Clamp neuromodulators to safe range [0, 2] to prevent instability
ach_clamped = torch.clamp(self.ach, 0.0, 2.0)
ne_clamped = torch.clamp(self.ne, 0.0, 2.0)
serotonin_clamped = torch.clamp(self.serotonin, 0.0, 2.0)
lr_effective = base_lr * ach_clamped * (1 + ne_clamped) * (1 - 0.5 * serotonin_clamped)
return cast(float, lr_effective.item())
def compute_predictive_coding_error(
self,
prediction: torch.Tensor,
target: torch.Tensor,
return_magnitude: bool = False,
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
"""
Compute predictive coding error for error-driven learning.
Implements Rao & Ballard (1999) predictive coding framework where
prediction errors drive learning through hierarchical error propagation.
Error neurons compute the difference between top-down predictions and
bottom-up sensory input, and these errors are used to update
representations at each level of the hierarchy.
Mathematical formulation:
error = target - prediction
error_magnitude = ||error||_2 (L2 norm)
In predictive coding, weights are updated proportional to error magnitude:
Δw ∝ error_magnitude * gradient
This creates a Bayesian inference framework where:
- Higher-level predictions influence lower-level representations
- Prediction errors are minimized through gradient descent
- Hierarchical structure emerges naturally
Args:
prediction: Model's prediction (batch, manifold_dim)
This represents the top-down prediction from higher
hierarchical levels
target: Ground truth target (batch, manifold_dim)
This represents the bottom-up sensory input or
desired output from lower hierarchical levels
return_magnitude: If True, also return L2 norm of error
(useful for monitoring convergence)
Returns:
Tuple of:
- error: Prediction error tensor (batch, manifold_dim)
Sign indicates direction of error (target > pred: +, target < pred: -)
- error_magnitude: L2 norm of error per sample (batch,)
Only returned if return_magnitude=True, else None
Example:
>>> enhancer = NeuroscienceEnhancer(manifold_dim=16384, device='cuda')
>>> prediction = torch.randn(32, 16384, device='cuda')
>>> target = torch.randn(32, 16384, device='cuda')
>>> error, magnitude = enhancer.compute_predictive_coding_error(
... prediction, target, return_magnitude=True
... )
>>> # Use error for weight updates: Δw ∝ error
>>> # Monitor magnitude to verify error reduction over training
Note:
In the hierarchical predictive coding framework:
- Level N+1 predicts activity at Level N
- Error at Level N = actual(N) - predicted(N)
- This error is used to:
1. Update Level N+1's predictions (top-down)
2. Update Level N's representations (bottom-up)
- Iterative minimization of prediction error across hierarchy
implements Bayesian inference
Reference:
Rao, R. P., & Ballard, D. H. (1999). Predictive coding in the visual
cortex: a functional interpretation of some extra-classical
receptive-field effects. Nature Neuroscience, 2(1), 79-87.
"""
# Compute raw prediction error (target - prediction)
# This represents the surprise signal that drives learning
error = target - prediction
# Optionally compute error magnitude for monitoring convergence
error_magnitude = None
if return_magnitude:
# L2 norm per sample: ||error||_2
# Used to verify that error decreases over training iterations
# (acceptance criterion from spec)
error_magnitude = torch.norm(error, p=2, dim=1)
return error, error_magnitude
def apply_lateral_inhibition(self, activations: torch.Tensor) -> torch.Tensor:
"""
Apply k-Winners-Take-All lateral inhibition (vectorized).
Keeps only the top-k activations, zeros out the rest to enforce sparsity.
Args:
activations: Input activations (batch, manifold_dim)
Returns:
Sparse activations with exactly k active neurons per sample
"""
if not self.enable_lateral_inhibition:
return activations
# Find top-k indices for each sample in batch
_, top_k_indices = torch.topk(activations, self.k, dim=-1)
# Vectorized: set top-k positions to 1 (no Python loop)
sparse_activations = torch.zeros_like(activations)
sparse_activations.scatter_(-1, top_k_indices, 1.0)
return sparse_activations
def forward(
self,
u: torch.Tensor,
target: Optional[torch.Tensor] = None
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
"""
Forward pass through neuroscience enhancement layer.
Args:
u: Input state from Module C propagator (batch, manifold_dim)
target: Optional target for predictive coding error computation
Returns:
Tuple of:
- Enhanced state after neuroscience mechanisms
- Prediction error (if target provided and predictive coding enabled)
"""
enhanced_u = u
prediction_error = None
# Apply lateral inhibition (if enabled and awake)
if self.enable_lateral_inhibition and not self.is_sleeping():
enhanced_u = self.apply_lateral_inhibition(enhanced_u)
# Compute predictive coding error (if enabled and target provided)
if self.enable_predictive_coding and target is not None:
prediction_error, _ = self.compute_predictive_coding_error(
prediction=enhanced_u,
target=target,
return_magnitude=False
)
# Note: STDP updates and sleep consolidation are handled externally
# via callbacks and training loop orchestration (to be implemented)
return enhanced_u, prediction_error