File size: 17,431 Bytes
518db7a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
import torch
import torch.nn as nn
from typing import Optional, Any, cast
from ..utils.sleep import EpisodicBuffer, SleepReplayScheduler

class NeuroscienceEnhancer(nn.Module):
    """
    Neuroscience-inspired enhancement layer for SEM V6.

    Integrates five biological learning mechanisms:

    1. **STDP (Spike-Timing Dependent Plasticity)**: Temporal causality learning
       where synaptic strength changes based on relative timing of pre- and
       post-synaptic spikes (Bi & Poo, 1998).

    2. **Sleep Consolidation**: Offline memory replay during designated "sleep"
       phases to strengthen important patterns (Wilson & McNaughton, 1994).

    3. **Neuromodulator Dynamics**: Three neuromodulators (ACh, NE, 5-HT)
       dynamically adjust learning rates based on context:
       - ACh (Acetylcholine): Enhances plasticity during learning
       - NE (Norepinephrine): Increases exploration during arousal
       - 5-HT (Serotonin): Stabilizes weights during consolidation
       (Sara, 2009; Hasselmo, 2006)

    4. **Lateral Inhibition**: Winner-take-all competition where highly active
       neurons suppress neighbors to enforce sparsity.

    5. **Predictive Coding**: Error-driven learning where prediction errors
       propagate backward to update representations (Rao & Ballard, 1999).

    These mechanisms enhance Module C (ChebyKAN propagator) with biological
    realism while maintaining the frozen architecture constraint.

    References:
        - Bi, G., & Poo, M. (1998). Synaptic modifications in cultured
          hippocampal neurons: dependence on spike timing, synaptic strength,
          and postsynaptic cell type. Journal of Neuroscience, 18(24), 10464-10472.
        - Wilson, M. A., & McNaughton, B. L. (1994). Reactivation of hippocampal
          ensemble memories during sleep. Science, 265(5172), 676-679.
        - Rao, R. P., & Ballard, D. H. (1999). Predictive coding in the visual
          cortex: a functional interpretation of some extra-classical
          receptive-field effects. Nature Neuroscience, 2(1), 79-87.
        - Sara, S. J. (2009). The locus coeruleus and noradrenergic modulation
          of cognition. Nature Reviews Neuroscience, 10(3), 211-223.
        - Hasselmo, M. E. (2006). The role of acetylcholine in learning and
          memory. Current Opinion in Neurobiology, 16(6), 710-715.
    """

    def __init__(
        self,
        manifold_dim: int = 16384,
        sparsity: float = 0.05,
        device: str = "cuda",
        enable_stdp: bool = True,
        enable_sleep: bool = True,
        enable_neuromodulation: bool = True,
        enable_lateral_inhibition: bool = True,
        enable_predictive_coding: bool = True,
        awake_steps: int = 1000,
        sleep_steps: int = 100,
        sleep_replay_batch: int = 32,
        buffer_max_size: int = 1000,
    ):
        """
        Initialize NeuroscienceEnhancer.

        Args:
            manifold_dim: Dimensionality of the hypergraph manifold (default: 16384)
            sparsity: Target sparsity for lateral inhibition (default: 0.05, i.e., 5%)
            device: Device for tensor operations ('cuda' or 'cpu')
            enable_stdp: Enable STDP learning mechanism
            enable_sleep: Enable sleep consolidation system
            enable_neuromodulation: Enable neuromodulator dynamics (ACh, NE, 5-HT)
            enable_lateral_inhibition: Enable lateral inhibition (k-WTA)
            enable_predictive_coding: Enable predictive coding error computation
            awake_steps: Number of training steps per awake phase (default: 1000)
            sleep_steps: Number of replay steps per sleep phase (default: 100)
            sleep_replay_batch: Batch size for sleep replay (default: 32)
            buffer_max_size: Maximum size of episodic replay buffer (default: 1000)
        """
        super().__init__()

        # GPU requirement check (per CLAUDE.md)
        assert torch.cuda.is_available(), "GPU required for Module E (per CLAUDE.md)"

        self.manifold_dim = manifold_dim
        self.sparsity = sparsity
        self.k = int(manifold_dim * sparsity)  # Number of winners for k-WTA
        self.device = torch.device(device)

        # Feature flags
        self.enable_stdp = enable_stdp
        self.enable_sleep = enable_sleep
        self.enable_neuromodulation = enable_neuromodulation
        self.enable_lateral_inhibition = enable_lateral_inhibition
        self.enable_predictive_coding = enable_predictive_coding

        # Sleep/wake cycle scheduler (subtask-2-3)
        self.sleep_scheduler: Optional[SleepReplayScheduler]
        self.episodic_buffer: Optional[EpisodicBuffer]

        if self.enable_sleep:
            self.sleep_scheduler = SleepReplayScheduler(
                awake_steps=awake_steps,
                sleep_steps=sleep_steps,
                sleep_replay_batch=sleep_replay_batch
            )
            # Episodic replay buffer (subtask-2-1)
            self.episodic_buffer = EpisodicBuffer(
                max_size=buffer_max_size,
                device=str(self.device)
            )
        else:
            self.sleep_scheduler = None
            self.episodic_buffer = None

        # Neuromodulator levels (learnable parameters)
        # Initialized to biologically plausible baseline values
        if self.enable_neuromodulation:
            self.ach = nn.Parameter(torch.tensor(1.0, device=self.device))  # Acetylcholine
            self.ne = nn.Parameter(torch.tensor(0.5, device=self.device))   # Norepinephrine
            self.serotonin = nn.Parameter(torch.tensor(0.3, device=self.device))  # Serotonin

        # Placeholder for future components (to be implemented in subsequent subtasks)
        # - STDP learner (subtask-1-2) - to be integrated
        # - Predictive coding error module (subtask-3-3) - to be implemented

    def set_sleep_mode(self, is_sleeping: bool) -> None:
        """
        Set sleep/wake mode for the enhancer.

        Args:
            is_sleeping: True for sleep mode (offline consolidation),
                        False for awake mode (online learning)

        Note:
            When using the sleep scheduler, prefer using step() for automatic
            sleep/wake transitions instead of manually setting sleep mode.
        """
        # Manual override (bypasses scheduler if enabled)
        if self.enable_sleep and self.sleep_scheduler is not None:
            # Synchronize scheduler state with manual override
            self.sleep_scheduler.awake = not is_sleeping

    def step(self) -> None:
        """
        Advance one step in sleep/wake cycle.

        Automatically transitions between awake and sleep modes based on
        the configured scheduler. Should be called once per training step.

        Example:
            >>> enhancer = NeuroscienceEnhancer(enable_sleep=True)
            >>> for step in range(10000):
            ...     if enhancer.is_awake():
            ...         # Online training
            ...         loss = train_step(data)
            ...         enhancer.add_episode(episode)
            ...     else:
            ...         # Sleep consolidation
            ...         replay_batch = enhancer.sample_episodes(batch_size=32)
            ...         consolidate(replay_batch)
            ...     enhancer.step()
        """
        if self.enable_sleep and self.sleep_scheduler is not None:
            self.sleep_scheduler.step()

    def is_awake(self) -> bool:
        """
        Check if currently in awake (online learning) mode.

        Returns:
            True if awake, False if sleeping
        """
        if self.enable_sleep and self.sleep_scheduler is not None:
            return self.sleep_scheduler.is_awake()
        return True  # Default to awake if sleep disabled

    def is_sleeping(self) -> bool:
        """
        Check if currently in sleep (offline consolidation) mode.

        Returns:
            True if sleeping, False if awake
        """
        if self.enable_sleep and self.sleep_scheduler is not None:
            return self.sleep_scheduler.is_sleeping()
        return False  # Default to not sleeping if sleep disabled

    def add_episode(self, episode: dict[str, Any]) -> None:
        """
        Add episode to replay buffer during awake phase.

        Args:
            episode: Episode dictionary containing 'sdr', 'reward', 'timestamp'

        Example:
            >>> episode = {
            ...     'sdr': torch.randn(16384, device='cuda'),
            ...     'reward': 1.5,
            ...     'timestamp': 100.0
            ... }
            >>> enhancer.add_episode(episode)
        """
        if self.enable_sleep and self.episodic_buffer is not None:
            self.episodic_buffer.add(episode)

    def sample_episodes(self, batch_size: int, **kwargs: Any) -> list[dict[str, Any]]:
        """
        Sample episodes from replay buffer for sleep consolidation.

        Args:
            batch_size: Number of episodes to sample
            **kwargs: Additional arguments passed to buffer.sample()
                     (e.g., prioritize=True, reverse_temporal=True)

        Returns:
            List of episode dictionaries

        Example:
            >>> # Sample for reverse temporal replay during sleep
            >>> batch = enhancer.sample_episodes(
            ...     batch_size=32,
            ...     reverse_temporal=True
            ... )
        """
        if self.enable_sleep and self.episodic_buffer is not None:
            return self.episodic_buffer.sample(batch_size, **kwargs)
        return []

    def get_phase_progress(self) -> float:
        """
        Get progress through current sleep/wake phase.

        Returns:
            Progress as fraction [0, 1] (0.0 = phase start, 1.0 = phase end)
        """
        if self.enable_sleep and self.sleep_scheduler is not None:
            return self.sleep_scheduler.get_phase_progress()
        return 0.0

    def get_neuromodulator_states(self) -> tuple[float, float, float]:
        """
        Get current neuromodulator levels.

        Returns:
            Tuple of (ACh, NE, Serotonin) levels
        """
        if self.enable_neuromodulation:
            return (
                self.ach.item(),
                self.ne.item(),
                self.serotonin.item()
            )
        else:
            return (1.0, 0.0, 0.0)  # Defaults when neuromodulation disabled

    def compute_effective_learning_rate(self, base_lr: float) -> float:
        """
        Compute effective learning rate modulated by neuromodulators.

        Formula (per spec):
        lr_effective = base_lr * ach * (1 + ne) * (1 - 0.5 * serotonin)

        Args:
            base_lr: Base learning rate from optimizer

        Returns:
            Effective learning rate after neuromodulator modulation
        """
        if not self.enable_neuromodulation:
            return base_lr

        # Clamp neuromodulators to safe range [0, 2] to prevent instability
        ach_clamped = torch.clamp(self.ach, 0.0, 2.0)
        ne_clamped = torch.clamp(self.ne, 0.0, 2.0)
        serotonin_clamped = torch.clamp(self.serotonin, 0.0, 2.0)

        lr_effective = base_lr * ach_clamped * (1 + ne_clamped) * (1 - 0.5 * serotonin_clamped)

        return cast(float, lr_effective.item())

    def compute_predictive_coding_error(
        self,
        prediction: torch.Tensor,
        target: torch.Tensor,
        return_magnitude: bool = False,
    ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
        """
        Compute predictive coding error for error-driven learning.

        Implements Rao & Ballard (1999) predictive coding framework where
        prediction errors drive learning through hierarchical error propagation.
        Error neurons compute the difference between top-down predictions and
        bottom-up sensory input, and these errors are used to update
        representations at each level of the hierarchy.

        Mathematical formulation:
            error = target - prediction
            error_magnitude = ||error||_2 (L2 norm)

        In predictive coding, weights are updated proportional to error magnitude:
            Δw ∝ error_magnitude * gradient

        This creates a Bayesian inference framework where:
        - Higher-level predictions influence lower-level representations
        - Prediction errors are minimized through gradient descent
        - Hierarchical structure emerges naturally

        Args:
            prediction: Model's prediction (batch, manifold_dim)
                       This represents the top-down prediction from higher
                       hierarchical levels
            target: Ground truth target (batch, manifold_dim)
                   This represents the bottom-up sensory input or
                   desired output from lower hierarchical levels
            return_magnitude: If True, also return L2 norm of error
                             (useful for monitoring convergence)

        Returns:
            Tuple of:
                - error: Prediction error tensor (batch, manifold_dim)
                         Sign indicates direction of error (target > pred: +, target < pred: -)
                - error_magnitude: L2 norm of error per sample (batch,)
                                  Only returned if return_magnitude=True, else None

        Example:
            >>> enhancer = NeuroscienceEnhancer(manifold_dim=16384, device='cuda')
            >>> prediction = torch.randn(32, 16384, device='cuda')
            >>> target = torch.randn(32, 16384, device='cuda')
            >>> error, magnitude = enhancer.compute_predictive_coding_error(
            ...     prediction, target, return_magnitude=True
            ... )
            >>> # Use error for weight updates: Δw ∝ error
            >>> # Monitor magnitude to verify error reduction over training

        Note:
            In the hierarchical predictive coding framework:
            - Level N+1 predicts activity at Level N
            - Error at Level N = actual(N) - predicted(N)
            - This error is used to:
              1. Update Level N+1's predictions (top-down)
              2. Update Level N's representations (bottom-up)
            - Iterative minimization of prediction error across hierarchy
              implements Bayesian inference

        Reference:
            Rao, R. P., & Ballard, D. H. (1999). Predictive coding in the visual
            cortex: a functional interpretation of some extra-classical
            receptive-field effects. Nature Neuroscience, 2(1), 79-87.
        """
        # Compute raw prediction error (target - prediction)
        # This represents the surprise signal that drives learning
        error = target - prediction

        # Optionally compute error magnitude for monitoring convergence
        error_magnitude = None
        if return_magnitude:
            # L2 norm per sample: ||error||_2
            # Used to verify that error decreases over training iterations
            # (acceptance criterion from spec)
            error_magnitude = torch.norm(error, p=2, dim=1)

        return error, error_magnitude

    def apply_lateral_inhibition(self, activations: torch.Tensor) -> torch.Tensor:
        """
        Apply k-Winners-Take-All lateral inhibition (vectorized).

        Keeps only the top-k activations, zeros out the rest to enforce sparsity.

        Args:
            activations: Input activations (batch, manifold_dim)

        Returns:
            Sparse activations with exactly k active neurons per sample
        """
        if not self.enable_lateral_inhibition:
            return activations

        # Find top-k indices for each sample in batch
        _, top_k_indices = torch.topk(activations, self.k, dim=-1)

        # Vectorized: set top-k positions to 1 (no Python loop)
        sparse_activations = torch.zeros_like(activations)
        sparse_activations.scatter_(-1, top_k_indices, 1.0)

        return sparse_activations

    def forward(
        self,
        u: torch.Tensor,
        target: Optional[torch.Tensor] = None
    ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
        """
        Forward pass through neuroscience enhancement layer.

        Args:
            u: Input state from Module C propagator (batch, manifold_dim)
            target: Optional target for predictive coding error computation

        Returns:
            Tuple of:
                - Enhanced state after neuroscience mechanisms
                - Prediction error (if target provided and predictive coding enabled)
        """
        enhanced_u = u
        prediction_error = None

        # Apply lateral inhibition (if enabled and awake)
        if self.enable_lateral_inhibition and not self.is_sleeping():
            enhanced_u = self.apply_lateral_inhibition(enhanced_u)

        # Compute predictive coding error (if enabled and target provided)
        if self.enable_predictive_coding and target is not None:
            prediction_error, _ = self.compute_predictive_coding_error(
                prediction=enhanced_u,
                target=target,
                return_magnitude=False
            )

        # Note: STDP updates and sleep consolidation are handled externally
        # via callbacks and training loop orchestration (to be implemented)

        return enhanced_u, prediction_error