1990two commited on
Commit
139362f
·
verified ·
1 Parent(s): 297eb9d

Upload 2 files

Browse files
Files changed (2) hide show
  1. liquid_state_space.py +463 -0
  2. liquid_state_space_docs.py +1107 -0
liquid_state_space.py ADDED
@@ -0,0 +1,463 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ##############################################################################################################################################
2
+ #||||- - - |6.25.2025| - - - || LIQUID STATE SPACE || - - - |1990two| - - -|||| #
3
+ ##############################################################################################################################################
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ import numpy as np
8
+ import math
9
+ from typing import List, Dict, Tuple, Optional
10
+ from scipy import linalg
11
+ from scipy.signal import cont2discrete
12
+
13
+ SAFE_MIN = -1e6
14
+ SAFE_MAX = 1e6
15
+ EPS = 1e-8
16
+
17
+ #||||- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - ð“…¸ - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -||||#
18
+
19
+ def make_safe(tensor, min_val=SAFE_MIN, max_val=SAFE_MAX):
20
+ zero = torch.tensor(0.0, device=tensor.device, dtype=tensor.dtype)
21
+ maxv = torch.tensor(max_val, device=tensor.device, dtype=tensor.dtype)
22
+ tensor = torch.where(torch.isnan(tensor), zero, tensor)
23
+ tensor = torch.where(torch.isinf(tensor), maxv, tensor)
24
+ return torch.clamp(tensor, min_val, max_val)
25
+
26
+ def discrete_to_continuous_time(A_discrete, dt=1.0):
27
+ try:
28
+ A_continuous = linalg.logm(A_discrete.detach().cpu().numpy()) / dt
29
+ return torch.tensor(A_continuous, dtype=torch.float32, device=A_discrete.device)
30
+ except:
31
+ return torch.eye(A_discrete.shape[0], device=A_discrete.device) * 0.01
32
+
33
+ def continuous_to_discrete_time(A_continuous, B_continuous, dt=1.0):
34
+ try:
35
+ A_np = A_continuous.detach().cpu().numpy()
36
+ B_np = B_continuous.detach().cpu().numpy()
37
+
38
+ if A_np.ndim == 3:
39
+ A_list, B_list = [], []
40
+ for i in range(A_np.shape[0]):
41
+ Ad, Bd, _, _, _ = cont2discrete((A_np[i], B_np, np.eye(A_np.shape[-1]), 0), dt)
42
+ A_list.append(Ad)
43
+ B_list.append(Bd)
44
+ A_discrete = torch.tensor(np.stack(A_list), dtype=torch.float32, device=A_continuous.device)
45
+ B_discrete = torch.tensor(np.stack(B_list), dtype=torch.float32, device=B_continuous.device)
46
+ else:
47
+ A_discrete, B_discrete, _, _, _ = cont2discrete((A_np, B_np, np.eye(A_np.shape[0]), 0), dt)
48
+ A_discrete = torch.tensor(A_discrete, dtype=torch.float32, device=A_continuous.device)
49
+ B_discrete = torch.tensor(B_discrete, dtype=torch.float32, device=B_continuous.device)
50
+
51
+
52
+ return A_discrete, B_discrete
53
+ except Exception:
54
+ n = A_continuous.shape[-1]
55
+ eye = torch.eye(n, device=A_continuous.device, dtype=A_continuous.dtype)
56
+ if A_continuous.dim() == 3:
57
+ eye = eye.unsqueeze(0).expand(A_continuous.size(0), -1, -1)
58
+ B_disc = B_continuous.to(dtype=A_continuous.dtype, device=A_continuous.device) \
59
+ .unsqueeze(0).expand(A_continuous.size(0), -1, -1)
60
+ else:
61
+ B_disc = B_continuous.to(dtype=A_continuous.dtype, device=A_continuous.device)
62
+ A_discrete = eye + A_continuous * dt
63
+ B_discrete = B_disc * dt
64
+ return A_discrete, B_discrete
65
+
66
+ ###########################################################################################################################################
67
+ #############################################- - - LIQUID TIME CONSTANT CONTROLLER - - -###############################################
68
+
69
+ class LiquidTimeConstantController(nn.Module):
70
+ def __init__(self, state_dim, input_dim, init_tau=1.0):
71
+ super().__init__()
72
+ self.state_dim = state_dim
73
+ self.input_dim = input_dim
74
+
75
+ self.log_tau = nn.Parameter(torch.ones(state_dim) * math.log(init_tau))
76
+
77
+ self.tau_adaptation = nn.Sequential(
78
+ nn.Linear(state_dim + input_dim, state_dim * 2),
79
+ nn.LayerNorm(state_dim * 2),
80
+ nn.Tanh(),
81
+ nn.Linear(state_dim * 2, state_dim),
82
+ nn.Tanh() # Output in [-1, 1] for modulation
83
+ )
84
+
85
+ self.adaptation_rate = nn.Parameter(torch.tensor(0.1))
86
+
87
+ def get_time_constants(self, state, input_signal):
88
+ base_tau = torch.exp(self.log_tau)
89
+ base_tau = torch.clamp(base_tau, 0.01, 10.0)
90
+
91
+ combined_input = torch.cat([state, input_signal], dim=-1)
92
+ tau_modulation = self.tau_adaptation(combined_input)
93
+
94
+ adaptation_rate = torch.clamp(self.adaptation_rate, 0.001, 1.0)
95
+ modulated_tau = base_tau * (1.0 + adaptation_rate * tau_modulation)
96
+
97
+ return torch.clamp(modulated_tau, 0.01, 10.0)
98
+
99
+ def get_effective_dt(self, tau, target_dt=0.1):
100
+ min_tau_val = torch.min(tau).item()
101
+ effective_dt = max(0.001, min(float(target_dt), min_tau_val * 0.1))
102
+ return effective_dt
103
+
104
+ ###########################################################################################################################################
105
+ ################################################- - - LIQUID SSM CORE - - -############################################################
106
+
107
+ class LiquidSSMCore(nn.Module):
108
+ def __init__(self, state_dim, input_dim, output_dim, dt=0.1, init_method='hippo'):
109
+ super().__init__()
110
+ self.state_dim = state_dim
111
+ self.input_dim = input_dim
112
+ self.output_dim = output_dim
113
+ self.dt = dt
114
+
115
+ if init_method == 'hippo':
116
+ self.A_continuous = nn.Parameter(self._init_hippo_matrix(state_dim))
117
+ else:
118
+ self.A_continuous = nn.Parameter(torch.randn(state_dim, state_dim) * 0.1)
119
+
120
+ self.B_continuous = nn.Parameter(torch.randn(state_dim, input_dim) * 0.1)
121
+ self.C = nn.Parameter(torch.randn(output_dim, state_dim) * 0.1)
122
+ self.D = nn.Parameter(torch.zeros(output_dim, input_dim))
123
+
124
+ self.time_controller = LiquidTimeConstantController(state_dim, input_dim, init_tau=1.0)
125
+
126
+ self.output_scale = nn.Parameter(torch.ones(output_dim))
127
+ self.output_bias = nn.Parameter(torch.zeros(output_dim))
128
+
129
+ self.state_normalizer = nn.LayerNorm(state_dim)
130
+
131
+ self.register_buffer('continuous_state', torch.zeros(1, state_dim))
132
+
133
+ def _init_hippo_matrix(self, N):
134
+ A = torch.zeros(N, N)
135
+ for i in range(N):
136
+ for j in range(N):
137
+ if i > j:
138
+ A[i, j] = math.sqrt(2 * i + 1) * math.sqrt(2 * j + 1)
139
+ elif i == j:
140
+ A[i, j] = -(2 * i + 1)
141
+ A = A * 0.1
142
+ with torch.no_grad():
143
+ eig = torch.linalg.eigvals(A).real.abs().max()
144
+ if eig > 0:
145
+ A = A / eig * 0.9
146
+ return A
147
+
148
+ def reset_state(self, batch_size=1):
149
+ device = self.A_continuous.device
150
+ self.continuous_state = torch.zeros(batch_size, self.state_dim, device=device)
151
+
152
+ def liquid_state_evolution(self, input_signal, num_steps=10):
153
+ batch_size = input_signal.shape[0]
154
+
155
+ if self.continuous_state.shape[0] != batch_size:
156
+ self.reset_state(batch_size)
157
+
158
+ tau = self.time_controller.get_time_constants(self.continuous_state, input_signal)
159
+ effective_dt = self.time_controller.get_effective_dt(tau, self.dt)
160
+
161
+ tau_matrix = torch.diag_embed(1.0 / tau)
162
+ liquid_A = self.A_continuous - tau_matrix
163
+
164
+ liquid_A = make_safe(liquid_A, min_val=-10.0, max_val=10.0)
165
+
166
+ A_discrete, B_discrete = continuous_to_discrete_time(
167
+ liquid_A, self.B_continuous, effective_dt
168
+ )
169
+
170
+ step_dt = float(effective_dt) / num_steps
171
+ A_discrete, B_discrete = continuous_to_discrete_time(
172
+ liquid_A, self.B_continuous, step_dt
173
+ )
174
+ current_state = self.continuous_state
175
+
176
+ if A_discrete.dim() == 3:
177
+ A_T = A_discrete.transpose(1, 2)
178
+ B_T = B_discrete.transpose(1, 2)
179
+ input_update = torch.bmm(input_signal.unsqueeze(1), B_T).squeeze(1)
180
+ for _ in range(num_steps):
181
+ state_update = torch.bmm(current_state.unsqueeze(1), A_T).squeeze(1)
182
+ current_state = state_update + input_update
183
+ current_state = make_safe(current_state)
184
+ else:
185
+ A_T = A_discrete.T
186
+ B_T = B_discrete.T
187
+ input_update = input_signal @ B_T
188
+ for _ in range(num_steps):
189
+ current_state = current_state @ A_T + input_update
190
+ current_state = make_safe(current_state)
191
+
192
+ current_state = make_safe(current_state)
193
+
194
+ self.continuous_state = current_state
195
+
196
+ return current_state, tau, effective_dt
197
+
198
+ def compute_output(self, state, input_signal):
199
+ normalized_state = self.state_normalizer(state)
200
+
201
+ state_output = torch.matmul(normalized_state, self.C.T)
202
+ direct_output = torch.matmul(input_signal, self.D.T)
203
+
204
+ raw_output = state_output + direct_output
205
+
206
+ output = self.output_scale * raw_output + self.output_bias
207
+
208
+ return make_safe(output)
209
+
210
+ def forward(self, input_signal, return_diagnostics=False):
211
+ evolved_state, tau, effective_dt = self.liquid_state_evolution(input_signal)
212
+
213
+ output = self.compute_output(evolved_state, input_signal)
214
+
215
+ result = {
216
+ 'output': output,
217
+ 'state': evolved_state
218
+ }
219
+
220
+ if return_diagnostics:
221
+ result.update({
222
+ 'time_constants': tau,
223
+ 'effective_dt': effective_dt,
224
+ 'state_norm': torch.norm(evolved_state, dim=-1),
225
+ 'adaptation_rate': self.time_controller.adaptation_rate
226
+ })
227
+
228
+ return result
229
+
230
+ ###########################################################################################################################################
231
+ ############################################- - - LIQUID SSM SEQUENCE LAYER - - -######################################################
232
+
233
+ class LiquidSSMSequenceLayer(nn.Module):
234
+ def __init__(self, input_dim, state_dim, output_dim, seq_len=None):
235
+ super().__init__()
236
+ self.input_dim = input_dim
237
+ self.state_dim = state_dim
238
+ self.output_dim = output_dim
239
+ self.seq_len = seq_len
240
+
241
+ self.liquid_ssm = LiquidSSMCore(state_dim, state_dim, output_dim)
242
+
243
+ self.input_projection = nn.Sequential(
244
+ nn.Linear(input_dim, state_dim),
245
+ nn.LayerNorm(state_dim),
246
+ nn.GELU()
247
+ )
248
+
249
+ self.output_projection = nn.Sequential(
250
+ nn.Linear(output_dim, output_dim * 2),
251
+ nn.LayerNorm(output_dim * 2),
252
+ nn.GELU(),
253
+ nn.Dropout(0.1),
254
+ nn.Linear(output_dim * 2, output_dim)
255
+ )
256
+
257
+ self.residual_weight = nn.Parameter(torch.tensor(0.1))
258
+
259
+ self.sequence_adapter = nn.Sequential(
260
+ nn.Linear(state_dim, state_dim),
261
+ nn.Tanh(),
262
+ nn.Linear(state_dim, 1),
263
+ nn.Sigmoid()
264
+ )
265
+
266
+ def forward(self, sequence, return_diagnostics=False):
267
+ batch_size, seq_len, input_dim = sequence.shape
268
+
269
+ self.liquid_ssm.reset_state(batch_size)
270
+
271
+ outputs = []
272
+ diagnostics = [] if return_diagnostics else None
273
+
274
+ for t in range(seq_len):
275
+ current_input = sequence[:, t, :]
276
+
277
+ projected_input = self.input_projection(current_input)
278
+
279
+ ssm_result = self.liquid_ssm(projected_input, return_diagnostics=return_diagnostics)
280
+
281
+ adaptation_factor = self.sequence_adapter(ssm_result['state'])
282
+ adapted_output = ssm_result['output'] * adaptation_factor
283
+
284
+ final_output = self.output_projection(adapted_output)
285
+
286
+ if final_output.shape == current_input.shape:
287
+ residual_strength = torch.clamp(self.residual_weight, 0.0, 1.0)
288
+ final_output = final_output + residual_strength * current_input
289
+
290
+ outputs.append(final_output)
291
+
292
+ if return_diagnostics:
293
+ diagnostics.append({
294
+ 'timestep': t,
295
+ 'adaptation_factor': adaptation_factor.mean().item(),
296
+ **ssm_result
297
+ })
298
+
299
+ output_sequence = torch.stack(outputs, dim=1)
300
+
301
+ result = {'output': output_sequence}
302
+
303
+ if return_diagnostics:
304
+ result['diagnostics'] = diagnostics
305
+
306
+ return result
307
+
308
+ ###########################################################################################################################################
309
+ ###########################################- - - LIQUID SSM LANGUAGE MODEL - - -#######################################################
310
+
311
+ class LiquidSSMLanguageModel(nn.Module):
312
+ def __init__(self, vocab_size, d_model=512, state_dim=256, num_layers=6, max_seq_len=2048):
313
+ super().__init__()
314
+ self.vocab_size = vocab_size
315
+ self.d_model = d_model
316
+ self.state_dim = state_dim
317
+ self.num_layers = num_layers
318
+ self.max_seq_len = max_seq_len
319
+
320
+ self.token_embedding = nn.Embedding(vocab_size, d_model)
321
+ self.position_embedding = nn.Embedding(max_seq_len, d_model)
322
+
323
+ self.liquid_layers = nn.ModuleList([
324
+ LiquidSSMSequenceLayer(d_model, state_dim, d_model)
325
+ for _ in range(num_layers)
326
+ ])
327
+
328
+ self.layer_norms = nn.ModuleList([
329
+ nn.LayerNorm(d_model) for _ in range(num_layers)
330
+ ])
331
+
332
+ self.output_norm = nn.LayerNorm(d_model)
333
+ self.lm_head = nn.Linear(d_model, vocab_size)
334
+
335
+ self.global_adaptation = nn.Sequential(
336
+ nn.Linear(d_model, d_model // 4),
337
+ nn.GELU(),
338
+ nn.Linear(d_model // 4, 1),
339
+ nn.Sigmoid()
340
+ )
341
+
342
+ self._init_weights()
343
+
344
+ def _init_weights(self):
345
+ for module in self.modules():
346
+ if isinstance(module, nn.Linear):
347
+ nn.init.xavier_uniform_(module.weight)
348
+ if module.bias is not None:
349
+ nn.init.zeros_(module.bias)
350
+ elif isinstance(module, nn.Embedding):
351
+ nn.init.normal_(module.weight, mean=0.0, std=0.02)
352
+
353
+ def forward(self, input_ids, labels=None, return_diagnostics=False):
354
+ batch_size, seq_len = input_ids.shape
355
+ device = input_ids.device
356
+
357
+ if seq_len > self.max_seq_len:
358
+ input_ids = input_ids[:, :self.max_seq_len]
359
+ seq_len = self.max_seq_len
360
+ if labels is not None:
361
+ labels = labels[:, :self.max_seq_len]
362
+
363
+ input_ids = torch.clamp(input_ids, 0, self.vocab_size - 1)
364
+
365
+ token_emb = self.token_embedding(input_ids)
366
+ pos_ids = torch.arange(seq_len, device=device).unsqueeze(0).expand(batch_size, -1)
367
+ pos_emb = self.position_embedding(pos_ids)
368
+
369
+ x = token_emb + pos_emb
370
+ x = make_safe(x)
371
+
372
+ layer_diagnostics = [] if return_diagnostics else None
373
+
374
+ for layer_idx, (liquid_layer, layer_norm) in enumerate(zip(self.liquid_layers, self.layer_norms)):
375
+ residual = x
376
+
377
+ x = layer_norm(x)
378
+
379
+ layer_result = liquid_layer(x, return_diagnostics=return_diagnostics)
380
+ x = layer_result['output']
381
+
382
+ adaptation = self.global_adaptation(x.mean(dim=1, keepdim=True))
383
+ x = x * adaptation
384
+
385
+ x = residual + x
386
+ x = make_safe(x)
387
+
388
+ if return_diagnostics:
389
+ layer_diagnostics.append({
390
+ 'layer': layer_idx,
391
+ 'adaptation': adaptation.mean().item(),
392
+ **layer_result
393
+ })
394
+
395
+ x = self.output_norm(x)
396
+ logits = self.lm_head(x)
397
+ logits = make_safe(logits, min_val=-50, max_val=50)
398
+
399
+ loss = None
400
+ if labels is not None:
401
+ shift_logits = logits[..., :-1, :].contiguous()
402
+ shift_labels = labels[..., 1:].contiguous()
403
+ loss = F.cross_entropy(
404
+ shift_logits.view(-1, self.vocab_size),
405
+ shift_labels.view(-1),
406
+ ignore_index=-100
407
+ )
408
+
409
+ result = {
410
+ 'logits': logits,
411
+ 'loss': loss
412
+ }
413
+
414
+ if return_diagnostics:
415
+ result['layer_diagnostics'] = layer_diagnostics
416
+
417
+ return result
418
+
419
+ @torch.no_grad()
420
+ def generate(self, input_ids, max_length=100, temperature=1.0, top_p=0.95, return_diagnostics=False):
421
+ self.eval()
422
+ generated = input_ids.clone()
423
+ all_diagnostics = [] if return_diagnostics else None
424
+
425
+ for step in range(max_length - input_ids.shape[1]):
426
+ if generated.shape[1] > self.max_seq_len:
427
+ break
428
+
429
+ outputs = self(generated, return_diagnostics=return_diagnostics)
430
+ logits = outputs['logits']
431
+
432
+ if return_diagnostics:
433
+ all_diagnostics.append(outputs.get('layer_diagnostics', []))
434
+
435
+ next_token_logits = logits[:, -1, :] / max(temperature, EPS)
436
+ next_token_logits = make_safe(next_token_logits, min_val=-50, max_val=50)
437
+
438
+ sorted_logits, sorted_indices = torch.sort(next_token_logits, descending=True)
439
+ cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
440
+
441
+ sorted_indices_to_remove = cumulative_probs > top_p
442
+ sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
443
+ sorted_indices_to_remove[..., 0] = False
444
+
445
+ for b in range(next_token_logits.size(0)):
446
+ indices_to_remove = sorted_indices[b][sorted_indices_to_remove[b]]
447
+ next_token_logits[b, indices_to_remove] = -float('inf')
448
+
449
+ probs = F.softmax(next_token_logits, dim=-1)
450
+ next_token = torch.multinomial(probs, num_samples=1)
451
+ next_token = torch.clamp(next_token, 0, self.vocab_size - 1)
452
+
453
+ generated = torch.cat([generated, next_token], dim=1)
454
+
455
+ if next_token.item() == 2: # EOS token
456
+ break
457
+
458
+ result = {'generated_ids': generated}
459
+ if return_diagnostics:
460
+ result['diagnostics'] = all_diagnostics
461
+
462
+ return result
463
+
liquid_state_space_docs.py ADDED
@@ -0,0 +1,1107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ##################################################################################################################################################
2
+ #||||- - - |6.25.2025| - - - || LIQUID STATE SPACE || - - - |1990two| - - -|||| #
3
+ ##################################################################################################################################################
4
+
5
+ """
6
+ Mathematical Foundation & Conceptual Documentation
7
+ -------------------------------------------------
8
+
9
+ CORE PRINCIPLE:
10
+ Combines state space models with liquid computing principles to create adaptive
11
+ continuous-time dynamics for sequence processing. The system learns time constants
12
+ dynamically based on input characteristics, enabling efficient processing of
13
+ variable-speed temporal patterns.
14
+
15
+ MATHEMATICAL FOUNDATION:
16
+ =======================
17
+
18
+ 1. STATE SPACE MODEL FUNDAMENTALS:
19
+ Continuous-time: dx/dt = Ax(t) + Bu(t)
20
+ y(t) = Cx(t) + Du(t)
21
+
22
+ Discrete-time: x[k+1] = A_d·x[k] + B_d·u[k]
23
+ y[k] = C·x[k] + D·u[k]
24
+
25
+ Where:
26
+ - x(t): state vector (hidden representation)
27
+ - u(t): input vector (external signals)
28
+ - y(t): output vector (observations)
29
+ - A: state transition matrix (dynamics)
30
+ - B: input matrix (how inputs affect states)
31
+ - C: output matrix (how states generate outputs)
32
+ - D: feedthrough matrix (direct input-output)
33
+
34
+ 2. LIQUID DYNAMICS WITH ADAPTIVE TIME CONSTANTS:
35
+ dx/dt = -x/τ(x,u) + A·x + B·u
36
+
37
+ Where Ï„(x,u) are adaptive time constants:
38
+ τ(x,u) = τ_base · (1 + α·φ(x,u))
39
+
40
+ - τ_base: learnable base time constants
41
+ - α: adaptation rate parameter
42
+ - φ(x,u): neural adaptation function
43
+
44
+ Fast time constants → quick adaptation to rapid changes
45
+ Slow time constants → smooth integration of stable patterns
46
+
47
+ 3. CONTINUOUS-TO-DISCRETE CONVERSION:
48
+ Using matrix exponential and zero-order hold:
49
+
50
+ A_d = exp(A·Δt)
51
+ B_d = A^(-1)·(A_d - I)·B
52
+
53
+ For numerical stability, we use:
54
+ [A_d B_d] = exp([A B] · Δt)
55
+ [0 I ] [0 0]
56
+
57
+ 4. HIPPO MATRIX INITIALIZATION:
58
+ HiPPO (High-order Polynomial Projection Operators) for optimal memory:
59
+
60
+ A_ij = {√(2i+1)·√(2j+1) if i > j
61
+ {-(2i+1) if i = j
62
+ {0 if i < j
63
+
64
+ This creates a skew-symmetric structure that preserves information
65
+ over long sequences by projecting onto Legendre polynomials.
66
+
67
+ 5. NUMERICAL INTEGRATION:
68
+ Multi-step Euler method for stability:
69
+ x(t+Δt) = x(t) + Δt·f(x(t),u(t))
70
+
71
+ With adaptive time stepping:
72
+ Δt_eff = min(Δt_target, 0.1·min(τ))
73
+
74
+
75
+ CONCEPTUAL REASONING:
76
+ ====================
77
+
78
+ WHY LIQUID + STATE SPACE MODELS?
79
+ - Traditional SSMs have fixed dynamics
80
+ - Real-world sequences have variable temporal scales
81
+ - Liquid dynamics enable adaptive processing speeds
82
+ - Continuous-time formulation handles irregular sampling
83
+
84
+ KEY INNOVATIONS:
85
+ 1. **Adaptive Time Constants**: Learn processing speed from data
86
+ 2. **HiPPO Initialization**: Optimal memory retention properties
87
+ 3. **Continuous-Discrete Bridge**: Seamless time-domain conversion
88
+ 4. **Multi-Scale Processing**: Handle fast and slow temporal patterns
89
+ 5. **Efficient Implementation**: Linear complexity in sequence length
90
+
91
+ APPLICATIONS:
92
+ - Long-range sequence modeling (DNA, audio, text)
93
+ - Time-series with irregular sampling rates
94
+ - Speech recognition with variable speaking speeds
95
+ - Language modeling with adaptive processing
96
+ - Control systems with time-varying dynamics
97
+
98
+ COMPLEXITY ANALYSIS:
99
+ - Time: O(N·d²) where N=sequence length, d=state dimension
100
+ - Space: O(d²) for state matrices + O(N·d) for sequence states
101
+ - Training: O(N·d²·L) where L=number of layers
102
+ - Inference: Linear in sequence length (vs quadratic for attention)
103
+
104
+ ADVANTAGES OVER TRANSFORMERS:
105
+ - Linear complexity vs quadratic attention
106
+ - Continuous-time formulation handles variable rates
107
+ - Built-in inductive bias for temporal dynamics
108
+ - Natural handling of infinite-length sequences
109
+ - Memory-efficient processing of long sequences
110
+
111
+ BIOLOGICAL INSPIRATION:
112
+ - Neural membrane time constants in biological circuits
113
+ - Adaptive integration windows in cortical processing
114
+ - Multiple timescale dynamics in neural networks
115
+ - Continuous-time neural differential equations
116
+ """
117
+
118
+ from __future__ import annotations
119
+ import torch
120
+ import torch.nn as nn
121
+ import torch.nn.functional as F
122
+ import numpy as np
123
+ import math
124
+ from typing import List, Dict, Tuple, Optional, Union, Any
125
+ from scipy import linalg
126
+ from scipy.signal import cont2discrete
127
+
128
+ # Numerical stability constants
129
+ SAFE_MIN: float = -1e6
130
+ SAFE_MAX: float = 1e6
131
+ EPS: float = 1e-8
132
+
133
+ #||||- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - ð“…¸ - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -||||#
134
+
135
+ def make_safe(
136
+ tensor: torch.Tensor,
137
+ min_val: float = SAFE_MIN,
138
+ max_val: float = SAFE_MAX
139
+ ) -> torch.Tensor:
140
+ """Clamp tensor values to safe numerical range, replacing NaN/Inf.
141
+
142
+ Args:
143
+ tensor: Input tensor to make numerically safe
144
+ min_val: Minimum allowed value
145
+ max_val: Maximum allowed value
146
+
147
+ Returns:
148
+ Numerically safe tensor with values in [min_val, max_val]
149
+ """
150
+ tensor = torch.where(torch.isnan(tensor), torch.tensor(0.0, device=tensor.device), tensor)
151
+ tensor = torch.where(torch.isinf(tensor), torch.tensor(max_val, device=tensor.device), tensor)
152
+ return torch.clamp(tensor, min_val, max_val)
153
+
154
+ def discrete_to_continuous_time(A_discrete: torch.Tensor, dt: float = 1.0) -> torch.Tensor:
155
+ """Convert discrete-time matrix to continuous-time using matrix logarithm.
156
+
157
+ Mathematical Details:
158
+ If A_d = exp(A_c · dt), then A_c = log(A_d) / dt
159
+
160
+ Args:
161
+ A_discrete: Discrete-time state transition matrix
162
+ dt: Time step used in discretization
163
+
164
+ Returns:
165
+ Continuous-time state matrix
166
+ """
167
+ try:
168
+ A_continuous = linalg.logm(A_discrete.detach().cpu().numpy()) / dt
169
+ return torch.tensor(A_continuous, dtype=torch.float32, device=A_discrete.device)
170
+ except:
171
+ # Fallback to small identity if matrix logarithm fails
172
+ return torch.eye(A_discrete.shape[0], device=A_discrete.device) * 0.01
173
+
174
+ def continuous_to_discrete_time(
175
+ A_continuous: torch.Tensor,
176
+ B_continuous: torch.Tensor,
177
+ dt: float = 1.0
178
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
179
+ """Convert continuous-time system to discrete-time using zero-order hold.
180
+
181
+ Mathematical Details:
182
+ Uses matrix exponential method for exact discretization:
183
+ [A_d B_d] = exp([A B] · dt)
184
+ [0 I ] [0 0]
185
+
186
+ Handles batched matrices by processing each batch element individually
187
+ to avoid SciPy's limitation with multi-dimensional arrays.
188
+
189
+ Args:
190
+ A_continuous: Continuous-time state matrix [batch?, state, state]
191
+ B_continuous: Continuous-time input matrix [state, input]
192
+ dt: Time step for discretization
193
+
194
+ Returns:
195
+ Tuple of (A_discrete, B_discrete) matrices
196
+ """
197
+ try:
198
+ A_np = A_continuous.detach().cpu().numpy()
199
+ B_np = B_continuous.detach().cpu().numpy()
200
+
201
+ if A_np.ndim == 3:
202
+ # Handle batched matrices
203
+ A_list, B_list = [], []
204
+ for i in range(A_np.shape[0]):
205
+ Ad, Bd, _, _, _ = cont2discrete(
206
+ (A_np[i], B_np, np.eye(A_np.shape[-1]), 0), dt
207
+ )
208
+ A_list.append(Ad)
209
+ B_list.append(Bd)
210
+ A_discrete = torch.tensor(np.stack(A_list), dtype=torch.float32, device=A_continuous.device)
211
+ B_discrete = torch.tensor(np.stack(B_list), dtype=torch.float32, device=B_continuous.device)
212
+ else:
213
+ # Handle single matrix
214
+ A_discrete, B_discrete, _, _, _ = cont2discrete(
215
+ (A_np, B_np, np.eye(A_np.shape[0]), 0), dt
216
+ )
217
+ A_discrete = torch.tensor(A_discrete, dtype=torch.float32, device=A_continuous.device)
218
+ B_discrete = torch.tensor(B_discrete, dtype=torch.float32, device=B_continuous.device)
219
+
220
+ return A_discrete, B_discrete
221
+ except Exception:
222
+ # Fallback to first-order Euler approximation
223
+ n = A_continuous.shape[-1]
224
+ eye = torch.eye(n, device=A_continuous.device)
225
+ if A_continuous.dim() == 3:
226
+ eye = eye.unsqueeze(0).expand(A_continuous.size(0), -1, -1)
227
+ B_disc = B_continuous.unsqueeze(0).expand(A_continuous.size(0), -1, -1)
228
+ else:
229
+ B_disc = B_continuous
230
+ A_discrete = eye + A_continuous * dt
231
+ B_discrete = B_disc * dt
232
+ return A_discrete, B_discrete
233
+
234
+ ###########################################################################################################################################
235
+ #############################################- - - LIQUID TIME CONSTANT CONTROLLER - - -###############################################
236
+
237
+ class LiquidTimeConstantController(nn.Module):
238
+ """Adaptive time constant controller for liquid dynamics.
239
+
240
+ Controls the temporal dynamics of the liquid state by learning context-dependent
241
+ time constants. Fast time constants enable quick adaptation to rapid changes,
242
+ while slow time constants provide stable integration of persistent patterns.
243
+
244
+ Mathematical Framework:
245
+ - Base time constants: τ_base = exp(log_τ)
246
+ - Adaptive modulation: τ(x,u) = τ_base · (1 + α·φ(x,u))
247
+ - Neural adaptation: φ(x,u) = tanh(W·[x,u] + b)
248
+ - Stability constraint: τ ∈ [0.01, 10.0]
249
+ """
250
+
251
+ def __init__(
252
+ self,
253
+ state_dim: int,
254
+ input_dim: int,
255
+ init_tau: float = 1.0
256
+ ) -> None:
257
+ """Initialize adaptive time constant controller.
258
+
259
+ Args:
260
+ state_dim: Dimension of state vector
261
+ input_dim: Dimension of input vector
262
+ init_tau: Initial time constant value
263
+ """
264
+ super().__init__()
265
+ self.state_dim = state_dim
266
+ self.input_dim = input_dim
267
+
268
+ # Learnable base time constants (in log space for positivity)
269
+ self.log_tau = nn.Parameter(torch.ones(state_dim) * math.log(init_tau))
270
+
271
+ # Neural network for adaptive time constant modulation
272
+ # Takes concatenated state and input, outputs modulation factors
273
+ self.tau_adaptation = nn.Sequential(
274
+ nn.Linear(state_dim + input_dim, state_dim * 2),
275
+ nn.LayerNorm(state_dim * 2),
276
+ nn.Tanh(),
277
+ nn.Linear(state_dim * 2, state_dim),
278
+ nn.Tanh() # Output in [-1, 1] for stable modulation
279
+ )
280
+
281
+ # Meta-learning rate controlling adaptation strength
282
+ self.adaptation_rate = nn.Parameter(torch.tensor(0.1))
283
+
284
+ def get_time_constants(
285
+ self,
286
+ state: torch.Tensor,
287
+ input_signal: torch.Tensor
288
+ ) -> torch.Tensor:
289
+ """Compute context-dependent time constants.
290
+
291
+ Mathematical Details:
292
+ 1. Base time constants: τ_base = exp(log_τ)
293
+ 2. Context features: f = [state, input]
294
+ 3. Modulation: m = tanh(W·f + b)
295
+ 4. Final time constants: τ = τ_base · (1 + α·m)
296
+
297
+ Args:
298
+ state: Current liquid state [batch_size, state_dim]
299
+ input_signal: Current input [batch_size, input_dim]
300
+
301
+ Returns:
302
+ Adaptive time constants [batch_size, state_dim]
303
+ """
304
+ # Convert log time constants to positive values
305
+ base_tau = torch.exp(self.log_tau)
306
+ base_tau = torch.clamp(base_tau, 0.01, 10.0)
307
+
308
+ # Compute adaptive modulation based on current context
309
+ combined_input = torch.cat([state, input_signal], dim=-1)
310
+ tau_modulation = self.tau_adaptation(combined_input)
311
+
312
+ # Apply modulation with learnable adaptation rate
313
+ adaptation_rate = torch.clamp(self.adaptation_rate, 0.001, 1.0)
314
+ modulated_tau = base_tau * (1.0 + adaptation_rate * tau_modulation)
315
+
316
+ # Ensure time constants remain in stable range
317
+ return torch.clamp(modulated_tau, 0.01, 10.0)
318
+
319
+ def get_effective_dt(self, tau: torch.Tensor, target_dt: float = 0.1) -> float:
320
+ """Compute effective time step for numerical stability.
321
+
322
+ The effective time step is chosen to be much smaller than the fastest
323
+ time constant to ensure numerical stability of the integration.
324
+
325
+ Mathematical Constraint:
326
+ Δt_eff ≤ 0.1 · min(τ) for stability
327
+
328
+ Args:
329
+ tau: Time constants tensor [batch_size, state_dim]
330
+ target_dt: Desired time step
331
+
332
+ Returns:
333
+ Effective time step (scalar)
334
+ """
335
+ # Find minimum time constant for stability constraint
336
+ min_tau_val = torch.min(tau).item()
337
+ effective_dt = max(0.001, min(float(target_dt), min_tau_val * 0.1))
338
+ return effective_dt
339
+
340
+ ###########################################################################################################################################
341
+ ################################################- - - LIQUID SSM CORE - - -############################################################
342
+
343
+ class LiquidSSMCore(nn.Module):
344
+ """Core Liquid State Space Model with adaptive continuous-time dynamics.
345
+
346
+ Implements a state space model with liquid computing principles where
347
+ time constants adapt based on input characteristics. Combines the
348
+ representational power of SSMs with the adaptability of liquid dynamics.
349
+
350
+ Mathematical Framework:
351
+ - Liquid dynamics: dx/dt = -x/τ(x,u) + A·x + B·u
352
+ - Output equation: y = C·x + D·u
353
+ - HiPPO initialization for optimal memory properties
354
+ - Adaptive discretization for numerical integration
355
+ """
356
+
357
+ def __init__(
358
+ self,
359
+ state_dim: int,
360
+ input_dim: int,
361
+ output_dim: int,
362
+ dt: float = 0.1,
363
+ init_method: str = 'hippo'
364
+ ) -> None:
365
+ """Initialize Liquid SSM core with adaptive dynamics.
366
+
367
+ Args:
368
+ state_dim: Dimension of hidden state vector
369
+ input_dim: Dimension of input vector
370
+ output_dim: Dimension of output vector
371
+ dt: Target time step for integration
372
+ init_method: Initialization method ('hippo' or 'random')
373
+ """
374
+ super().__init__()
375
+ self.state_dim = state_dim
376
+ self.input_dim = input_dim
377
+ self.output_dim = output_dim
378
+ self.dt = dt
379
+
380
+ # Initialize continuous-time state transition matrix
381
+ if init_method == 'hippo':
382
+ self.A_continuous = nn.Parameter(self._init_hippo_matrix(state_dim))
383
+ else:
384
+ self.A_continuous = nn.Parameter(torch.randn(state_dim, state_dim) * 0.1)
385
+
386
+ # Input, output, and feedthrough matrices
387
+ self.B_continuous = nn.Parameter(torch.randn(state_dim, input_dim) * 0.1)
388
+ self.C = nn.Parameter(torch.randn(output_dim, state_dim) * 0.1)
389
+ self.D = nn.Parameter(torch.zeros(output_dim, input_dim))
390
+
391
+ # Adaptive time constant controller
392
+ self.time_controller = LiquidTimeConstantController(state_dim, input_dim, init_tau=1.0)
393
+
394
+ # Learnable output scaling and bias
395
+ self.output_scale = nn.Parameter(torch.ones(output_dim))
396
+ self.output_bias = nn.Parameter(torch.zeros(output_dim))
397
+
398
+ # State normalization for training stability
399
+ self.state_normalizer = nn.LayerNorm(state_dim)
400
+
401
+ # Current continuous state (persistent memory)
402
+ self.register_buffer('continuous_state', torch.zeros(1, state_dim))
403
+
404
+ def _init_hippo_matrix(self, N: int) -> torch.Tensor:
405
+ """Initialize state matrix with HiPPO structure for optimal memory.
406
+
407
+ HiPPO (High-order Polynomial Projection Operators) creates a state
408
+ transition matrix that optimally preserves information by projecting
409
+ the input history onto a basis of Legendre polynomials.
410
+
411
+ Mathematical Details:
412
+ A_ij = {√(2i+1)·√(2j+1) if i > j (coupling strength)
413
+ {-(2i+1) if i = j (decay rate)
414
+ {0 if i < j (causality)
415
+
416
+ Args:
417
+ N: State dimension (number of basis functions)
418
+
419
+ Returns:
420
+ HiPPO matrix [N, N]
421
+ """
422
+ A = torch.zeros(N, N)
423
+ for i in range(N):
424
+ for j in range(N):
425
+ if i > j:
426
+ # Coupling between basis functions
427
+ A[i, j] = math.sqrt(2 * i + 1) * math.sqrt(2 * j + 1)
428
+ elif i == j:
429
+ # Decay rate for each basis function
430
+ A[i, j] = -(2 * i + 1)
431
+ return A * 0.1 # Scale for training stability
432
+
433
+ def reset_state(self, batch_size: int = 1) -> None:
434
+ """Reset continuous state for new sequence processing.
435
+
436
+ Args:
437
+ batch_size: Number of parallel sequences to process
438
+ """
439
+ device = self.A_continuous.device
440
+ self.continuous_state = torch.zeros(batch_size, self.state_dim, device=device)
441
+
442
+ def liquid_state_evolution(
443
+ self,
444
+ input_signal: torch.Tensor,
445
+ num_steps: int = 10
446
+ ) -> Tuple[torch.Tensor, torch.Tensor, float]:
447
+ """Evolve state using adaptive liquid dynamics with numerical integration.
448
+
449
+ Implements the core liquid evolution equation:
450
+ dx/dt = -x/τ(x,u) + A·x + B·u
451
+
452
+ Uses multi-step integration for numerical accuracy and adaptive
453
+ time stepping based on the fastest time constant.
454
+
455
+ Mathematical Process:
456
+ 1. Compute adaptive time constants: Ï„(x,u)
457
+ 2. Form liquid dynamics matrix: A_liquid = A - diag(1/Ï„)
458
+ 3. Discretize system: (A_d, B_d) = discretize(A_liquid, B, Δt)
459
+ 4. Integrate: x(k+1) = A_d·x(k) + B_d·u(k)
460
+
461
+ Args:
462
+ input_signal: External input [batch_size, input_dim]
463
+ num_steps: Number of integration steps for accuracy
464
+
465
+ Returns:
466
+ Tuple of (evolved_state, time_constants, effective_dt)
467
+ """
468
+ batch_size = input_signal.shape[0]
469
+
470
+ # Ensure state tensor matches batch size
471
+ if self.continuous_state.shape[0] != batch_size:
472
+ self.reset_state(batch_size)
473
+
474
+ # Compute adaptive time constants based on current state and input
475
+ tau = self.time_controller.get_time_constants(self.continuous_state, input_signal)
476
+ effective_dt = self.time_controller.get_effective_dt(tau, self.dt)
477
+
478
+ # Create time-varying dynamics matrix with liquid adaptation
479
+ # Standard SSM: dx/dt = A·x + B·u
480
+ # Liquid SSM: dx/dt = -x/τ + A·x + B·u = (A - diag(1/τ))·x + B·u
481
+ tau_matrix = torch.diag_embed(1.0 / tau) # Decay rates
482
+ liquid_A = self.A_continuous - tau_matrix
483
+
484
+ # Ensure numerical stability
485
+ liquid_A = make_safe(liquid_A, min_val=-10.0, max_val=10.0)
486
+
487
+ # Convert to discrete-time for numerical integration
488
+ A_discrete, B_discrete = continuous_to_discrete_time(
489
+ liquid_A, self.B_continuous, effective_dt
490
+ )
491
+
492
+ # Multi-step integration for improved accuracy
493
+ current_state = self.continuous_state
494
+
495
+ # Handle batched vs single matrix operations
496
+ if A_discrete.dim() == 3:
497
+ # Batched matrix multiplication
498
+ A_T = A_discrete.transpose(1, 2)
499
+ B_T = B_discrete.transpose(1, 2)
500
+ input_update = torch.bmm(input_signal.unsqueeze(1), B_T).squeeze(1)
501
+ for _ in range(num_steps):
502
+ state_update = torch.bmm(current_state.unsqueeze(1), A_T).squeeze(1)
503
+ current_state = state_update + input_update
504
+ current_state = make_safe(current_state)
505
+ else:
506
+ # Single matrix operations
507
+ A_T = A_discrete.T
508
+ B_T = B_discrete.T
509
+ input_update = input_signal @ B_T
510
+ for _ in range(num_steps):
511
+ current_state = current_state @ A_T + input_update
512
+ current_state = make_safe(current_state)
513
+
514
+ # Update persistent state
515
+ self.continuous_state = current_state
516
+
517
+ return current_state, tau, effective_dt
518
+
519
+ def compute_output(
520
+ self,
521
+ state: torch.Tensor,
522
+ input_signal: torch.Tensor
523
+ ) -> torch.Tensor:
524
+ """Compute output from state space model: y = C·x + D·u.
525
+
526
+ Args:
527
+ state: Current state vector [batch_size, state_dim]
528
+ input_signal: Current input [batch_size, input_dim]
529
+
530
+ Returns:
531
+ Output vector [batch_size, output_dim]
532
+ """
533
+ # Normalize state for training stability
534
+ normalized_state = self.state_normalizer(state)
535
+
536
+ # Standard SSM output equation
537
+ state_output = torch.matmul(normalized_state, self.C.T) # C·x
538
+ direct_output = torch.matmul(input_signal, self.D.T) # D·u
539
+
540
+ raw_output = state_output + direct_output
541
+
542
+ # Apply learnable output scaling and bias
543
+ output = self.output_scale * raw_output + self.output_bias
544
+
545
+ return make_safe(output)
546
+
547
+ def forward(
548
+ self,
549
+ input_signal: torch.Tensor,
550
+ return_diagnostics: bool = False
551
+ ) -> Dict[str, Union[torch.Tensor, float]]:
552
+ """Complete forward pass through Liquid SSM.
553
+
554
+ Args:
555
+ input_signal: Input vector [batch_size, input_dim]
556
+ return_diagnostics: Whether to return diagnostic information
557
+
558
+ Returns:
559
+ Dictionary containing output and optional diagnostics
560
+ """
561
+ # Evolve liquid state with adaptive dynamics
562
+ evolved_state, tau, effective_dt = self.liquid_state_evolution(input_signal)
563
+
564
+ # Compute output from current state
565
+ output = self.compute_output(evolved_state, input_signal)
566
+
567
+ result = {
568
+ 'output': output,
569
+ 'state': evolved_state
570
+ }
571
+
572
+ if return_diagnostics:
573
+ result.update({
574
+ 'time_constants': tau,
575
+ 'effective_dt': effective_dt,
576
+ 'state_norm': torch.norm(evolved_state, dim=-1),
577
+ 'adaptation_rate': self.time_controller.adaptation_rate
578
+ })
579
+
580
+ return result
581
+
582
+ ###########################################################################################################################################
583
+ ############################################- - - LIQUID SSM SEQUENCE LAYER - - -######################################################
584
+
585
+ class LiquidSSMSequenceLayer(nn.Module):
586
+ """Sequence processing layer using Liquid SSM with residual connections.
587
+
588
+ Processes variable-length sequences through Liquid SSM while maintaining
589
+ adaptive dynamics across time steps. Includes input/output projections,
590
+ residual connections, and sequence-level adaptation mechanisms.
591
+
592
+ Architecture:
593
+ Input → Projection → Liquid SSM → Sequence Adaptation → Output Projection → Residual
594
+ """
595
+
596
+ def __init__(
597
+ self,
598
+ input_dim: int,
599
+ state_dim: int,
600
+ output_dim: int,
601
+ seq_len: Optional[int] = None
602
+ ) -> None:
603
+ """Initialize Liquid SSM sequence processing layer.
604
+
605
+ Args:
606
+ input_dim: Dimension of input features
607
+ state_dim: Dimension of internal state
608
+ output_dim: Dimension of output features
609
+ seq_len: Maximum sequence length (optional)
610
+ """
611
+ super().__init__()
612
+ self.input_dim = input_dim
613
+ self.state_dim = state_dim
614
+ self.output_dim = output_dim
615
+ self.seq_len = seq_len
616
+
617
+ # Core Liquid SSM operating on projected state dimension
618
+ # Both input and state dimensions set to state_dim to ensure
619
+ # compatibility in time constant controller computations
620
+ self.liquid_ssm = LiquidSSMCore(state_dim, state_dim, output_dim)
621
+
622
+ # Input projection and preprocessing
623
+ self.input_projection = nn.Sequential(
624
+ nn.Linear(input_dim, state_dim),
625
+ nn.LayerNorm(state_dim),
626
+ nn.GELU()
627
+ )
628
+
629
+ # Output projection and postprocessing
630
+ self.output_projection = nn.Sequential(
631
+ nn.Linear(output_dim, output_dim * 2),
632
+ nn.LayerNorm(output_dim * 2),
633
+ nn.GELU(),
634
+ nn.Dropout(0.1),
635
+ nn.Linear(output_dim * 2, output_dim)
636
+ )
637
+
638
+ # Learnable residual connection strength
639
+ self.residual_weight = nn.Parameter(torch.tensor(0.1))
640
+
641
+ # Sequence-level adaptation mechanism
642
+ self.sequence_adapter = nn.Sequential(
643
+ nn.Linear(state_dim, state_dim),
644
+ nn.Tanh(),
645
+ nn.Linear(state_dim, 1),
646
+ nn.Sigmoid()
647
+ )
648
+
649
+ def forward(
650
+ self,
651
+ sequence: torch.Tensor,
652
+ return_diagnostics: bool = False
653
+ ) -> Dict[str, Union[torch.Tensor, List[Dict]]]:
654
+ """Process complete sequence through Liquid SSM.
655
+
656
+ Processes each time step sequentially while maintaining liquid state
657
+ continuity across the sequence. Applies sequence-level adaptation
658
+ and residual connections for improved gradient flow.
659
+
660
+ Args:
661
+ sequence: Input sequence [batch_size, seq_len, input_dim]
662
+ return_diagnostics: Whether to return per-timestep diagnostics
663
+
664
+ Returns:
665
+ Dictionary containing output sequence and optional diagnostics
666
+ """
667
+ batch_size, seq_len, input_dim = sequence.shape
668
+
669
+ # Reset SSM state for new sequence
670
+ self.liquid_ssm.reset_state(batch_size)
671
+
672
+ # Process sequence timestep by timestep
673
+ outputs = []
674
+ diagnostics = [] if return_diagnostics else None
675
+
676
+ for t in range(seq_len):
677
+ # Extract current timestep input
678
+ current_input = sequence[:, t, :]
679
+
680
+ # Project input to state dimension
681
+ projected_input = self.input_projection(current_input)
682
+
683
+ # Process through Liquid SSM
684
+ ssm_result = self.liquid_ssm(projected_input, return_diagnostics=return_diagnostics)
685
+
686
+ # Apply sequence-level adaptation
687
+ adaptation_factor = self.sequence_adapter(ssm_result['state'])
688
+ adapted_output = ssm_result['output'] * adaptation_factor
689
+
690
+ # Post-process output
691
+ final_output = self.output_projection(adapted_output)
692
+
693
+ # Apply residual connection if dimensions match
694
+ if final_output.shape == current_input.shape:
695
+ residual_strength = torch.clamp(self.residual_weight, 0.0, 1.0)
696
+ final_output = final_output + residual_strength * current_input
697
+
698
+ outputs.append(final_output)
699
+
700
+ if return_diagnostics:
701
+ diagnostics.append({
702
+ 'timestep': t,
703
+ 'adaptation_factor': adaptation_factor.mean().item(),
704
+ **ssm_result
705
+ })
706
+
707
+ # Stack outputs along sequence dimension
708
+ output_sequence = torch.stack(outputs, dim=1)
709
+
710
+ result = {'output': output_sequence}
711
+
712
+ if return_diagnostics:
713
+ result['diagnostics'] = diagnostics
714
+
715
+ return result
716
+
717
+ ###########################################################################################################################################
718
+ ##############################################- - - LIQUID SSM LANGUAGE MODEL - - -####################################################
719
+
720
+ class LiquidSSMLanguageModel(nn.Module):
721
+ """Complete language model using Liquid State Space Models.
722
+
723
+ Implements a transformer-alternative architecture using Liquid SSMs for
724
+ sequence processing. Provides linear complexity in sequence length while
725
+ maintaining strong representational capabilities through adaptive dynamics.
726
+
727
+ Architecture:
728
+ Embeddings → Liquid SSM Layers → Output Head
729
+
730
+ Each layer includes:
731
+ - Layer normalization
732
+ - Liquid SSM processing
733
+ - Global adaptation
734
+ - Residual connections
735
+ """
736
+
737
+ def __init__(
738
+ self,
739
+ vocab_size: int,
740
+ d_model: int = 512,
741
+ state_dim: int = 256,
742
+ num_layers: int = 6,
743
+ max_seq_len: int = 2048
744
+ ) -> None:
745
+ """Initialize Liquid SSM Language Model.
746
+
747
+ Args:
748
+ vocab_size: Size of vocabulary
749
+ d_model: Model dimension (embedding/hidden size)
750
+ state_dim: Liquid state dimension
751
+ num_layers: Number of Liquid SSM layers
752
+ max_seq_len: Maximum sequence length
753
+ """
754
+ super().__init__()
755
+ self.vocab_size = vocab_size
756
+ self.d_model = d_model
757
+ self.state_dim = state_dim
758
+ self.num_layers = num_layers
759
+ self.max_seq_len = max_seq_len
760
+
761
+ # Token and position embeddings
762
+ self.token_embedding = nn.Embedding(vocab_size, d_model)
763
+ self.position_embedding = nn.Embedding(max_seq_len, d_model)
764
+
765
+ # Stack of Liquid SSM layers
766
+ self.liquid_layers = nn.ModuleList([
767
+ LiquidSSMSequenceLayer(d_model, state_dim, d_model)
768
+ for _ in range(num_layers)
769
+ ])
770
+
771
+ # Layer normalization for each layer
772
+ self.layer_norms = nn.ModuleList([
773
+ nn.LayerNorm(d_model) for _ in range(num_layers)
774
+ ])
775
+
776
+ # Output head for language modeling
777
+ self.output_norm = nn.LayerNorm(d_model)
778
+ self.lm_head = nn.Linear(d_model, vocab_size)
779
+
780
+ # Global adaptation mechanism
781
+ self.global_adaptation = nn.Sequential(
782
+ nn.Linear(d_model, d_model // 4),
783
+ nn.GELU(),
784
+ nn.Linear(d_model // 4, 1),
785
+ nn.Sigmoid()
786
+ )
787
+
788
+ self._init_weights()
789
+
790
+ def _init_weights(self) -> None:
791
+ for module in self.modules():
792
+ if isinstance(module, nn.Linear):
793
+ nn.init.xavier_uniform_(module.weight)
794
+ if module.bias is not None:
795
+ nn.init.zeros_(module.bias)
796
+ elif isinstance(module, nn.Embedding):
797
+ nn.init.normal_(module.weight, mean=0.0, std=0.02)
798
+
799
+ def forward(
800
+ self,
801
+ input_ids: torch.Tensor,
802
+ labels: Optional[torch.Tensor] = None,
803
+ return_diagnostics: bool = False
804
+ ) -> Dict[str, Union[torch.Tensor, List[Dict]]]:
805
+ """Forward pass through Liquid SSM Language Model.
806
+
807
+ Args:
808
+ input_ids: Token IDs [batch_size, seq_len]
809
+ labels: Target labels for loss computation [batch_size, seq_len]
810
+ return_diagnostics: Whether to return layer diagnostics
811
+
812
+ Returns:
813
+ Dictionary containing logits, loss, and optional diagnostics
814
+ """
815
+ batch_size, seq_len = input_ids.shape
816
+ device = input_ids.device
817
+
818
+ # Clamp sequence length to maximum supported
819
+ if seq_len > self.max_seq_len:
820
+ input_ids = input_ids[:, :self.max_seq_len]
821
+ seq_len = self.max_seq_len
822
+ if labels is not None:
823
+ labels = labels[:, :self.max_seq_len]
824
+
825
+ # Ensure valid token IDs
826
+ input_ids = torch.clamp(input_ids, 0, self.vocab_size - 1)
827
+
828
+ # Compute embeddings
829
+ token_emb = self.token_embedding(input_ids)
830
+ pos_ids = torch.arange(seq_len, device=device).unsqueeze(0).expand(batch_size, -1)
831
+ pos_emb = self.position_embedding(pos_ids)
832
+
833
+ x = token_emb + pos_emb
834
+ x = make_safe(x)
835
+
836
+ # Store layer diagnostics if requested
837
+ layer_diagnostics = [] if return_diagnostics else None
838
+
839
+ # Process through Liquid SSM layers
840
+ for layer_idx, (liquid_layer, layer_norm) in enumerate(zip(self.liquid_layers, self.layer_norms)):
841
+ # Store input for residual connection
842
+ residual = x
843
+
844
+ # Pre-layer normalization
845
+ x = layer_norm(x)
846
+
847
+ # Liquid SSM processing
848
+ layer_result = liquid_layer(x, return_diagnostics=return_diagnostics)
849
+ x = layer_result['output']
850
+
851
+ # Global adaptation based on sequence statistics
852
+ adaptation = self.global_adaptation(x.mean(dim=1, keepdim=True))
853
+ x = x * adaptation
854
+
855
+ # Residual connection
856
+ x = residual + x
857
+ x = make_safe(x)
858
+
859
+ if return_diagnostics:
860
+ layer_diagnostics.append({
861
+ 'layer': layer_idx,
862
+ 'adaptation': adaptation.mean().item(),
863
+ **layer_result
864
+ })
865
+
866
+ # Final normalization and output projection
867
+ x = self.output_norm(x)
868
+ logits = self.lm_head(x)
869
+ logits = make_safe(logits, min_val=-50, max_val=50)
870
+
871
+ # Compute cross-entropy loss if labels provided
872
+ loss = None
873
+ if labels is not None:
874
+ shift_logits = logits[..., :-1, :].contiguous()
875
+ shift_labels = labels[..., 1:].contiguous()
876
+ loss = F.cross_entropy(
877
+ shift_logits.view(-1, self.vocab_size),
878
+ shift_labels.view(-1),
879
+ ignore_index=-100
880
+ )
881
+
882
+ result = {
883
+ 'logits': logits,
884
+ 'loss': loss
885
+ }
886
+
887
+ if return_diagnostics:
888
+ result['layer_diagnostics'] = layer_diagnostics
889
+
890
+ return result
891
+
892
+ @torch.no_grad()
893
+ def generate(
894
+ self,
895
+ input_ids: torch.Tensor,
896
+ max_length: int = 100,
897
+ temperature: float = 1.0,
898
+ top_p: float = 0.95,
899
+ return_diagnostics: bool = False
900
+ ) -> Dict[str, Union[torch.Tensor, List[Dict]]]:
901
+ """Generate text using Liquid SSM with nucleus sampling.
902
+
903
+ Args:
904
+ input_ids: Prompt token IDs [batch_size, prompt_len]
905
+ max_length: Maximum total sequence length
906
+ temperature: Sampling temperature (higher = more random)
907
+ top_p: Nucleus sampling probability threshold
908
+ return_diagnostics: Whether to return generation diagnostics
909
+
910
+ Returns:
911
+ Dictionary containing generated IDs and optional diagnostics
912
+ """
913
+ self.eval()
914
+ generated = input_ids.clone()
915
+ all_diagnostics = [] if return_diagnostics else None
916
+
917
+ for step in range(max_length - input_ids.shape[1]):
918
+ # Stop if sequence exceeds maximum length
919
+ if generated.shape[1] > self.max_seq_len:
920
+ break
921
+
922
+ # Forward pass to get next token logits
923
+ outputs = self(generated, return_diagnostics=return_diagnostics)
924
+ logits = outputs['logits']
925
+
926
+ if return_diagnostics:
927
+ all_diagnostics.append(outputs.get('layer_diagnostics', []))
928
+
929
+ # Extract logits for next token prediction
930
+ next_token_logits = logits[:, -1, :] / max(temperature, EPS)
931
+ next_token_logits = make_safe(next_token_logits, min_val=-50, max_val=50)
932
+
933
+ # Nucleus (top-p) sampling
934
+ sorted_logits, sorted_indices = torch.sort(next_token_logits, descending=True)
935
+ cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
936
+
937
+ # Identify tokens to remove (cumulative probability > top_p)
938
+ sorted_indices_to_remove = cumulative_probs > top_p
939
+ sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
940
+ sorted_indices_to_remove[..., 0] = False
941
+
942
+ # Remove low-probability tokens
943
+ for b in range(next_token_logits.size(0)):
944
+ indices_to_remove = sorted_indices[b][sorted_indices_to_remove[b]]
945
+ next_token_logits[b, indices_to_remove] = -float('inf')
946
+
947
+ # Sample next token
948
+ probs = F.softmax(next_token_logits, dim=-1)
949
+ next_token = torch.multinomial(probs, num_samples=1)
950
+ next_token = torch.clamp(next_token, 0, self.vocab_size - 1)
951
+
952
+ # Append to generated sequence
953
+ generated = torch.cat([generated, next_token], dim=1)
954
+
955
+ # Stop on EOS token
956
+ if next_token.item() == 2: # Assuming token ID 2 is EOS
957
+ break
958
+
959
+ result = {'generated_ids': generated}
960
+ if return_diagnostics:
961
+ result['diagnostics'] = all_diagnostics
962
+
963
+ return result
964
+
965
+ ###########################################################################################################################################
966
+ ##############################################- - - LIQUID SSM DEMO + TESTING - - -####################################################
967
+
968
+ def test_liquid_ssm() -> bool:
969
+ print("Testing Liquid State Space Model - Continuous-Time Adaptive Sequence Processing")
970
+ print("=" * 90)
971
+
972
+ # Create Liquid SSM Language Model
973
+ vocab_size = 1000
974
+ d_model = 256
975
+ state_dim = 128
976
+ num_layers = 4
977
+
978
+ model = LiquidSSMLanguageModel(
979
+ vocab_size=vocab_size,
980
+ d_model=d_model,
981
+ state_dim=state_dim,
982
+ num_layers=num_layers,
983
+ max_seq_len=512
984
+ )
985
+
986
+ print(f"Created Liquid SSM Language Model:")
987
+ print(f" - Vocabulary size: {vocab_size}")
988
+ print(f" - Model dimension: {d_model}")
989
+ print(f" - State dimension: {state_dim}")
990
+ print(f" - Number of layers: {num_layers}")
991
+
992
+ # Count parameters
993
+ total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
994
+ print(f" - Total parameters: {total_params:,} ({total_params/1e6:.1f}M)")
995
+
996
+ # Test with sample data
997
+ batch_size = 4
998
+ seq_len = 32
999
+ test_input = torch.randint(0, vocab_size, (batch_size, seq_len))
1000
+ test_labels = torch.randint(0, vocab_size, (batch_size, seq_len))
1001
+
1002
+ print(f"\nTesting with batch_size={batch_size}, seq_len={seq_len}")
1003
+
1004
+ # Forward pass
1005
+ print("\nExecuting forward pass...")
1006
+ outputs = model(test_input, labels=test_labels, return_diagnostics=True)
1007
+
1008
+ print("Forward pass results:")
1009
+ print(f" - Output logits shape: {outputs['logits'].shape}")
1010
+ print(f" - Loss: {outputs['loss']:.4f}")
1011
+
1012
+ # Analyze liquid dynamics
1013
+ print("\nLiquid dynamics analysis:")
1014
+ diagnostics = outputs['layer_diagnostics']
1015
+
1016
+ for layer_idx in range(min(3, len(diagnostics))):
1017
+ layer_diag = diagnostics[layer_idx]
1018
+ print(f" Layer {layer_idx + 1}:")
1019
+ print(f" - Global adaptation: {layer_diag['adaptation']:.3f}")
1020
+
1021
+ if 'diagnostics' in layer_diag:
1022
+ time_constants = [d['time_constants'].mean().item() for d in layer_diag['diagnostics'][:3]]
1023
+ print(f" - Avg time constants: {[f'{tc:.3f}' for tc in time_constants]}")
1024
+
1025
+ # Test generation
1026
+ print("\nTesting text generation...")
1027
+ prompt = torch.randint(0, vocab_size, (1, 8))
1028
+ generation_result = model.generate(
1029
+ prompt,
1030
+ max_length=20,
1031
+ temperature=1.0,
1032
+ return_diagnostics=True
1033
+ )
1034
+
1035
+ generated_ids = generation_result['generated_ids']
1036
+ print(f" - Generated sequence length: {generated_ids.shape[1]}")
1037
+ print(f" - Prompt length: {prompt.shape[1]}")
1038
+ print(f" - New tokens generated: {generated_ids.shape[1] - prompt.shape[1]}")
1039
+
1040
+ # Test efficiency comparison
1041
+ print("\nEfficiency analysis:")
1042
+
1043
+ # Test different sequence lengths
1044
+ seq_lengths = [64, 128, 256]
1045
+ for test_len in seq_lengths:
1046
+ test_seq = torch.randint(0, vocab_size, (1, test_len))
1047
+
1048
+ import time
1049
+ start_time = time.time()
1050
+ with torch.no_grad():
1051
+ test_output = model(test_seq)
1052
+ end_time = time.time()
1053
+
1054
+ processing_time = end_time - start_time
1055
+ tokens_per_second = test_len / processing_time
1056
+
1057
+ print(f" - Length {test_len}: {processing_time:.3f}s ({tokens_per_second:.0f} tokens/s)")
1058
+
1059
+ print("\nLiquid SSM test completed!")
1060
+ print("✓ Continuous-time adaptive dynamics")
1061
+ print("✓ Learnable time constants based on content")
1062
+ print("✓ Efficient sequence processing")
1063
+ print("✓ State space model foundation with liquid adaptation")
1064
+ print("✓ Potential transformer alternative with continuous dynamics")
1065
+
1066
+ return True
1067
+
1068
+ def adaptive_dynamics_demo() -> None:
1069
+ print("\n" + "="*70)
1070
+ print("ADAPTIVE DYNAMICS DEMONSTRATION")
1071
+ print("="*70)
1072
+
1073
+ # Create simple model for demonstration
1074
+ model = LiquidSSMCore(state_dim=16, input_dim=8, output_dim=8)
1075
+ model.eval()
1076
+
1077
+ # Test patterns with different temporal characteristics
1078
+ patterns = {
1079
+ "Smooth": torch.sin(torch.linspace(0, 2*math.pi, 8)).unsqueeze(0),
1080
+ "Spiky": torch.tensor([0, 1, 0, -1, 0, 1, 0, -1], dtype=torch.float).unsqueeze(0),
1081
+ "Constant": torch.ones(1, 8) * 0.5,
1082
+ "Random": torch.randn(1, 8)
1083
+ }
1084
+
1085
+ print("Testing adaptive time constants with different input patterns:")
1086
+
1087
+ for pattern_name, pattern_input in patterns.items():
1088
+ model.reset_state(1)
1089
+
1090
+ # Process pattern through liquid dynamics
1091
+ with torch.no_grad():
1092
+ result = model(pattern_input, return_diagnostics=True)
1093
+
1094
+ time_constants = result['time_constants'].squeeze().tolist()
1095
+ adaptation_rate = result['adaptation_rate'].item()
1096
+
1097
+ print(f"\n{pattern_name} pattern:")
1098
+ print(f" Time constants: {[f'{tc:.3f}' for tc in time_constants[:4]]}...")
1099
+ print(f" Adaptation rate: {adaptation_rate:.4f}")
1100
+ print(f" Effective dt: {result['effective_dt']:.4f}")
1101
+
1102
+ print("\n Adaptive dynamics show how liquid SSM adjusts to different input characteristics")
1103
+ print(" Smooth inputs → larger time constants, Spiky inputs → smaller time constants")
1104
+
1105
+ if __name__ == "__main__":
1106
+ test_liquid_ssm()
1107
+ adaptive_dynamics_demo()