############################################################################################################################################## #||||- - - |8.19.2025| - - - || LIQUID STATE SPACE || - - - |1990two| - - -|||| # ############################################################################################################################################## import torch import torch.nn as nn import torch.nn.functional as F import numpy as np import math from typing import List, Dict, Tuple, Optional from scipy import linalg from scipy.signal import cont2discrete SAFE_MIN = -1e6 SAFE_MAX = 1e6 EPS = 1e-8 #||||- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - ð“…¸ - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -||||# def make_safe(tensor, min_val=SAFE_MIN, max_val=SAFE_MAX): zero = torch.tensor(0.0, device=tensor.device, dtype=tensor.dtype) maxv = torch.tensor(max_val, device=tensor.device, dtype=tensor.dtype) tensor = torch.where(torch.isnan(tensor), zero, tensor) tensor = torch.where(torch.isinf(tensor), maxv, tensor) return torch.clamp(tensor, min_val, max_val) def discrete_to_continuous_time(A_discrete, dt=1.0): try: A_continuous = linalg.logm(A_discrete.detach().cpu().numpy()) / dt return torch.tensor(A_continuous, dtype=torch.float32, device=A_discrete.device) except: return torch.eye(A_discrete.shape[0], device=A_discrete.device) * 0.01 def continuous_to_discrete_time(A_continuous, B_continuous, dt=1.0): try: A_np = A_continuous.detach().cpu().numpy() B_np = B_continuous.detach().cpu().numpy() if A_np.ndim == 3: A_list, B_list = [], [] for i in range(A_np.shape[0]): Ad, Bd, _, _, _ = cont2discrete((A_np[i], B_np, np.eye(A_np.shape[-1]), 0), dt) A_list.append(Ad) B_list.append(Bd) A_discrete = torch.tensor(np.stack(A_list), dtype=torch.float32, device=A_continuous.device) B_discrete = torch.tensor(np.stack(B_list), dtype=torch.float32, device=B_continuous.device) else: A_discrete, B_discrete, _, _, _ = cont2discrete((A_np, B_np, np.eye(A_np.shape[0]), 0), dt) A_discrete = torch.tensor(A_discrete, dtype=torch.float32, device=A_continuous.device) B_discrete = torch.tensor(B_discrete, dtype=torch.float32, device=B_continuous.device) return A_discrete, B_discrete except Exception: n = A_continuous.shape[-1] eye = torch.eye(n, device=A_continuous.device, dtype=A_continuous.dtype) if A_continuous.dim() == 3: eye = eye.unsqueeze(0).expand(A_continuous.size(0), -1, -1) B_disc = B_continuous.to(dtype=A_continuous.dtype, device=A_continuous.device) \ .unsqueeze(0).expand(A_continuous.size(0), -1, -1) else: B_disc = B_continuous.to(dtype=A_continuous.dtype, device=A_continuous.device) A_discrete = eye + A_continuous * dt B_discrete = B_disc * dt return A_discrete, B_discrete ########################################################################################################################################### #############################################- - - LIQUID TIME CONSTANT CONTROLLER - - -############################################### class LiquidTimeConstantController(nn.Module): def __init__(self, state_dim, input_dim, init_tau=1.0): super().__init__() self.state_dim = state_dim self.input_dim = input_dim self.log_tau = nn.Parameter(torch.ones(state_dim) * math.log(init_tau)) self.tau_adaptation = nn.Sequential( nn.Linear(state_dim + input_dim, state_dim * 2), nn.LayerNorm(state_dim * 2), nn.Tanh(), nn.Linear(state_dim * 2, state_dim), nn.Tanh() # Output in [-1, 1] for modulation ) self.adaptation_rate = nn.Parameter(torch.tensor(0.1)) def get_time_constants(self, state, input_signal): base_tau = torch.exp(self.log_tau) base_tau = torch.clamp(base_tau, 0.01, 10.0) combined_input = torch.cat([state, input_signal], dim=-1) tau_modulation = self.tau_adaptation(combined_input) adaptation_rate = torch.clamp(self.adaptation_rate, 0.001, 1.0) modulated_tau = base_tau * (1.0 + adaptation_rate * tau_modulation) return torch.clamp(modulated_tau, 0.01, 10.0) def get_effective_dt(self, tau, target_dt=0.1): min_tau_val = torch.min(tau).item() effective_dt = max(0.001, min(float(target_dt), min_tau_val * 0.1)) return effective_dt ########################################################################################################################################### ################################################- - - LIQUID SSM CORE - - -############################################################ class LiquidSSMCore(nn.Module): def __init__(self, state_dim, input_dim, output_dim, dt=0.1, init_method='hippo'): super().__init__() self.state_dim = state_dim self.input_dim = input_dim self.output_dim = output_dim self.dt = dt if init_method == 'hippo': self.A_continuous = nn.Parameter(self._init_hippo_matrix(state_dim)) else: self.A_continuous = nn.Parameter(torch.randn(state_dim, state_dim) * 0.1) self.B_continuous = nn.Parameter(torch.randn(state_dim, input_dim) * 0.1) self.C = nn.Parameter(torch.randn(output_dim, state_dim) * 0.1) self.D = nn.Parameter(torch.zeros(output_dim, input_dim)) self.time_controller = LiquidTimeConstantController(state_dim, input_dim, init_tau=1.0) self.output_scale = nn.Parameter(torch.ones(output_dim)) self.output_bias = nn.Parameter(torch.zeros(output_dim)) self.state_normalizer = nn.LayerNorm(state_dim) self.register_buffer('continuous_state', torch.zeros(1, state_dim)) def _init_hippo_matrix(self, N): A = torch.zeros(N, N) for i in range(N): for j in range(N): if i > j: A[i, j] = math.sqrt(2 * i + 1) * math.sqrt(2 * j + 1) elif i == j: A[i, j] = -(2 * i + 1) A = A * 0.1 with torch.no_grad(): eig = torch.linalg.eigvals(A).real.abs().max() if eig > 0: A = A / eig * 0.9 return A def reset_state(self, batch_size=1): device = self.A_continuous.device self.continuous_state = torch.zeros(batch_size, self.state_dim, device=device) def liquid_state_evolution(self, input_signal, num_steps=10): batch_size = input_signal.shape[0] if self.continuous_state.shape[0] != batch_size: self.reset_state(batch_size) tau = self.time_controller.get_time_constants(self.continuous_state, input_signal) effective_dt = self.time_controller.get_effective_dt(tau, self.dt) tau_matrix = torch.diag_embed(1.0 / tau) liquid_A = self.A_continuous - tau_matrix liquid_A = make_safe(liquid_A, min_val=-10.0, max_val=10.0) A_discrete, B_discrete = continuous_to_discrete_time( liquid_A, self.B_continuous, effective_dt ) step_dt = float(effective_dt) / num_steps A_discrete, B_discrete = continuous_to_discrete_time( liquid_A, self.B_continuous, step_dt ) current_state = self.continuous_state if A_discrete.dim() == 3: A_T = A_discrete.transpose(1, 2) B_T = B_discrete.transpose(1, 2) input_update = torch.bmm(input_signal.unsqueeze(1), B_T).squeeze(1) for _ in range(num_steps): state_update = torch.bmm(current_state.unsqueeze(1), A_T).squeeze(1) current_state = state_update + input_update current_state = make_safe(current_state) else: A_T = A_discrete.T B_T = B_discrete.T input_update = input_signal @ B_T for _ in range(num_steps): current_state = current_state @ A_T + input_update current_state = make_safe(current_state) current_state = make_safe(current_state) self.continuous_state = current_state return current_state, tau, effective_dt def compute_output(self, state, input_signal): normalized_state = self.state_normalizer(state) state_output = torch.matmul(normalized_state, self.C.T) direct_output = torch.matmul(input_signal, self.D.T) raw_output = state_output + direct_output output = self.output_scale * raw_output + self.output_bias return make_safe(output) def forward(self, input_signal, return_diagnostics=False): evolved_state, tau, effective_dt = self.liquid_state_evolution(input_signal) output = self.compute_output(evolved_state, input_signal) result = { 'output': output, 'state': evolved_state } if return_diagnostics: result.update({ 'time_constants': tau, 'effective_dt': effective_dt, 'state_norm': torch.norm(evolved_state, dim=-1), 'adaptation_rate': self.time_controller.adaptation_rate }) return result ########################################################################################################################################### ############################################- - - LIQUID SSM SEQUENCE LAYER - - -###################################################### class LiquidSSMSequenceLayer(nn.Module): def __init__(self, input_dim, state_dim, output_dim, seq_len=None): super().__init__() self.input_dim = input_dim self.state_dim = state_dim self.output_dim = output_dim self.seq_len = seq_len self.liquid_ssm = LiquidSSMCore(state_dim, state_dim, output_dim) self.input_projection = nn.Sequential( nn.Linear(input_dim, state_dim), nn.LayerNorm(state_dim), nn.GELU() ) self.output_projection = nn.Sequential( nn.Linear(output_dim, output_dim * 2), nn.LayerNorm(output_dim * 2), nn.GELU(), nn.Dropout(0.1), nn.Linear(output_dim * 2, output_dim) ) self.residual_weight = nn.Parameter(torch.tensor(0.1)) self.sequence_adapter = nn.Sequential( nn.Linear(state_dim, state_dim), nn.Tanh(), nn.Linear(state_dim, 1), nn.Sigmoid() ) def forward(self, sequence, return_diagnostics=False): batch_size, seq_len, input_dim = sequence.shape self.liquid_ssm.reset_state(batch_size) outputs = [] diagnostics = [] if return_diagnostics else None for t in range(seq_len): current_input = sequence[:, t, :] projected_input = self.input_projection(current_input) ssm_result = self.liquid_ssm(projected_input, return_diagnostics=return_diagnostics) adaptation_factor = self.sequence_adapter(ssm_result['state']) adapted_output = ssm_result['output'] * adaptation_factor final_output = self.output_projection(adapted_output) if final_output.shape == current_input.shape: residual_strength = torch.clamp(self.residual_weight, 0.0, 1.0) final_output = final_output + residual_strength * current_input outputs.append(final_output) if return_diagnostics: diagnostics.append({ 'timestep': t, 'adaptation_factor': adaptation_factor.mean().item(), **ssm_result }) output_sequence = torch.stack(outputs, dim=1) result = {'output': output_sequence} if return_diagnostics: result['diagnostics'] = diagnostics return result ########################################################################################################################################### ###########################################- - - LIQUID SSM LANGUAGE MODEL - - -####################################################### class LiquidSSMLanguageModel(nn.Module): def __init__(self, vocab_size, d_model=512, state_dim=256, num_layers=6, max_seq_len=2048): super().__init__() self.vocab_size = vocab_size self.d_model = d_model self.state_dim = state_dim self.num_layers = num_layers self.max_seq_len = max_seq_len self.token_embedding = nn.Embedding(vocab_size, d_model) self.position_embedding = nn.Embedding(max_seq_len, d_model) self.liquid_layers = nn.ModuleList([ LiquidSSMSequenceLayer(d_model, state_dim, d_model) for _ in range(num_layers) ]) self.layer_norms = nn.ModuleList([ nn.LayerNorm(d_model) for _ in range(num_layers) ]) self.output_norm = nn.LayerNorm(d_model) self.lm_head = nn.Linear(d_model, vocab_size) self.global_adaptation = nn.Sequential( nn.Linear(d_model, d_model // 4), nn.GELU(), nn.Linear(d_model // 4, 1), nn.Sigmoid() ) self._init_weights() def _init_weights(self): for module in self.modules(): if isinstance(module, nn.Linear): nn.init.xavier_uniform_(module.weight) if module.bias is not None: nn.init.zeros_(module.bias) elif isinstance(module, nn.Embedding): nn.init.normal_(module.weight, mean=0.0, std=0.02) def forward(self, input_ids, labels=None, return_diagnostics=False): batch_size, seq_len = input_ids.shape device = input_ids.device if seq_len > self.max_seq_len: input_ids = input_ids[:, :self.max_seq_len] seq_len = self.max_seq_len if labels is not None: labels = labels[:, :self.max_seq_len] input_ids = torch.clamp(input_ids, 0, self.vocab_size - 1) token_emb = self.token_embedding(input_ids) pos_ids = torch.arange(seq_len, device=device).unsqueeze(0).expand(batch_size, -1) pos_emb = self.position_embedding(pos_ids) x = token_emb + pos_emb x = make_safe(x) layer_diagnostics = [] if return_diagnostics else None for layer_idx, (liquid_layer, layer_norm) in enumerate(zip(self.liquid_layers, self.layer_norms)): residual = x x = layer_norm(x) layer_result = liquid_layer(x, return_diagnostics=return_diagnostics) x = layer_result['output'] adaptation = self.global_adaptation(x.mean(dim=1, keepdim=True)) x = x * adaptation x = residual + x x = make_safe(x) if return_diagnostics: layer_diagnostics.append({ 'layer': layer_idx, 'adaptation': adaptation.mean().item(), **layer_result }) x = self.output_norm(x) logits = self.lm_head(x) logits = make_safe(logits, min_val=-50, max_val=50) loss = None if labels is not None: shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() loss = F.cross_entropy( shift_logits.view(-1, self.vocab_size), shift_labels.view(-1), ignore_index=-100 ) result = { 'logits': logits, 'loss': loss } if return_diagnostics: result['layer_diagnostics'] = layer_diagnostics return result @torch.no_grad() def generate(self, input_ids, max_length=100, temperature=1.0, top_p=0.95, return_diagnostics=False): self.eval() generated = input_ids.clone() all_diagnostics = [] if return_diagnostics else None for step in range(max_length - input_ids.shape[1]): if generated.shape[1] > self.max_seq_len: break outputs = self(generated, return_diagnostics=return_diagnostics) logits = outputs['logits'] if return_diagnostics: all_diagnostics.append(outputs.get('layer_diagnostics', [])) next_token_logits = logits[:, -1, :] / max(temperature, EPS) next_token_logits = make_safe(next_token_logits, min_val=-50, max_val=50) sorted_logits, sorted_indices = torch.sort(next_token_logits, descending=True) cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) sorted_indices_to_remove = cumulative_probs > top_p sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() sorted_indices_to_remove[..., 0] = False for b in range(next_token_logits.size(0)): indices_to_remove = sorted_indices[b][sorted_indices_to_remove[b]] next_token_logits[b, indices_to_remove] = -float('inf') probs = F.softmax(next_token_logits, dim=-1) next_token = torch.multinomial(probs, num_samples=1) next_token = torch.clamp(next_token, 0, self.vocab_size - 1) generated = torch.cat([generated, next_token], dim=1) if next_token.item() == 2: # EOS token break result = {'generated_ids': generated} if return_diagnostics: result['diagnostics'] = all_diagnostics return result