liquid_state_space / liquid_state_space_docs.py
1990two's picture
Update liquid_state_space_docs.py
3042e98 verified
##################################################################################################################################################
#||||- - - |8.19.2025| - - - || LIQUID STATE SPACE || - - - |1990two| - - -|||| #
##################################################################################################################################################
"""
Mathematical Foundation & Conceptual Documentation
-------------------------------------------------
CORE PRINCIPLE:
Combines state space models with liquid computing principles to create adaptive
continuous-time dynamics for sequence processing. The system learns time constants
dynamically based on input characteristics, enabling efficient processing of
variable-speed temporal patterns.
MATHEMATICAL FOUNDATION:
=======================
1. STATE SPACE MODEL FUNDAMENTALS:
Continuous-time: dx/dt = Ax(t) + Bu(t)
y(t) = Cx(t) + Du(t)
Discrete-time: x[k+1] = A_d·x[k] + B_d·u[k]
y[k] = C·x[k] + D·u[k]
Where:
- x(t): state vector (hidden representation)
- u(t): input vector (external signals)
- y(t): output vector (observations)
- A: state transition matrix (dynamics)
- B: input matrix (how inputs affect states)
- C: output matrix (how states generate outputs)
- D: feedthrough matrix (direct input-output)
2. LIQUID DYNAMICS WITH ADAPTIVE TIME CONSTANTS:
dx/dt = -x/τ(x,u) + A·x + B·u
Where τ(x,u) are adaptive time constants:
τ(x,u) = τ_base · (1 + α·φ(x,u))
- τ_base: learnable base time constants
- α: adaptation rate parameter
- φ(x,u): neural adaptation function
Fast time constants → quick adaptation to rapid changes
Slow time constants → smooth integration of stable patterns
3. CONTINUOUS-TO-DISCRETE CONVERSION:
Using matrix exponential and zero-order hold:
A_d = exp(A·Δt)
B_d = A^(-1)·(A_d - I)·B
For numerical stability, we use:
[A_d B_d] = exp([A B] · Δt)
[0 I ] [0 0]
4. HIPPO MATRIX INITIALIZATION:
HiPPO (High-order Polynomial Projection Operators) for optimal memory:
A_ij = {√(2i+1)·√(2j+1) if i > j
{-(2i+1) if i = j
{0 if i < j
This creates a skew-symmetric structure that preserves information
over long sequences by projecting onto Legendre polynomials.
5. NUMERICAL INTEGRATION:
Multi-step Euler method for stability:
x(t+Δt) = x(t) + Δt·f(x(t),u(t))
With adaptive time stepping:
Δt_eff = min(Δt_target, 0.1·min(τ))
CONCEPTUAL REASONING:
====================
WHY LIQUID + STATE SPACE MODELS?
- Traditional SSMs have fixed dynamics
- Real-world sequences have variable temporal scales
- Liquid dynamics enable adaptive processing speeds
- Continuous-time formulation handles irregular sampling
KEY INNOVATIONS:
1. **Adaptive Time Constants**: Learn processing speed from data
2. **HiPPO Initialization**: Optimal memory retention properties
3. **Continuous-Discrete Bridge**: Seamless time-domain conversion
4. **Multi-Scale Processing**: Handle fast and slow temporal patterns
5. **Efficient Implementation**: Linear complexity in sequence length
APPLICATIONS:
- Long-range sequence modeling (DNA, audio, text)
- Time-series with irregular sampling rates
- Speech recognition with variable speaking speeds
- Language modeling with adaptive processing
- Control systems with time-varying dynamics
COMPLEXITY ANALYSIS:
- Time: O(N·d²) where N=sequence length, d=state dimension
- Space: O(d²) for state matrices + O(N·d) for sequence states
- Training: O(N·d²·L) where L=number of layers
- Inference: Linear in sequence length (vs quadratic for attention)
ADVANTAGES OVER TRANSFORMERS:
- Linear complexity vs quadratic attention
- Continuous-time formulation handles variable rates
- Built-in inductive bias for temporal dynamics
- Natural handling of infinite-length sequences
- Memory-efficient processing of long sequences
BIOLOGICAL INSPIRATION:
- Neural membrane time constants in biological circuits
- Adaptive integration windows in cortical processing
- Multiple timescale dynamics in neural networks
- Continuous-time neural differential equations
"""
from __future__ import annotations
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import math
from typing import List, Dict, Tuple, Optional, Union, Any
from scipy import linalg
from scipy.signal import cont2discrete
# Numerical stability constants
SAFE_MIN: float = -1e6
SAFE_MAX: float = 1e6
EPS: float = 1e-8
#||||- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - 𓅸 - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -||||#
def make_safe(
tensor: torch.Tensor,
min_val: float = SAFE_MIN,
max_val: float = SAFE_MAX
) -> torch.Tensor:
"""Clamp tensor values to safe numerical range, replacing NaN/Inf.
Args:
tensor: Input tensor to make numerically safe
min_val: Minimum allowed value
max_val: Maximum allowed value
Returns:
Numerically safe tensor with values in [min_val, max_val]
"""
tensor = torch.where(torch.isnan(tensor), torch.tensor(0.0, device=tensor.device), tensor)
tensor = torch.where(torch.isinf(tensor), torch.tensor(max_val, device=tensor.device), tensor)
return torch.clamp(tensor, min_val, max_val)
def discrete_to_continuous_time(A_discrete: torch.Tensor, dt: float = 1.0) -> torch.Tensor:
"""Convert discrete-time matrix to continuous-time using matrix logarithm.
Mathematical Details:
If A_d = exp(A_c · dt), then A_c = log(A_d) / dt
Args:
A_discrete: Discrete-time state transition matrix
dt: Time step used in discretization
Returns:
Continuous-time state matrix
"""
try:
A_continuous = linalg.logm(A_discrete.detach().cpu().numpy()) / dt
return torch.tensor(A_continuous, dtype=torch.float32, device=A_discrete.device)
except:
# Fallback to small identity if matrix logarithm fails
return torch.eye(A_discrete.shape[0], device=A_discrete.device) * 0.01
def continuous_to_discrete_time(
A_continuous: torch.Tensor,
B_continuous: torch.Tensor,
dt: float = 1.0
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Convert continuous-time system to discrete-time using zero-order hold.
Mathematical Details:
Uses matrix exponential method for exact discretization:
[A_d B_d] = exp([A B] · dt)
[0 I ] [0 0]
Handles batched matrices by processing each batch element individually
to avoid SciPy's limitation with multi-dimensional arrays.
Args:
A_continuous: Continuous-time state matrix [batch?, state, state]
B_continuous: Continuous-time input matrix [state, input]
dt: Time step for discretization
Returns:
Tuple of (A_discrete, B_discrete) matrices
"""
try:
A_np = A_continuous.detach().cpu().numpy()
B_np = B_continuous.detach().cpu().numpy()
if A_np.ndim == 3:
# Handle batched matrices
A_list, B_list = [], []
for i in range(A_np.shape[0]):
Ad, Bd, _, _, _ = cont2discrete(
(A_np[i], B_np, np.eye(A_np.shape[-1]), 0), dt
)
A_list.append(Ad)
B_list.append(Bd)
A_discrete = torch.tensor(np.stack(A_list), dtype=torch.float32, device=A_continuous.device)
B_discrete = torch.tensor(np.stack(B_list), dtype=torch.float32, device=B_continuous.device)
else:
# Handle single matrix
A_discrete, B_discrete, _, _, _ = cont2discrete(
(A_np, B_np, np.eye(A_np.shape[0]), 0), dt
)
A_discrete = torch.tensor(A_discrete, dtype=torch.float32, device=A_continuous.device)
B_discrete = torch.tensor(B_discrete, dtype=torch.float32, device=B_continuous.device)
return A_discrete, B_discrete
except Exception:
# Fallback to first-order Euler approximation
n = A_continuous.shape[-1]
eye = torch.eye(n, device=A_continuous.device)
if A_continuous.dim() == 3:
eye = eye.unsqueeze(0).expand(A_continuous.size(0), -1, -1)
B_disc = B_continuous.unsqueeze(0).expand(A_continuous.size(0), -1, -1)
else:
B_disc = B_continuous
A_discrete = eye + A_continuous * dt
B_discrete = B_disc * dt
return A_discrete, B_discrete
###########################################################################################################################################
#############################################- - - LIQUID TIME CONSTANT CONTROLLER - - -###############################################
class LiquidTimeConstantController(nn.Module):
"""Adaptive time constant controller for liquid dynamics.
Controls the temporal dynamics of the liquid state by learning context-dependent
time constants. Fast time constants enable quick adaptation to rapid changes,
while slow time constants provide stable integration of persistent patterns.
Mathematical Framework:
- Base time constants: τ_base = exp(log_τ)
- Adaptive modulation: τ(x,u) = τ_base · (1 + α·φ(x,u))
- Neural adaptation: φ(x,u) = tanh(W·[x,u] + b)
- Stability constraint: τ ∈ [0.01, 10.0]
"""
def __init__(
self,
state_dim: int,
input_dim: int,
init_tau: float = 1.0
) -> None:
"""Initialize adaptive time constant controller.
Args:
state_dim: Dimension of state vector
input_dim: Dimension of input vector
init_tau: Initial time constant value
"""
super().__init__()
self.state_dim = state_dim
self.input_dim = input_dim
# Learnable base time constants (in log space for positivity)
self.log_tau = nn.Parameter(torch.ones(state_dim) * math.log(init_tau))
# Neural network for adaptive time constant modulation
# Takes concatenated state and input, outputs modulation factors
self.tau_adaptation = nn.Sequential(
nn.Linear(state_dim + input_dim, state_dim * 2),
nn.LayerNorm(state_dim * 2),
nn.Tanh(),
nn.Linear(state_dim * 2, state_dim),
nn.Tanh() # Output in [-1, 1] for stable modulation
)
# Meta-learning rate controlling adaptation strength
self.adaptation_rate = nn.Parameter(torch.tensor(0.1))
def get_time_constants(
self,
state: torch.Tensor,
input_signal: torch.Tensor
) -> torch.Tensor:
"""Compute context-dependent time constants.
Mathematical Details:
1. Base time constants: τ_base = exp(log_τ)
2. Context features: f = [state, input]
3. Modulation: m = tanh(W·f + b)
4. Final time constants: τ = τ_base · (1 + α·m)
Args:
state: Current liquid state [batch_size, state_dim]
input_signal: Current input [batch_size, input_dim]
Returns:
Adaptive time constants [batch_size, state_dim]
"""
# Convert log time constants to positive values
base_tau = torch.exp(self.log_tau)
base_tau = torch.clamp(base_tau, 0.01, 10.0)
# Compute adaptive modulation based on current context
combined_input = torch.cat([state, input_signal], dim=-1)
tau_modulation = self.tau_adaptation(combined_input)
# Apply modulation with learnable adaptation rate
adaptation_rate = torch.clamp(self.adaptation_rate, 0.001, 1.0)
modulated_tau = base_tau * (1.0 + adaptation_rate * tau_modulation)
# Ensure time constants remain in stable range
return torch.clamp(modulated_tau, 0.01, 10.0)
def get_effective_dt(self, tau: torch.Tensor, target_dt: float = 0.1) -> float:
"""Compute effective time step for numerical stability.
The effective time step is chosen to be much smaller than the fastest
time constant to ensure numerical stability of the integration.
Mathematical Constraint:
Δt_eff ≤ 0.1 · min(τ) for stability
Args:
tau: Time constants tensor [batch_size, state_dim]
target_dt: Desired time step
Returns:
Effective time step (scalar)
"""
# Find minimum time constant for stability constraint
min_tau_val = torch.min(tau).item()
effective_dt = max(0.001, min(float(target_dt), min_tau_val * 0.1))
return effective_dt
###########################################################################################################################################
################################################- - - LIQUID SSM CORE - - -############################################################
class LiquidSSMCore(nn.Module):
"""Core Liquid State Space Model with adaptive continuous-time dynamics.
Implements a state space model with liquid computing principles where
time constants adapt based on input characteristics. Combines the
representational power of SSMs with the adaptability of liquid dynamics.
Mathematical Framework:
- Liquid dynamics: dx/dt = -x/τ(x,u) + A·x + B·u
- Output equation: y = C·x + D·u
- HiPPO initialization for optimal memory properties
- Adaptive discretization for numerical integration
"""
def __init__(
self,
state_dim: int,
input_dim: int,
output_dim: int,
dt: float = 0.1,
init_method: str = 'hippo'
) -> None:
"""Initialize Liquid SSM core with adaptive dynamics.
Args:
state_dim: Dimension of hidden state vector
input_dim: Dimension of input vector
output_dim: Dimension of output vector
dt: Target time step for integration
init_method: Initialization method ('hippo' or 'random')
"""
super().__init__()
self.state_dim = state_dim
self.input_dim = input_dim
self.output_dim = output_dim
self.dt = dt
# Initialize continuous-time state transition matrix
if init_method == 'hippo':
self.A_continuous = nn.Parameter(self._init_hippo_matrix(state_dim))
else:
self.A_continuous = nn.Parameter(torch.randn(state_dim, state_dim) * 0.1)
# Input, output, and feedthrough matrices
self.B_continuous = nn.Parameter(torch.randn(state_dim, input_dim) * 0.1)
self.C = nn.Parameter(torch.randn(output_dim, state_dim) * 0.1)
self.D = nn.Parameter(torch.zeros(output_dim, input_dim))
# Adaptive time constant controller
self.time_controller = LiquidTimeConstantController(state_dim, input_dim, init_tau=1.0)
# Learnable output scaling and bias
self.output_scale = nn.Parameter(torch.ones(output_dim))
self.output_bias = nn.Parameter(torch.zeros(output_dim))
# State normalization for training stability
self.state_normalizer = nn.LayerNorm(state_dim)
# Current continuous state (persistent memory)
self.register_buffer('continuous_state', torch.zeros(1, state_dim))
def _init_hippo_matrix(self, N: int) -> torch.Tensor:
"""Initialize state matrix with HiPPO structure for optimal memory.
HiPPO (High-order Polynomial Projection Operators) creates a state
transition matrix that optimally preserves information by projecting
the input history onto a basis of Legendre polynomials.
Mathematical Details:
A_ij = {√(2i+1)·√(2j+1) if i > j (coupling strength)
{-(2i+1) if i = j (decay rate)
{0 if i < j (causality)
Args:
N: State dimension (number of basis functions)
Returns:
HiPPO matrix [N, N]
"""
A = torch.zeros(N, N)
for i in range(N):
for j in range(N):
if i > j:
# Coupling between basis functions
A[i, j] = math.sqrt(2 * i + 1) * math.sqrt(2 * j + 1)
elif i == j:
# Decay rate for each basis function
A[i, j] = -(2 * i + 1)
return A * 0.1 # Scale for training stability
def reset_state(self, batch_size: int = 1) -> None:
"""Reset continuous state for new sequence processing.
Args:
batch_size: Number of parallel sequences to process
"""
device = self.A_continuous.device
self.continuous_state = torch.zeros(batch_size, self.state_dim, device=device)
def liquid_state_evolution(
self,
input_signal: torch.Tensor,
num_steps: int = 10
) -> Tuple[torch.Tensor, torch.Tensor, float]:
"""Evolve state using adaptive liquid dynamics with numerical integration.
Implements the core liquid evolution equation:
dx/dt = -x/τ(x,u) + A·x + B·u
Uses multi-step integration for numerical accuracy and adaptive
time stepping based on the fastest time constant.
Mathematical Process:
1. Compute adaptive time constants: τ(x,u)
2. Form liquid dynamics matrix: A_liquid = A - diag(1/τ)
3. Discretize system: (A_d, B_d) = discretize(A_liquid, B, Δt)
4. Integrate: x(k+1) = A_d·x(k) + B_d·u(k)
Args:
input_signal: External input [batch_size, input_dim]
num_steps: Number of integration steps for accuracy
Returns:
Tuple of (evolved_state, time_constants, effective_dt)
"""
batch_size = input_signal.shape[0]
# Ensure state tensor matches batch size
if self.continuous_state.shape[0] != batch_size:
self.reset_state(batch_size)
# Compute adaptive time constants based on current state and input
tau = self.time_controller.get_time_constants(self.continuous_state, input_signal)
effective_dt = self.time_controller.get_effective_dt(tau, self.dt)
# Create time-varying dynamics matrix with liquid adaptation
# Standard SSM: dx/dt = A·x + B·u
# Liquid SSM: dx/dt = -x/τ + A·x + B·u = (A - diag(1/τ))·x + B·u
tau_matrix = torch.diag_embed(1.0 / tau) # Decay rates
liquid_A = self.A_continuous - tau_matrix
# Ensure numerical stability
liquid_A = make_safe(liquid_A, min_val=-10.0, max_val=10.0)
# Convert to discrete-time for numerical integration
A_discrete, B_discrete = continuous_to_discrete_time(
liquid_A, self.B_continuous, effective_dt
)
# Multi-step integration for improved accuracy
current_state = self.continuous_state
# Handle batched vs single matrix operations
if A_discrete.dim() == 3:
# Batched matrix multiplication
A_T = A_discrete.transpose(1, 2)
B_T = B_discrete.transpose(1, 2)
input_update = torch.bmm(input_signal.unsqueeze(1), B_T).squeeze(1)
for _ in range(num_steps):
state_update = torch.bmm(current_state.unsqueeze(1), A_T).squeeze(1)
current_state = state_update + input_update
current_state = make_safe(current_state)
else:
# Single matrix operations
A_T = A_discrete.T
B_T = B_discrete.T
input_update = input_signal @ B_T
for _ in range(num_steps):
current_state = current_state @ A_T + input_update
current_state = make_safe(current_state)
# Update persistent state
self.continuous_state = current_state
return current_state, tau, effective_dt
def compute_output(
self,
state: torch.Tensor,
input_signal: torch.Tensor
) -> torch.Tensor:
"""Compute output from state space model: y = C·x + D·u.
Args:
state: Current state vector [batch_size, state_dim]
input_signal: Current input [batch_size, input_dim]
Returns:
Output vector [batch_size, output_dim]
"""
# Normalize state for training stability
normalized_state = self.state_normalizer(state)
# Standard SSM output equation
state_output = torch.matmul(normalized_state, self.C.T) # C·x
direct_output = torch.matmul(input_signal, self.D.T) # D·u
raw_output = state_output + direct_output
# Apply learnable output scaling and bias
output = self.output_scale * raw_output + self.output_bias
return make_safe(output)
def forward(
self,
input_signal: torch.Tensor,
return_diagnostics: bool = False
) -> Dict[str, Union[torch.Tensor, float]]:
"""Complete forward pass through Liquid SSM.
Args:
input_signal: Input vector [batch_size, input_dim]
return_diagnostics: Whether to return diagnostic information
Returns:
Dictionary containing output and optional diagnostics
"""
# Evolve liquid state with adaptive dynamics
evolved_state, tau, effective_dt = self.liquid_state_evolution(input_signal)
# Compute output from current state
output = self.compute_output(evolved_state, input_signal)
result = {
'output': output,
'state': evolved_state
}
if return_diagnostics:
result.update({
'time_constants': tau,
'effective_dt': effective_dt,
'state_norm': torch.norm(evolved_state, dim=-1),
'adaptation_rate': self.time_controller.adaptation_rate
})
return result
###########################################################################################################################################
############################################- - - LIQUID SSM SEQUENCE LAYER - - -######################################################
class LiquidSSMSequenceLayer(nn.Module):
"""Sequence processing layer using Liquid SSM with residual connections.
Processes variable-length sequences through Liquid SSM while maintaining
adaptive dynamics across time steps. Includes input/output projections,
residual connections, and sequence-level adaptation mechanisms.
Architecture:
Input → Projection → Liquid SSM → Sequence Adaptation → Output Projection → Residual
"""
def __init__(
self,
input_dim: int,
state_dim: int,
output_dim: int,
seq_len: Optional[int] = None
) -> None:
"""Initialize Liquid SSM sequence processing layer.
Args:
input_dim: Dimension of input features
state_dim: Dimension of internal state
output_dim: Dimension of output features
seq_len: Maximum sequence length (optional)
"""
super().__init__()
self.input_dim = input_dim
self.state_dim = state_dim
self.output_dim = output_dim
self.seq_len = seq_len
# Core Liquid SSM operating on projected state dimension
# Both input and state dimensions set to state_dim to ensure
# compatibility in time constant controller computations
self.liquid_ssm = LiquidSSMCore(state_dim, state_dim, output_dim)
# Input projection and preprocessing
self.input_projection = nn.Sequential(
nn.Linear(input_dim, state_dim),
nn.LayerNorm(state_dim),
nn.GELU()
)
# Output projection and postprocessing
self.output_projection = nn.Sequential(
nn.Linear(output_dim, output_dim * 2),
nn.LayerNorm(output_dim * 2),
nn.GELU(),
nn.Dropout(0.1),
nn.Linear(output_dim * 2, output_dim)
)
# Learnable residual connection strength
self.residual_weight = nn.Parameter(torch.tensor(0.1))
# Sequence-level adaptation mechanism
self.sequence_adapter = nn.Sequential(
nn.Linear(state_dim, state_dim),
nn.Tanh(),
nn.Linear(state_dim, 1),
nn.Sigmoid()
)
def forward(
self,
sequence: torch.Tensor,
return_diagnostics: bool = False
) -> Dict[str, Union[torch.Tensor, List[Dict]]]:
"""Process complete sequence through Liquid SSM.
Processes each time step sequentially while maintaining liquid state
continuity across the sequence. Applies sequence-level adaptation
and residual connections for improved gradient flow.
Args:
sequence: Input sequence [batch_size, seq_len, input_dim]
return_diagnostics: Whether to return per-timestep diagnostics
Returns:
Dictionary containing output sequence and optional diagnostics
"""
batch_size, seq_len, input_dim = sequence.shape
# Reset SSM state for new sequence
self.liquid_ssm.reset_state(batch_size)
# Process sequence timestep by timestep
outputs = []
diagnostics = [] if return_diagnostics else None
for t in range(seq_len):
# Extract current timestep input
current_input = sequence[:, t, :]
# Project input to state dimension
projected_input = self.input_projection(current_input)
# Process through Liquid SSM
ssm_result = self.liquid_ssm(projected_input, return_diagnostics=return_diagnostics)
# Apply sequence-level adaptation
adaptation_factor = self.sequence_adapter(ssm_result['state'])
adapted_output = ssm_result['output'] * adaptation_factor
# Post-process output
final_output = self.output_projection(adapted_output)
# Apply residual connection if dimensions match
if final_output.shape == current_input.shape:
residual_strength = torch.clamp(self.residual_weight, 0.0, 1.0)
final_output = final_output + residual_strength * current_input
outputs.append(final_output)
if return_diagnostics:
diagnostics.append({
'timestep': t,
'adaptation_factor': adaptation_factor.mean().item(),
**ssm_result
})
# Stack outputs along sequence dimension
output_sequence = torch.stack(outputs, dim=1)
result = {'output': output_sequence}
if return_diagnostics:
result['diagnostics'] = diagnostics
return result
###########################################################################################################################################
##############################################- - - LIQUID SSM LANGUAGE MODEL - - -####################################################
class LiquidSSMLanguageModel(nn.Module):
"""Complete language model using Liquid State Space Models.
Implements a transformer-alternative architecture using Liquid SSMs for
sequence processing. Provides linear complexity in sequence length while
maintaining strong representational capabilities through adaptive dynamics.
Architecture:
Embeddings → Liquid SSM Layers → Output Head
Each layer includes:
- Layer normalization
- Liquid SSM processing
- Global adaptation
- Residual connections
"""
def __init__(
self,
vocab_size: int,
d_model: int = 512,
state_dim: int = 256,
num_layers: int = 6,
max_seq_len: int = 2048
) -> None:
"""Initialize Liquid SSM Language Model.
Args:
vocab_size: Size of vocabulary
d_model: Model dimension (embedding/hidden size)
state_dim: Liquid state dimension
num_layers: Number of Liquid SSM layers
max_seq_len: Maximum sequence length
"""
super().__init__()
self.vocab_size = vocab_size
self.d_model = d_model
self.state_dim = state_dim
self.num_layers = num_layers
self.max_seq_len = max_seq_len
# Token and position embeddings
self.token_embedding = nn.Embedding(vocab_size, d_model)
self.position_embedding = nn.Embedding(max_seq_len, d_model)
# Stack of Liquid SSM layers
self.liquid_layers = nn.ModuleList([
LiquidSSMSequenceLayer(d_model, state_dim, d_model)
for _ in range(num_layers)
])
# Layer normalization for each layer
self.layer_norms = nn.ModuleList([
nn.LayerNorm(d_model) for _ in range(num_layers)
])
# Output head for language modeling
self.output_norm = nn.LayerNorm(d_model)
self.lm_head = nn.Linear(d_model, vocab_size)
# Global adaptation mechanism
self.global_adaptation = nn.Sequential(
nn.Linear(d_model, d_model // 4),
nn.GELU(),
nn.Linear(d_model // 4, 1),
nn.Sigmoid()
)
self._init_weights()
def _init_weights(self) -> None:
for module in self.modules():
if isinstance(module, nn.Linear):
nn.init.xavier_uniform_(module.weight)
if module.bias is not None:
nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
nn.init.normal_(module.weight, mean=0.0, std=0.02)
def forward(
self,
input_ids: torch.Tensor,
labels: Optional[torch.Tensor] = None,
return_diagnostics: bool = False
) -> Dict[str, Union[torch.Tensor, List[Dict]]]:
"""Forward pass through Liquid SSM Language Model.
Args:
input_ids: Token IDs [batch_size, seq_len]
labels: Target labels for loss computation [batch_size, seq_len]
return_diagnostics: Whether to return layer diagnostics
Returns:
Dictionary containing logits, loss, and optional diagnostics
"""
batch_size, seq_len = input_ids.shape
device = input_ids.device
# Clamp sequence length to maximum supported
if seq_len > self.max_seq_len:
input_ids = input_ids[:, :self.max_seq_len]
seq_len = self.max_seq_len
if labels is not None:
labels = labels[:, :self.max_seq_len]
# Ensure valid token IDs
input_ids = torch.clamp(input_ids, 0, self.vocab_size - 1)
# Compute embeddings
token_emb = self.token_embedding(input_ids)
pos_ids = torch.arange(seq_len, device=device).unsqueeze(0).expand(batch_size, -1)
pos_emb = self.position_embedding(pos_ids)
x = token_emb + pos_emb
x = make_safe(x)
# Store layer diagnostics if requested
layer_diagnostics = [] if return_diagnostics else None
# Process through Liquid SSM layers
for layer_idx, (liquid_layer, layer_norm) in enumerate(zip(self.liquid_layers, self.layer_norms)):
# Store input for residual connection
residual = x
# Pre-layer normalization
x = layer_norm(x)
# Liquid SSM processing
layer_result = liquid_layer(x, return_diagnostics=return_diagnostics)
x = layer_result['output']
# Global adaptation based on sequence statistics
adaptation = self.global_adaptation(x.mean(dim=1, keepdim=True))
x = x * adaptation
# Residual connection
x = residual + x
x = make_safe(x)
if return_diagnostics:
layer_diagnostics.append({
'layer': layer_idx,
'adaptation': adaptation.mean().item(),
**layer_result
})
# Final normalization and output projection
x = self.output_norm(x)
logits = self.lm_head(x)
logits = make_safe(logits, min_val=-50, max_val=50)
# Compute cross-entropy loss if labels provided
loss = None
if labels is not None:
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
loss = F.cross_entropy(
shift_logits.view(-1, self.vocab_size),
shift_labels.view(-1),
ignore_index=-100
)
result = {
'logits': logits,
'loss': loss
}
if return_diagnostics:
result['layer_diagnostics'] = layer_diagnostics
return result
@torch.no_grad()
def generate(
self,
input_ids: torch.Tensor,
max_length: int = 100,
temperature: float = 1.0,
top_p: float = 0.95,
return_diagnostics: bool = False
) -> Dict[str, Union[torch.Tensor, List[Dict]]]:
"""Generate text using Liquid SSM with nucleus sampling.
Args:
input_ids: Prompt token IDs [batch_size, prompt_len]
max_length: Maximum total sequence length
temperature: Sampling temperature (higher = more random)
top_p: Nucleus sampling probability threshold
return_diagnostics: Whether to return generation diagnostics
Returns:
Dictionary containing generated IDs and optional diagnostics
"""
self.eval()
generated = input_ids.clone()
all_diagnostics = [] if return_diagnostics else None
for step in range(max_length - input_ids.shape[1]):
# Stop if sequence exceeds maximum length
if generated.shape[1] > self.max_seq_len:
break
# Forward pass to get next token logits
outputs = self(generated, return_diagnostics=return_diagnostics)
logits = outputs['logits']
if return_diagnostics:
all_diagnostics.append(outputs.get('layer_diagnostics', []))
# Extract logits for next token prediction
next_token_logits = logits[:, -1, :] / max(temperature, EPS)
next_token_logits = make_safe(next_token_logits, min_val=-50, max_val=50)
# Nucleus (top-p) sampling
sorted_logits, sorted_indices = torch.sort(next_token_logits, descending=True)
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
# Identify tokens to remove (cumulative probability > top_p)
sorted_indices_to_remove = cumulative_probs > top_p
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = False
# Remove low-probability tokens
for b in range(next_token_logits.size(0)):
indices_to_remove = sorted_indices[b][sorted_indices_to_remove[b]]
next_token_logits[b, indices_to_remove] = -float('inf')
# Sample next token
probs = F.softmax(next_token_logits, dim=-1)
next_token = torch.multinomial(probs, num_samples=1)
next_token = torch.clamp(next_token, 0, self.vocab_size - 1)
# Append to generated sequence
generated = torch.cat([generated, next_token], dim=1)
# Stop on EOS token
if next_token.item() == 2: # Assuming token ID 2 is EOS
break
result = {'generated_ids': generated}
if return_diagnostics:
result['diagnostics'] = all_diagnostics
return result
###########################################################################################################################################
##############################################- - - LIQUID SSM DEMO + TESTING - - -####################################################
def test_liquid_ssm() -> bool:
print("Testing Liquid State Space Model - Continuous-Time Adaptive Sequence Processing")
print("=" * 90)
# Create Liquid SSM Language Model
vocab_size = 1000
d_model = 256
state_dim = 128
num_layers = 4
model = LiquidSSMLanguageModel(
vocab_size=vocab_size,
d_model=d_model,
state_dim=state_dim,
num_layers=num_layers,
max_seq_len=512
)
print(f"Created Liquid SSM Language Model:")
print(f" - Vocabulary size: {vocab_size}")
print(f" - Model dimension: {d_model}")
print(f" - State dimension: {state_dim}")
print(f" - Number of layers: {num_layers}")
# Count parameters
total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f" - Total parameters: {total_params:,} ({total_params/1e6:.1f}M)")
# Test with sample data
batch_size = 4
seq_len = 32
test_input = torch.randint(0, vocab_size, (batch_size, seq_len))
test_labels = torch.randint(0, vocab_size, (batch_size, seq_len))
print(f"\nTesting with batch_size={batch_size}, seq_len={seq_len}")
# Forward pass
print("\nExecuting forward pass...")
outputs = model(test_input, labels=test_labels, return_diagnostics=True)
print("Forward pass results:")
print(f" - Output logits shape: {outputs['logits'].shape}")
print(f" - Loss: {outputs['loss']:.4f}")
# Analyze liquid dynamics
print("\nLiquid dynamics analysis:")
diagnostics = outputs['layer_diagnostics']
for layer_idx in range(min(3, len(diagnostics))):
layer_diag = diagnostics[layer_idx]
print(f" Layer {layer_idx + 1}:")
print(f" - Global adaptation: {layer_diag['adaptation']:.3f}")
if 'diagnostics' in layer_diag:
time_constants = [d['time_constants'].mean().item() for d in layer_diag['diagnostics'][:3]]
print(f" - Avg time constants: {[f'{tc:.3f}' for tc in time_constants]}")
# Test generation
print("\nTesting text generation...")
prompt = torch.randint(0, vocab_size, (1, 8))
generation_result = model.generate(
prompt,
max_length=20,
temperature=1.0,
return_diagnostics=True
)
generated_ids = generation_result['generated_ids']
print(f" - Generated sequence length: {generated_ids.shape[1]}")
print(f" - Prompt length: {prompt.shape[1]}")
print(f" - New tokens generated: {generated_ids.shape[1] - prompt.shape[1]}")
# Test efficiency comparison
print("\nEfficiency analysis:")
# Test different sequence lengths
seq_lengths = [64, 128, 256]
for test_len in seq_lengths:
test_seq = torch.randint(0, vocab_size, (1, test_len))
import time
start_time = time.time()
with torch.no_grad():
test_output = model(test_seq)
end_time = time.time()
processing_time = end_time - start_time
tokens_per_second = test_len / processing_time
print(f" - Length {test_len}: {processing_time:.3f}s ({tokens_per_second:.0f} tokens/s)")
print("\nLiquid SSM test completed!")
print("✓ Continuous-time adaptive dynamics")
print("✓ Learnable time constants based on content")
print("✓ Efficient sequence processing")
print("✓ State space model foundation with liquid adaptation")
print("✓ Potential transformer alternative with continuous dynamics")
return True
def adaptive_dynamics_demo() -> None:
print("\n" + "="*70)
print("ADAPTIVE DYNAMICS DEMONSTRATION")
print("="*70)
# Create simple model for demonstration
model = LiquidSSMCore(state_dim=16, input_dim=8, output_dim=8)
model.eval()
# Test patterns with different temporal characteristics
patterns = {
"Smooth": torch.sin(torch.linspace(0, 2*math.pi, 8)).unsqueeze(0),
"Spiky": torch.tensor([0, 1, 0, -1, 0, 1, 0, -1], dtype=torch.float).unsqueeze(0),
"Constant": torch.ones(1, 8) * 0.5,
"Random": torch.randn(1, 8)
}
print("Testing adaptive time constants with different input patterns:")
for pattern_name, pattern_input in patterns.items():
model.reset_state(1)
# Process pattern through liquid dynamics
with torch.no_grad():
result = model(pattern_input, return_diagnostics=True)
time_constants = result['time_constants'].squeeze().tolist()
adaptation_rate = result['adaptation_rate'].item()
print(f"\n{pattern_name} pattern:")
print(f" Time constants: {[f'{tc:.3f}' for tc in time_constants[:4]]}...")
print(f" Adaptation rate: {adaptation_rate:.4f}")
print(f" Effective dt: {result['effective_dt']:.4f}")
print("\n Adaptive dynamics show how liquid SSM adjusts to different input characteristics")
print(" Smooth inputs → larger time constants, Spiky inputs → smaller time constants")
if __name__ == "__main__":
test_liquid_ssm()
adaptive_dynamics_demo()