BrainChipInc NickMarkovsky commited on
Commit
c8c055f
·
0 Parent(s):

Duplicate from NickMarkovsky/tenns-llm-1b

Browse files

Co-authored-by: Nick Markovsky <NickMarkovsky@users.noreply.huggingface.co>

.gitattributes ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: cc-by-nc-4.0
3
+ language:
4
+ - en
5
+ library_name: transformers
6
+ tags:
7
+ - ssm
8
+ - causal-lm
9
+ - custom-architecture
10
+ - recurrent
11
+ pipeline_tag: text-generation
12
+ ---
13
+
14
+ # TENNs LLM 1B
15
+
16
+ A 1-billion-parameter causal language model built on gate-mode SSM (State Space Model) layers from [TENNs Core](https://huggingface.co/BrainChipInc/tenns-llm-1b/tree/main/tenns_core). Uses recurrent inference instead of attention, making it efficient for streaming and long-context generation.
17
+
18
+ ## Architecture
19
+
20
+ | Component | Details |
21
+ |-----------|---------|
22
+ | Layers | 24 × TENNsBlock (gate mode) |
23
+ | Hidden dim | 2048 |
24
+ | Inner dim | 4096 |
25
+ | Vocabulary | 32,000 (Mistral-7B tokenizer) |
26
+ | Parameters | ~1B |
27
+
28
+ Each TENNsBlock: `RMSNorm → in_proj → causal_conv(4) → SSM(gate) → out_proj → residual`
29
+
30
+ ## Quick Start (Google Colab / any environment)
31
+
32
+ ```python
33
+ !pip install transformers torch einops opt_einsum safetensors
34
+
35
+ from transformers import AutoModelForCausalLM, AutoTokenizer
36
+
37
+ tokenizer = AutoTokenizer.from_pretrained("BrainChipInc/tenns-llm-1b")
38
+ model = AutoModelForCausalLM.from_pretrained(
39
+ "BrainChipInc/tenns-llm-1b",
40
+ trust_remote_code=True,
41
+ )
42
+
43
+ output = model.generate_text("The history of artificial intelligence", tokenizer, max_new_tokens=100)
44
+ print(output)
45
+ ```
46
+
47
+ > **Do not use `pipeline()`** — this model uses a custom recurrent architecture that is not
48
+ > compatible with HuggingFace's standard text-generation pipeline.
49
+
50
+ ## Installation
51
+
52
+ ```bash
53
+ pip install transformers torch einops opt_einsum safetensors
54
+ ```
55
+
56
+ ## Usage
57
+
58
+ > **Note:** Do **not** use `pipeline()` — this model requires `model.generate_text()` instead of
59
+ > HuggingFace's standard `generate()`. The recurrent SSM architecture is not compatible with the
60
+ > attention KV-cache pipeline.
61
+
62
+ ```python
63
+ from transformers import AutoModelForCausalLM, AutoTokenizer
64
+
65
+ tokenizer = AutoTokenizer.from_pretrained("BrainChipInc/tenns-llm-1b")
66
+ model = AutoModelForCausalLM.from_pretrained(
67
+ "BrainChipInc/tenns-llm-1b",
68
+ trust_remote_code=True,
69
+ )
70
+
71
+ output = model.generate_text("The history of artificial intelligence", tokenizer, max_new_tokens=100)
72
+ print(output)
73
+ ```
74
+
75
+ ### Generation options
76
+
77
+ ```python
78
+ # Greedy decoding (default)
79
+ output = model.generate_text(prompt, tokenizer, max_new_tokens=50)
80
+
81
+ # Top-k sampling with temperature
82
+ output = model.generate_text(prompt, tokenizer, max_new_tokens=100, temperature=0.8, top_k=50)
83
+ ```
84
+
85
+ ## `trust_remote_code=True`
86
+
87
+ This model uses custom modeling code bundled in the repository
88
+ (`modeling_tenns_llm.py`, `configuration_tenns_llm.py`, `tenns_core/`).
89
+ Loading requires `trust_remote_code=True`. The bundled `tenns_core/` package
90
+ is a snapshot of the TENNs Core SSM library — no separate installation needed.
91
+
92
+ ## Training
93
+
94
+ Fine-tuned from a base TENNs gate-mode model using LoRA adapters on English instruction data.
95
+ LoRA adapters are merged into base weights at export time.
96
+
97
+ ## Limitations
98
+
99
+ - English only
100
+ - No system prompt or chat template — plain completion model
101
+ - Recurrent state resets between calls to `generate_text()`
config.json ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "model_type": "tenns_llm",
3
+ "auto_map": {
4
+ "AutoConfig": "configuration_tenns_llm.TennsLLMConfig",
5
+ "AutoModelForCausalLM": "modeling_tenns_llm.TennsLLMForCausalLM"
6
+ },
7
+ "vocab_size": 32000,
8
+ "channels": 2048,
9
+ "num_blocks": 24,
10
+ "num_coeffs": 16,
11
+ "repeat": 256,
12
+ "transformers_version": "4.40.0"
13
+ }
configuration_tenns_llm.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+
4
+ from transformers import PretrainedConfig
5
+
6
+ # Inject the repo directory into sys.path so the bundled tenns_core/ is
7
+ # importable without a pip install, both locally and when loaded from HF hub.
8
+ _HERE = os.path.dirname(os.path.abspath(__file__))
9
+ if _HERE not in sys.path:
10
+ sys.path.insert(0, _HERE)
11
+
12
+
13
+ class TennsLLMConfig(PretrainedConfig):
14
+ model_type = "tenns_llm"
15
+
16
+ def __init__(
17
+ self,
18
+ vocab_size=32000,
19
+ channels=2048,
20
+ num_blocks=24,
21
+ num_coeffs=16,
22
+ repeat=256,
23
+ **kwargs,
24
+ ):
25
+ super().__init__(**kwargs)
26
+ self.vocab_size = vocab_size
27
+ self.channels = channels
28
+ self.num_blocks = num_blocks
29
+ self.num_coeffs = num_coeffs
30
+ self.repeat = repeat
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:695805667bf74d3bb24b8fc0c676e75c26c21191ebe91326429d4f61e43740ff
3
+ size 4957835584
modeling_tenns_llm.py ADDED
@@ -0,0 +1,289 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import importlib
2
+ import os
3
+ import sys
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+ from torch.nn import RMSNorm
9
+
10
+ from transformers import PreTrainedModel
11
+ from transformers.modeling_outputs import CausalLMOutputWithPast
12
+
13
+ from configuration_tenns_llm import TennsLLMConfig
14
+
15
+ def _get_tenns_core_path():
16
+ """Return a directory that contains tenns_core/.
17
+
18
+ HF's from_pretrained only downloads the .py files listed in auto_map —
19
+ it does not download subdirectories like tenns_core/. We use
20
+ snapshot_download (with local cache) to ensure tenns_core/ is present.
21
+ The first call downloads it; subsequent calls are instant cache hits.
22
+ """
23
+ # Derive the repo_id from __file__ path in the HF modules cache:
24
+ # .../modules/transformers_modules/ORG/REPO_SLUG/HASH/modeling_tenns_llm.py
25
+ here = os.path.dirname(os.path.abspath(__file__))
26
+ parts = here.replace("\\", "/").split("/")
27
+ try:
28
+ idx = next(i for i, p in enumerate(parts) if p == "transformers_modules")
29
+ org_id = parts[idx + 1].replace("_hyphen_", "-")
30
+ repo_id = parts[idx + 2].replace("_hyphen_", "-")
31
+ except (StopIteration, IndexError):
32
+ return here # not in HF cache — assume tenns_core/ is next to this file
33
+
34
+ from huggingface_hub import snapshot_download
35
+ snapshot = snapshot_download(
36
+ f"{org_id}/{repo_id}",
37
+ allow_patterns=["tenns_core/**"],
38
+ )
39
+ return snapshot
40
+
41
+
42
+ _tenns_core_dir = _get_tenns_core_path()
43
+ if _tenns_core_dir not in sys.path:
44
+ sys.path.insert(0, _tenns_core_dir)
45
+
46
+ _tc = importlib.import_module("tenns_core")
47
+ _rc = importlib.import_module("tenns_core.recurrent_ops")
48
+ SSMLayer = _tc.SSMLayer
49
+ recurrent_gate = _rc.recurrent_gate
50
+
51
+
52
+ # ============================================================================
53
+ # Model Components (from tenns_llm.py)
54
+ # ============================================================================
55
+
56
+
57
+ class CausalConvDwFast(nn.Module):
58
+ """Holds depthwise causal convolution weights for TENNs blocks."""
59
+ def __init__(self, coeffs, kernel_size):
60
+ super().__init__()
61
+ self.weight = nn.Parameter(torch.rand(kernel_size, coeffs))
62
+
63
+
64
+ class PassthroughConv(nn.Module):
65
+ """Applies causal convolution via FIFO buffer for streaming inference."""
66
+ def __init__(self, causal_conv, d_inner):
67
+ super().__init__()
68
+ self.causal_conv = causal_conv
69
+ self.d_inner = d_inner
70
+ self.fifo = None
71
+
72
+ def apply_conv(self, x):
73
+ """Apply causal convolution. x: (B, T, C) -> (B, T, C)"""
74
+ B, T, C = x.shape
75
+
76
+ if self.fifo is None or self.fifo.shape[0] != B:
77
+ self.fifo = torch.zeros(B, C, 4, device=x.device, dtype=x.dtype)
78
+
79
+ conv_weight = self.causal_conv.weight.squeeze().T # (C, 4)
80
+
81
+ x_conv = []
82
+ for t in range(T):
83
+ self.fifo = self.fifo.roll(-1, dims=-1)
84
+ self.fifo[:, :, -1] = x[:, t, :]
85
+ x_t = (self.fifo * conv_weight).sum(-1)
86
+ x_conv.append(x_t)
87
+
88
+ x_conv = torch.stack(x_conv, dim=1)
89
+ x_conv = F.silu(x_conv)
90
+ return x_conv
91
+
92
+ def reset_states(self):
93
+ if self.fifo is not None:
94
+ self.fifo.zero_()
95
+
96
+
97
+ class TENNsBlock(nn.Module):
98
+ """TENNs block with gate-mode SSM for LLM inference."""
99
+ def __init__(self, channels, num_coeffs, repeat, mode='gate'):
100
+ super().__init__()
101
+ d_inner = channels * 2
102
+ self.d_inner = d_inner
103
+
104
+ self.pre_norm = RMSNorm(channels, elementwise_affine=True)
105
+ self.pre_conv = CausalConvDwFast(d_inner, 4)
106
+ self.in_proj = nn.Linear(channels, d_inner * 2, bias=True)
107
+ self.out_proj = nn.Linear(d_inner, channels, bias=True)
108
+
109
+ self.ssm_layer = SSMLayer(num_coeffs, d_inner, d_inner,
110
+ repeat=repeat, mode=mode, transposed=True)
111
+
112
+ self.ssm_layer.register_buffer('state_lora', torch.zeros(d_inner))
113
+
114
+ self.D = nn.Parameter(torch.ones(d_inner, dtype=torch.float))
115
+
116
+ self._conv_handler = None
117
+ self.state = None
118
+
119
+ def forward(self, input):
120
+ x = self.pre_norm(input)
121
+ x_and_res = self.in_proj(x)
122
+ x, res = x_and_res.split([self.d_inner, self.d_inner], -1)
123
+
124
+ if self._conv_handler is None:
125
+ self._conv_handler = PassthroughConv(self.pre_conv, self.d_inner)
126
+
127
+ x_conv = self._conv_handler.apply_conv(x)
128
+
129
+ state = self.state
130
+ if state is None:
131
+ state = self.ssm_layer.state_lora
132
+
133
+ y, self.state = recurrent_gate(
134
+ x_conv,
135
+ self.ssm_layer.A,
136
+ self.ssm_layer.B,
137
+ self.ssm_layer.C,
138
+ self.ssm_layer.log_dt,
139
+ state
140
+ )
141
+
142
+ y = y.transpose(1, 2)
143
+ y = y + self.D * x_conv
144
+ output = self.out_proj(y * F.silu(res))
145
+
146
+ return input + output
147
+
148
+ def reset_states(self):
149
+ if self._conv_handler is not None:
150
+ self._conv_handler.reset_states()
151
+ self.state = None
152
+
153
+
154
+ class TENNsLLM(nn.Module):
155
+ """TENNs-based language model for autoregressive text generation."""
156
+ def __init__(self, vocab_size=32000, channels=2048, num_blocks=24,
157
+ num_coeffs=16, repeat=256):
158
+ super().__init__()
159
+ self.channels = channels
160
+ self.embedding = nn.Embedding(vocab_size, channels)
161
+ self.backbone = nn.Sequential(
162
+ *[TENNsBlock(channels, num_coeffs, repeat, mode='gate')
163
+ for _ in range(num_blocks)]
164
+ )
165
+ self.head = nn.Sequential(
166
+ RMSNorm(channels, elementwise_affine=False),
167
+ nn.Linear(channels, vocab_size, bias=False),
168
+ )
169
+
170
+ def forward(self, tokens):
171
+ x = self.embedding(tokens)
172
+ x = self.backbone(x)
173
+ return self.head(x)
174
+
175
+ def reset_states(self):
176
+ for module in self.modules():
177
+ if isinstance(module, TENNsBlock):
178
+ module.reset_states()
179
+
180
+
181
+ # ============================================================================
182
+ # HuggingFace wrapper
183
+ # ============================================================================
184
+
185
+
186
+ class TennsLLMForCausalLM(PreTrainedModel):
187
+ """HuggingFace PreTrainedModel wrapper for TENNsLLM.
188
+
189
+ Load with:
190
+ from transformers import AutoModelForCausalLM, AutoTokenizer
191
+ model = AutoModelForCausalLM.from_pretrained(
192
+ "aliborji/tenns-llm-1b", trust_remote_code=True
193
+ )
194
+ tokenizer = AutoTokenizer.from_pretrained("aliborji/tenns-llm-1b")
195
+
196
+ Generate with:
197
+ output = model.generate_text("Hello, world!", tokenizer, max_new_tokens=50)
198
+ print(output)
199
+
200
+ Note: This model uses recurrent SSM states. Use generate_text() rather than
201
+ model.generate(), which is designed for attention-based KV-cache models.
202
+ """
203
+ config_class = TennsLLMConfig
204
+ # Weights are saved without a 'model.' prefix — flatten components directly
205
+ # onto this class so state dict keys match the safetensors file exactly.
206
+ _tied_weights_keys = []
207
+
208
+ @property
209
+ def all_tied_weights_keys(self):
210
+ return {}
211
+
212
+ def __init__(self, config: TennsLLMConfig):
213
+ super().__init__(config)
214
+ # Assign TENNsLLM components directly (not as self.model) so that
215
+ # state dict keys match the safetensors: embedding.weight, backbone.0...
216
+ _backbone = TENNsLLM(
217
+ vocab_size=config.vocab_size,
218
+ channels=config.channels,
219
+ num_blocks=config.num_blocks,
220
+ num_coeffs=config.num_coeffs,
221
+ repeat=config.repeat,
222
+ )
223
+ self.embedding = _backbone.embedding
224
+ self.backbone = _backbone.backbone
225
+ self.head = _backbone.head
226
+
227
+ def _reset_states(self):
228
+ for module in self.modules():
229
+ if isinstance(module, TENNsBlock):
230
+ module.reset_states()
231
+
232
+ def forward(self, input_ids, **kwargs):
233
+ x = self.embedding(input_ids)
234
+ x = self.backbone(x)
235
+ logits = self.head(x)
236
+ return CausalLMOutputWithPast(logits=logits)
237
+
238
+ @torch.no_grad()
239
+ def generate_text(self, prompt, tokenizer, max_new_tokens=50,
240
+ temperature=1.0, top_k=None):
241
+ """Autoregressive text generation.
242
+
243
+ Args:
244
+ prompt: Input text string
245
+ tokenizer: HuggingFace tokenizer
246
+ max_new_tokens: Maximum number of tokens to generate
247
+ temperature: Sampling temperature (lower = more deterministic)
248
+ top_k: If set, sample from top-k tokens; otherwise greedy argmax
249
+
250
+ Returns:
251
+ Generated text string (not including the prompt)
252
+ """
253
+ self.eval()
254
+ self._reset_states()
255
+
256
+ input_ids = tokenizer(prompt, return_tensors='pt',
257
+ add_special_tokens=False)['input_ids'].squeeze()
258
+ input_ids = input_ids.to(self.device)
259
+
260
+ # Ingest prompt tokens
261
+ for token in input_ids:
262
+ logits = self.forward(token.view(1, 1)).logits
263
+ probs = F.softmax(logits[0, -1], dim=-1)
264
+ next_token = torch.argmax(probs).item()
265
+
266
+ # Autoregressive generation
267
+ output_ids = []
268
+ token = next_token
269
+ for _ in range(max_new_tokens):
270
+ logits = self.forward(torch.tensor([[token]], device=self.device)).logits
271
+ next_logits = logits[0, -1]
272
+
273
+ if temperature != 1.0:
274
+ next_logits = next_logits / temperature
275
+
276
+ if top_k is not None:
277
+ v, _ = torch.topk(next_logits, top_k)
278
+ next_logits[next_logits < v[-1]] = float('-inf')
279
+
280
+ probs = F.softmax(next_logits, dim=-1)
281
+ token = (torch.multinomial(probs, 1).item() if top_k is not None
282
+ else torch.argmax(probs).item())
283
+
284
+ if token == tokenizer.eos_token_id:
285
+ break
286
+
287
+ output_ids.append(token)
288
+
289
+ return tokenizer.decode(output_ids)
tenns_core/__init__.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ TENNs Core: Efficient State Space Models for Sequence Modeling
3
+
4
+ A standalone library providing various SSM (State Space Model) architectures
5
+ for deep learning on sequences. Includes S5, DWS, Neck, Full, and Gate modes
6
+ all implemented in pure PyTorch.
7
+
8
+ Quick Start - Training:
9
+ ----------------------
10
+ >>> from tenns_core import SSMLayer
11
+ >>> import torch
12
+ >>>
13
+ >>> # Create S5-mode SSM layer
14
+ >>> layer = SSMLayer(
15
+ ... num_coeffs=64,
16
+ ... in_channels=128,
17
+ ... out_channels=256,
18
+ ... mode='s5',
19
+ ... norm='layer',
20
+ ... postact='gelu'
21
+ ... )
22
+ >>>
23
+ >>> # Forward pass (training mode - FFT convolution)
24
+ >>> x = torch.randn(4, 128, 512) # (batch, channels, length)
25
+ >>> y = layer(x) # (4, 256, 512)
26
+
27
+ Quick Start - Streaming Inference:
28
+ ----------------------------------
29
+ >>> # Convert trained model to streaming inference
30
+ >>> infer_layer = layer.to_inference()
31
+ >>>
32
+ >>> # Process audio stream chunk-by-chunk
33
+ >>> for chunk in audio_stream:
34
+ >>> output = infer_layer(chunk) # State maintained automatically
35
+ >>>
36
+ >>> # Reset state between utterances
37
+ >>> infer_layer.reset_state()
38
+ """
39
+
40
+ from importlib.metadata import PackageNotFoundError, version
41
+
42
+ from .inference import SSMLayerInference
43
+ from .ssm import Kernelizer, SSMLayer
44
+
45
+ try:
46
+ __version__ = version('tenns-core')
47
+ except PackageNotFoundError:
48
+ __version__ = '0.0.0+unknown'
49
+
50
+ __all__ = ['Kernelizer', 'SSMLayer', 'SSMLayerInference']
tenns_core/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (1.42 kB). View file
 
tenns_core/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (1.65 kB). View file
 
tenns_core/__pycache__/activations.cpython-310.pyc ADDED
Binary file (4.51 kB). View file
 
tenns_core/__pycache__/activations.cpython-312.pyc ADDED
Binary file (6.34 kB). View file
 
tenns_core/__pycache__/fft_ops.cpython-310.pyc ADDED
Binary file (5.04 kB). View file
 
tenns_core/__pycache__/fft_ops.cpython-312.pyc ADDED
Binary file (8.45 kB). View file
 
tenns_core/__pycache__/inference.cpython-310.pyc ADDED
Binary file (14.7 kB). View file
 
tenns_core/__pycache__/inference.cpython-312.pyc ADDED
Binary file (24.4 kB). View file
 
tenns_core/__pycache__/recurrent_ops.cpython-310.pyc ADDED
Binary file (16 kB). View file
 
tenns_core/__pycache__/recurrent_ops.cpython-312.pyc ADDED
Binary file (14.8 kB). View file
 
tenns_core/__pycache__/scan_ops.cpython-310.pyc ADDED
Binary file (4.68 kB). View file
 
tenns_core/__pycache__/scan_ops.cpython-312.pyc ADDED
Binary file (20.3 kB). View file
 
tenns_core/__pycache__/ssm.cpython-310.pyc ADDED
Binary file (12.6 kB). View file
 
tenns_core/__pycache__/ssm.cpython-312.pyc ADDED
Binary file (20.9 kB). View file
 
tenns_core/activations.py ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Activation, normalization, and dropout utilities for SSM layers.
3
+
4
+ Extracted from tenns.models.utils to provide activation layer construction.
5
+ """
6
+
7
+ from torch import nn
8
+ from torch.nn import RMSNorm
9
+
10
+
11
+ class LayerNormFeature(nn.LayerNorm):
12
+ """LayerNorm that operates on the feature dimension (dim=-2) instead of time (dim=-1)."""
13
+
14
+ def forward(self, input):
15
+ return super().forward(input.moveaxis(-1, -2)).moveaxis(-1, -2)
16
+
17
+
18
+ class RmsNormFeature(nn.Module):
19
+ """RMSNorm that operates on the feature dimension (dim=-2) instead of time (dim=-1)."""
20
+
21
+ def __init__(self, features):
22
+ super().__init__()
23
+ self.rms_norm = RMSNorm(features)
24
+
25
+ def forward(self, input):
26
+ return self.rms_norm(input.moveaxis(-1, -2)).moveaxis(-1, -2)
27
+
28
+
29
+ def get_norm(norm, num_features, ndim=2):
30
+ """Get normalization layer by name.
31
+
32
+ Args:
33
+ norm: Normalization type ('batch', 'layer', 'layer-feature', 'rms', None)
34
+ num_features: Number of features/channels
35
+ ndim: Number of dimensions (1, 2, or 3)
36
+
37
+ Returns:
38
+ Normalization layer module
39
+ """
40
+ match norm:
41
+ case 'batch':
42
+ match ndim:
43
+ case 1:
44
+ return nn.BatchNorm1d(num_features)
45
+ case 2:
46
+ return nn.BatchNorm2d(num_features)
47
+ case 3:
48
+ return nn.BatchNorm3d(num_features)
49
+ case _:
50
+ raise ValueError(f'Invalid dimensions: {ndim}')
51
+
52
+ case 'layer':
53
+ return nn.LayerNorm(num_features)
54
+
55
+ case 'layer-feature':
56
+ if num_features > 1:
57
+ return LayerNormFeature(num_features)
58
+ else:
59
+ return nn.Identity()
60
+
61
+ case 'rms':
62
+ if num_features > 1:
63
+ return RmsNormFeature(num_features)
64
+ else:
65
+ return nn.Identity()
66
+
67
+ case None:
68
+ return nn.Identity()
69
+
70
+ case _:
71
+ raise ValueError(f'Invalid normalization type: {norm}')
72
+
73
+
74
+ def get_postact(postact):
75
+ """Get activation function by name.
76
+
77
+ Args:
78
+ postact: Activation type ('relu', 'gelu', 'silu', etc., or None)
79
+
80
+ Returns:
81
+ Activation function module
82
+ """
83
+ if postact is None:
84
+ return nn.Identity()
85
+
86
+ postact_registry = {
87
+ 'relu': nn.ReLU(),
88
+ 'relu6': nn.ReLU6(),
89
+ 'lelu': nn.LeakyReLU(0.1),
90
+ 'sigmoid': nn.Sigmoid(),
91
+ 'tanh': nn.Tanh(),
92
+ 'gelu': nn.GELU(),
93
+ 'glu': nn.GLU(dim=1),
94
+ 'silu': nn.SiLU(),
95
+ }
96
+
97
+ if postact in postact_registry:
98
+ return postact_registry[postact]
99
+ else:
100
+ raise ValueError(f'Invalid activation name: {postact}')
101
+
102
+
103
+ def get_dropout(p, dropout_dim, num_features):
104
+ """Get dropout layer by dimension.
105
+
106
+ Args:
107
+ p: Dropout probability (None for no dropout)
108
+ dropout_dim: Dimension of dropout (0 for standard, 1 for 1d, etc.)
109
+ num_features: Number of features (used to determine if dropout should be applied)
110
+
111
+ Returns:
112
+ Dropout module
113
+ """
114
+ if p is None:
115
+ return nn.Identity()
116
+
117
+ dropout_registry = {
118
+ 0: nn.Dropout,
119
+ 1: nn.Dropout1d,
120
+ 2: nn.Dropout2d,
121
+ 3: nn.Dropout3d,
122
+ }
123
+
124
+ if dropout_dim in dropout_registry:
125
+ # Only apply dropout if we have enough features
126
+ if dropout_dim == 0 or num_features >= 16:
127
+ return dropout_registry[dropout_dim](p)
128
+ else:
129
+ return nn.Identity()
130
+ else:
131
+ raise ValueError(f'Invalid dropout dimension: {dropout_dim}')
132
+
133
+
134
+ def get_activations(ndim, num_features, norm=None, postact=None, p=None, dropout_dim=0):
135
+ """Build a sequential module with normalization, activation, and dropout.
136
+
137
+ Args:
138
+ ndim: Number of dimensions (1, 2, or 3)
139
+ num_features: Number of features/channels
140
+ norm: Normalization type (None, 'batch', 'layer', 'layer-feature', 'rms')
141
+ postact: Activation function type (None, 'relu', 'gelu', 'silu', etc.)
142
+ p: Dropout probability (None for no dropout)
143
+ dropout_dim: Dimension of dropout (0, 1, 2, or 3)
144
+
145
+ Returns:
146
+ Sequential module combining norm, activation, and dropout
147
+ """
148
+ if (norm is None) and (postact is None) and (p is None):
149
+ return nn.Identity()
150
+
151
+ activations = nn.Sequential()
152
+ if norm is not None:
153
+ activations.append(get_norm(norm, num_features, ndim))
154
+ if postact is not None:
155
+ activations.append(get_postact(postact))
156
+ if p is not None:
157
+ activations.append(get_dropout(p, dropout_dim, num_features))
158
+ return activations
tenns_core/fft_ops.py ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ FFT-based convolution operations for SSM layers.
3
+
4
+ This module provides optimized FFT convolution operations used in SSM training,
5
+ combining functionality from fft_utils.py and fft_utils_opt.py.
6
+ """
7
+
8
+ import torch
9
+ from torch.amp import custom_bwd, custom_fwd
10
+
11
+
12
+ class PaddedFFTConv(torch.autograd.Function):
13
+ """Custom autograd function for padded FFT convolution with efficient gradients.
14
+
15
+ Supports both depthwise ('dw') and full ('full') convolution modes.
16
+ """
17
+
18
+ @staticmethod
19
+ @torch.compiler.disable
20
+ @custom_fwd(device_type='cuda', cast_inputs=torch.float32)
21
+ def forward(ctx, u, k, n, mode, is_complex=False):
22
+ """
23
+ Args:
24
+ u: Input tensor
25
+ k: Kernel tensor
26
+ n: Sequence length
27
+ mode: 'dw' for depthwise or 'full' for full convolution
28
+ is_complex: Whether to use complex FFT
29
+ """
30
+ if is_complex:
31
+ uf = torch.fft.fft(u, 2 * n)
32
+ kf = torch.fft.fft(k, 2 * n)
33
+ else:
34
+ uf = torch.fft.rfft(u, 2 * n)
35
+ kf = torch.fft.rfft(k, 2 * n)
36
+
37
+ if mode == 'dw':
38
+ yf = uf * kf
39
+ elif mode == 'full':
40
+ yf = torch.einsum('bcl,dcl->bdl', uf, kf)
41
+
42
+ ctx.is_complex = is_complex
43
+ ctx.mode = mode
44
+ ctx.n = n
45
+ ctx.save_for_backward(u, k)
46
+
47
+ if is_complex:
48
+ return torch.fft.ifft(yf)[..., :n]
49
+ else:
50
+ return torch.fft.irfft(yf)[..., :n]
51
+
52
+ @staticmethod
53
+ @torch.compiler.disable
54
+ @custom_bwd(device_type='cuda')
55
+ def backward(ctx, grad_output):
56
+ is_complex = ctx.is_complex
57
+ mode = ctx.mode
58
+ n = ctx.n
59
+ u, k = ctx.saved_tensors
60
+
61
+ if is_complex:
62
+ uf = torch.fft.fft(u, 2 * n)
63
+ kf = torch.fft.fft(k, 2 * n)
64
+ grad_yf = torch.fft.fft(grad_output, 2 * n)
65
+ else:
66
+ uf = torch.fft.rfft(u, 2 * n)
67
+ kf = torch.fft.rfft(k, 2 * n)
68
+ grad_yf = torch.fft.rfft(grad_output, 2 * n)
69
+
70
+ if mode == 'dw':
71
+ grad_uf = grad_yf * torch.conj(kf)
72
+ elif mode == 'full':
73
+ grad_uf = torch.einsum('bdl,dcl->bcl', grad_yf, torch.conj(kf))
74
+
75
+ if is_complex:
76
+ grad_u = torch.fft.ifft(grad_uf, 2 * n)[..., :n]
77
+ else:
78
+ grad_u = torch.fft.irfft(grad_uf, 2 * n)[..., :n]
79
+
80
+ if mode == 'dw':
81
+ grad_kf = torch.einsum('bnl,bnl->nl', grad_yf, torch.conj(uf))
82
+ elif mode == 'full':
83
+ grad_kf = torch.einsum('bdl,bcl->dcl', grad_yf, torch.conj(uf))
84
+
85
+ if is_complex:
86
+ grad_k = torch.fft.ifft(grad_kf, 2 * n)[..., :n]
87
+ else:
88
+ grad_k = torch.fft.irfft(grad_kf, 2 * n)[..., :n]
89
+
90
+ return grad_u, grad_k, None, None, None
91
+
92
+
93
+ def _K(dtA_real, dtA_imag, length, weight=None, dim=-2, complex_proj=False, l_shift=0):
94
+ """Generate SSM convolution kernel from discretized state matrix.
95
+
96
+ Args:
97
+ dtA_real: Real part of discretized state matrix diagonal
98
+ dtA_imag: Imaginary part of discretized state matrix diagonal
99
+ length: Sequence length
100
+ weight: Optional weight matrix to apply
101
+ dim: Dimension to reduce over if weight is provided
102
+ complex_proj: Whether to use complex projection
103
+ l_shift: Shift amount for the range
104
+
105
+ Returns:
106
+ SSM convolution kernel of shape (..., length)
107
+ """
108
+ device = dtA_real.device
109
+ lrange = torch.arange(l_shift, length + l_shift, device=device)
110
+
111
+ with torch.autocast('cuda', enabled=False):
112
+ dtA_real, dtA_imag = dtA_real.float(), dtA_imag.float()
113
+ if complex_proj:
114
+ K = (torch.complex(dtA_real, dtA_imag)[..., None] * lrange).exp()
115
+ else:
116
+ K = (dtA_real[..., None] * lrange).exp() * torch.cos(dtA_imag[..., None] * lrange)
117
+
118
+ if weight is not None:
119
+ return (K * weight[..., None]).sum(dim)
120
+ else:
121
+ return K
122
+
123
+
124
+ def _full_k(dtA_real, dtA_imag, B, C, E, length):
125
+ """Generate full SSM kernel by combining B, C, and state kernel.
126
+
127
+ Used for optimizing s5/neck mode when full kernel is more efficient.
128
+ """
129
+ K = _K(dtA_real, dtA_imag, length, weight=E)
130
+ return (B[..., None] * C[..., None, None] * K[:, None, :]).sum(1)
131
+
132
+
133
+ def padded_fft_conv_opt(input, dtA_real, dtA_imag, B, C, E):
134
+ """Optimized padded FFT convolution for SSM layers.
135
+
136
+ Automatically chooses between naive and optimized contraction based on
137
+ tensor shapes to minimize computation.
138
+
139
+ Args:
140
+ input: Input tensor of shape (batch, in_channels, length)
141
+ dtA_real: Real part of discretized A matrix
142
+ dtA_imag: Imaginary part of discretized A matrix
143
+ B: Input projection matrix (None for dws/full modes)
144
+ C: Output projection matrix (None for dws/full modes)
145
+ E: State projection matrix (None for s5/neck modes)
146
+
147
+ Returns:
148
+ Output tensor of shape (batch, out_channels, length)
149
+ """
150
+ batch, chin, length = input.shape
151
+
152
+ # DWS/Full mode: no B/C matrices
153
+ if B is None:
154
+ K = _K(dtA_real, dtA_imag, length, weight=E)
155
+ if K.ndim == 3:
156
+ return PaddedFFTConv.apply(input, K, length, 'full', False)
157
+ elif K.ndim == 2:
158
+ return PaddedFFTConv.apply(input, K, length, 'dw', False)
159
+
160
+ # S5/Neck mode: has B/C matrices
161
+ chout, coeffs = C.shape
162
+
163
+ # Choose contraction order based on efficiency
164
+ # Compare cost of: (1) fusing B,C,K vs (2) separate contractions
165
+ if (1 / chin + 1 / chout) > (1 / batch + 1 / coeffs):
166
+ # Fuse full kernel and apply single convolution
167
+ kernel = _full_k(dtA_real, dtA_imag, B, C, E, length)
168
+ return PaddedFFTConv.apply(input, kernel, length, 'full', False)
169
+ else:
170
+ # Separate: project input, convolve, then project output
171
+ K = _K(dtA_real, dtA_imag, length, weight=E)
172
+ x = torch.einsum('bcl,nc->bnl', input, B)
173
+ x = PaddedFFTConv.apply(x, K, length, 'dw', False)
174
+ return torch.einsum('bnl,dn->bdl', x, C)
tenns_core/inference.py ADDED
@@ -0,0 +1,540 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Inference mode for SSM layers.
3
+
4
+ Provides streaming/online inference with stateful processing for real-time applications.
5
+ """
6
+
7
+ import torch
8
+ from torch import nn
9
+
10
+ from .recurrent_ops import (
11
+ discretize_dws,
12
+ discretize_full,
13
+ discretize_neck,
14
+ discretize_s5,
15
+ recurrent_gate,
16
+ recurrent_gate_single_step,
17
+ step_dws,
18
+ step_full,
19
+ step_neck,
20
+ step_s5,
21
+ )
22
+
23
+
24
+ class SSMLayerInference(nn.Module):
25
+ """Streaming inference wrapper for SSMLayer.
26
+
27
+ Provides stateful recurrent inference for real-time applications.
28
+ Maintains internal state across chunks for low-latency streaming.
29
+
30
+ Discretization (Ad, B_hat, etc.) is precomputed once at construction time
31
+ from the raw SSM parameters, so only the per-timestep step function runs
32
+ during forward passes.
33
+
34
+ Example:
35
+ >>> # After training
36
+ >>> train_layer = SSMLayer(64, 128, 256, mode='s5')
37
+ >>> # ... training ...
38
+ >>>
39
+ >>> # Convert to inference mode
40
+ >>> infer_layer = SSMLayerInference.from_training(train_layer)
41
+ >>>
42
+ >>> # Process streaming chunks (state maintained automatically)
43
+ >>> for chunk in audio_stream:
44
+ >>> output = infer_layer(chunk)
45
+ >>>
46
+ >>> # Reset state when starting new utterance
47
+ >>> infer_layer.reset_state()
48
+
49
+ Note:
50
+ - Inference mode uses sequential scan (O(T) per chunk)
51
+ - Training mode uses FFT (O(T log T) for full sequence)
52
+ - For streaming, inference mode has lower latency
53
+ - For batch processing full sequences, training mode is faster
54
+ """
55
+
56
+ def __init__(self, mode, in_channels, out_channels, **kwargs):
57
+ """Initialize inference layer.
58
+
59
+ Args:
60
+ mode: SSM mode ('s5', 'dws', 'neck', 'full', 'gate')
61
+ in_channels: Number of input channels
62
+ out_channels: Number of output channels
63
+ **kwargs: Mode-specific parameters (Ad, B_hat, C, dt, B, E, A, log_dt, mixer, etc.)
64
+ """
65
+ super().__init__()
66
+ self.mode = mode
67
+ self.in_channels = in_channels
68
+ self.out_channels = out_channels
69
+
70
+ if mode == 's5':
71
+ self.register_buffer('Ad', kwargs['Ad'])
72
+ self.register_buffer('B_hat', kwargs['B_hat'])
73
+ self.register_buffer('C', kwargs['C'])
74
+ elif mode == 'dws':
75
+ self.register_buffer('Ad', kwargs['Ad'])
76
+ self.register_buffer('B_hat', kwargs['B_hat'])
77
+ elif mode == 'neck':
78
+ self.register_buffer('Ad', kwargs['Ad'])
79
+ self.register_buffer('dt', kwargs['dt'])
80
+ self.register_buffer('B', kwargs['B'])
81
+ self.register_buffer('C', kwargs['C'])
82
+ self.register_buffer('E', kwargs['E'])
83
+ elif mode == 'full':
84
+ self.register_buffer('Ad', kwargs['Ad'])
85
+ self.register_buffer('B_hat', kwargs['B_hat'])
86
+ elif mode == 'gate':
87
+ # Gate mode: input-dependent discretization, store raw params
88
+ self.register_buffer('A', kwargs['A'])
89
+ self.B = kwargs['B'] # nn.Module
90
+ self.C = kwargs['C'] # nn.Module
91
+ self.log_dt = kwargs['log_dt'] # nn.Module
92
+ else:
93
+ raise ValueError(f'Unknown mode: {mode}')
94
+
95
+ # Mixer module (for DWS mode to project channels)
96
+ self.mixer = kwargs.get('mixer') or nn.Identity()
97
+
98
+ # Internal state
99
+ self.state = None
100
+
101
+ @classmethod
102
+ def from_training(cls, ssm_layer):
103
+ """Create inference layer from trained SSMLayer.
104
+
105
+ Args:
106
+ ssm_layer: Trained SSMLayer instance
107
+
108
+ Returns:
109
+ SSMLayerInference instance with precomputed discretized weights
110
+
111
+ Example:
112
+ >>> train_layer = SSMLayer(64, 128, 256, mode='s5')
113
+ >>> infer_layer = SSMLayerInference.from_training(train_layer)
114
+ """
115
+ mode = ssm_layer.mode
116
+ kwargs = {}
117
+
118
+ if mode == 's5':
119
+ Ad, B_hat = discretize_s5(
120
+ ssm_layer.A.detach().clone(),
121
+ ssm_layer.B.detach().clone(),
122
+ ssm_layer.log_dt.detach().clone(),
123
+ )
124
+ kwargs['Ad'] = Ad
125
+ kwargs['B_hat'] = B_hat
126
+ kwargs['C'] = ssm_layer.C.detach().clone()
127
+
128
+ elif mode == 'dws':
129
+ Ad, B_hat = discretize_dws(
130
+ ssm_layer.A.detach().clone(),
131
+ ssm_layer.E.detach().clone(),
132
+ ssm_layer.log_dt.detach().clone(),
133
+ )
134
+ kwargs['Ad'] = Ad
135
+ kwargs['B_hat'] = B_hat
136
+ kwargs['mixer'] = ssm_layer.mixer
137
+
138
+ elif mode == 'neck':
139
+ Ad, dt = discretize_neck(
140
+ ssm_layer.A.detach().clone(),
141
+ ssm_layer.log_dt.detach().clone(),
142
+ )
143
+ kwargs['Ad'] = Ad
144
+ kwargs['dt'] = dt
145
+ kwargs['B'] = ssm_layer.B.detach().clone()
146
+ kwargs['C'] = ssm_layer.C.detach().clone()
147
+ kwargs['E'] = ssm_layer.E.detach().clone()
148
+
149
+ elif mode == 'full':
150
+ Ad, B_hat = discretize_full(
151
+ ssm_layer.A.detach().clone(),
152
+ ssm_layer.E.detach().clone(),
153
+ ssm_layer.log_dt.detach().clone(),
154
+ )
155
+ kwargs['Ad'] = Ad
156
+ kwargs['B_hat'] = B_hat
157
+
158
+ elif mode == 'gate':
159
+ kwargs['A'] = ssm_layer.A.detach().clone()
160
+ kwargs['B'] = ssm_layer.B
161
+ kwargs['C'] = ssm_layer.C
162
+ kwargs['log_dt'] = ssm_layer.log_dt
163
+ kwargs['mixer'] = ssm_layer.mixer
164
+
165
+ else:
166
+ raise ValueError(f'Unknown mode: {mode}')
167
+
168
+ return cls(
169
+ mode=mode,
170
+ in_channels=ssm_layer.in_channels,
171
+ out_channels=ssm_layer.out_channels,
172
+ **kwargs,
173
+ )
174
+
175
+ def forward(self, input, return_state=False):
176
+ """Forward pass with stateful processing.
177
+
178
+ Args:
179
+ input: Input tensor of shape (B, C, T) or (C, T) for single sample
180
+ return_state: If True, return (output, state) tuple
181
+
182
+ Returns:
183
+ output: Output tensor of shape (B, D, T) or (D, T)
184
+ state (optional): Internal state if return_state=True
185
+
186
+ Note:
187
+ State is maintained internally across calls. Use reset_state()
188
+ to clear it.
189
+ """
190
+ # Handle input format
191
+ squeeze_batch = False
192
+ if input.dim() == 2:
193
+ input = input.unsqueeze(0) # (C, T) -> (1, C, T)
194
+ squeeze_batch = True
195
+
196
+ B_batch, _C, T = input.shape
197
+ # Transpose to (B, T, C) for step functions
198
+ input = input.transpose(1, 2)
199
+
200
+ if self.mode == 'gate':
201
+ output, self.state = recurrent_gate(
202
+ input, self.A, self.B, self.C, self.log_dt, self.state
203
+ )
204
+ else:
205
+ # Non-gate modes: loop over timesteps with precomputed discretization
206
+ outputs = []
207
+ for b in range(B_batch):
208
+ batch_outputs = []
209
+ # Use per-batch state or init
210
+ if self.state is not None and self.state.dim() > len(self._state_shape()):
211
+ x = self.state[b]
212
+ else:
213
+ x = self.state
214
+
215
+ for t in range(T):
216
+ u_t = input[b, t] # (C_in,)
217
+ y_t, x = self._step(u_t, x)
218
+ batch_outputs.append(y_t)
219
+
220
+ # Update state
221
+ if b == 0:
222
+ self.state = x.unsqueeze(0) if B_batch > 1 else x
223
+ elif B_batch > 1:
224
+ self.state = torch.cat([self.state, x.unsqueeze(0)], dim=0)
225
+
226
+ outputs.append(torch.stack(batch_outputs, dim=1)) # (D, T)
227
+
228
+ output = torch.stack(outputs, dim=0) # (B, D, T)
229
+
230
+ # Apply mixer (important for DWS mode which projects channels)
231
+ output = self.mixer(output)
232
+
233
+ if squeeze_batch:
234
+ output = output.squeeze(0)
235
+ if self.state is not None and self.state.dim() > len(self._state_shape()):
236
+ self.state = self.state.squeeze(0)
237
+
238
+ if return_state:
239
+ return output, self.state
240
+ return output
241
+
242
+ def _step(self, u, state):
243
+ """Dispatch to mode-specific step function."""
244
+ if self.mode == 's5':
245
+ return step_s5(u, self.Ad, self.B_hat, self.C, state)
246
+ elif self.mode == 'dws':
247
+ return step_dws(u, self.Ad, self.B_hat, state)
248
+ elif self.mode == 'neck':
249
+ return step_neck(u, self.Ad, self.dt, self.B, self.C, self.E, state)
250
+ elif self.mode == 'full':
251
+ return step_full(u, self.Ad, self.B_hat, state)
252
+
253
+ def _state_shape(self):
254
+ """Return expected unbatched state shape for current mode."""
255
+ if self.mode == 's5':
256
+ return self.Ad.shape # (N, 2)
257
+ elif self.mode == 'dws':
258
+ return self.Ad.shape # (C, N, 2)
259
+ elif self.mode == 'neck':
260
+ return self.Ad.shape # (R, N, 2)
261
+ elif self.mode == 'full':
262
+ return self.Ad.shape # (D, C, N, 2)
263
+ elif self.mode == 'gate':
264
+ return (self.A.shape[0],) # (N,)
265
+
266
+ def reset_state(self):
267
+ """Reset internal state.
268
+
269
+ Call this when starting a new sequence.
270
+
271
+ Example:
272
+ >>> for utterance in utterances:
273
+ >>> infer_layer.reset_state() # Clear state
274
+ >>> for chunk in utterance:
275
+ >>> output = infer_layer(chunk)
276
+ """
277
+ if self.state is not None:
278
+ self.state.zero_()
279
+ else:
280
+ self.state = None
281
+
282
+ def get_state(self):
283
+ """Get current internal state for checkpointing or branching.
284
+
285
+ Returns a clone of the state to prevent accidental mutations.
286
+ Useful for beam search, hypothesis tracking, or state snapshots.
287
+
288
+ Returns:
289
+ state: Cloned state tensor or None if no state exists
290
+
291
+ Example:
292
+ >>> # Save state for beam search
293
+ >>> saved_state = infer_layer.get_state()
294
+ >>> # Process hypothesis 1
295
+ >>> output1 = infer_layer(chunk1)
296
+ >>> # Restore and try hypothesis 2
297
+ >>> infer_layer.set_state(saved_state)
298
+ >>> output2 = infer_layer(chunk2)
299
+ """
300
+ return self.state.clone() if self.state is not None else None
301
+
302
+ def set_state(self, state):
303
+ """Restore internal state from checkpoint.
304
+
305
+ Sets the state to a clone of the provided tensor to prevent
306
+ accidental mutations. Useful for restoring checkpoints or
307
+ branching hypotheses in beam search.
308
+
309
+ Args:
310
+ state: State tensor (shape depends on mode) or None to reset
311
+
312
+ Example:
313
+ >>> # Checkpoint state before branching
314
+ >>> checkpoint = infer_layer.get_state()
315
+ >>> # ... process some data ...
316
+ >>> # Restore to checkpoint
317
+ >>> infer_layer.set_state(checkpoint)
318
+ """
319
+ self.state = state.clone() if state is not None else None
320
+
321
+ def __repr__(self):
322
+ return (
323
+ f'SSMLayerInference(mode={self.mode}, '
324
+ f'in_channels={self.in_channels}, '
325
+ f'out_channels={self.out_channels}, '
326
+ f'state={"active" if self.state is not None else "reset"})'
327
+ )
328
+
329
+
330
+ class SSMLayerExportable(nn.Module):
331
+ """Single-timestep exportable SSM layer for ONNX export (B=1, T=1).
332
+
333
+ This class processes one timestep at a time with explicit state input/output,
334
+ enabling export to ONNX by eliminating dynamic control flow and
335
+ complex number dtypes.
336
+
337
+ Discretization is precomputed at construction time, so the forward pass
338
+ only runs the step function.
339
+
340
+ Currently supports S5, DWS, Neck, Full, and Gate modes. State is represented as real tensors (..., 2)
341
+ where [..., 0] is the real part and [..., 1] is the imaginary part.
342
+
343
+ Example:
344
+ >>> # After training
345
+ >>> train_layer = SSMLayer(num_coeffs=64, in_channels=32, out_channels=32, mode='s5')
346
+ >>> # ... training ...
347
+ >>>
348
+ >>> # Convert to exportable inference mode
349
+ >>> export_layer = SSMLayerExportable.from_training(train_layer)
350
+ >>>
351
+ >>> # Export to ONNX
352
+ >>> dummy_input = torch.randn(32)
353
+ >>> torch.onnx.export(export_layer, (dummy_input, None), "model.onnx")
354
+ >>>
355
+ >>> # Use in streaming application (external loop)
356
+ >>> state = None
357
+ >>> for t in range(audio_length):
358
+ >>> output, state = export_layer(audio[t], state)
359
+
360
+ Note:
361
+ - Processes single sample (B=1), single timestep (T=1) per call
362
+ - State is automatically initialized to zeros if None
363
+ - Loop over time must be external to the model
364
+ - Complex numbers represented as (..., 2) real tensors
365
+ """
366
+
367
+ def __init__(self, mode, in_channels, out_channels, **kwargs):
368
+ """Initialize exportable SSM layer.
369
+
370
+ Args:
371
+ mode: SSM mode ('s5', 'dws', 'neck', 'full', 'gate')
372
+ in_channels: Number of input channels
373
+ out_channels: Number of output channels
374
+ **kwargs: Mode-specific discretized parameters
375
+ """
376
+ super().__init__()
377
+ self.mode = mode
378
+ self.in_channels = in_channels
379
+ self.out_channels = out_channels
380
+
381
+ if mode == 's5':
382
+ self.register_buffer('Ad', kwargs['Ad'])
383
+ self.register_buffer('B_hat', kwargs['B_hat'])
384
+ self.register_buffer('C', kwargs['C'])
385
+ elif mode == 'dws':
386
+ self.register_buffer('Ad', kwargs['Ad'])
387
+ self.register_buffer('B_hat', kwargs['B_hat'])
388
+ elif mode == 'neck':
389
+ self.register_buffer('Ad', kwargs['Ad'])
390
+ self.register_buffer('dt', kwargs['dt'])
391
+ self.register_buffer('B', kwargs['B'])
392
+ self.register_buffer('C', kwargs['C'])
393
+ self.register_buffer('E', kwargs['E'])
394
+ elif mode == 'full':
395
+ self.register_buffer('Ad', kwargs['Ad'])
396
+ self.register_buffer('B_hat', kwargs['B_hat'])
397
+ elif mode == 'gate':
398
+ self.register_buffer('A', kwargs['A'])
399
+ self.B = kwargs['B'] # nn.Module
400
+ self.C = kwargs['C'] # nn.Module
401
+ self.log_dt = kwargs['log_dt'] # nn.Module
402
+ else:
403
+ raise ValueError(f'Unknown mode: {mode}')
404
+
405
+ # Mixer module (for DWS mode to project channels)
406
+ self.mixer = kwargs.get('mixer') or nn.Identity()
407
+
408
+ @classmethod
409
+ def from_training(cls, ssm_layer):
410
+ """Create exportable layer from trained SSMLayer.
411
+
412
+ Args:
413
+ ssm_layer: Trained SSMLayer instance
414
+
415
+ Returns:
416
+ SSMLayerExportable instance with precomputed discretized weights
417
+
418
+ Raises:
419
+ ValueError: If ssm_layer.mode is not supported
420
+
421
+ Example:
422
+ >>> train_layer = SSMLayer(num_coeffs=64, in_channels=32, out_channels=32, mode='s5')
423
+ >>> export_layer = SSMLayerExportable.from_training(train_layer)
424
+ """
425
+ mode = ssm_layer.mode
426
+ kwargs = {}
427
+
428
+ if mode == 's5':
429
+ Ad, B_hat = discretize_s5(
430
+ ssm_layer.A.detach().clone(),
431
+ ssm_layer.B.detach().clone(),
432
+ ssm_layer.log_dt.detach().clone(),
433
+ )
434
+ kwargs['Ad'] = Ad
435
+ kwargs['B_hat'] = B_hat
436
+ kwargs['C'] = ssm_layer.C.detach().clone()
437
+
438
+ elif mode == 'dws':
439
+ Ad, B_hat = discretize_dws(
440
+ ssm_layer.A.detach().clone(),
441
+ ssm_layer.E.detach().clone(),
442
+ ssm_layer.log_dt.detach().clone(),
443
+ )
444
+ kwargs['Ad'] = Ad
445
+ kwargs['B_hat'] = B_hat
446
+ kwargs['mixer'] = ssm_layer.mixer
447
+
448
+ elif mode == 'neck':
449
+ Ad, dt = discretize_neck(
450
+ ssm_layer.A.detach().clone(),
451
+ ssm_layer.log_dt.detach().clone(),
452
+ )
453
+ kwargs['Ad'] = Ad
454
+ kwargs['dt'] = dt
455
+ kwargs['B'] = ssm_layer.B.detach().clone()
456
+ kwargs['C'] = ssm_layer.C.detach().clone()
457
+ kwargs['E'] = ssm_layer.E.detach().clone()
458
+
459
+ elif mode == 'full':
460
+ Ad, B_hat = discretize_full(
461
+ ssm_layer.A.detach().clone(),
462
+ ssm_layer.E.detach().clone(),
463
+ ssm_layer.log_dt.detach().clone(),
464
+ )
465
+ kwargs['Ad'] = Ad
466
+ kwargs['B_hat'] = B_hat
467
+
468
+ elif mode == 'gate':
469
+ kwargs['A'] = ssm_layer.A.detach().clone()
470
+ kwargs['B'] = ssm_layer.B
471
+ kwargs['C'] = ssm_layer.C
472
+ kwargs['log_dt'] = ssm_layer.log_dt
473
+
474
+ else:
475
+ raise ValueError(
476
+ f'SSMLayerExportable only supports S5, DWS, Neck, Full, and Gate modes, got {mode}'
477
+ )
478
+
479
+ return cls(
480
+ mode=mode,
481
+ in_channels=ssm_layer.in_channels,
482
+ out_channels=ssm_layer.out_channels,
483
+ **kwargs,
484
+ )
485
+
486
+ def forward(self, input, state=None):
487
+ """Forward pass for single timestep.
488
+
489
+ Args:
490
+ input: Input tensor of shape (C_in,) - single sample, single timestep
491
+ state: Optional state tensor - shape depends on mode:
492
+ - S5: (N, 2) real representation
493
+ - DWS: (C, N, 2) real representation
494
+ - Neck: (R, N, 2) real representation
495
+ - Full: (D, C, N, 2) real representation
496
+ - Gate: (N,) real-valued
497
+ If None, initializes to zeros internally
498
+
499
+ Returns:
500
+ output: Output tensor of shape (D,)
501
+ new_state: Updated state - same shape as state input
502
+
503
+ Example:
504
+ >>> export_layer = SSMLayerExportable.from_training(trained_layer)
505
+ >>> x = torch.randn(32) # Single timestep input
506
+ >>> y, state = export_layer(x, None) # First call, state=None
507
+ >>> y2, state = export_layer(x2, state) # Subsequent call with state
508
+ """
509
+ if self.mode == 's5':
510
+ output, new_state = step_s5(input, self.Ad, self.B_hat, self.C, state)
511
+ elif self.mode == 'dws':
512
+ output, new_state = step_dws(input, self.Ad, self.B_hat, state)
513
+ # Apply mixer for DWS mode (channel projection)
514
+ # Mixer expects (B, C, T) format, we have (C,) single timestep
515
+ output = (
516
+ self.mixer(output.unsqueeze(0).unsqueeze(-1)).squeeze(0).squeeze(-1)
517
+ ) # (C,) -> (1, C, 1) -> (1, D, 1) -> (D,)
518
+ elif self.mode == 'neck':
519
+ output, new_state = step_neck(input, self.Ad, self.dt, self.B, self.C, self.E, state)
520
+ elif self.mode == 'full':
521
+ output, new_state = step_full(input, self.Ad, self.B_hat, state)
522
+ elif self.mode == 'gate':
523
+ # Initialize state if None
524
+ if state is None:
525
+ N = self.A.shape[0]
526
+ state = torch.zeros(N, dtype=torch.float32, device=input.device)
527
+ output, new_state = recurrent_gate_single_step(
528
+ input, self.A, self.B, self.C, self.log_dt, state
529
+ )
530
+ else:
531
+ raise ValueError(f'Unsupported mode: {self.mode}')
532
+
533
+ return output, new_state
534
+
535
+ def __repr__(self):
536
+ return (
537
+ f'SSMLayerExportable(mode={self.mode}, '
538
+ f'in_channels={self.in_channels}, '
539
+ f'out_channels={self.out_channels})'
540
+ )
tenns_core/recurrent_ops.py ADDED
@@ -0,0 +1,437 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Recurrent operations for streaming SSM inference.
3
+
4
+ Provides discretize_* functions (called once at init) and step_* functions
5
+ (called per timestep) for each SSM mode, enabling low-latency streaming
6
+ inference by maintaining state across chunks.
7
+
8
+ Gate mode is special: its discretization is input-dependent, so it keeps
9
+ combined recurrent_gate / recurrent_gate_single_step functions.
10
+ """
11
+
12
+ import torch
13
+ import torch.nn.functional as F
14
+
15
+ # ============================================================================
16
+ # Complex arithmetic helpers for real representation (ONNX compat)
17
+ # ============================================================================
18
+
19
+
20
+ def complex_mul_real(a, b):
21
+ """Multiply two complex numbers in real representation (..., 2).
22
+
23
+ Args:
24
+ a: Complex tensor as real representation (..., 2) where [..., 0] is real, [..., 1] is imag
25
+ b: Complex tensor as real representation (..., 2)
26
+
27
+ Returns:
28
+ Complex product as real representation (..., 2)
29
+ Formula: (a_r + i*a_i) * (b_r + i*b_i) = (a_r*b_r - a_i*b_i) + i*(a_r*b_i + a_i*b_r)
30
+ """
31
+ a_real = a[..., 0]
32
+ a_imag = a[..., 1]
33
+ b_real = b[..., 0]
34
+ b_imag = b[..., 1]
35
+
36
+ result_real = a_real * b_real - a_imag * b_imag
37
+ result_imag = a_real * b_imag + a_imag * b_real
38
+
39
+ return torch.stack([result_real, result_imag], dim=-1)
40
+
41
+
42
+ # ============================================================================
43
+ # S5 mode
44
+ # ============================================================================
45
+
46
+
47
+ def discretize_s5(A, B, log_dt):
48
+ """Precompute discretized parameters for S5 mode.
49
+
50
+ Args:
51
+ A: State transition parameter of shape (N, 2) - real repr of complex
52
+ B: Input projection of shape (N, C_in)
53
+ log_dt: Time step of shape (N,)
54
+
55
+ Returns:
56
+ Ad: Discretized state transition of shape (N, 2)
57
+ B_hat: Discretized input projection of shape (N, C_in)
58
+ """
59
+ A_real = -F.softplus(A[:, 0]) # (N,)
60
+ A_imag = A[:, 1] # (N,)
61
+
62
+ dt = torch.exp(log_dt) # (N,)
63
+ scaled_real = dt * A_real
64
+ scaled_imag = dt * A_imag
65
+ exp_scaled_real = torch.exp(scaled_real)
66
+ Ad = torch.stack(
67
+ [
68
+ exp_scaled_real * torch.cos(scaled_imag),
69
+ exp_scaled_real * torch.sin(scaled_imag),
70
+ ],
71
+ dim=-1,
72
+ ) # (N, 2)
73
+
74
+ B_hat = dt[:, None] * B # (N, C_in)
75
+
76
+ return Ad, B_hat
77
+
78
+
79
+ def step_s5(u, Ad, B_hat, C, state):
80
+ """Single timestep for S5 mode using pre-discretized parameters.
81
+
82
+ Args:
83
+ u: Input tensor of shape (C_in,)
84
+ Ad: Discretized state transition of shape (N, 2)
85
+ B_hat: Discretized input projection of shape (N, C_in)
86
+ C: Output projection of shape (D, N)
87
+ state: Previous state of shape (N, 2), or None for zero init
88
+
89
+ Returns:
90
+ y: Output tensor of shape (D,)
91
+ new_state: Updated state of shape (N, 2)
92
+ """
93
+ if state is None:
94
+ N = Ad.shape[0]
95
+ state = torch.zeros((N, 2), dtype=torch.float32, device=u.device)
96
+
97
+ # State update: x = Ad * x + B_hat @ u
98
+ x_new = complex_mul_real(Ad, state) # (N, 2)
99
+ Bu = B_hat @ u # (N,)
100
+ x_new[..., 0] = x_new[..., 0] + Bu
101
+
102
+ # Output: y = C @ real(x)
103
+ y = C @ x_new[..., 0] # (D,)
104
+
105
+ return y, x_new
106
+
107
+
108
+ # ============================================================================
109
+ # DWS mode
110
+ # ============================================================================
111
+
112
+
113
+ def discretize_dws(A, E, log_dt):
114
+ """Precompute discretized parameters for DWS mode.
115
+
116
+ Args:
117
+ A: State parameter of shape (C, N, 2) - real repr of complex
118
+ E: Weight matrix of shape (C, N)
119
+ log_dt: Time step of shape (C, N)
120
+
121
+ Returns:
122
+ Ad: Discretized state transition of shape (C, N, 2)
123
+ B_hat: Discretized input projection of shape (C, N)
124
+ """
125
+ A_real = -F.softplus(A[..., 0]) # (C, N)
126
+ A_imag = A[..., 1] # (C, N)
127
+
128
+ dt = torch.exp(log_dt) # (C, N)
129
+ scaled_real = dt * A_real
130
+ scaled_imag = dt * A_imag
131
+ exp_scaled_real = torch.exp(scaled_real)
132
+ Ad = torch.stack(
133
+ [
134
+ exp_scaled_real * torch.cos(scaled_imag),
135
+ exp_scaled_real * torch.sin(scaled_imag),
136
+ ],
137
+ dim=-1,
138
+ ) # (C, N, 2)
139
+
140
+ B_hat = E * dt # (C, N)
141
+
142
+ return Ad, B_hat
143
+
144
+
145
+ def step_dws(u, Ad, B_hat, state):
146
+ """Single timestep for DWS mode using pre-discretized parameters.
147
+
148
+ Args:
149
+ u: Input tensor of shape (C,)
150
+ Ad: Discretized state transition of shape (C, N, 2)
151
+ B_hat: Discretized input projection of shape (C, N)
152
+ state: Previous state of shape (C, N, 2), or None for zero init
153
+
154
+ Returns:
155
+ y: Output tensor of shape (C,)
156
+ new_state: Updated state of shape (C, N, 2)
157
+ """
158
+ if state is None:
159
+ C, N = B_hat.shape
160
+ state = torch.zeros((C, N, 2), dtype=torch.float32, device=u.device)
161
+
162
+ # State update: x = Ad * x + B_hat * u
163
+ x_new = complex_mul_real(Ad, state) # (C, N, 2)
164
+ Bu = B_hat * u.unsqueeze(1) # (C, N)
165
+ x_new[..., 0] = x_new[..., 0] + Bu
166
+
167
+ # Output: y = sum(real(x), dim=1)
168
+ y = torch.sum(x_new[..., 0], dim=1) # (C,)
169
+
170
+ return y, x_new
171
+
172
+
173
+ # ============================================================================
174
+ # Neck mode
175
+ # ============================================================================
176
+
177
+
178
+ def discretize_neck(A, log_dt):
179
+ """Precompute discretized parameters for Neck mode.
180
+
181
+ Args:
182
+ A: State transition parameter of shape (R, N, 2) - real repr of complex
183
+ log_dt: Time step of shape (R,)
184
+
185
+ Returns:
186
+ Ad: Discretized state transition of shape (R, N, 2)
187
+ dt: Discretized time step of shape (R, 1) - needed for input scaling
188
+ """
189
+ A_real = -F.softplus(A[..., 0]) # (R, N)
190
+ A_imag = A[..., 1] # (R, N)
191
+
192
+ dt = torch.exp(log_dt).reshape(-1, 1) # (R, 1)
193
+ scaled_real = dt * A_real
194
+ scaled_imag = dt * A_imag
195
+ exp_scaled_real = torch.exp(scaled_real)
196
+ Ad = torch.stack(
197
+ [
198
+ exp_scaled_real * torch.cos(scaled_imag),
199
+ exp_scaled_real * torch.sin(scaled_imag),
200
+ ],
201
+ dim=-1,
202
+ ) # (R, N, 2)
203
+
204
+ return Ad, dt
205
+
206
+
207
+ def step_neck(u, Ad, dt, B, C, E, state):
208
+ """Single timestep for Neck mode using pre-discretized parameters.
209
+
210
+ Args:
211
+ u: Input tensor of shape (C_in,)
212
+ Ad: Discretized state transition of shape (R, N, 2)
213
+ dt: Discretized time step of shape (R, 1)
214
+ B: Input projection of shape (R, C_in)
215
+ C: Output projection of shape (D, R)
216
+ E: State mixing matrix of shape (R, N)
217
+ state: Previous state of shape (R, N, 2), or None for zero init
218
+
219
+ Returns:
220
+ y: Output tensor of shape (D,)
221
+ new_state: Updated state of shape (R, N, 2)
222
+ """
223
+ if state is None:
224
+ R, N = Ad.shape[0], Ad.shape[1]
225
+ state = torch.zeros((R, N, 2), dtype=torch.float32, device=u.device)
226
+
227
+ # Input projection: v = dt * B @ u
228
+ v = dt.squeeze(1) * (B @ u) # (R,)
229
+
230
+ # State update: x = Ad * x + v
231
+ x_new = complex_mul_real(Ad, state) # (R, N, 2)
232
+ x_new[..., 0] = x_new[..., 0] + v.unsqueeze(1)
233
+
234
+ # Output: z = real((x * E).sum(N)), y = C @ z
235
+ E_cplx = torch.stack([E, torch.zeros_like(E)], dim=-1) # (R, N, 2)
236
+ z = torch.sum(complex_mul_real(x_new, E_cplx)[..., 0], dim=1) # (R,)
237
+ y = C @ z # (D,)
238
+
239
+ return y, x_new
240
+
241
+
242
+ # ============================================================================
243
+ # Full mode
244
+ # ============================================================================
245
+
246
+
247
+ def discretize_full(A, E, log_dt):
248
+ """Precompute discretized parameters for Full mode.
249
+
250
+ Args:
251
+ A: State parameter of shape (D, C, N, 2) - real repr of complex
252
+ E: Weight matrix of shape (D, C, N)
253
+ log_dt: Time step of shape (D, N)
254
+
255
+ Returns:
256
+ Ad: Discretized state transition of shape (D, C, N, 2)
257
+ B_hat: Discretized input projection of shape (D, C, N)
258
+ """
259
+ A_real = -F.softplus(A[..., 0]) # (D, C, N)
260
+ A_imag = A[..., 1] # (D, C, N)
261
+
262
+ dt = torch.exp(log_dt) # (D, N)
263
+ dt_exp = dt[:, None, :] # (D, 1, N)
264
+ scaled_real = dt_exp * A_real
265
+ scaled_imag = dt_exp * A_imag
266
+ exp_scaled_real = torch.exp(scaled_real)
267
+ Ad = torch.stack(
268
+ [
269
+ exp_scaled_real * torch.cos(scaled_imag),
270
+ exp_scaled_real * torch.sin(scaled_imag),
271
+ ],
272
+ dim=-1,
273
+ ) # (D, C, N, 2)
274
+
275
+ B_hat = E * dt_exp # (D, C, N)
276
+
277
+ return Ad, B_hat
278
+
279
+
280
+ def step_full(u, Ad, B_hat, state):
281
+ """Single timestep for Full mode using pre-discretized parameters.
282
+
283
+ Args:
284
+ u: Input tensor of shape (C,)
285
+ Ad: Discretized state transition of shape (D, C, N, 2)
286
+ B_hat: Discretized input projection of shape (D, C, N)
287
+ state: Previous state of shape (D, C, N, 2), or None for zero init
288
+
289
+ Returns:
290
+ y: Output tensor of shape (D,)
291
+ new_state: Updated state of shape (D, C, N, 2)
292
+ """
293
+ if state is None:
294
+ D, C, N = B_hat.shape
295
+ state = torch.zeros((D, C, N, 2), dtype=torch.float32, device=u.device)
296
+
297
+ # State update: x = Ad * x + B_hat * u
298
+ x_new = complex_mul_real(Ad, state) # (D, C, N, 2)
299
+ u_broadcast = u.unsqueeze(0).unsqueeze(2) # (1, C, 1)
300
+ Bu = B_hat * u_broadcast # (D, C, N)
301
+ x_new[..., 0] = x_new[..., 0] + Bu
302
+
303
+ # Output: y = sum(real(x), dim=(1, 2))
304
+ y = torch.sum(x_new[..., 0], dim=(1, 2)) # (D,)
305
+
306
+ return y, x_new
307
+
308
+
309
+ # ============================================================================
310
+ # Gate mode (input-dependent discretization — cannot precompute)
311
+ # ============================================================================
312
+
313
+
314
+ def recurrent_gate_single_step(u, A, B_proj, C_proj, log_dt_proj, state):
315
+ """
316
+ Gate-style SSM single timestep for ONNX export.
317
+
318
+ Processes single timestep with input-dependent parameters.
319
+ Unlike other modes, gate uses neural network projections for B, C, and dt.
320
+
321
+ Args:
322
+ u: Input tensor of shape (C_in,) - single timestep, single batch
323
+ A: State transition parameter of shape (N,) - in log space, represents decay rates
324
+ B_proj: nn.Module that projects (C_in,) -> (N,)
325
+ C_proj: nn.Module that projects (N,) -> (D,)
326
+ log_dt_proj: nn.Module that projects (C_in,) -> (N,)
327
+ state: Previous state of shape (N,) - real-valued state
328
+
329
+ Returns:
330
+ y: Output tensor of shape (D,)
331
+ new_state: Updated state of shape (N,)
332
+
333
+ State update formula:
334
+ log_dt = log_dt_proj(u)
335
+ dt = softplus(log_dt)
336
+ u_proj = B_proj(u)
337
+ dta = exp(-dt * exp(A)) # discretized decay
338
+ x_new = dta * x_old + dt * u_proj
339
+ y = C_proj(x_new)
340
+ """
341
+ # Get input-dependent projections
342
+ u_proj = B_proj(u) # (N,)
343
+ log_dt = log_dt_proj(u) # (N,)
344
+
345
+ # Discretization
346
+ dt = F.softplus(log_dt) # (N,)
347
+ exp_A = torch.exp(A) # (N,) - decay rate
348
+ dta = torch.exp(-dt * exp_A) # (N,) - discretized decay factor
349
+
350
+ # State update: x_new = dta * x_old + dt * u_proj
351
+ u_dt = u_proj * dt # (N,)
352
+ new_state = dta * state + u_dt # (N,)
353
+
354
+ # Output projection
355
+ y = C_proj(new_state) # (D,)
356
+
357
+ return y, new_state
358
+
359
+
360
+ def recurrent_gate(u, A, B_proj, C_proj, log_dt_proj, state=None):
361
+ """
362
+ Gate-style SSM using sequential scan for streaming inference.
363
+
364
+ Args:
365
+ u: Input tensor of shape (T, C_in) or (B, T, C_in)
366
+ A: State transition parameter of shape (N,) - in log space
367
+ B_proj: nn.Module or callable that projects (*, C_in) -> (*, N)
368
+ C_proj: nn.Module or callable that projects (*, N) -> (*, D)
369
+ log_dt_proj: nn.Module or callable that projects (*, C_in) -> (*, N)
370
+ state: Optional previous state of shape (N,) or (B, N)
371
+
372
+ Returns:
373
+ y: Output tensor of shape (D, T) or (B, D, T)
374
+ state: Updated state of shape (N,) or (B, N)
375
+ """
376
+ # Handle batched input
377
+ if u.dim() == 2:
378
+ u = u.unsqueeze(0) # (T, C_in) -> (1, T, C_in)
379
+ squeeze_batch = True
380
+ else:
381
+ squeeze_batch = False
382
+
383
+ B_batch, T, C_in = u.shape
384
+
385
+ # Reshape to (B*T, C_in) for vectorized projection
386
+ u_flat = u.reshape(B_batch * T, C_in)
387
+
388
+ # Get projections
389
+ u_proj = B_proj(u_flat) # (B*T, N)
390
+ log_dt = log_dt_proj(u_flat) # (B*T, N)
391
+
392
+ N = u_proj.shape[1]
393
+
394
+ # Reshape back to (B, T, N)
395
+ u_proj = u_proj.reshape(B_batch, T, N)
396
+ log_dt = log_dt.reshape(B_batch, T, N)
397
+
398
+ # Discretize (vectorized across batch)
399
+ dt = F.softplus(log_dt).to(torch.float32) # (B, T, N)
400
+ exp_A = torch.exp(A).to(torch.float32) # (N,)
401
+ log_dta = -dt * exp_A[None, None, :] # (B, T, N)
402
+ dta = torch.exp(log_dta) # (B, T, N)
403
+
404
+ # Prepare scan input
405
+ u_dt = u_proj * dt # (B, T, N)
406
+
407
+ # Initialize state
408
+ if state is None:
409
+ x = torch.zeros((B_batch, N), dtype=torch.float32, device=u.device)
410
+ else:
411
+ if state.dim() == 1:
412
+ x = state.unsqueeze(0).expand(B_batch, -1)
413
+ else:
414
+ x = state
415
+
416
+ # Output accumulator
417
+ states = torch.zeros((B_batch, T, N), dtype=torch.float32, device=u.device)
418
+
419
+ # Sequential scan over time (vectorized over batch)
420
+ for t in range(T):
421
+ x = dta[:, t] * x + u_dt[:, t] # (B, N)
422
+ states[:, t] = x
423
+
424
+ # Apply C projection: (B*T, N) -> (B*T, D)
425
+ states_flat = states.reshape(B_batch * T, N)
426
+ y_flat = C_proj(states_flat) # (B*T, D)
427
+ D = y_flat.shape[1]
428
+ y = y_flat.reshape(B_batch, T, D) # (B, T, D)
429
+
430
+ # Return format
431
+ y = y.transpose(1, 2) # (B, D, T)
432
+
433
+ if squeeze_batch:
434
+ y = y.squeeze(0) # (D, T)
435
+ x = x.squeeze(0) # (N,)
436
+
437
+ return y, x
tenns_core/scan_ops.py ADDED
@@ -0,0 +1,515 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Parallel scan operations for gate mode SSM.
3
+
4
+ Implements parallel prefix scan with custom autograd for training support.
5
+ Uses Triton kernels when available on CUDA, falls back to pure PyTorch otherwise.
6
+ """
7
+
8
+ import torch
9
+ from torch import nn
10
+ from torch.nn import functional as F
11
+
12
+ try:
13
+ import triton
14
+ import triton.language as tl
15
+
16
+ _HAS_TRITON = hasattr(tl, 'associative_scan')
17
+ except ImportError:
18
+ _HAS_TRITON = False
19
+
20
+
21
+ # ----------------------------
22
+ # Utility
23
+ # ----------------------------
24
+ def _tp(x: torch.Tensor) -> torch.Tensor:
25
+ """(B, L, N) -> (B, N, L) contiguous."""
26
+ return x.moveaxis(-1, -2).contiguous()
27
+
28
+
29
+ # ----------------------------
30
+ # Reference (naive) scan
31
+ # ----------------------------
32
+ def scan_naive(input, log_dt, A, state=None, dim=-1):
33
+ """Naive sequential scan implementation.
34
+
35
+ Useful for testing and understanding, but slow (O(N) sequential steps).
36
+
37
+ Args:
38
+ input: Input tensor
39
+ log_dt: Log timestep parameters
40
+ A: State decay parameters
41
+ state: Optional initial state
42
+ dim: Dimension to scan over
43
+
44
+ Returns:
45
+ Scanned output tensor
46
+ """
47
+ dt = F.softplus(log_dt)
48
+ log_dta = -dt * A.exp()[..., None]
49
+ a = log_dta.exp()
50
+
51
+ if state is None:
52
+ state = 0
53
+ output = []
54
+ u = input * dt
55
+
56
+ for ui, ai in zip(u.moveaxis(dim, 0), a.moveaxis(dim, 0), strict=True):
57
+ state = ai * state + ui
58
+ output.append(state)
59
+
60
+ return torch.stack(output, dim=dim)
61
+
62
+
63
+ # ----------------------------
64
+ # PyTorch parallel scan
65
+ # ----------------------------
66
+ class ParallelScan(torch.autograd.Function):
67
+ """Parallel prefix scan with custom autograd.
68
+
69
+ Implements the associative scan operation:
70
+ state[t] = a[t] * state[t-1] + u[t]
71
+
72
+ In O(log N) parallel depth instead of O(N) sequential steps.
73
+
74
+ Note: This uses the naive sequential scan for backward pass to ensure
75
+ correctness. For production use with very long sequences, a parallel
76
+ backward scan could be implemented.
77
+ """
78
+
79
+ @staticmethod
80
+ def forward(ctx, u, a):
81
+ """Forward pass: parallel prefix scan.
82
+
83
+ Args:
84
+ u: Input values (batch, N, length)
85
+ a: Decay factors (batch, N, length)
86
+
87
+ Returns:
88
+ Scanned output (batch, N, length)
89
+ """
90
+ length = u.shape[-1]
91
+ strides = [2**i for i in range((length - 1).bit_length())]
92
+
93
+ # Save original inputs for backward
94
+ u_original = u.clone()
95
+ a_original = a.clone()
96
+
97
+ # Clone to avoid in-place modifications
98
+ u = u.clone()
99
+ a = a.clone()
100
+
101
+ for stride in strides:
102
+ u[..., stride:] = u[..., stride:] + u[..., :-stride] * a[..., stride:]
103
+ a[..., stride:] = a[..., stride:] * a[..., :-stride]
104
+
105
+ ctx.save_for_backward(u_original, a_original, u)
106
+ return u
107
+
108
+ @staticmethod
109
+ def backward(ctx, grad_output):
110
+ """Backward pass using sequential scan for correctness.
111
+
112
+ For production, this could be parallelized, but sequential is more
113
+ numerically stable and easier to verify.
114
+ """
115
+ u_original, a_original, y = ctx.saved_tensors
116
+
117
+ # Compute gradients using reverse-mode automatic differentiation
118
+ # by recomputing forward pass while tracking dependencies
119
+
120
+ grad_u = torch.zeros_like(u_original)
121
+ grad_a = torch.zeros_like(a_original)
122
+
123
+ # Backward scan: process from right to left
124
+ length = u_original.shape[-1]
125
+
126
+ # Accumulator for gradient flowing backward through time
127
+ grad_state = torch.zeros_like(u_original[..., 0:1])
128
+
129
+ for t in range(length - 1, -1, -1):
130
+ # Gradient from output at time t
131
+ grad_y_t = grad_output[..., t : t + 1]
132
+
133
+ # Total gradient flowing into state[t]
134
+ grad_state_t = grad_y_t + grad_state
135
+
136
+ # Gradients w.r.t. inputs
137
+ grad_u[..., t : t + 1] = grad_state_t
138
+ if t > 0:
139
+ grad_a[..., t : t + 1] = grad_state_t * y[..., t - 1 : t]
140
+
141
+ # Propagate gradient to previous state
142
+ if t > 0:
143
+ grad_state = grad_state_t * a_original[..., t : t + 1]
144
+
145
+ return grad_u, grad_a
146
+
147
+
148
+ def parallel_scan_pytorch(input, log_dt, A, state=None):
149
+ """Pure PyTorch parallel scan for SSM.
150
+
151
+ Args:
152
+ input: Input tensor (batch, length, N)
153
+ log_dt: Log timestep parameters (batch, length, N)
154
+ A: State decay parameters (N,)
155
+ state: Optional initial state (N,)
156
+
157
+ Returns:
158
+ Scanned output (batch, length, N)
159
+ """
160
+ dt = F.softplus(log_dt)
161
+ log_dta = -dt * A.exp()[None, None, :]
162
+ a = log_dta.exp()
163
+ u = input * dt
164
+
165
+ # Fold initial state into first timestep
166
+ if state is not None:
167
+ u = u.clone()
168
+ u[:, 0, :] = u[:, 0, :] + state * a[:, 0, :]
169
+
170
+ # Transpose for scan: (batch, N, length)
171
+ u = u.transpose(-1, -2)
172
+ a = a.transpose(-1, -2)
173
+
174
+ # Apply parallel scan
175
+ output = ParallelScan.apply(u, a)
176
+
177
+ # Transpose back: (batch, length, N)
178
+ return output.transpose(-1, -2)
179
+
180
+
181
+ # ----------------------------
182
+ # Triton kernels (guarded)
183
+ # ----------------------------
184
+ if _HAS_TRITON:
185
+
186
+ @triton.jit
187
+ def _roll_op(x1, y1, x2, y2):
188
+ return x2, tl.where(y2 == float('inf'), x1, y2)
189
+
190
+ @triton.jit
191
+ def roll(u, length: tl.constexpr, reverse: tl.constexpr = 0):
192
+ if reverse:
193
+ _, u_rol = tl.associative_scan((u, float('inf') + u), 0, _roll_op, reverse=1)
194
+ u_rol = tl.where(tl.arange(0, length) < length - 1, u_rol, 0)
195
+ else:
196
+ _, u_rol = tl.associative_scan((u, float('inf') + u), 0, _roll_op)
197
+ u_rol = tl.where(tl.arange(0, length) > 0, u_rol, 0)
198
+ return u_rol
199
+
200
+ @triton.jit
201
+ def _scan_op(a1, x1, a2, x2):
202
+ return a1 * a2, a2 * x1 + x2
203
+
204
+ @triton.jit
205
+ def softplus_tl(x):
206
+ return tl.where(x < 20, tl.log(1 + tl.exp(x)), x)
207
+
208
+ @triton.jit
209
+ def scan_heisen_fwd_triton(
210
+ u_ptr,
211
+ log_dt_ptr,
212
+ A_ptr,
213
+ y_ptr,
214
+ state_ptr,
215
+ L,
216
+ N: tl.constexpr,
217
+ MAX_L: tl.constexpr,
218
+ INIT_STATE: tl.constexpr = 1,
219
+ ):
220
+ id_BATCH, id_N = tl.program_id(0), tl.program_id(1)
221
+ id_sample = id_BATCH * N + id_N
222
+
223
+ lrange = tl.arange(0, MAX_L)
224
+ offsets = id_sample * L + lrange
225
+ mask = lrange < L
226
+
227
+ A = tl.load(A_ptr + id_N)
228
+ if INIT_STATE:
229
+ state = tl.load(state_ptr + id_N)
230
+
231
+ u = tl.load(u_ptr + offsets, mask, 0).to(tl.float32)
232
+ log_dt = tl.load(log_dt_ptr + offsets, mask, 0).to(tl.float32)
233
+
234
+ dt = softplus_tl(log_dt)
235
+ log_dta = -1.0 * dt * tl.exp(A)
236
+ dta = tl.exp(log_dta)
237
+
238
+ if INIT_STATE:
239
+ u_dt = tl.where(lrange > 0, u * dt, u * dt + state * dta)
240
+ else:
241
+ u_dt = u * dt
242
+
243
+ _, y = tl.associative_scan((dta, u_dt), 0, _scan_op)
244
+ tl.store(y_ptr + offsets, y, mask)
245
+
246
+ @triton.jit
247
+ def scan_heisen_bwd_triton(
248
+ u_ptr,
249
+ grad_x_ptr,
250
+ log_dt_ptr,
251
+ A_ptr,
252
+ state_ptr,
253
+ grad_u_ptr,
254
+ grad_log_dt_ptr,
255
+ grad_A_ptr,
256
+ grad_x0_ptr,
257
+ L,
258
+ N: tl.constexpr,
259
+ MAX_L: tl.constexpr,
260
+ INIT_STATE: tl.constexpr = 1,
261
+ ):
262
+ id_BATCH, id_N = tl.program_id(0), tl.program_id(1)
263
+ id_sample = id_BATCH * N + id_N
264
+
265
+ lrange = tl.arange(0, MAX_L)
266
+ offsets = id_sample * L + lrange
267
+ mask = lrange < L
268
+
269
+ A = tl.load(A_ptr + id_N)
270
+ exp_A = tl.exp(A)
271
+ if INIT_STATE:
272
+ state = tl.load(state_ptr + id_N)
273
+
274
+ u = tl.load(u_ptr + offsets, mask, 0).to(tl.float32)
275
+ log_dt = tl.load(log_dt_ptr + offsets, mask, 0).to(tl.float32)
276
+
277
+ dt = softplus_tl(log_dt)
278
+ log_dta = -1.0 * dt * exp_A
279
+ dta = tl.exp(log_dta)
280
+
281
+ if INIT_STATE:
282
+ u_dt = tl.where(lrange > 0, u * dt, u * dt + state * dta)
283
+ else:
284
+ u_dt = u * dt
285
+
286
+ _, x = tl.associative_scan((dta, u_dt), 0, _scan_op)
287
+ x_rol = roll(x, MAX_L)
288
+
289
+ grad_x = tl.load(grad_x_ptr + offsets, mask, 0).to(tl.float32)
290
+
291
+ if INIT_STATE:
292
+ log_dta_star = tl.cumsum(log_dta, 0)
293
+ dta_star = tl.exp(log_dta_star)
294
+ grad_x0 = tl.sum(grad_x * dta_star, 0)
295
+ tl.store(grad_x0_ptr + id_sample, grad_x0)
296
+ x_rol = tl.where(lrange > 0, x_rol, state)
297
+
298
+ dta_rol = roll(dta, MAX_L, reverse=1)
299
+ _, grad_x = tl.associative_scan((dta_rol, grad_x), 0, _scan_op, reverse=1)
300
+
301
+ grad_u = grad_x * dt
302
+ tl.store(grad_u_ptr + offsets, grad_u, mask)
303
+
304
+ grad_dta = grad_x * x_rol
305
+ grad_log_dta = tl.exp(log_dta) * grad_dta
306
+
307
+ grad_log_dt = (-1.0 * grad_log_dta * exp_A + u * grad_x) * tl.sigmoid(log_dt)
308
+ tl.store(grad_log_dt_ptr + offsets, grad_log_dt, mask)
309
+
310
+ grad_A = tl.sum(grad_log_dta * log_dta, 0)
311
+ tl.store(grad_A_ptr + id_sample, grad_A)
312
+
313
+ class FusedScanTriton(torch.autograd.Function):
314
+ @staticmethod
315
+ @torch.compiler.disable
316
+ @torch.amp.custom_fwd(device_type='cuda')
317
+ def forward(ctx, u, T1, T2, logdt_bias, A, B1, B2, state=None):
318
+ INIT_STATE = state is not None
319
+
320
+ uh = u.half()
321
+ T1, T2, logdt_bias, B1 = T1.half(), T2.half(), logdt_bias.half(), B1.half()
322
+ if B2 is not None:
323
+ B2 = B2.half()
324
+
325
+ if B2 is not None:
326
+ u1 = F.linear(uh, B1)
327
+ u2_tp = _tp(F.linear(u1, B2))
328
+ else:
329
+ u2_tp = _tp(F.linear(uh, B1))
330
+
331
+ logdt_1 = F.linear(uh, T1)
332
+ logdt_tp = _tp(F.linear(logdt_1, T2, bias=logdt_bias))
333
+
334
+ x_tp = torch.empty_like(u2_tp, dtype=torch.float32)
335
+
336
+ BATCH, N, L = u2_tp.shape
337
+ grid = (BATCH, N)
338
+ max_L = triton.next_power_of_2(L)
339
+ num_warps = max(max_L // 1024, 1)
340
+
341
+ scan_heisen_fwd_triton[grid](
342
+ u2_tp,
343
+ logdt_tp,
344
+ A,
345
+ x_tp,
346
+ state,
347
+ L,
348
+ N,
349
+ max_L,
350
+ INIT_STATE=INIT_STATE,
351
+ num_warps=num_warps,
352
+ num_stages=3,
353
+ )
354
+
355
+ if B2 is not None:
356
+ ctx.save_for_backward(uh, state, A, T1, T2, logdt_bias, B1, B2)
357
+ ctx.B2_flag = True
358
+ else:
359
+ ctx.save_for_backward(uh, u2_tp, state, A, T1, T2, logdt_bias, B1)
360
+ ctx.B2_flag = False
361
+
362
+ return x_tp.moveaxis(-1, -2) # (B, L, N)
363
+
364
+ @staticmethod
365
+ @torch.compiler.disable
366
+ @torch.amp.custom_bwd(device_type='cuda')
367
+ def backward(ctx, grad_x):
368
+ def back_dot(x, y):
369
+ return torch.tensordot(x, y, dims=([0], [0]))
370
+
371
+ if ctx.B2_flag:
372
+ uh, state, A, T1, T2, logdt_bias, B1, B2 = ctx.saved_tensors
373
+ else:
374
+ uh, u2_tp, state, A, T1, T2, logdt_bias, B1 = ctx.saved_tensors
375
+ B2 = None
376
+
377
+ INIT_STATE = state is not None
378
+ grad_x_tp = _tp(grad_x)
379
+
380
+ if B2 is not None:
381
+ u1 = F.linear(uh, B1)
382
+ u2_tp = _tp(F.linear(u1, B2))
383
+
384
+ logdt1 = F.linear(uh, T1)
385
+ logdt2 = F.linear(logdt1, T2, bias=logdt_bias)
386
+ logdt2_tp = _tp(logdt2)
387
+
388
+ BATCH, N, L = u2_tp.shape
389
+ grid = (BATCH, N)
390
+
391
+ grad_u2_tp = torch.empty_like(u2_tp, dtype=torch.float32)
392
+ grad_logdt_tp = torch.empty_like(u2_tp, dtype=torch.float32)
393
+ grad_A = torch.empty(grid, dtype=torch.float32, device=A.device)
394
+ grad_x0 = (
395
+ torch.empty((BATCH, N), dtype=torch.float32, device=A.device)
396
+ if INIT_STATE
397
+ else None
398
+ )
399
+
400
+ max_L = triton.next_power_of_2(L)
401
+ num_warps = max(max_L // 1024, 1)
402
+
403
+ scan_heisen_bwd_triton[grid](
404
+ u2_tp,
405
+ grad_x_tp,
406
+ logdt2_tp,
407
+ A,
408
+ state,
409
+ grad_u2_tp,
410
+ grad_logdt_tp,
411
+ grad_A,
412
+ grad_x0,
413
+ L,
414
+ N,
415
+ max_L,
416
+ INIT_STATE=INIT_STATE,
417
+ num_warps=num_warps,
418
+ num_stages=3,
419
+ )
420
+
421
+ grad_A = grad_A.sum(0)
422
+ grad_init_state = grad_x0.sum(0) if INIT_STATE else None
423
+
424
+ uh2 = uh.view(-1, N)
425
+ grad_u2 = _tp(grad_u2_tp).view(-1, N)
426
+
427
+ if B2 is not None:
428
+ grad_u1 = grad_u2 @ B2
429
+ grad_B2 = back_dot(grad_u2, u1.view(BATCH * L, -1))
430
+ else:
431
+ grad_B2 = None
432
+ grad_u1 = grad_u2
433
+
434
+ grad_u = grad_u1 @ B1
435
+ grad_B1 = back_dot(grad_u1, uh2)
436
+
437
+ grad_logdt_bias = grad_logdt_tp.sum((0, 2))
438
+
439
+ grad_logdt = _tp(grad_logdt_tp).view(-1, N)
440
+ grad_logdt_1 = grad_logdt @ T2
441
+ grad_T2 = back_dot(grad_logdt, logdt1.view(BATCH * L, -1))
442
+
443
+ grad_u = grad_u + grad_logdt_1 @ T1
444
+ grad_T1 = back_dot(grad_logdt_1, uh2)
445
+
446
+ return (
447
+ grad_u.view(BATCH, L, N),
448
+ grad_T1,
449
+ grad_T2,
450
+ grad_logdt_bias,
451
+ grad_A,
452
+ grad_B1,
453
+ grad_B2,
454
+ grad_init_state,
455
+ )
456
+
457
+
458
+ # ----------------------------
459
+ # Unified API
460
+ # ----------------------------
461
+ def _can_use_triton(u: torch.Tensor) -> bool:
462
+ if not _HAS_TRITON:
463
+ return False
464
+ if not u.is_cuda:
465
+ return False
466
+ try:
467
+ major, _ = torch.cuda.get_device_capability(u.device)
468
+ if major < 7:
469
+ return False
470
+ except Exception:
471
+ pass
472
+ return True
473
+
474
+
475
+ def fused_scan(u, log_dt_proj, in_proj, A, state=None):
476
+ """Fused scan operation for gate mode SSM.
477
+
478
+ Uses Triton kernels on CUDA when available, falls back to PyTorch parallel scan.
479
+
480
+ Args:
481
+ u: Input tensor (batch, length, channels)
482
+ log_dt_proj: Sequential module for timestep projection
483
+ in_proj: Sequential or single module for input projection
484
+ A: State decay parameters (N,)
485
+ state: Optional initial state (N,)
486
+
487
+ Returns:
488
+ Scanned output (batch, length, N)
489
+ """
490
+ if _can_use_triton(u):
491
+ # Extract weights for Triton path
492
+ if isinstance(in_proj, nn.Linear):
493
+ B1, B2 = in_proj.weight, None
494
+ else:
495
+ B1, B2 = in_proj[0].weight, in_proj[1].weight
496
+
497
+ T1 = log_dt_proj[0].weight
498
+ T2 = log_dt_proj[1].weight
499
+ logdt_bias = log_dt_proj[1].bias
500
+
501
+ return FusedScanTriton.apply(
502
+ u.contiguous(),
503
+ T1.contiguous(),
504
+ T2.contiguous(),
505
+ logdt_bias.contiguous(),
506
+ A.contiguous(),
507
+ B1.contiguous(),
508
+ B2.contiguous() if B2 is not None else None,
509
+ state.contiguous() if state is not None else None,
510
+ )
511
+
512
+ # PyTorch fallback (CPU or CUDA without Triton)
513
+ u_proj = in_proj(u)
514
+ log_dt = log_dt_proj(u)
515
+ return parallel_scan_pytorch(u_proj, log_dt, A, state=state)
tenns_core/ssm.py ADDED
@@ -0,0 +1,481 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ State Space Model (SSM) layers for sequence modeling.
3
+
4
+ This module provides SSMLayer, a flexible implementation of various SSM architectures
5
+ including S5, DWS, Neck, Full, and Gate modes. All implementations use pure PyTorch
6
+ with custom autograd functions for efficient training.
7
+ """
8
+
9
+ import math
10
+
11
+ import einops
12
+ import numpy as np
13
+ import torch
14
+ from torch import nn
15
+ from torch.nn import functional as F
16
+ from torch.nn.parameter import Parameter
17
+
18
+ from .activations import get_activations
19
+ from .fft_ops import padded_fft_conv_opt
20
+ from .scan_ops import fused_scan
21
+
22
+
23
+ # Utility functions
24
+ def c2r(inputs):
25
+ return torch.view_as_real(inputs)
26
+
27
+
28
+ def r2c(inputs):
29
+ return torch.view_as_complex(inputs)
30
+
31
+
32
+ def inv_softplus(x):
33
+ return x + np.log(-np.expm1(-x))
34
+
35
+
36
+ class Kernelizer(nn.Module):
37
+ """Core module for SSM operations using FFT convolutions and parallel scans.
38
+
39
+ This is the base class that handles the actual SSM computation.
40
+ SSMLayer extends this with parameter initialization and training utilities.
41
+ """
42
+
43
+ def __init__(self, mode='s5', transposed=False, complex_proj=False, **kwargs):
44
+ """Initialize Kernelizer.
45
+
46
+ Args:
47
+ mode: SSM mode ('s5', 'dws', 'neck', 'full', 'gate')
48
+ transposed: Whether to use transposed operations (time-last vs channel-last)
49
+ complex_proj: Whether to use complex projections
50
+ """
51
+ super().__init__()
52
+
53
+ self.mode = mode
54
+ self.transposed = transposed
55
+ self.complex_proj = complex_proj
56
+
57
+ @torch.compiler.disable
58
+ def discretize(self, A: torch.Tensor, weight: torch.Tensor, log_dt: torch.Tensor):
59
+ """Discretize continuous-time SSM using zero-order-hold method.
60
+
61
+ Converts continuous-time parameters (A, B, dt) to discrete-time (A_bar, B_bar)
62
+ using the zero-order-hold discretization:
63
+ A_bar = exp(A * dt)
64
+ B_bar = B * dt
65
+
66
+ NOTE: Assumes diagonal state matrix A.
67
+
68
+ Args:
69
+ A: State matrix diagonal [real, imag] (shape varies by mode)
70
+ weight: Input weight matrix B or output weight E (shape varies by mode)
71
+ log_dt: Log of timestep parameters
72
+
73
+ Returns:
74
+ Tuple of (dtA_real, dtA_imag, weight_hat) discretized parameters
75
+ """
76
+ with torch.autocast('cuda', enabled=False):
77
+ A_real, A_imag = -F.softplus(A[..., 0]), A[..., 1]
78
+ dt = log_dt.exp()
79
+
80
+ match self.mode:
81
+ case 'neck':
82
+ dt = dt.unsqueeze(-1) # (R, :) -> (R, :, 1)
83
+ weight_hat = weight * dt
84
+ case 'full':
85
+ dt = dt.unsqueeze(-2) # (D, :) -> (D, 1, :)
86
+ weight_hat = weight * dt
87
+ case 'dws':
88
+ weight_hat = weight * dt # (C, N)
89
+ case _: # s5, gate
90
+ weight_hat = weight * dt.unsqueeze(-1) # (R*N, :) -> (R*N, C)
91
+
92
+ dtA_real, dtA_imag = dt * A_real, dt * A_imag
93
+
94
+ return dtA_real, dtA_imag, weight_hat
95
+
96
+ def forward(
97
+ self,
98
+ input: torch.Tensor,
99
+ A: torch.Tensor,
100
+ B: torch.Tensor,
101
+ C: torch.Tensor,
102
+ log_dt: torch.Tensor,
103
+ E: torch.Tensor,
104
+ state=None,
105
+ ):
106
+ """Forward pass through SSM layer.
107
+
108
+ Args:
109
+ input: Input tensor (batch, channels, length)
110
+ A: State matrix diagonal parameters
111
+ B: Input projection matrix (for s5/neck/gate modes)
112
+ C: Output projection matrix (for s5/neck modes) or module (for gate)
113
+ log_dt: Log timestep parameters
114
+ E: State mixing matrix (for dws/neck/full modes)
115
+ state: Optional initial state (for gate mode prefix tuning)
116
+
117
+ Returns:
118
+ Output tensor (batch, out_channels, length)
119
+ """
120
+ match self.mode:
121
+ case 's5' | 'neck':
122
+ dtA_real, dtA_imag, B_hat = self.discretize(A, B, log_dt)
123
+ return padded_fft_conv_opt(input, dtA_real, dtA_imag, B_hat, C, E)
124
+
125
+ case 'dws' | 'full':
126
+ dtA_real, dtA_imag, E_hat = self.discretize(A, E, log_dt)
127
+ return padded_fft_conv_opt(input, dtA_real, dtA_imag, None, None, E_hat)
128
+
129
+ case 'gate':
130
+ # Gate mode can work with both formats
131
+ # Transpose if needed: (B, C, T) -> (B, T, C)
132
+ if not self.transposed:
133
+ input = input.transpose(1, 2)
134
+
135
+ output = C(fused_scan(input, log_dt, B, A, state=state))
136
+
137
+ # Transpose back if needed: (B, T, D) -> (B, D, T)
138
+ if not self.transposed:
139
+ output = output.transpose(1, 2)
140
+
141
+ return output
142
+
143
+
144
+ class SSMLayer(Kernelizer):
145
+ """State Space Model layer with multiple architecture variants.
146
+
147
+ Extends Kernelizer with parameter initialization, activation layers,
148
+ and training utilities. Supports multiple SSM modes:
149
+
150
+ - **s5**: Standard S5 architecture with shared state space
151
+ - **dws**: Depthwise separable variant (per-channel state spaces)
152
+ - **neck**: Bottleneck architecture with low-rank state mixing
153
+ - **full**: Full parameterization (per-output-channel state spaces)
154
+ - **gate**: Input-dependent gating (Mamba-style selective SSM)
155
+
156
+ Mode Comparison:
157
+ ----------------
158
+ | Mode | Parameters | Best For | Speed |
159
+ |-------|------------|-----------------------------|---------|
160
+ | s5 | Medium | General sequence modeling | Fast |
161
+ | dws | Low | Efficient local processing | Fastest |
162
+ | neck | Low | Long sequences, low memory | Fast |
163
+ | full | High | Rich feature interactions | Medium |
164
+ | gate | High | Input-adaptive processing | Slow |
165
+
166
+ Usage Example:
167
+ --------------
168
+ >>> # S5 mode for sequence classification
169
+ >>> layer = SSMLayer(
170
+ ... num_coeffs=64, # State space dimension
171
+ ... in_channels=128, # Input features
172
+ ... out_channels=256, # Output features
173
+ ... mode='s5',
174
+ ... repeat=1, # Number of parallel SSMs
175
+ ... norm='layer',
176
+ ... postact='gelu'
177
+ ... )
178
+ >>> input = torch.randn(4, 128, 512) # (batch, channels, length)
179
+ >>> output = layer(input) # (4, 256, 512)
180
+ """
181
+
182
+ def __init__(
183
+ self,
184
+ num_coeffs: int,
185
+ in_channels: int,
186
+ out_channels: int,
187
+ repeat=None,
188
+ norm='batch',
189
+ postact='relu',
190
+ dropout=None,
191
+ dropout_dim=1,
192
+ use_activations=False,
193
+ **kwargs,
194
+ ):
195
+ """Initialize SSM layer.
196
+
197
+ Args:
198
+ num_coeffs: Dimension of state space (N in SSM notation)
199
+ in_channels: Number of input channels
200
+ out_channels: Number of output channels
201
+ repeat: Number of parallel SSM blocks (default: 1)
202
+ norm: Normalization type ('batch', 'layer', 'rms', None)
203
+ postact: Activation function ('relu', 'gelu', 'silu', None)
204
+ dropout: Dropout probability (None for no dropout)
205
+ dropout_dim: Dimension for dropout (0, 1, 2, or 3)
206
+ use_activations: Whether to apply activations to mixer output
207
+ **kwargs: Additional arguments (mode, transposed, complex_proj, etc.)
208
+ """
209
+ _VALID_MODES = {'s5', 'dws', 'neck', 'full', 'gate'}
210
+ _VALID_NORMS = {'batch', 'layer', 'layer-feature', 'rms', None}
211
+ _VALID_POSTACTS = {'relu', 'relu6', 'lelu', 'sigmoid', 'tanh', 'gelu', 'glu', 'silu', None}
212
+ _VALID_DROPOUT_DIMS = {0, 1, 2, 3}
213
+
214
+ mode = kwargs.get('mode', 's5')
215
+ if mode not in _VALID_MODES:
216
+ raise ValueError(f"Invalid mode '{mode}'. Must be one of {sorted(_VALID_MODES)}.")
217
+ if norm not in _VALID_NORMS:
218
+ raise ValueError(
219
+ f"Invalid norm '{norm}'. Must be one of {sorted(_VALID_NORMS, key=str)}."
220
+ )
221
+ if postact not in _VALID_POSTACTS:
222
+ raise ValueError(
223
+ f"Invalid postact '{postact}'. Must be one of {sorted(_VALID_POSTACTS, key=str)}."
224
+ )
225
+ if dropout_dim not in _VALID_DROPOUT_DIMS:
226
+ raise ValueError(
227
+ f'Invalid dropout_dim {dropout_dim}. Must be one of {sorted(_VALID_DROPOUT_DIMS)}.'
228
+ )
229
+ if num_coeffs < 1:
230
+ raise ValueError(f'num_coeffs must be >= 1, got {num_coeffs}.')
231
+ if in_channels < 1:
232
+ raise ValueError(f'in_channels must be >= 1, got {in_channels}.')
233
+ if out_channels < 1:
234
+ raise ValueError(f'out_channels must be >= 1, got {out_channels}.')
235
+
236
+ super().__init__(**kwargs)
237
+ self.in_channels = in_channels
238
+ self.out_channels = out_channels
239
+
240
+ self.repeat = 1 if repeat is None else repeat
241
+
242
+ self.norm = norm
243
+ self.postact = postact
244
+ self.dropout = dropout
245
+ self.dropout_dim = dropout_dim
246
+
247
+ self.bias = None
248
+ self.E = None
249
+
250
+ # Initialize state matrix A
251
+ if self.mode == 'gate':
252
+ # For gate mode: log-spaced initialization
253
+ A = np.arange(1, num_coeffs + 1)
254
+ A = np.log(A)
255
+ else:
256
+ # For FFT modes: complex eigenvalues
257
+ # Real part: decay rate, Imaginary part: frequency
258
+ A = np.stack([0.5 * np.ones(num_coeffs), math.pi * np.arange(num_coeffs)], -1)
259
+ A[..., 0] = inv_softplus(A[..., 0])
260
+
261
+ # Initialize timestep parameters
262
+ if self.mode in ['dws']:
263
+ dt = np.geomspace(1e-3, 1e-1, in_channels)
264
+ elif self.mode == 'full':
265
+ dt = np.geomspace(1e-3, 1e-1, out_channels)
266
+ else:
267
+ dt = np.geomspace(1e-3, 1e-1, self.repeat)
268
+
269
+ if self.mode == 'gate':
270
+ log_dt = inv_softplus(dt)
271
+ else:
272
+ log_dt = np.log(dt)
273
+
274
+ # Helper functions for parameter creation
275
+ def to_parameter(mat, is_complex=False, requires_grad=True):
276
+ if mat is None:
277
+ return None
278
+ tensor = torch.tensor(mat, dtype=torch.float)
279
+ if is_complex:
280
+ tensor = tensor.cfloat()
281
+ return Parameter(tensor, requires_grad=requires_grad)
282
+
283
+ def ones(shape, fan_in):
284
+ mat = np.ones(shape) / math.sqrt(fan_in)
285
+ return to_parameter(mat, is_complex=self.complex_proj)
286
+
287
+ def normal(shape, fan_in):
288
+ mat = np.random.randn(*shape) * math.sqrt(2 / fan_in)
289
+ return to_parameter(mat, is_complex=self.complex_proj)
290
+
291
+ tot_coeffs = self.repeat * num_coeffs
292
+
293
+ # Mode-specific parameter initialization
294
+ match self.mode:
295
+ case 'dws':
296
+ log_dt = einops.repeat(log_dt, 'c -> c n', n=num_coeffs)
297
+ A = einops.repeat(A, 'n i -> c n i', c=in_channels)
298
+ self.B = None
299
+ self.C = None
300
+ self.E = ones((in_channels, num_coeffs), num_coeffs)
301
+
302
+ case 's5':
303
+ log_dt = einops.repeat(log_dt, 'j -> (j n)', n=num_coeffs)
304
+ A = einops.repeat(A, 'n i -> (j n) i', j=self.repeat)
305
+ self.B = ones((tot_coeffs, in_channels), in_channels)
306
+ self.C = normal((out_channels, tot_coeffs), tot_coeffs)
307
+ self.E = None
308
+
309
+ case 'neck':
310
+ # Neck mode uses fewer repeated log_dt parameters
311
+ A = einops.repeat(A, 'n i -> r n i', r=self.repeat)
312
+ self.B = ones((self.repeat, in_channels), in_channels)
313
+ self.C = normal((out_channels, self.repeat), tot_coeffs)
314
+ self.E = normal((self.repeat, num_coeffs), 1)
315
+
316
+ case 'full':
317
+ log_dt = einops.repeat(log_dt, 'd -> d n', n=num_coeffs)
318
+ A = einops.repeat(A, 'n i -> d c n i', c=in_channels, d=out_channels)
319
+ self.B = None
320
+ self.C = None
321
+ self.E = ones((out_channels, in_channels, num_coeffs), in_channels)
322
+
323
+ case 'gate':
324
+ log_dt = einops.repeat(log_dt, 'j -> (j n)', n=num_coeffs)
325
+
326
+ # Timestep projection: learns input-dependent timesteps
327
+ self.log_dt = nn.Sequential(
328
+ nn.Linear(in_channels, self.repeat, bias=False),
329
+ nn.Linear(self.repeat, tot_coeffs, bias=True),
330
+ )
331
+ nn.init.zeros_(self.log_dt[-1].weight)
332
+ self.log_dt[-1].bias = to_parameter(log_dt)
333
+
334
+ # State decay parameters
335
+ A = einops.repeat(A, 'n -> (j n)', j=self.repeat)
336
+
337
+ # Input and output projections
338
+ self.B = nn.Sequential(
339
+ nn.Linear(in_channels, self.repeat, bias=False),
340
+ nn.Linear(self.repeat, tot_coeffs, bias=False),
341
+ )
342
+ self.C = nn.Linear(tot_coeffs, out_channels, bias=False)
343
+
344
+ # Register parameters
345
+ self.A = to_parameter(A)
346
+
347
+ if self.mode not in ['gate']:
348
+ self.log_dt = to_parameter(log_dt)
349
+
350
+ # Mark certain parameters as "sensitive" for optimizer
351
+ # (suggests using smaller learning rates for these)
352
+ match self.mode:
353
+ case 'dws' | 'full' | 'neck':
354
+ self._register_sensitives(self.log_dt, self.A)
355
+ case 'gate':
356
+ self._register_sensitives(self.A)
357
+
358
+ # Mixer layer: final projection and activations
359
+ if self.mode in ['dws']:
360
+ # DWS mode has explicit channel mixing
361
+ self.mixer = nn.Sequential(
362
+ self._make_activation_block(in_channels),
363
+ nn.Conv1d(in_channels, out_channels, 1, bias=False),
364
+ self._make_activation_block(out_channels) if use_activations else nn.Identity(),
365
+ )
366
+ else:
367
+ self.mixer = (
368
+ self._make_activation_block(out_channels) if use_activations else nn.Identity()
369
+ )
370
+
371
+ @staticmethod
372
+ def _register_sensitives(*args):
373
+ """Mark parameters as sensitive (for optimizer to use smaller learning rates).
374
+
375
+ Args:
376
+ *args: Parameters or modules to mark as sensitive
377
+ """
378
+ for arg in args:
379
+ if isinstance(arg, nn.Module):
380
+ for param in arg.parameters():
381
+ param.sensitive = True
382
+ continue
383
+ arg.sensitive = True
384
+
385
+ def get_param_groups(self, lr=1e-3, sensitive_lr_factor=0.1):
386
+ """Get optimizer parameter groups with separate learning rates.
387
+
388
+ Sensitive parameters (A matrix, log_dt) benefit from smaller learning
389
+ rates. This method returns ready-made param groups for the optimizer.
390
+
391
+ Args:
392
+ lr: Base learning rate for regular parameters
393
+ sensitive_lr_factor: Multiplier for sensitive parameter learning rate
394
+ (default: 0.1, i.e. 10x smaller than base lr)
395
+
396
+ Returns:
397
+ List of dicts suitable for torch.optim optimizers
398
+
399
+ Example:
400
+ >>> layer = SSMLayer(64, 128, 256, mode='s5')
401
+ >>> optimizer = torch.optim.AdamW(layer.get_param_groups(lr=1e-3))
402
+ """
403
+ regular, sensitive = [], []
404
+ for param in self.parameters():
405
+ if getattr(param, 'sensitive', False):
406
+ sensitive.append(param)
407
+ else:
408
+ regular.append(param)
409
+ groups = [{'params': regular, 'lr': lr}]
410
+ if sensitive:
411
+ groups.append({'params': sensitive, 'lr': lr * sensitive_lr_factor})
412
+ return groups
413
+
414
+ def _make_activation_block(self, num_features):
415
+ """Create normalization + activation + dropout block.
416
+
417
+ Args:
418
+ num_features: Number of features for norm/dropout
419
+
420
+ Returns:
421
+ Sequential module with norm, activation, dropout
422
+ """
423
+ return get_activations(
424
+ 1, num_features, self.norm, self.postact, self.dropout, self.dropout_dim
425
+ )
426
+
427
+ def forward(self, input):
428
+ """Forward pass through SSM layer.
429
+
430
+ Args:
431
+ input: Input tensor of shape (batch, in_channels, length)
432
+
433
+ Returns:
434
+ Output tensor of shape (batch, out_channels, length)
435
+ """
436
+ output = super().forward(input, self.A, self.B, self.C, self.log_dt, E=self.E)
437
+
438
+ if self.bias is not None:
439
+ output = output + self.bias
440
+
441
+ return self.mixer(output)
442
+
443
+ def to_inference(self):
444
+ """Convert to streaming inference mode.
445
+
446
+ Returns SSMLayerInference instance for low-latency streaming processing.
447
+ The inference layer maintains state across chunks for applications.
448
+
449
+ Returns:
450
+ SSMLayerInference: Inference layer with copied weights
451
+
452
+ Example:
453
+ >>> # After training
454
+ >>> train_layer = SSMLayer(64, 128, 256, mode='s5')
455
+ >>> # ... training ...
456
+ >>>
457
+ >>> # Convert for streaming
458
+ >>> infer_layer = train_layer.to_inference()
459
+ >>>
460
+ >>> # Process audio stream
461
+ >>> for chunk in audio_stream:
462
+ >>> output = infer_layer(chunk)
463
+ >>>
464
+ >>> # Reset between utterances
465
+ >>> infer_layer.reset_state()
466
+
467
+ Note:
468
+ The inference layer uses sequential scan which is slower than
469
+ FFT for full sequences but has lower latency for streaming.
470
+ """
471
+ from .inference import SSMLayerInference
472
+
473
+ return SSMLayerInference.from_training(self)
474
+
475
+ def __repr__(self):
476
+ """String representation showing parameters."""
477
+ param_info = []
478
+ for name, param in self.named_parameters():
479
+ if param.requires_grad:
480
+ param_info.append(f'{name}: {list(param.shape)}')
481
+ return f'{self.__class__.__name__}(\n ' + '\n '.join(param_info) + '\n)'
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer_config.json ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_prefix_space": null,
3
+ "backend": "tokenizers",
4
+ "bos_token": "<s>",
5
+ "clean_up_tokenization_spaces": false,
6
+ "eos_token": "</s>",
7
+ "extra_special_tokens": [],
8
+ "is_local": false,
9
+ "legacy": false,
10
+ "model_max_length": 1000000000000000019884624838656,
11
+ "pad_token": "</s>",
12
+ "sp_model_kwargs": {},
13
+ "spaces_between_special_tokens": false,
14
+ "tokenizer_class": "TokenizersBackend",
15
+ "unk_token": "<unk>",
16
+ "use_default_system_prompt": false
17
+ }