| import torch |
| import torch.nn as nn |
| import sys |
| import os |
| import logging |
| from typing import Optional, Callable |
| from torchdiffeq import odeint, odeint_adjoint |
|
|
| |
| sys.path.insert( |
| 0, os.path.join(os.path.dirname(__file__), "../../..", "ChebyKan_cuda_op") |
| ) |
|
|
| from cuChebyKan.layer import ChebyKANLayer as OfficialChebyKANLayer |
|
|
| _USE_CUDA_KAN = True |
|
|
| logger = logging.getLogger(__name__) |
|
|
|
|
| class ChebyKANLayer(nn.Module): |
| """ |
| Wrapper for official ChebyKAN implementation. |
| |
| Uses Da1sypetals/ChebyKan-cuda-op (CUDA-accelerated) |
| Reference: https://github.com/Da1sypetals/ChebyKan-cuda-op |
| Original: https://github.com/SynodicMonth/ChebyKAN |
| |
| KAN Paper: Liu et al. (2024) - Kolmogorov-Arnold Networks |
| """ |
|
|
| def __init__( |
| self, |
| in_features: int, |
| out_features: int, |
| degree: int = 5, |
| use_tanh: bool = True, |
| ): |
| super().__init__() |
|
|
| |
| assert _USE_CUDA_KAN, ( |
| "CUDA ChebyKAN is required. Install cuChebyKan from ChebyKan_cuda_op/" |
| ) |
| self.kan = OfficialChebyKANLayer( |
| input_dim=in_features, output_dim=out_features, degree=degree |
| ) |
|
|
| if hasattr(self.kan, "cheby_coeffs"): |
|
|
| def _sanitize_grad(grad: torch.Tensor) -> torch.Tensor: |
| if torch.isfinite(grad).all(): |
| return grad |
| return torch.nan_to_num(grad, nan=0.0, posinf=0.0, neginf=0.0) |
|
|
| self.kan.cheby_coeffs.register_hook(_sanitize_grad) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| """ |
| Forward pass through ChebyKAN layer. |
| |
| Args: |
| x: Input (batch, in_features) |
| |
| Returns: |
| y: Output (batch, out_features) |
| """ |
| return self.kan(x) |
|
|
|
|
| class SolitonPropagator(nn.Module): |
| """ |
| Neural ODE-based soliton wave propagator. |
| |
| Models inference as continuous dynamics: |
| ∂U/∂t = -LU + ChebyKAN(U) |
| |
| Where: |
| - L: Hypergraph Laplacian (diffusion) |
| - ChebyKAN(U): Learnable reaction term (soliton non-linearity) |
| |
| Latent Space Operation: |
| For memory-efficient training, this propagator can operate in a low-dimensional |
| latent space instead of the full SDR dimension. Based on: |
| |
| - Rubanova et al. (2019): "Latent ODEs for Irregularly-Sampled Time Series" |
| https://arxiv.org/abs/1907.03907 |
| |
| Memory savings (O(d²) vs O(N²)): |
| - Full SDR (2048): 2048² × 4 = 16.7M params, ~4GB VRAM |
| - Latent (256): 256² × 4 = 262K params, <1GB VRAM |
| - Reduction: ~64x fewer parameters |
| |
| Use the `for_latent_space()` factory method for latent space operation. |
| |
| Reference: DNLSE (Discrete Nonlinear Schrödinger Equation) |
| """ |
|
|
| def __init__( |
| self, |
| manifold_dim: int = 16384, |
| kan_degree: int = 5, |
| solver: str = "rk4", |
| rtol: float = 1e-3, |
| atol: float = 1e-4, |
| step_size: float = 0.01, |
| ): |
| """ |
| Initialize the soliton propagator. |
| |
| Args: |
| manifold_dim: Dimension of the state space to propagate. For latent space |
| operation, pass latent_dim (e.g., 256) instead of sdr_dim (e.g., 2048). |
| The Johnson-Lindenstrauss lemma suggests latent_dim >= 4*ln(N) for |
| preserving pairwise distances, so 256 is optimal for 2048-dim SDRs. |
| kan_degree: Degree of Chebyshev polynomials in ChebyKAN (default 5). |
| Higher degrees capture more complex dynamics but increase compute. |
| solver: ODE solver method. Options: |
| - 'rk4': Fixed-step Runge-Kutta 4 (default, stable for training) |
| - 'euler': Simple Euler (fast but less accurate) |
| - 'dopri5': Adaptive Dormand-Prince 5(4) (can cause dt underflow) |
| - 'bdf': Backward differentiation (for stiff systems) |
| rtol: Relative tolerance for adaptive solvers (default 1e-3). |
| atol: Absolute tolerance for adaptive solvers (default 1e-4). |
| step_size: Step size for fixed-step solvers like rk4 (default 0.01). |
| |
| Memory Scaling: |
| ChebyKAN parameters scale as O(manifold_dim² × kan_degree). |
| - manifold_dim=2048, degree=5: ~21M params |
| - manifold_dim=256, degree=5: ~328K params |
| |
| Example: |
| >>> # Full SDR operation (high memory) |
| >>> propagator = SolitonPropagator(manifold_dim=2048) |
| >>> |
| >>> # Latent space operation (memory-efficient) |
| >>> propagator = SolitonPropagator(manifold_dim=256) |
| >>> # Or use the factory method: |
| >>> propagator = SolitonPropagator.for_latent_space(latent_dim=256) |
| |
| References: |
| - Rubanova et al. (2019): Latent ODEs for Irregularly-Sampled Time Series |
| - Liu et al. (2024): Kolmogorov-Arnold Networks (KAN paper) |
| - torchdiffeq FAQ: https://github.com/rtqichen/torchdiffeq/blob/master/FAQ.md |
| """ |
| super().__init__() |
| self.manifold_dim = manifold_dim |
| self.solver = solver |
| self.rtol = rtol |
| self.atol = atol |
| self.step_size = step_size |
|
|
| self.kan_reaction = ChebyKANLayer( |
| in_features=manifold_dim, out_features=manifold_dim, degree=kan_degree |
| ) |
|
|
| |
| |
| self.register_buffer("laplacian", None) |
| self._laplacian_scale = 0.5 |
|
|
| |
| |
| self._sleep_mode: bool = False |
| self._stdp_callback: Optional[Callable[[torch.Tensor, torch.Tensor], None]] = ( |
| None |
| ) |
| self._neuromodulator_state: Optional[dict[str, float]] = None |
|
|
| @classmethod |
| def for_latent_space( |
| cls, |
| latent_dim: int = 256, |
| kan_degree: int = 5, |
| solver: str = "rk4", |
| rtol: float = 1e-3, |
| atol: float = 1e-4, |
| step_size: float = 0.01, |
| ) -> "SolitonPropagator": |
| """ |
| Create propagator optimized for latent space dynamics. |
| |
| This factory method creates a memory-efficient propagator that operates |
| in a low-dimensional latent space rather than the full SDR dimension. |
| The latent space approach is based on Latent ODEs (Rubanova et al. 2019), |
| which demonstrate that ODE dynamics can be learned effectively in |
| compressed representations. |
| |
| Memory savings vs full SDR: |
| - SDR (2048): 2048² × 5 = 21M params, ~4GB VRAM |
| - Latent (256): 256² × 5 = 328K params, <1GB VRAM |
| - Reduction: ~64x fewer parameters |
| |
| Mathematical justification: |
| The Johnson-Lindenstrauss lemma guarantees that for N points in |
| high-dimensional space, a random projection to d >= 4*ln(N)/ε² |
| dimensions preserves pairwise distances within (1±ε). For 2048-dim |
| SDRs with ε=0.1, this gives d >= 256. |
| |
| Additionally, diffeomorphism preservation ensures that smooth |
| dynamics in the original space map to smooth dynamics in latent |
| space, justifying the latent ODE approach. |
| |
| Args: |
| latent_dim: Dimension of latent space (default 256, JL-optimal for 2048 SDR). |
| Common choices: |
| - 128: Aggressive compression, may lose fine structure |
| - 256: Recommended balance (default) |
| - 512: Higher fidelity, 4x more params than 256 |
| kan_degree: Chebyshev polynomial degree for ChebyKAN (default 5). |
| solver: ODE solver method (default 'rk4' for stability). |
| rtol: Relative tolerance for adaptive solvers. |
| atol: Absolute tolerance for adaptive solvers. |
| step_size: Step size for fixed-step solvers. |
| |
| Returns: |
| SolitonPropagator configured for latent space operation. |
| |
| Example: |
| >>> # Create latent space propagator |
| >>> propagator = SolitonPropagator.for_latent_space(latent_dim=256) |
| >>> propagator = propagator.cuda() |
| >>> |
| >>> # Forward pass with latent vectors |
| >>> z = torch.randn(batch_size, 256, device='cuda') |
| >>> z_next = propagator(z) # z_next is (batch_size, 256) |
| >>> |
| >>> # Use with VAE encoder/decoder for full pipeline: |
| >>> # x (2048) -> encoder -> z (256) -> propagator -> z' (256) -> decoder -> x' (2048) |
| |
| References: |
| - Rubanova et al. (2019): Latent ODEs for Irregularly-Sampled Time Series |
| https://arxiv.org/abs/1907.03907 |
| - Johnson & Lindenstrauss (1984): Extensions of Lipschitz mappings |
| """ |
| return cls( |
| manifold_dim=latent_dim, |
| kan_degree=kan_degree, |
| solver=solver, |
| rtol=rtol, |
| atol=atol, |
| step_size=step_size, |
| ) |
|
|
| def set_laplacian(self, L: torch.Tensor) -> None: |
| """Update Laplacian from hypergraph structure.""" |
| self.laplacian = L |
|
|
| |
|
|
| def set_sleep_mode(self, is_sleeping: bool) -> None: |
| """ |
| Set sleep/wake mode for neuroscience enhancements (Module E integration). |
| |
| During sleep mode, Module E can perform offline memory consolidation |
| with episodic replay and STDP-driven weight updates. |
| |
| Args: |
| is_sleeping: True for sleep mode (offline consolidation), |
| False for awake mode (online learning) |
| |
| Note: |
| This is an optional integration hook that doesn't affect Module C's |
| core functionality unless Module E is explicitly integrated. |
| """ |
| self._sleep_mode = is_sleeping |
|
|
| def is_sleeping(self) -> bool: |
| """ |
| Check if currently in sleep mode (Module E integration). |
| |
| Returns: |
| True if in sleep mode, False if in awake mode |
| """ |
| return self._sleep_mode |
|
|
| def register_stdp_callback( |
| self, callback: Optional[Callable[[torch.Tensor, torch.Tensor], None]] |
| ) -> None: |
| """ |
| Register STDP callback for Module E integration. |
| |
| The callback will be invoked after each forward pass with (u_initial, u_final) |
| to enable spike-timing dependent plasticity learning. Module E can use this |
| to apply STDP weight updates based on temporal causality. |
| |
| Args: |
| callback: Function with signature (u_initial, u_final) -> None, |
| or None to unregister |
| |
| Example: |
| >>> def stdp_update(u_initial, u_final): |
| ... # Apply STDP based on temporal evolution |
| ... delta_w = compute_stdp(u_initial, u_final) |
| ... # Update weights... |
| >>> propagator.register_stdp_callback(stdp_update) |
| |
| Note: |
| This is an optional integration hook that doesn't affect Module C's |
| core functionality unless Module E is explicitly integrated. |
| """ |
| self._stdp_callback = callback |
|
|
| def set_neuromodulator_state(self, ach: float, ne: float, serotonin: float) -> None: |
| """ |
| Inject neuromodulator state from Module E. |
| |
| Neuromodulators (ACh, NE, Serotonin) can influence learning dynamics |
| by modulating plasticity, exploration, and consolidation. |
| |
| Args: |
| ach: Acetylcholine level [0, 2] - enhances plasticity during learning |
| ne: Norepinephrine level [0, 2] - increases exploration during arousal |
| serotonin: Serotonin level [0, 2] - stabilizes weights during consolidation |
| |
| Note: |
| This is an optional integration hook that doesn't affect Module C's |
| core functionality unless Module E is explicitly integrated. |
| |
| References: |
| - Sara, S. J. (2009). The locus coeruleus and noradrenergic modulation |
| of cognition. Nature Reviews Neuroscience, 10(3), 211-223. |
| - Hasselmo, M. E. (2006). The role of acetylcholine in learning and |
| memory. Current Opinion in Neurobiology, 16(6), 710-715. |
| """ |
| self._neuromodulator_state = {"ach": ach, "ne": ne, "serotonin": serotonin} |
|
|
| def get_neuromodulator_state(self) -> Optional[dict[str, float]]: |
| """ |
| Get current neuromodulator state (Module E integration). |
| |
| Returns: |
| Dictionary with keys 'ach', 'ne', 'serotonin', or None if not set |
| """ |
| return self._neuromodulator_state |
|
|
| def ode_func(self, t: float, u: torch.Tensor) -> torch.Tensor: |
| """ |
| ODE right-hand side: dU/dt = f(U). |
| |
| Args: |
| t: Time (unused, autonomous system) |
| u: State vector (latent_dim,) or batch (batch, latent_dim) |
| |
| Returns: |
| du_dt: Time derivative, same shape as u |
| |
| Note: |
| Includes regularization to prevent numerical instability: |
| - Smooth compression via tanh to keep Chebyshev basis stable |
| - Clamps output derivative to [-100, 100] for stable integration |
| See: https://github.com/rtqichen/torchdiffeq/issues/27 |
| """ |
| |
| u = torch.nan_to_num(u, nan=0.0, posinf=1.0, neginf=-1.0) |
| u_stable = torch.tanh(u * 0.7) |
|
|
| if not torch.isfinite(u_stable).all(): |
| logger.error( |
| "Non-finite u_stable in ode_func " |
| f"(u_min={u_stable.min().item():.6f}, u_max={u_stable.max().item():.6f})" |
| ) |
|
|
| |
| if self.laplacian is None: |
| diffusion = -self._laplacian_scale * u_stable |
| else: |
| diffusion = -self.laplacian @ u_stable |
|
|
| |
| if u_stable.dim() == 1: |
| reaction = self.kan_reaction(u_stable.unsqueeze(0)).squeeze(0) |
| else: |
| reaction = self.kan_reaction(u_stable) |
|
|
| reaction = torch.nan_to_num(reaction, nan=0.0, posinf=1.0, neginf=-1.0) |
| if not torch.isfinite(reaction).all(): |
| logger.error( |
| "Non-finite reaction in ode_func " |
| f"(reaction_min={reaction.min().item():.6f}, reaction_max={reaction.max().item():.6f})" |
| ) |
| diffusion = torch.nan_to_num(diffusion, nan=0.0, posinf=1.0, neginf=-1.0) |
| if not torch.isfinite(diffusion).all(): |
| logger.error( |
| "Non-finite diffusion in ode_func " |
| f"(diffusion_min={diffusion.min().item():.6f}, diffusion_max={diffusion.max().item():.6f})" |
| ) |
|
|
| |
| du_dt = diffusion + reaction |
| du_dt = torch.nan_to_num(du_dt, nan=0.0, posinf=1.0, neginf=-1.0) |
| if not torch.isfinite(du_dt).all(): |
| logger.error( |
| "Non-finite du_dt in ode_func " |
| f"(du_dt_min={du_dt.min().item():.6f}, du_dt_max={du_dt.max().item():.6f})" |
| ) |
| return torch.clamp(du_dt, -100.0, 100.0) |
|
|
| def forward( |
| self, |
| u0: torch.Tensor, |
| t_span: tuple = (0.0, 0.1), |
| return_trajectory: bool = False, |
| ) -> torch.Tensor: |
| """ |
| Propagate initial state to final time. |
| |
| Args: |
| u0: Initial state (manifold_dim,) |
| t_span: (t_start, t_end) |
| return_trajectory: If True, return all timesteps |
| |
| Returns: |
| u_final: Final state (manifold_dim,) |
| OR trajectory (num_steps, manifold_dim) |
| """ |
| |
| |
| |
|
|
| |
| t_eval = torch.tensor( |
| [t_span[0], t_span[1]], device=u0.device, dtype=torch.float64 |
| ) |
|
|
| |
| is_adaptive = self.solver in ("dopri5", "bosh3", "adaptive_heun", "dopri8") |
|
|
| if is_adaptive: |
| |
| options = {"dtype": torch.float64} |
| trajectory = odeint( |
| self.ode_func, |
| u0, |
| t_eval, |
| method=self.solver, |
| rtol=self.rtol, |
| atol=self.atol, |
| options=options, |
| ) |
| else: |
| |
| options = {"step_size": self.step_size} |
| trajectory = odeint( |
| self.ode_func, u0, t_eval, method=self.solver, options=options |
| ) |
|
|
| u_final = trajectory[-1] |
|
|
| |
| if self._stdp_callback is not None: |
| self._stdp_callback(u0, u_final) |
|
|
| if return_trajectory: |
| return trajectory |
| else: |
| return u_final |
|
|