Text Generation
Transformers
Safetensors
English
tenns_llm
ssm
causal-lm
custom-architecture
recurrent
custom_code
Instructions to use BrainChip-AI/tenns-llm-1b with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use BrainChip-AI/tenns-llm-1b with Transformers:
# Use a pipeline as a high-level helper from transformers import pipeline pipe = pipeline("text-generation", model="BrainChip-AI/tenns-llm-1b", trust_remote_code=True)# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("BrainChip-AI/tenns-llm-1b", trust_remote_code=True, dtype="auto") - Notebooks
- Google Colab
- Kaggle
- Local Apps
- vLLM
How to use BrainChip-AI/tenns-llm-1b with vLLM:
Install from pip and serve model
# Install vLLM from pip: pip install vllm # Start the vLLM server: vllm serve "BrainChip-AI/tenns-llm-1b" # Call the server using curl (OpenAI-compatible API): curl -X POST "http://localhost:8000/v1/completions" \ -H "Content-Type: application/json" \ --data '{ "model": "BrainChip-AI/tenns-llm-1b", "prompt": "Once upon a time,", "max_tokens": 512, "temperature": 0.5 }'Use Docker
docker model run hf.co/BrainChip-AI/tenns-llm-1b
- SGLang
How to use BrainChip-AI/tenns-llm-1b with SGLang:
Install from pip and serve model
# Install SGLang from pip: pip install sglang # Start the SGLang server: python3 -m sglang.launch_server \ --model-path "BrainChip-AI/tenns-llm-1b" \ --host 0.0.0.0 \ --port 30000 # Call the server using curl (OpenAI-compatible API): curl -X POST "http://localhost:30000/v1/completions" \ -H "Content-Type: application/json" \ --data '{ "model": "BrainChip-AI/tenns-llm-1b", "prompt": "Once upon a time,", "max_tokens": 512, "temperature": 0.5 }'Use Docker images
docker run --gpus all \ --shm-size 32g \ -p 30000:30000 \ -v ~/.cache/huggingface:/root/.cache/huggingface \ --env "HF_TOKEN=<secret>" \ --ipc=host \ lmsysorg/sglang:latest \ python3 -m sglang.launch_server \ --model-path "BrainChip-AI/tenns-llm-1b" \ --host 0.0.0.0 \ --port 30000 # Call the server using curl (OpenAI-compatible API): curl -X POST "http://localhost:30000/v1/completions" \ -H "Content-Type: application/json" \ --data '{ "model": "BrainChip-AI/tenns-llm-1b", "prompt": "Once upon a time,", "max_tokens": 512, "temperature": 0.5 }' - Docker Model Runner
How to use BrainChip-AI/tenns-llm-1b with Docker Model Runner:
docker model run hf.co/BrainChip-AI/tenns-llm-1b
Commit ·
c8c055f
0
Parent(s):
Duplicate from NickMarkovsky/tenns-llm-1b
Browse filesCo-authored-by: Nick Markovsky <NickMarkovsky@users.noreply.huggingface.co>
- .gitattributes +35 -0
- README.md +101 -0
- config.json +13 -0
- configuration_tenns_llm.py +30 -0
- model.safetensors +3 -0
- modeling_tenns_llm.py +289 -0
- tenns_core/__init__.py +50 -0
- tenns_core/__pycache__/__init__.cpython-310.pyc +0 -0
- tenns_core/__pycache__/__init__.cpython-312.pyc +0 -0
- tenns_core/__pycache__/activations.cpython-310.pyc +0 -0
- tenns_core/__pycache__/activations.cpython-312.pyc +0 -0
- tenns_core/__pycache__/fft_ops.cpython-310.pyc +0 -0
- tenns_core/__pycache__/fft_ops.cpython-312.pyc +0 -0
- tenns_core/__pycache__/inference.cpython-310.pyc +0 -0
- tenns_core/__pycache__/inference.cpython-312.pyc +0 -0
- tenns_core/__pycache__/recurrent_ops.cpython-310.pyc +0 -0
- tenns_core/__pycache__/recurrent_ops.cpython-312.pyc +0 -0
- tenns_core/__pycache__/scan_ops.cpython-310.pyc +0 -0
- tenns_core/__pycache__/scan_ops.cpython-312.pyc +0 -0
- tenns_core/__pycache__/ssm.cpython-310.pyc +0 -0
- tenns_core/__pycache__/ssm.cpython-312.pyc +0 -0
- tenns_core/activations.py +158 -0
- tenns_core/fft_ops.py +174 -0
- tenns_core/inference.py +540 -0
- tenns_core/recurrent_ops.py +437 -0
- tenns_core/scan_ops.py +515 -0
- tenns_core/ssm.py +481 -0
- tokenizer.json +0 -0
- tokenizer_config.json +17 -0
.gitattributes
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
| 2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
| 3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
| 5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
| 6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
| 7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
| 8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
| 9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
| 10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
| 11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
| 12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
| 13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
| 14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
| 15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
| 16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
| 17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
| 18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
| 19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
| 20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
| 21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
| 22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
| 25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
| 27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
| 28 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
| 29 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
| 30 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
| 31 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
| 32 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
| 33 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
README.md
ADDED
|
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
license: cc-by-nc-4.0
|
| 3 |
+
language:
|
| 4 |
+
- en
|
| 5 |
+
library_name: transformers
|
| 6 |
+
tags:
|
| 7 |
+
- ssm
|
| 8 |
+
- causal-lm
|
| 9 |
+
- custom-architecture
|
| 10 |
+
- recurrent
|
| 11 |
+
pipeline_tag: text-generation
|
| 12 |
+
---
|
| 13 |
+
|
| 14 |
+
# TENNs LLM 1B
|
| 15 |
+
|
| 16 |
+
A 1-billion-parameter causal language model built on gate-mode SSM (State Space Model) layers from [TENNs Core](https://huggingface.co/BrainChipInc/tenns-llm-1b/tree/main/tenns_core). Uses recurrent inference instead of attention, making it efficient for streaming and long-context generation.
|
| 17 |
+
|
| 18 |
+
## Architecture
|
| 19 |
+
|
| 20 |
+
| Component | Details |
|
| 21 |
+
|-----------|---------|
|
| 22 |
+
| Layers | 24 × TENNsBlock (gate mode) |
|
| 23 |
+
| Hidden dim | 2048 |
|
| 24 |
+
| Inner dim | 4096 |
|
| 25 |
+
| Vocabulary | 32,000 (Mistral-7B tokenizer) |
|
| 26 |
+
| Parameters | ~1B |
|
| 27 |
+
|
| 28 |
+
Each TENNsBlock: `RMSNorm → in_proj → causal_conv(4) → SSM(gate) → out_proj → residual`
|
| 29 |
+
|
| 30 |
+
## Quick Start (Google Colab / any environment)
|
| 31 |
+
|
| 32 |
+
```python
|
| 33 |
+
!pip install transformers torch einops opt_einsum safetensors
|
| 34 |
+
|
| 35 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 36 |
+
|
| 37 |
+
tokenizer = AutoTokenizer.from_pretrained("BrainChipInc/tenns-llm-1b")
|
| 38 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 39 |
+
"BrainChipInc/tenns-llm-1b",
|
| 40 |
+
trust_remote_code=True,
|
| 41 |
+
)
|
| 42 |
+
|
| 43 |
+
output = model.generate_text("The history of artificial intelligence", tokenizer, max_new_tokens=100)
|
| 44 |
+
print(output)
|
| 45 |
+
```
|
| 46 |
+
|
| 47 |
+
> **Do not use `pipeline()`** — this model uses a custom recurrent architecture that is not
|
| 48 |
+
> compatible with HuggingFace's standard text-generation pipeline.
|
| 49 |
+
|
| 50 |
+
## Installation
|
| 51 |
+
|
| 52 |
+
```bash
|
| 53 |
+
pip install transformers torch einops opt_einsum safetensors
|
| 54 |
+
```
|
| 55 |
+
|
| 56 |
+
## Usage
|
| 57 |
+
|
| 58 |
+
> **Note:** Do **not** use `pipeline()` — this model requires `model.generate_text()` instead of
|
| 59 |
+
> HuggingFace's standard `generate()`. The recurrent SSM architecture is not compatible with the
|
| 60 |
+
> attention KV-cache pipeline.
|
| 61 |
+
|
| 62 |
+
```python
|
| 63 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 64 |
+
|
| 65 |
+
tokenizer = AutoTokenizer.from_pretrained("BrainChipInc/tenns-llm-1b")
|
| 66 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 67 |
+
"BrainChipInc/tenns-llm-1b",
|
| 68 |
+
trust_remote_code=True,
|
| 69 |
+
)
|
| 70 |
+
|
| 71 |
+
output = model.generate_text("The history of artificial intelligence", tokenizer, max_new_tokens=100)
|
| 72 |
+
print(output)
|
| 73 |
+
```
|
| 74 |
+
|
| 75 |
+
### Generation options
|
| 76 |
+
|
| 77 |
+
```python
|
| 78 |
+
# Greedy decoding (default)
|
| 79 |
+
output = model.generate_text(prompt, tokenizer, max_new_tokens=50)
|
| 80 |
+
|
| 81 |
+
# Top-k sampling with temperature
|
| 82 |
+
output = model.generate_text(prompt, tokenizer, max_new_tokens=100, temperature=0.8, top_k=50)
|
| 83 |
+
```
|
| 84 |
+
|
| 85 |
+
## `trust_remote_code=True`
|
| 86 |
+
|
| 87 |
+
This model uses custom modeling code bundled in the repository
|
| 88 |
+
(`modeling_tenns_llm.py`, `configuration_tenns_llm.py`, `tenns_core/`).
|
| 89 |
+
Loading requires `trust_remote_code=True`. The bundled `tenns_core/` package
|
| 90 |
+
is a snapshot of the TENNs Core SSM library — no separate installation needed.
|
| 91 |
+
|
| 92 |
+
## Training
|
| 93 |
+
|
| 94 |
+
Fine-tuned from a base TENNs gate-mode model using LoRA adapters on English instruction data.
|
| 95 |
+
LoRA adapters are merged into base weights at export time.
|
| 96 |
+
|
| 97 |
+
## Limitations
|
| 98 |
+
|
| 99 |
+
- English only
|
| 100 |
+
- No system prompt or chat template — plain completion model
|
| 101 |
+
- Recurrent state resets between calls to `generate_text()`
|
config.json
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"model_type": "tenns_llm",
|
| 3 |
+
"auto_map": {
|
| 4 |
+
"AutoConfig": "configuration_tenns_llm.TennsLLMConfig",
|
| 5 |
+
"AutoModelForCausalLM": "modeling_tenns_llm.TennsLLMForCausalLM"
|
| 6 |
+
},
|
| 7 |
+
"vocab_size": 32000,
|
| 8 |
+
"channels": 2048,
|
| 9 |
+
"num_blocks": 24,
|
| 10 |
+
"num_coeffs": 16,
|
| 11 |
+
"repeat": 256,
|
| 12 |
+
"transformers_version": "4.40.0"
|
| 13 |
+
}
|
configuration_tenns_llm.py
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import sys
|
| 3 |
+
|
| 4 |
+
from transformers import PretrainedConfig
|
| 5 |
+
|
| 6 |
+
# Inject the repo directory into sys.path so the bundled tenns_core/ is
|
| 7 |
+
# importable without a pip install, both locally and when loaded from HF hub.
|
| 8 |
+
_HERE = os.path.dirname(os.path.abspath(__file__))
|
| 9 |
+
if _HERE not in sys.path:
|
| 10 |
+
sys.path.insert(0, _HERE)
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class TennsLLMConfig(PretrainedConfig):
|
| 14 |
+
model_type = "tenns_llm"
|
| 15 |
+
|
| 16 |
+
def __init__(
|
| 17 |
+
self,
|
| 18 |
+
vocab_size=32000,
|
| 19 |
+
channels=2048,
|
| 20 |
+
num_blocks=24,
|
| 21 |
+
num_coeffs=16,
|
| 22 |
+
repeat=256,
|
| 23 |
+
**kwargs,
|
| 24 |
+
):
|
| 25 |
+
super().__init__(**kwargs)
|
| 26 |
+
self.vocab_size = vocab_size
|
| 27 |
+
self.channels = channels
|
| 28 |
+
self.num_blocks = num_blocks
|
| 29 |
+
self.num_coeffs = num_coeffs
|
| 30 |
+
self.repeat = repeat
|
model.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:695805667bf74d3bb24b8fc0c676e75c26c21191ebe91326429d4f61e43740ff
|
| 3 |
+
size 4957835584
|
modeling_tenns_llm.py
ADDED
|
@@ -0,0 +1,289 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import importlib
|
| 2 |
+
import os
|
| 3 |
+
import sys
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
import torch.nn.functional as F
|
| 8 |
+
from torch.nn import RMSNorm
|
| 9 |
+
|
| 10 |
+
from transformers import PreTrainedModel
|
| 11 |
+
from transformers.modeling_outputs import CausalLMOutputWithPast
|
| 12 |
+
|
| 13 |
+
from configuration_tenns_llm import TennsLLMConfig
|
| 14 |
+
|
| 15 |
+
def _get_tenns_core_path():
|
| 16 |
+
"""Return a directory that contains tenns_core/.
|
| 17 |
+
|
| 18 |
+
HF's from_pretrained only downloads the .py files listed in auto_map —
|
| 19 |
+
it does not download subdirectories like tenns_core/. We use
|
| 20 |
+
snapshot_download (with local cache) to ensure tenns_core/ is present.
|
| 21 |
+
The first call downloads it; subsequent calls are instant cache hits.
|
| 22 |
+
"""
|
| 23 |
+
# Derive the repo_id from __file__ path in the HF modules cache:
|
| 24 |
+
# .../modules/transformers_modules/ORG/REPO_SLUG/HASH/modeling_tenns_llm.py
|
| 25 |
+
here = os.path.dirname(os.path.abspath(__file__))
|
| 26 |
+
parts = here.replace("\\", "/").split("/")
|
| 27 |
+
try:
|
| 28 |
+
idx = next(i for i, p in enumerate(parts) if p == "transformers_modules")
|
| 29 |
+
org_id = parts[idx + 1].replace("_hyphen_", "-")
|
| 30 |
+
repo_id = parts[idx + 2].replace("_hyphen_", "-")
|
| 31 |
+
except (StopIteration, IndexError):
|
| 32 |
+
return here # not in HF cache — assume tenns_core/ is next to this file
|
| 33 |
+
|
| 34 |
+
from huggingface_hub import snapshot_download
|
| 35 |
+
snapshot = snapshot_download(
|
| 36 |
+
f"{org_id}/{repo_id}",
|
| 37 |
+
allow_patterns=["tenns_core/**"],
|
| 38 |
+
)
|
| 39 |
+
return snapshot
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
_tenns_core_dir = _get_tenns_core_path()
|
| 43 |
+
if _tenns_core_dir not in sys.path:
|
| 44 |
+
sys.path.insert(0, _tenns_core_dir)
|
| 45 |
+
|
| 46 |
+
_tc = importlib.import_module("tenns_core")
|
| 47 |
+
_rc = importlib.import_module("tenns_core.recurrent_ops")
|
| 48 |
+
SSMLayer = _tc.SSMLayer
|
| 49 |
+
recurrent_gate = _rc.recurrent_gate
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
# ============================================================================
|
| 53 |
+
# Model Components (from tenns_llm.py)
|
| 54 |
+
# ============================================================================
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
class CausalConvDwFast(nn.Module):
|
| 58 |
+
"""Holds depthwise causal convolution weights for TENNs blocks."""
|
| 59 |
+
def __init__(self, coeffs, kernel_size):
|
| 60 |
+
super().__init__()
|
| 61 |
+
self.weight = nn.Parameter(torch.rand(kernel_size, coeffs))
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
class PassthroughConv(nn.Module):
|
| 65 |
+
"""Applies causal convolution via FIFO buffer for streaming inference."""
|
| 66 |
+
def __init__(self, causal_conv, d_inner):
|
| 67 |
+
super().__init__()
|
| 68 |
+
self.causal_conv = causal_conv
|
| 69 |
+
self.d_inner = d_inner
|
| 70 |
+
self.fifo = None
|
| 71 |
+
|
| 72 |
+
def apply_conv(self, x):
|
| 73 |
+
"""Apply causal convolution. x: (B, T, C) -> (B, T, C)"""
|
| 74 |
+
B, T, C = x.shape
|
| 75 |
+
|
| 76 |
+
if self.fifo is None or self.fifo.shape[0] != B:
|
| 77 |
+
self.fifo = torch.zeros(B, C, 4, device=x.device, dtype=x.dtype)
|
| 78 |
+
|
| 79 |
+
conv_weight = self.causal_conv.weight.squeeze().T # (C, 4)
|
| 80 |
+
|
| 81 |
+
x_conv = []
|
| 82 |
+
for t in range(T):
|
| 83 |
+
self.fifo = self.fifo.roll(-1, dims=-1)
|
| 84 |
+
self.fifo[:, :, -1] = x[:, t, :]
|
| 85 |
+
x_t = (self.fifo * conv_weight).sum(-1)
|
| 86 |
+
x_conv.append(x_t)
|
| 87 |
+
|
| 88 |
+
x_conv = torch.stack(x_conv, dim=1)
|
| 89 |
+
x_conv = F.silu(x_conv)
|
| 90 |
+
return x_conv
|
| 91 |
+
|
| 92 |
+
def reset_states(self):
|
| 93 |
+
if self.fifo is not None:
|
| 94 |
+
self.fifo.zero_()
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
class TENNsBlock(nn.Module):
|
| 98 |
+
"""TENNs block with gate-mode SSM for LLM inference."""
|
| 99 |
+
def __init__(self, channels, num_coeffs, repeat, mode='gate'):
|
| 100 |
+
super().__init__()
|
| 101 |
+
d_inner = channels * 2
|
| 102 |
+
self.d_inner = d_inner
|
| 103 |
+
|
| 104 |
+
self.pre_norm = RMSNorm(channels, elementwise_affine=True)
|
| 105 |
+
self.pre_conv = CausalConvDwFast(d_inner, 4)
|
| 106 |
+
self.in_proj = nn.Linear(channels, d_inner * 2, bias=True)
|
| 107 |
+
self.out_proj = nn.Linear(d_inner, channels, bias=True)
|
| 108 |
+
|
| 109 |
+
self.ssm_layer = SSMLayer(num_coeffs, d_inner, d_inner,
|
| 110 |
+
repeat=repeat, mode=mode, transposed=True)
|
| 111 |
+
|
| 112 |
+
self.ssm_layer.register_buffer('state_lora', torch.zeros(d_inner))
|
| 113 |
+
|
| 114 |
+
self.D = nn.Parameter(torch.ones(d_inner, dtype=torch.float))
|
| 115 |
+
|
| 116 |
+
self._conv_handler = None
|
| 117 |
+
self.state = None
|
| 118 |
+
|
| 119 |
+
def forward(self, input):
|
| 120 |
+
x = self.pre_norm(input)
|
| 121 |
+
x_and_res = self.in_proj(x)
|
| 122 |
+
x, res = x_and_res.split([self.d_inner, self.d_inner], -1)
|
| 123 |
+
|
| 124 |
+
if self._conv_handler is None:
|
| 125 |
+
self._conv_handler = PassthroughConv(self.pre_conv, self.d_inner)
|
| 126 |
+
|
| 127 |
+
x_conv = self._conv_handler.apply_conv(x)
|
| 128 |
+
|
| 129 |
+
state = self.state
|
| 130 |
+
if state is None:
|
| 131 |
+
state = self.ssm_layer.state_lora
|
| 132 |
+
|
| 133 |
+
y, self.state = recurrent_gate(
|
| 134 |
+
x_conv,
|
| 135 |
+
self.ssm_layer.A,
|
| 136 |
+
self.ssm_layer.B,
|
| 137 |
+
self.ssm_layer.C,
|
| 138 |
+
self.ssm_layer.log_dt,
|
| 139 |
+
state
|
| 140 |
+
)
|
| 141 |
+
|
| 142 |
+
y = y.transpose(1, 2)
|
| 143 |
+
y = y + self.D * x_conv
|
| 144 |
+
output = self.out_proj(y * F.silu(res))
|
| 145 |
+
|
| 146 |
+
return input + output
|
| 147 |
+
|
| 148 |
+
def reset_states(self):
|
| 149 |
+
if self._conv_handler is not None:
|
| 150 |
+
self._conv_handler.reset_states()
|
| 151 |
+
self.state = None
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
class TENNsLLM(nn.Module):
|
| 155 |
+
"""TENNs-based language model for autoregressive text generation."""
|
| 156 |
+
def __init__(self, vocab_size=32000, channels=2048, num_blocks=24,
|
| 157 |
+
num_coeffs=16, repeat=256):
|
| 158 |
+
super().__init__()
|
| 159 |
+
self.channels = channels
|
| 160 |
+
self.embedding = nn.Embedding(vocab_size, channels)
|
| 161 |
+
self.backbone = nn.Sequential(
|
| 162 |
+
*[TENNsBlock(channels, num_coeffs, repeat, mode='gate')
|
| 163 |
+
for _ in range(num_blocks)]
|
| 164 |
+
)
|
| 165 |
+
self.head = nn.Sequential(
|
| 166 |
+
RMSNorm(channels, elementwise_affine=False),
|
| 167 |
+
nn.Linear(channels, vocab_size, bias=False),
|
| 168 |
+
)
|
| 169 |
+
|
| 170 |
+
def forward(self, tokens):
|
| 171 |
+
x = self.embedding(tokens)
|
| 172 |
+
x = self.backbone(x)
|
| 173 |
+
return self.head(x)
|
| 174 |
+
|
| 175 |
+
def reset_states(self):
|
| 176 |
+
for module in self.modules():
|
| 177 |
+
if isinstance(module, TENNsBlock):
|
| 178 |
+
module.reset_states()
|
| 179 |
+
|
| 180 |
+
|
| 181 |
+
# ============================================================================
|
| 182 |
+
# HuggingFace wrapper
|
| 183 |
+
# ============================================================================
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
class TennsLLMForCausalLM(PreTrainedModel):
|
| 187 |
+
"""HuggingFace PreTrainedModel wrapper for TENNsLLM.
|
| 188 |
+
|
| 189 |
+
Load with:
|
| 190 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 191 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 192 |
+
"aliborji/tenns-llm-1b", trust_remote_code=True
|
| 193 |
+
)
|
| 194 |
+
tokenizer = AutoTokenizer.from_pretrained("aliborji/tenns-llm-1b")
|
| 195 |
+
|
| 196 |
+
Generate with:
|
| 197 |
+
output = model.generate_text("Hello, world!", tokenizer, max_new_tokens=50)
|
| 198 |
+
print(output)
|
| 199 |
+
|
| 200 |
+
Note: This model uses recurrent SSM states. Use generate_text() rather than
|
| 201 |
+
model.generate(), which is designed for attention-based KV-cache models.
|
| 202 |
+
"""
|
| 203 |
+
config_class = TennsLLMConfig
|
| 204 |
+
# Weights are saved without a 'model.' prefix — flatten components directly
|
| 205 |
+
# onto this class so state dict keys match the safetensors file exactly.
|
| 206 |
+
_tied_weights_keys = []
|
| 207 |
+
|
| 208 |
+
@property
|
| 209 |
+
def all_tied_weights_keys(self):
|
| 210 |
+
return {}
|
| 211 |
+
|
| 212 |
+
def __init__(self, config: TennsLLMConfig):
|
| 213 |
+
super().__init__(config)
|
| 214 |
+
# Assign TENNsLLM components directly (not as self.model) so that
|
| 215 |
+
# state dict keys match the safetensors: embedding.weight, backbone.0...
|
| 216 |
+
_backbone = TENNsLLM(
|
| 217 |
+
vocab_size=config.vocab_size,
|
| 218 |
+
channels=config.channels,
|
| 219 |
+
num_blocks=config.num_blocks,
|
| 220 |
+
num_coeffs=config.num_coeffs,
|
| 221 |
+
repeat=config.repeat,
|
| 222 |
+
)
|
| 223 |
+
self.embedding = _backbone.embedding
|
| 224 |
+
self.backbone = _backbone.backbone
|
| 225 |
+
self.head = _backbone.head
|
| 226 |
+
|
| 227 |
+
def _reset_states(self):
|
| 228 |
+
for module in self.modules():
|
| 229 |
+
if isinstance(module, TENNsBlock):
|
| 230 |
+
module.reset_states()
|
| 231 |
+
|
| 232 |
+
def forward(self, input_ids, **kwargs):
|
| 233 |
+
x = self.embedding(input_ids)
|
| 234 |
+
x = self.backbone(x)
|
| 235 |
+
logits = self.head(x)
|
| 236 |
+
return CausalLMOutputWithPast(logits=logits)
|
| 237 |
+
|
| 238 |
+
@torch.no_grad()
|
| 239 |
+
def generate_text(self, prompt, tokenizer, max_new_tokens=50,
|
| 240 |
+
temperature=1.0, top_k=None):
|
| 241 |
+
"""Autoregressive text generation.
|
| 242 |
+
|
| 243 |
+
Args:
|
| 244 |
+
prompt: Input text string
|
| 245 |
+
tokenizer: HuggingFace tokenizer
|
| 246 |
+
max_new_tokens: Maximum number of tokens to generate
|
| 247 |
+
temperature: Sampling temperature (lower = more deterministic)
|
| 248 |
+
top_k: If set, sample from top-k tokens; otherwise greedy argmax
|
| 249 |
+
|
| 250 |
+
Returns:
|
| 251 |
+
Generated text string (not including the prompt)
|
| 252 |
+
"""
|
| 253 |
+
self.eval()
|
| 254 |
+
self._reset_states()
|
| 255 |
+
|
| 256 |
+
input_ids = tokenizer(prompt, return_tensors='pt',
|
| 257 |
+
add_special_tokens=False)['input_ids'].squeeze()
|
| 258 |
+
input_ids = input_ids.to(self.device)
|
| 259 |
+
|
| 260 |
+
# Ingest prompt tokens
|
| 261 |
+
for token in input_ids:
|
| 262 |
+
logits = self.forward(token.view(1, 1)).logits
|
| 263 |
+
probs = F.softmax(logits[0, -1], dim=-1)
|
| 264 |
+
next_token = torch.argmax(probs).item()
|
| 265 |
+
|
| 266 |
+
# Autoregressive generation
|
| 267 |
+
output_ids = []
|
| 268 |
+
token = next_token
|
| 269 |
+
for _ in range(max_new_tokens):
|
| 270 |
+
logits = self.forward(torch.tensor([[token]], device=self.device)).logits
|
| 271 |
+
next_logits = logits[0, -1]
|
| 272 |
+
|
| 273 |
+
if temperature != 1.0:
|
| 274 |
+
next_logits = next_logits / temperature
|
| 275 |
+
|
| 276 |
+
if top_k is not None:
|
| 277 |
+
v, _ = torch.topk(next_logits, top_k)
|
| 278 |
+
next_logits[next_logits < v[-1]] = float('-inf')
|
| 279 |
+
|
| 280 |
+
probs = F.softmax(next_logits, dim=-1)
|
| 281 |
+
token = (torch.multinomial(probs, 1).item() if top_k is not None
|
| 282 |
+
else torch.argmax(probs).item())
|
| 283 |
+
|
| 284 |
+
if token == tokenizer.eos_token_id:
|
| 285 |
+
break
|
| 286 |
+
|
| 287 |
+
output_ids.append(token)
|
| 288 |
+
|
| 289 |
+
return tokenizer.decode(output_ids)
|
tenns_core/__init__.py
ADDED
|
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
TENNs Core: Efficient State Space Models for Sequence Modeling
|
| 3 |
+
|
| 4 |
+
A standalone library providing various SSM (State Space Model) architectures
|
| 5 |
+
for deep learning on sequences. Includes S5, DWS, Neck, Full, and Gate modes
|
| 6 |
+
all implemented in pure PyTorch.
|
| 7 |
+
|
| 8 |
+
Quick Start - Training:
|
| 9 |
+
----------------------
|
| 10 |
+
>>> from tenns_core import SSMLayer
|
| 11 |
+
>>> import torch
|
| 12 |
+
>>>
|
| 13 |
+
>>> # Create S5-mode SSM layer
|
| 14 |
+
>>> layer = SSMLayer(
|
| 15 |
+
... num_coeffs=64,
|
| 16 |
+
... in_channels=128,
|
| 17 |
+
... out_channels=256,
|
| 18 |
+
... mode='s5',
|
| 19 |
+
... norm='layer',
|
| 20 |
+
... postact='gelu'
|
| 21 |
+
... )
|
| 22 |
+
>>>
|
| 23 |
+
>>> # Forward pass (training mode - FFT convolution)
|
| 24 |
+
>>> x = torch.randn(4, 128, 512) # (batch, channels, length)
|
| 25 |
+
>>> y = layer(x) # (4, 256, 512)
|
| 26 |
+
|
| 27 |
+
Quick Start - Streaming Inference:
|
| 28 |
+
----------------------------------
|
| 29 |
+
>>> # Convert trained model to streaming inference
|
| 30 |
+
>>> infer_layer = layer.to_inference()
|
| 31 |
+
>>>
|
| 32 |
+
>>> # Process audio stream chunk-by-chunk
|
| 33 |
+
>>> for chunk in audio_stream:
|
| 34 |
+
>>> output = infer_layer(chunk) # State maintained automatically
|
| 35 |
+
>>>
|
| 36 |
+
>>> # Reset state between utterances
|
| 37 |
+
>>> infer_layer.reset_state()
|
| 38 |
+
"""
|
| 39 |
+
|
| 40 |
+
from importlib.metadata import PackageNotFoundError, version
|
| 41 |
+
|
| 42 |
+
from .inference import SSMLayerInference
|
| 43 |
+
from .ssm import Kernelizer, SSMLayer
|
| 44 |
+
|
| 45 |
+
try:
|
| 46 |
+
__version__ = version('tenns-core')
|
| 47 |
+
except PackageNotFoundError:
|
| 48 |
+
__version__ = '0.0.0+unknown'
|
| 49 |
+
|
| 50 |
+
__all__ = ['Kernelizer', 'SSMLayer', 'SSMLayerInference']
|
tenns_core/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (1.42 kB). View file
|
|
|
tenns_core/__pycache__/__init__.cpython-312.pyc
ADDED
|
Binary file (1.65 kB). View file
|
|
|
tenns_core/__pycache__/activations.cpython-310.pyc
ADDED
|
Binary file (4.51 kB). View file
|
|
|
tenns_core/__pycache__/activations.cpython-312.pyc
ADDED
|
Binary file (6.34 kB). View file
|
|
|
tenns_core/__pycache__/fft_ops.cpython-310.pyc
ADDED
|
Binary file (5.04 kB). View file
|
|
|
tenns_core/__pycache__/fft_ops.cpython-312.pyc
ADDED
|
Binary file (8.45 kB). View file
|
|
|
tenns_core/__pycache__/inference.cpython-310.pyc
ADDED
|
Binary file (14.7 kB). View file
|
|
|
tenns_core/__pycache__/inference.cpython-312.pyc
ADDED
|
Binary file (24.4 kB). View file
|
|
|
tenns_core/__pycache__/recurrent_ops.cpython-310.pyc
ADDED
|
Binary file (16 kB). View file
|
|
|
tenns_core/__pycache__/recurrent_ops.cpython-312.pyc
ADDED
|
Binary file (14.8 kB). View file
|
|
|
tenns_core/__pycache__/scan_ops.cpython-310.pyc
ADDED
|
Binary file (4.68 kB). View file
|
|
|
tenns_core/__pycache__/scan_ops.cpython-312.pyc
ADDED
|
Binary file (20.3 kB). View file
|
|
|
tenns_core/__pycache__/ssm.cpython-310.pyc
ADDED
|
Binary file (12.6 kB). View file
|
|
|
tenns_core/__pycache__/ssm.cpython-312.pyc
ADDED
|
Binary file (20.9 kB). View file
|
|
|
tenns_core/activations.py
ADDED
|
@@ -0,0 +1,158 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Activation, normalization, and dropout utilities for SSM layers.
|
| 3 |
+
|
| 4 |
+
Extracted from tenns.models.utils to provide activation layer construction.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
from torch import nn
|
| 8 |
+
from torch.nn import RMSNorm
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class LayerNormFeature(nn.LayerNorm):
|
| 12 |
+
"""LayerNorm that operates on the feature dimension (dim=-2) instead of time (dim=-1)."""
|
| 13 |
+
|
| 14 |
+
def forward(self, input):
|
| 15 |
+
return super().forward(input.moveaxis(-1, -2)).moveaxis(-1, -2)
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class RmsNormFeature(nn.Module):
|
| 19 |
+
"""RMSNorm that operates on the feature dimension (dim=-2) instead of time (dim=-1)."""
|
| 20 |
+
|
| 21 |
+
def __init__(self, features):
|
| 22 |
+
super().__init__()
|
| 23 |
+
self.rms_norm = RMSNorm(features)
|
| 24 |
+
|
| 25 |
+
def forward(self, input):
|
| 26 |
+
return self.rms_norm(input.moveaxis(-1, -2)).moveaxis(-1, -2)
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def get_norm(norm, num_features, ndim=2):
|
| 30 |
+
"""Get normalization layer by name.
|
| 31 |
+
|
| 32 |
+
Args:
|
| 33 |
+
norm: Normalization type ('batch', 'layer', 'layer-feature', 'rms', None)
|
| 34 |
+
num_features: Number of features/channels
|
| 35 |
+
ndim: Number of dimensions (1, 2, or 3)
|
| 36 |
+
|
| 37 |
+
Returns:
|
| 38 |
+
Normalization layer module
|
| 39 |
+
"""
|
| 40 |
+
match norm:
|
| 41 |
+
case 'batch':
|
| 42 |
+
match ndim:
|
| 43 |
+
case 1:
|
| 44 |
+
return nn.BatchNorm1d(num_features)
|
| 45 |
+
case 2:
|
| 46 |
+
return nn.BatchNorm2d(num_features)
|
| 47 |
+
case 3:
|
| 48 |
+
return nn.BatchNorm3d(num_features)
|
| 49 |
+
case _:
|
| 50 |
+
raise ValueError(f'Invalid dimensions: {ndim}')
|
| 51 |
+
|
| 52 |
+
case 'layer':
|
| 53 |
+
return nn.LayerNorm(num_features)
|
| 54 |
+
|
| 55 |
+
case 'layer-feature':
|
| 56 |
+
if num_features > 1:
|
| 57 |
+
return LayerNormFeature(num_features)
|
| 58 |
+
else:
|
| 59 |
+
return nn.Identity()
|
| 60 |
+
|
| 61 |
+
case 'rms':
|
| 62 |
+
if num_features > 1:
|
| 63 |
+
return RmsNormFeature(num_features)
|
| 64 |
+
else:
|
| 65 |
+
return nn.Identity()
|
| 66 |
+
|
| 67 |
+
case None:
|
| 68 |
+
return nn.Identity()
|
| 69 |
+
|
| 70 |
+
case _:
|
| 71 |
+
raise ValueError(f'Invalid normalization type: {norm}')
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def get_postact(postact):
|
| 75 |
+
"""Get activation function by name.
|
| 76 |
+
|
| 77 |
+
Args:
|
| 78 |
+
postact: Activation type ('relu', 'gelu', 'silu', etc., or None)
|
| 79 |
+
|
| 80 |
+
Returns:
|
| 81 |
+
Activation function module
|
| 82 |
+
"""
|
| 83 |
+
if postact is None:
|
| 84 |
+
return nn.Identity()
|
| 85 |
+
|
| 86 |
+
postact_registry = {
|
| 87 |
+
'relu': nn.ReLU(),
|
| 88 |
+
'relu6': nn.ReLU6(),
|
| 89 |
+
'lelu': nn.LeakyReLU(0.1),
|
| 90 |
+
'sigmoid': nn.Sigmoid(),
|
| 91 |
+
'tanh': nn.Tanh(),
|
| 92 |
+
'gelu': nn.GELU(),
|
| 93 |
+
'glu': nn.GLU(dim=1),
|
| 94 |
+
'silu': nn.SiLU(),
|
| 95 |
+
}
|
| 96 |
+
|
| 97 |
+
if postact in postact_registry:
|
| 98 |
+
return postact_registry[postact]
|
| 99 |
+
else:
|
| 100 |
+
raise ValueError(f'Invalid activation name: {postact}')
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
def get_dropout(p, dropout_dim, num_features):
|
| 104 |
+
"""Get dropout layer by dimension.
|
| 105 |
+
|
| 106 |
+
Args:
|
| 107 |
+
p: Dropout probability (None for no dropout)
|
| 108 |
+
dropout_dim: Dimension of dropout (0 for standard, 1 for 1d, etc.)
|
| 109 |
+
num_features: Number of features (used to determine if dropout should be applied)
|
| 110 |
+
|
| 111 |
+
Returns:
|
| 112 |
+
Dropout module
|
| 113 |
+
"""
|
| 114 |
+
if p is None:
|
| 115 |
+
return nn.Identity()
|
| 116 |
+
|
| 117 |
+
dropout_registry = {
|
| 118 |
+
0: nn.Dropout,
|
| 119 |
+
1: nn.Dropout1d,
|
| 120 |
+
2: nn.Dropout2d,
|
| 121 |
+
3: nn.Dropout3d,
|
| 122 |
+
}
|
| 123 |
+
|
| 124 |
+
if dropout_dim in dropout_registry:
|
| 125 |
+
# Only apply dropout if we have enough features
|
| 126 |
+
if dropout_dim == 0 or num_features >= 16:
|
| 127 |
+
return dropout_registry[dropout_dim](p)
|
| 128 |
+
else:
|
| 129 |
+
return nn.Identity()
|
| 130 |
+
else:
|
| 131 |
+
raise ValueError(f'Invalid dropout dimension: {dropout_dim}')
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
def get_activations(ndim, num_features, norm=None, postact=None, p=None, dropout_dim=0):
|
| 135 |
+
"""Build a sequential module with normalization, activation, and dropout.
|
| 136 |
+
|
| 137 |
+
Args:
|
| 138 |
+
ndim: Number of dimensions (1, 2, or 3)
|
| 139 |
+
num_features: Number of features/channels
|
| 140 |
+
norm: Normalization type (None, 'batch', 'layer', 'layer-feature', 'rms')
|
| 141 |
+
postact: Activation function type (None, 'relu', 'gelu', 'silu', etc.)
|
| 142 |
+
p: Dropout probability (None for no dropout)
|
| 143 |
+
dropout_dim: Dimension of dropout (0, 1, 2, or 3)
|
| 144 |
+
|
| 145 |
+
Returns:
|
| 146 |
+
Sequential module combining norm, activation, and dropout
|
| 147 |
+
"""
|
| 148 |
+
if (norm is None) and (postact is None) and (p is None):
|
| 149 |
+
return nn.Identity()
|
| 150 |
+
|
| 151 |
+
activations = nn.Sequential()
|
| 152 |
+
if norm is not None:
|
| 153 |
+
activations.append(get_norm(norm, num_features, ndim))
|
| 154 |
+
if postact is not None:
|
| 155 |
+
activations.append(get_postact(postact))
|
| 156 |
+
if p is not None:
|
| 157 |
+
activations.append(get_dropout(p, dropout_dim, num_features))
|
| 158 |
+
return activations
|
tenns_core/fft_ops.py
ADDED
|
@@ -0,0 +1,174 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
FFT-based convolution operations for SSM layers.
|
| 3 |
+
|
| 4 |
+
This module provides optimized FFT convolution operations used in SSM training,
|
| 5 |
+
combining functionality from fft_utils.py and fft_utils_opt.py.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
from torch.amp import custom_bwd, custom_fwd
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class PaddedFFTConv(torch.autograd.Function):
|
| 13 |
+
"""Custom autograd function for padded FFT convolution with efficient gradients.
|
| 14 |
+
|
| 15 |
+
Supports both depthwise ('dw') and full ('full') convolution modes.
|
| 16 |
+
"""
|
| 17 |
+
|
| 18 |
+
@staticmethod
|
| 19 |
+
@torch.compiler.disable
|
| 20 |
+
@custom_fwd(device_type='cuda', cast_inputs=torch.float32)
|
| 21 |
+
def forward(ctx, u, k, n, mode, is_complex=False):
|
| 22 |
+
"""
|
| 23 |
+
Args:
|
| 24 |
+
u: Input tensor
|
| 25 |
+
k: Kernel tensor
|
| 26 |
+
n: Sequence length
|
| 27 |
+
mode: 'dw' for depthwise or 'full' for full convolution
|
| 28 |
+
is_complex: Whether to use complex FFT
|
| 29 |
+
"""
|
| 30 |
+
if is_complex:
|
| 31 |
+
uf = torch.fft.fft(u, 2 * n)
|
| 32 |
+
kf = torch.fft.fft(k, 2 * n)
|
| 33 |
+
else:
|
| 34 |
+
uf = torch.fft.rfft(u, 2 * n)
|
| 35 |
+
kf = torch.fft.rfft(k, 2 * n)
|
| 36 |
+
|
| 37 |
+
if mode == 'dw':
|
| 38 |
+
yf = uf * kf
|
| 39 |
+
elif mode == 'full':
|
| 40 |
+
yf = torch.einsum('bcl,dcl->bdl', uf, kf)
|
| 41 |
+
|
| 42 |
+
ctx.is_complex = is_complex
|
| 43 |
+
ctx.mode = mode
|
| 44 |
+
ctx.n = n
|
| 45 |
+
ctx.save_for_backward(u, k)
|
| 46 |
+
|
| 47 |
+
if is_complex:
|
| 48 |
+
return torch.fft.ifft(yf)[..., :n]
|
| 49 |
+
else:
|
| 50 |
+
return torch.fft.irfft(yf)[..., :n]
|
| 51 |
+
|
| 52 |
+
@staticmethod
|
| 53 |
+
@torch.compiler.disable
|
| 54 |
+
@custom_bwd(device_type='cuda')
|
| 55 |
+
def backward(ctx, grad_output):
|
| 56 |
+
is_complex = ctx.is_complex
|
| 57 |
+
mode = ctx.mode
|
| 58 |
+
n = ctx.n
|
| 59 |
+
u, k = ctx.saved_tensors
|
| 60 |
+
|
| 61 |
+
if is_complex:
|
| 62 |
+
uf = torch.fft.fft(u, 2 * n)
|
| 63 |
+
kf = torch.fft.fft(k, 2 * n)
|
| 64 |
+
grad_yf = torch.fft.fft(grad_output, 2 * n)
|
| 65 |
+
else:
|
| 66 |
+
uf = torch.fft.rfft(u, 2 * n)
|
| 67 |
+
kf = torch.fft.rfft(k, 2 * n)
|
| 68 |
+
grad_yf = torch.fft.rfft(grad_output, 2 * n)
|
| 69 |
+
|
| 70 |
+
if mode == 'dw':
|
| 71 |
+
grad_uf = grad_yf * torch.conj(kf)
|
| 72 |
+
elif mode == 'full':
|
| 73 |
+
grad_uf = torch.einsum('bdl,dcl->bcl', grad_yf, torch.conj(kf))
|
| 74 |
+
|
| 75 |
+
if is_complex:
|
| 76 |
+
grad_u = torch.fft.ifft(grad_uf, 2 * n)[..., :n]
|
| 77 |
+
else:
|
| 78 |
+
grad_u = torch.fft.irfft(grad_uf, 2 * n)[..., :n]
|
| 79 |
+
|
| 80 |
+
if mode == 'dw':
|
| 81 |
+
grad_kf = torch.einsum('bnl,bnl->nl', grad_yf, torch.conj(uf))
|
| 82 |
+
elif mode == 'full':
|
| 83 |
+
grad_kf = torch.einsum('bdl,bcl->dcl', grad_yf, torch.conj(uf))
|
| 84 |
+
|
| 85 |
+
if is_complex:
|
| 86 |
+
grad_k = torch.fft.ifft(grad_kf, 2 * n)[..., :n]
|
| 87 |
+
else:
|
| 88 |
+
grad_k = torch.fft.irfft(grad_kf, 2 * n)[..., :n]
|
| 89 |
+
|
| 90 |
+
return grad_u, grad_k, None, None, None
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
def _K(dtA_real, dtA_imag, length, weight=None, dim=-2, complex_proj=False, l_shift=0):
|
| 94 |
+
"""Generate SSM convolution kernel from discretized state matrix.
|
| 95 |
+
|
| 96 |
+
Args:
|
| 97 |
+
dtA_real: Real part of discretized state matrix diagonal
|
| 98 |
+
dtA_imag: Imaginary part of discretized state matrix diagonal
|
| 99 |
+
length: Sequence length
|
| 100 |
+
weight: Optional weight matrix to apply
|
| 101 |
+
dim: Dimension to reduce over if weight is provided
|
| 102 |
+
complex_proj: Whether to use complex projection
|
| 103 |
+
l_shift: Shift amount for the range
|
| 104 |
+
|
| 105 |
+
Returns:
|
| 106 |
+
SSM convolution kernel of shape (..., length)
|
| 107 |
+
"""
|
| 108 |
+
device = dtA_real.device
|
| 109 |
+
lrange = torch.arange(l_shift, length + l_shift, device=device)
|
| 110 |
+
|
| 111 |
+
with torch.autocast('cuda', enabled=False):
|
| 112 |
+
dtA_real, dtA_imag = dtA_real.float(), dtA_imag.float()
|
| 113 |
+
if complex_proj:
|
| 114 |
+
K = (torch.complex(dtA_real, dtA_imag)[..., None] * lrange).exp()
|
| 115 |
+
else:
|
| 116 |
+
K = (dtA_real[..., None] * lrange).exp() * torch.cos(dtA_imag[..., None] * lrange)
|
| 117 |
+
|
| 118 |
+
if weight is not None:
|
| 119 |
+
return (K * weight[..., None]).sum(dim)
|
| 120 |
+
else:
|
| 121 |
+
return K
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
def _full_k(dtA_real, dtA_imag, B, C, E, length):
|
| 125 |
+
"""Generate full SSM kernel by combining B, C, and state kernel.
|
| 126 |
+
|
| 127 |
+
Used for optimizing s5/neck mode when full kernel is more efficient.
|
| 128 |
+
"""
|
| 129 |
+
K = _K(dtA_real, dtA_imag, length, weight=E)
|
| 130 |
+
return (B[..., None] * C[..., None, None] * K[:, None, :]).sum(1)
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
def padded_fft_conv_opt(input, dtA_real, dtA_imag, B, C, E):
|
| 134 |
+
"""Optimized padded FFT convolution for SSM layers.
|
| 135 |
+
|
| 136 |
+
Automatically chooses between naive and optimized contraction based on
|
| 137 |
+
tensor shapes to minimize computation.
|
| 138 |
+
|
| 139 |
+
Args:
|
| 140 |
+
input: Input tensor of shape (batch, in_channels, length)
|
| 141 |
+
dtA_real: Real part of discretized A matrix
|
| 142 |
+
dtA_imag: Imaginary part of discretized A matrix
|
| 143 |
+
B: Input projection matrix (None for dws/full modes)
|
| 144 |
+
C: Output projection matrix (None for dws/full modes)
|
| 145 |
+
E: State projection matrix (None for s5/neck modes)
|
| 146 |
+
|
| 147 |
+
Returns:
|
| 148 |
+
Output tensor of shape (batch, out_channels, length)
|
| 149 |
+
"""
|
| 150 |
+
batch, chin, length = input.shape
|
| 151 |
+
|
| 152 |
+
# DWS/Full mode: no B/C matrices
|
| 153 |
+
if B is None:
|
| 154 |
+
K = _K(dtA_real, dtA_imag, length, weight=E)
|
| 155 |
+
if K.ndim == 3:
|
| 156 |
+
return PaddedFFTConv.apply(input, K, length, 'full', False)
|
| 157 |
+
elif K.ndim == 2:
|
| 158 |
+
return PaddedFFTConv.apply(input, K, length, 'dw', False)
|
| 159 |
+
|
| 160 |
+
# S5/Neck mode: has B/C matrices
|
| 161 |
+
chout, coeffs = C.shape
|
| 162 |
+
|
| 163 |
+
# Choose contraction order based on efficiency
|
| 164 |
+
# Compare cost of: (1) fusing B,C,K vs (2) separate contractions
|
| 165 |
+
if (1 / chin + 1 / chout) > (1 / batch + 1 / coeffs):
|
| 166 |
+
# Fuse full kernel and apply single convolution
|
| 167 |
+
kernel = _full_k(dtA_real, dtA_imag, B, C, E, length)
|
| 168 |
+
return PaddedFFTConv.apply(input, kernel, length, 'full', False)
|
| 169 |
+
else:
|
| 170 |
+
# Separate: project input, convolve, then project output
|
| 171 |
+
K = _K(dtA_real, dtA_imag, length, weight=E)
|
| 172 |
+
x = torch.einsum('bcl,nc->bnl', input, B)
|
| 173 |
+
x = PaddedFFTConv.apply(x, K, length, 'dw', False)
|
| 174 |
+
return torch.einsum('bnl,dn->bdl', x, C)
|
tenns_core/inference.py
ADDED
|
@@ -0,0 +1,540 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Inference mode for SSM layers.
|
| 3 |
+
|
| 4 |
+
Provides streaming/online inference with stateful processing for real-time applications.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
from torch import nn
|
| 9 |
+
|
| 10 |
+
from .recurrent_ops import (
|
| 11 |
+
discretize_dws,
|
| 12 |
+
discretize_full,
|
| 13 |
+
discretize_neck,
|
| 14 |
+
discretize_s5,
|
| 15 |
+
recurrent_gate,
|
| 16 |
+
recurrent_gate_single_step,
|
| 17 |
+
step_dws,
|
| 18 |
+
step_full,
|
| 19 |
+
step_neck,
|
| 20 |
+
step_s5,
|
| 21 |
+
)
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class SSMLayerInference(nn.Module):
|
| 25 |
+
"""Streaming inference wrapper for SSMLayer.
|
| 26 |
+
|
| 27 |
+
Provides stateful recurrent inference for real-time applications.
|
| 28 |
+
Maintains internal state across chunks for low-latency streaming.
|
| 29 |
+
|
| 30 |
+
Discretization (Ad, B_hat, etc.) is precomputed once at construction time
|
| 31 |
+
from the raw SSM parameters, so only the per-timestep step function runs
|
| 32 |
+
during forward passes.
|
| 33 |
+
|
| 34 |
+
Example:
|
| 35 |
+
>>> # After training
|
| 36 |
+
>>> train_layer = SSMLayer(64, 128, 256, mode='s5')
|
| 37 |
+
>>> # ... training ...
|
| 38 |
+
>>>
|
| 39 |
+
>>> # Convert to inference mode
|
| 40 |
+
>>> infer_layer = SSMLayerInference.from_training(train_layer)
|
| 41 |
+
>>>
|
| 42 |
+
>>> # Process streaming chunks (state maintained automatically)
|
| 43 |
+
>>> for chunk in audio_stream:
|
| 44 |
+
>>> output = infer_layer(chunk)
|
| 45 |
+
>>>
|
| 46 |
+
>>> # Reset state when starting new utterance
|
| 47 |
+
>>> infer_layer.reset_state()
|
| 48 |
+
|
| 49 |
+
Note:
|
| 50 |
+
- Inference mode uses sequential scan (O(T) per chunk)
|
| 51 |
+
- Training mode uses FFT (O(T log T) for full sequence)
|
| 52 |
+
- For streaming, inference mode has lower latency
|
| 53 |
+
- For batch processing full sequences, training mode is faster
|
| 54 |
+
"""
|
| 55 |
+
|
| 56 |
+
def __init__(self, mode, in_channels, out_channels, **kwargs):
|
| 57 |
+
"""Initialize inference layer.
|
| 58 |
+
|
| 59 |
+
Args:
|
| 60 |
+
mode: SSM mode ('s5', 'dws', 'neck', 'full', 'gate')
|
| 61 |
+
in_channels: Number of input channels
|
| 62 |
+
out_channels: Number of output channels
|
| 63 |
+
**kwargs: Mode-specific parameters (Ad, B_hat, C, dt, B, E, A, log_dt, mixer, etc.)
|
| 64 |
+
"""
|
| 65 |
+
super().__init__()
|
| 66 |
+
self.mode = mode
|
| 67 |
+
self.in_channels = in_channels
|
| 68 |
+
self.out_channels = out_channels
|
| 69 |
+
|
| 70 |
+
if mode == 's5':
|
| 71 |
+
self.register_buffer('Ad', kwargs['Ad'])
|
| 72 |
+
self.register_buffer('B_hat', kwargs['B_hat'])
|
| 73 |
+
self.register_buffer('C', kwargs['C'])
|
| 74 |
+
elif mode == 'dws':
|
| 75 |
+
self.register_buffer('Ad', kwargs['Ad'])
|
| 76 |
+
self.register_buffer('B_hat', kwargs['B_hat'])
|
| 77 |
+
elif mode == 'neck':
|
| 78 |
+
self.register_buffer('Ad', kwargs['Ad'])
|
| 79 |
+
self.register_buffer('dt', kwargs['dt'])
|
| 80 |
+
self.register_buffer('B', kwargs['B'])
|
| 81 |
+
self.register_buffer('C', kwargs['C'])
|
| 82 |
+
self.register_buffer('E', kwargs['E'])
|
| 83 |
+
elif mode == 'full':
|
| 84 |
+
self.register_buffer('Ad', kwargs['Ad'])
|
| 85 |
+
self.register_buffer('B_hat', kwargs['B_hat'])
|
| 86 |
+
elif mode == 'gate':
|
| 87 |
+
# Gate mode: input-dependent discretization, store raw params
|
| 88 |
+
self.register_buffer('A', kwargs['A'])
|
| 89 |
+
self.B = kwargs['B'] # nn.Module
|
| 90 |
+
self.C = kwargs['C'] # nn.Module
|
| 91 |
+
self.log_dt = kwargs['log_dt'] # nn.Module
|
| 92 |
+
else:
|
| 93 |
+
raise ValueError(f'Unknown mode: {mode}')
|
| 94 |
+
|
| 95 |
+
# Mixer module (for DWS mode to project channels)
|
| 96 |
+
self.mixer = kwargs.get('mixer') or nn.Identity()
|
| 97 |
+
|
| 98 |
+
# Internal state
|
| 99 |
+
self.state = None
|
| 100 |
+
|
| 101 |
+
@classmethod
|
| 102 |
+
def from_training(cls, ssm_layer):
|
| 103 |
+
"""Create inference layer from trained SSMLayer.
|
| 104 |
+
|
| 105 |
+
Args:
|
| 106 |
+
ssm_layer: Trained SSMLayer instance
|
| 107 |
+
|
| 108 |
+
Returns:
|
| 109 |
+
SSMLayerInference instance with precomputed discretized weights
|
| 110 |
+
|
| 111 |
+
Example:
|
| 112 |
+
>>> train_layer = SSMLayer(64, 128, 256, mode='s5')
|
| 113 |
+
>>> infer_layer = SSMLayerInference.from_training(train_layer)
|
| 114 |
+
"""
|
| 115 |
+
mode = ssm_layer.mode
|
| 116 |
+
kwargs = {}
|
| 117 |
+
|
| 118 |
+
if mode == 's5':
|
| 119 |
+
Ad, B_hat = discretize_s5(
|
| 120 |
+
ssm_layer.A.detach().clone(),
|
| 121 |
+
ssm_layer.B.detach().clone(),
|
| 122 |
+
ssm_layer.log_dt.detach().clone(),
|
| 123 |
+
)
|
| 124 |
+
kwargs['Ad'] = Ad
|
| 125 |
+
kwargs['B_hat'] = B_hat
|
| 126 |
+
kwargs['C'] = ssm_layer.C.detach().clone()
|
| 127 |
+
|
| 128 |
+
elif mode == 'dws':
|
| 129 |
+
Ad, B_hat = discretize_dws(
|
| 130 |
+
ssm_layer.A.detach().clone(),
|
| 131 |
+
ssm_layer.E.detach().clone(),
|
| 132 |
+
ssm_layer.log_dt.detach().clone(),
|
| 133 |
+
)
|
| 134 |
+
kwargs['Ad'] = Ad
|
| 135 |
+
kwargs['B_hat'] = B_hat
|
| 136 |
+
kwargs['mixer'] = ssm_layer.mixer
|
| 137 |
+
|
| 138 |
+
elif mode == 'neck':
|
| 139 |
+
Ad, dt = discretize_neck(
|
| 140 |
+
ssm_layer.A.detach().clone(),
|
| 141 |
+
ssm_layer.log_dt.detach().clone(),
|
| 142 |
+
)
|
| 143 |
+
kwargs['Ad'] = Ad
|
| 144 |
+
kwargs['dt'] = dt
|
| 145 |
+
kwargs['B'] = ssm_layer.B.detach().clone()
|
| 146 |
+
kwargs['C'] = ssm_layer.C.detach().clone()
|
| 147 |
+
kwargs['E'] = ssm_layer.E.detach().clone()
|
| 148 |
+
|
| 149 |
+
elif mode == 'full':
|
| 150 |
+
Ad, B_hat = discretize_full(
|
| 151 |
+
ssm_layer.A.detach().clone(),
|
| 152 |
+
ssm_layer.E.detach().clone(),
|
| 153 |
+
ssm_layer.log_dt.detach().clone(),
|
| 154 |
+
)
|
| 155 |
+
kwargs['Ad'] = Ad
|
| 156 |
+
kwargs['B_hat'] = B_hat
|
| 157 |
+
|
| 158 |
+
elif mode == 'gate':
|
| 159 |
+
kwargs['A'] = ssm_layer.A.detach().clone()
|
| 160 |
+
kwargs['B'] = ssm_layer.B
|
| 161 |
+
kwargs['C'] = ssm_layer.C
|
| 162 |
+
kwargs['log_dt'] = ssm_layer.log_dt
|
| 163 |
+
kwargs['mixer'] = ssm_layer.mixer
|
| 164 |
+
|
| 165 |
+
else:
|
| 166 |
+
raise ValueError(f'Unknown mode: {mode}')
|
| 167 |
+
|
| 168 |
+
return cls(
|
| 169 |
+
mode=mode,
|
| 170 |
+
in_channels=ssm_layer.in_channels,
|
| 171 |
+
out_channels=ssm_layer.out_channels,
|
| 172 |
+
**kwargs,
|
| 173 |
+
)
|
| 174 |
+
|
| 175 |
+
def forward(self, input, return_state=False):
|
| 176 |
+
"""Forward pass with stateful processing.
|
| 177 |
+
|
| 178 |
+
Args:
|
| 179 |
+
input: Input tensor of shape (B, C, T) or (C, T) for single sample
|
| 180 |
+
return_state: If True, return (output, state) tuple
|
| 181 |
+
|
| 182 |
+
Returns:
|
| 183 |
+
output: Output tensor of shape (B, D, T) or (D, T)
|
| 184 |
+
state (optional): Internal state if return_state=True
|
| 185 |
+
|
| 186 |
+
Note:
|
| 187 |
+
State is maintained internally across calls. Use reset_state()
|
| 188 |
+
to clear it.
|
| 189 |
+
"""
|
| 190 |
+
# Handle input format
|
| 191 |
+
squeeze_batch = False
|
| 192 |
+
if input.dim() == 2:
|
| 193 |
+
input = input.unsqueeze(0) # (C, T) -> (1, C, T)
|
| 194 |
+
squeeze_batch = True
|
| 195 |
+
|
| 196 |
+
B_batch, _C, T = input.shape
|
| 197 |
+
# Transpose to (B, T, C) for step functions
|
| 198 |
+
input = input.transpose(1, 2)
|
| 199 |
+
|
| 200 |
+
if self.mode == 'gate':
|
| 201 |
+
output, self.state = recurrent_gate(
|
| 202 |
+
input, self.A, self.B, self.C, self.log_dt, self.state
|
| 203 |
+
)
|
| 204 |
+
else:
|
| 205 |
+
# Non-gate modes: loop over timesteps with precomputed discretization
|
| 206 |
+
outputs = []
|
| 207 |
+
for b in range(B_batch):
|
| 208 |
+
batch_outputs = []
|
| 209 |
+
# Use per-batch state or init
|
| 210 |
+
if self.state is not None and self.state.dim() > len(self._state_shape()):
|
| 211 |
+
x = self.state[b]
|
| 212 |
+
else:
|
| 213 |
+
x = self.state
|
| 214 |
+
|
| 215 |
+
for t in range(T):
|
| 216 |
+
u_t = input[b, t] # (C_in,)
|
| 217 |
+
y_t, x = self._step(u_t, x)
|
| 218 |
+
batch_outputs.append(y_t)
|
| 219 |
+
|
| 220 |
+
# Update state
|
| 221 |
+
if b == 0:
|
| 222 |
+
self.state = x.unsqueeze(0) if B_batch > 1 else x
|
| 223 |
+
elif B_batch > 1:
|
| 224 |
+
self.state = torch.cat([self.state, x.unsqueeze(0)], dim=0)
|
| 225 |
+
|
| 226 |
+
outputs.append(torch.stack(batch_outputs, dim=1)) # (D, T)
|
| 227 |
+
|
| 228 |
+
output = torch.stack(outputs, dim=0) # (B, D, T)
|
| 229 |
+
|
| 230 |
+
# Apply mixer (important for DWS mode which projects channels)
|
| 231 |
+
output = self.mixer(output)
|
| 232 |
+
|
| 233 |
+
if squeeze_batch:
|
| 234 |
+
output = output.squeeze(0)
|
| 235 |
+
if self.state is not None and self.state.dim() > len(self._state_shape()):
|
| 236 |
+
self.state = self.state.squeeze(0)
|
| 237 |
+
|
| 238 |
+
if return_state:
|
| 239 |
+
return output, self.state
|
| 240 |
+
return output
|
| 241 |
+
|
| 242 |
+
def _step(self, u, state):
|
| 243 |
+
"""Dispatch to mode-specific step function."""
|
| 244 |
+
if self.mode == 's5':
|
| 245 |
+
return step_s5(u, self.Ad, self.B_hat, self.C, state)
|
| 246 |
+
elif self.mode == 'dws':
|
| 247 |
+
return step_dws(u, self.Ad, self.B_hat, state)
|
| 248 |
+
elif self.mode == 'neck':
|
| 249 |
+
return step_neck(u, self.Ad, self.dt, self.B, self.C, self.E, state)
|
| 250 |
+
elif self.mode == 'full':
|
| 251 |
+
return step_full(u, self.Ad, self.B_hat, state)
|
| 252 |
+
|
| 253 |
+
def _state_shape(self):
|
| 254 |
+
"""Return expected unbatched state shape for current mode."""
|
| 255 |
+
if self.mode == 's5':
|
| 256 |
+
return self.Ad.shape # (N, 2)
|
| 257 |
+
elif self.mode == 'dws':
|
| 258 |
+
return self.Ad.shape # (C, N, 2)
|
| 259 |
+
elif self.mode == 'neck':
|
| 260 |
+
return self.Ad.shape # (R, N, 2)
|
| 261 |
+
elif self.mode == 'full':
|
| 262 |
+
return self.Ad.shape # (D, C, N, 2)
|
| 263 |
+
elif self.mode == 'gate':
|
| 264 |
+
return (self.A.shape[0],) # (N,)
|
| 265 |
+
|
| 266 |
+
def reset_state(self):
|
| 267 |
+
"""Reset internal state.
|
| 268 |
+
|
| 269 |
+
Call this when starting a new sequence.
|
| 270 |
+
|
| 271 |
+
Example:
|
| 272 |
+
>>> for utterance in utterances:
|
| 273 |
+
>>> infer_layer.reset_state() # Clear state
|
| 274 |
+
>>> for chunk in utterance:
|
| 275 |
+
>>> output = infer_layer(chunk)
|
| 276 |
+
"""
|
| 277 |
+
if self.state is not None:
|
| 278 |
+
self.state.zero_()
|
| 279 |
+
else:
|
| 280 |
+
self.state = None
|
| 281 |
+
|
| 282 |
+
def get_state(self):
|
| 283 |
+
"""Get current internal state for checkpointing or branching.
|
| 284 |
+
|
| 285 |
+
Returns a clone of the state to prevent accidental mutations.
|
| 286 |
+
Useful for beam search, hypothesis tracking, or state snapshots.
|
| 287 |
+
|
| 288 |
+
Returns:
|
| 289 |
+
state: Cloned state tensor or None if no state exists
|
| 290 |
+
|
| 291 |
+
Example:
|
| 292 |
+
>>> # Save state for beam search
|
| 293 |
+
>>> saved_state = infer_layer.get_state()
|
| 294 |
+
>>> # Process hypothesis 1
|
| 295 |
+
>>> output1 = infer_layer(chunk1)
|
| 296 |
+
>>> # Restore and try hypothesis 2
|
| 297 |
+
>>> infer_layer.set_state(saved_state)
|
| 298 |
+
>>> output2 = infer_layer(chunk2)
|
| 299 |
+
"""
|
| 300 |
+
return self.state.clone() if self.state is not None else None
|
| 301 |
+
|
| 302 |
+
def set_state(self, state):
|
| 303 |
+
"""Restore internal state from checkpoint.
|
| 304 |
+
|
| 305 |
+
Sets the state to a clone of the provided tensor to prevent
|
| 306 |
+
accidental mutations. Useful for restoring checkpoints or
|
| 307 |
+
branching hypotheses in beam search.
|
| 308 |
+
|
| 309 |
+
Args:
|
| 310 |
+
state: State tensor (shape depends on mode) or None to reset
|
| 311 |
+
|
| 312 |
+
Example:
|
| 313 |
+
>>> # Checkpoint state before branching
|
| 314 |
+
>>> checkpoint = infer_layer.get_state()
|
| 315 |
+
>>> # ... process some data ...
|
| 316 |
+
>>> # Restore to checkpoint
|
| 317 |
+
>>> infer_layer.set_state(checkpoint)
|
| 318 |
+
"""
|
| 319 |
+
self.state = state.clone() if state is not None else None
|
| 320 |
+
|
| 321 |
+
def __repr__(self):
|
| 322 |
+
return (
|
| 323 |
+
f'SSMLayerInference(mode={self.mode}, '
|
| 324 |
+
f'in_channels={self.in_channels}, '
|
| 325 |
+
f'out_channels={self.out_channels}, '
|
| 326 |
+
f'state={"active" if self.state is not None else "reset"})'
|
| 327 |
+
)
|
| 328 |
+
|
| 329 |
+
|
| 330 |
+
class SSMLayerExportable(nn.Module):
|
| 331 |
+
"""Single-timestep exportable SSM layer for ONNX export (B=1, T=1).
|
| 332 |
+
|
| 333 |
+
This class processes one timestep at a time with explicit state input/output,
|
| 334 |
+
enabling export to ONNX by eliminating dynamic control flow and
|
| 335 |
+
complex number dtypes.
|
| 336 |
+
|
| 337 |
+
Discretization is precomputed at construction time, so the forward pass
|
| 338 |
+
only runs the step function.
|
| 339 |
+
|
| 340 |
+
Currently supports S5, DWS, Neck, Full, and Gate modes. State is represented as real tensors (..., 2)
|
| 341 |
+
where [..., 0] is the real part and [..., 1] is the imaginary part.
|
| 342 |
+
|
| 343 |
+
Example:
|
| 344 |
+
>>> # After training
|
| 345 |
+
>>> train_layer = SSMLayer(num_coeffs=64, in_channels=32, out_channels=32, mode='s5')
|
| 346 |
+
>>> # ... training ...
|
| 347 |
+
>>>
|
| 348 |
+
>>> # Convert to exportable inference mode
|
| 349 |
+
>>> export_layer = SSMLayerExportable.from_training(train_layer)
|
| 350 |
+
>>>
|
| 351 |
+
>>> # Export to ONNX
|
| 352 |
+
>>> dummy_input = torch.randn(32)
|
| 353 |
+
>>> torch.onnx.export(export_layer, (dummy_input, None), "model.onnx")
|
| 354 |
+
>>>
|
| 355 |
+
>>> # Use in streaming application (external loop)
|
| 356 |
+
>>> state = None
|
| 357 |
+
>>> for t in range(audio_length):
|
| 358 |
+
>>> output, state = export_layer(audio[t], state)
|
| 359 |
+
|
| 360 |
+
Note:
|
| 361 |
+
- Processes single sample (B=1), single timestep (T=1) per call
|
| 362 |
+
- State is automatically initialized to zeros if None
|
| 363 |
+
- Loop over time must be external to the model
|
| 364 |
+
- Complex numbers represented as (..., 2) real tensors
|
| 365 |
+
"""
|
| 366 |
+
|
| 367 |
+
def __init__(self, mode, in_channels, out_channels, **kwargs):
|
| 368 |
+
"""Initialize exportable SSM layer.
|
| 369 |
+
|
| 370 |
+
Args:
|
| 371 |
+
mode: SSM mode ('s5', 'dws', 'neck', 'full', 'gate')
|
| 372 |
+
in_channels: Number of input channels
|
| 373 |
+
out_channels: Number of output channels
|
| 374 |
+
**kwargs: Mode-specific discretized parameters
|
| 375 |
+
"""
|
| 376 |
+
super().__init__()
|
| 377 |
+
self.mode = mode
|
| 378 |
+
self.in_channels = in_channels
|
| 379 |
+
self.out_channels = out_channels
|
| 380 |
+
|
| 381 |
+
if mode == 's5':
|
| 382 |
+
self.register_buffer('Ad', kwargs['Ad'])
|
| 383 |
+
self.register_buffer('B_hat', kwargs['B_hat'])
|
| 384 |
+
self.register_buffer('C', kwargs['C'])
|
| 385 |
+
elif mode == 'dws':
|
| 386 |
+
self.register_buffer('Ad', kwargs['Ad'])
|
| 387 |
+
self.register_buffer('B_hat', kwargs['B_hat'])
|
| 388 |
+
elif mode == 'neck':
|
| 389 |
+
self.register_buffer('Ad', kwargs['Ad'])
|
| 390 |
+
self.register_buffer('dt', kwargs['dt'])
|
| 391 |
+
self.register_buffer('B', kwargs['B'])
|
| 392 |
+
self.register_buffer('C', kwargs['C'])
|
| 393 |
+
self.register_buffer('E', kwargs['E'])
|
| 394 |
+
elif mode == 'full':
|
| 395 |
+
self.register_buffer('Ad', kwargs['Ad'])
|
| 396 |
+
self.register_buffer('B_hat', kwargs['B_hat'])
|
| 397 |
+
elif mode == 'gate':
|
| 398 |
+
self.register_buffer('A', kwargs['A'])
|
| 399 |
+
self.B = kwargs['B'] # nn.Module
|
| 400 |
+
self.C = kwargs['C'] # nn.Module
|
| 401 |
+
self.log_dt = kwargs['log_dt'] # nn.Module
|
| 402 |
+
else:
|
| 403 |
+
raise ValueError(f'Unknown mode: {mode}')
|
| 404 |
+
|
| 405 |
+
# Mixer module (for DWS mode to project channels)
|
| 406 |
+
self.mixer = kwargs.get('mixer') or nn.Identity()
|
| 407 |
+
|
| 408 |
+
@classmethod
|
| 409 |
+
def from_training(cls, ssm_layer):
|
| 410 |
+
"""Create exportable layer from trained SSMLayer.
|
| 411 |
+
|
| 412 |
+
Args:
|
| 413 |
+
ssm_layer: Trained SSMLayer instance
|
| 414 |
+
|
| 415 |
+
Returns:
|
| 416 |
+
SSMLayerExportable instance with precomputed discretized weights
|
| 417 |
+
|
| 418 |
+
Raises:
|
| 419 |
+
ValueError: If ssm_layer.mode is not supported
|
| 420 |
+
|
| 421 |
+
Example:
|
| 422 |
+
>>> train_layer = SSMLayer(num_coeffs=64, in_channels=32, out_channels=32, mode='s5')
|
| 423 |
+
>>> export_layer = SSMLayerExportable.from_training(train_layer)
|
| 424 |
+
"""
|
| 425 |
+
mode = ssm_layer.mode
|
| 426 |
+
kwargs = {}
|
| 427 |
+
|
| 428 |
+
if mode == 's5':
|
| 429 |
+
Ad, B_hat = discretize_s5(
|
| 430 |
+
ssm_layer.A.detach().clone(),
|
| 431 |
+
ssm_layer.B.detach().clone(),
|
| 432 |
+
ssm_layer.log_dt.detach().clone(),
|
| 433 |
+
)
|
| 434 |
+
kwargs['Ad'] = Ad
|
| 435 |
+
kwargs['B_hat'] = B_hat
|
| 436 |
+
kwargs['C'] = ssm_layer.C.detach().clone()
|
| 437 |
+
|
| 438 |
+
elif mode == 'dws':
|
| 439 |
+
Ad, B_hat = discretize_dws(
|
| 440 |
+
ssm_layer.A.detach().clone(),
|
| 441 |
+
ssm_layer.E.detach().clone(),
|
| 442 |
+
ssm_layer.log_dt.detach().clone(),
|
| 443 |
+
)
|
| 444 |
+
kwargs['Ad'] = Ad
|
| 445 |
+
kwargs['B_hat'] = B_hat
|
| 446 |
+
kwargs['mixer'] = ssm_layer.mixer
|
| 447 |
+
|
| 448 |
+
elif mode == 'neck':
|
| 449 |
+
Ad, dt = discretize_neck(
|
| 450 |
+
ssm_layer.A.detach().clone(),
|
| 451 |
+
ssm_layer.log_dt.detach().clone(),
|
| 452 |
+
)
|
| 453 |
+
kwargs['Ad'] = Ad
|
| 454 |
+
kwargs['dt'] = dt
|
| 455 |
+
kwargs['B'] = ssm_layer.B.detach().clone()
|
| 456 |
+
kwargs['C'] = ssm_layer.C.detach().clone()
|
| 457 |
+
kwargs['E'] = ssm_layer.E.detach().clone()
|
| 458 |
+
|
| 459 |
+
elif mode == 'full':
|
| 460 |
+
Ad, B_hat = discretize_full(
|
| 461 |
+
ssm_layer.A.detach().clone(),
|
| 462 |
+
ssm_layer.E.detach().clone(),
|
| 463 |
+
ssm_layer.log_dt.detach().clone(),
|
| 464 |
+
)
|
| 465 |
+
kwargs['Ad'] = Ad
|
| 466 |
+
kwargs['B_hat'] = B_hat
|
| 467 |
+
|
| 468 |
+
elif mode == 'gate':
|
| 469 |
+
kwargs['A'] = ssm_layer.A.detach().clone()
|
| 470 |
+
kwargs['B'] = ssm_layer.B
|
| 471 |
+
kwargs['C'] = ssm_layer.C
|
| 472 |
+
kwargs['log_dt'] = ssm_layer.log_dt
|
| 473 |
+
|
| 474 |
+
else:
|
| 475 |
+
raise ValueError(
|
| 476 |
+
f'SSMLayerExportable only supports S5, DWS, Neck, Full, and Gate modes, got {mode}'
|
| 477 |
+
)
|
| 478 |
+
|
| 479 |
+
return cls(
|
| 480 |
+
mode=mode,
|
| 481 |
+
in_channels=ssm_layer.in_channels,
|
| 482 |
+
out_channels=ssm_layer.out_channels,
|
| 483 |
+
**kwargs,
|
| 484 |
+
)
|
| 485 |
+
|
| 486 |
+
def forward(self, input, state=None):
|
| 487 |
+
"""Forward pass for single timestep.
|
| 488 |
+
|
| 489 |
+
Args:
|
| 490 |
+
input: Input tensor of shape (C_in,) - single sample, single timestep
|
| 491 |
+
state: Optional state tensor - shape depends on mode:
|
| 492 |
+
- S5: (N, 2) real representation
|
| 493 |
+
- DWS: (C, N, 2) real representation
|
| 494 |
+
- Neck: (R, N, 2) real representation
|
| 495 |
+
- Full: (D, C, N, 2) real representation
|
| 496 |
+
- Gate: (N,) real-valued
|
| 497 |
+
If None, initializes to zeros internally
|
| 498 |
+
|
| 499 |
+
Returns:
|
| 500 |
+
output: Output tensor of shape (D,)
|
| 501 |
+
new_state: Updated state - same shape as state input
|
| 502 |
+
|
| 503 |
+
Example:
|
| 504 |
+
>>> export_layer = SSMLayerExportable.from_training(trained_layer)
|
| 505 |
+
>>> x = torch.randn(32) # Single timestep input
|
| 506 |
+
>>> y, state = export_layer(x, None) # First call, state=None
|
| 507 |
+
>>> y2, state = export_layer(x2, state) # Subsequent call with state
|
| 508 |
+
"""
|
| 509 |
+
if self.mode == 's5':
|
| 510 |
+
output, new_state = step_s5(input, self.Ad, self.B_hat, self.C, state)
|
| 511 |
+
elif self.mode == 'dws':
|
| 512 |
+
output, new_state = step_dws(input, self.Ad, self.B_hat, state)
|
| 513 |
+
# Apply mixer for DWS mode (channel projection)
|
| 514 |
+
# Mixer expects (B, C, T) format, we have (C,) single timestep
|
| 515 |
+
output = (
|
| 516 |
+
self.mixer(output.unsqueeze(0).unsqueeze(-1)).squeeze(0).squeeze(-1)
|
| 517 |
+
) # (C,) -> (1, C, 1) -> (1, D, 1) -> (D,)
|
| 518 |
+
elif self.mode == 'neck':
|
| 519 |
+
output, new_state = step_neck(input, self.Ad, self.dt, self.B, self.C, self.E, state)
|
| 520 |
+
elif self.mode == 'full':
|
| 521 |
+
output, new_state = step_full(input, self.Ad, self.B_hat, state)
|
| 522 |
+
elif self.mode == 'gate':
|
| 523 |
+
# Initialize state if None
|
| 524 |
+
if state is None:
|
| 525 |
+
N = self.A.shape[0]
|
| 526 |
+
state = torch.zeros(N, dtype=torch.float32, device=input.device)
|
| 527 |
+
output, new_state = recurrent_gate_single_step(
|
| 528 |
+
input, self.A, self.B, self.C, self.log_dt, state
|
| 529 |
+
)
|
| 530 |
+
else:
|
| 531 |
+
raise ValueError(f'Unsupported mode: {self.mode}')
|
| 532 |
+
|
| 533 |
+
return output, new_state
|
| 534 |
+
|
| 535 |
+
def __repr__(self):
|
| 536 |
+
return (
|
| 537 |
+
f'SSMLayerExportable(mode={self.mode}, '
|
| 538 |
+
f'in_channels={self.in_channels}, '
|
| 539 |
+
f'out_channels={self.out_channels})'
|
| 540 |
+
)
|
tenns_core/recurrent_ops.py
ADDED
|
@@ -0,0 +1,437 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Recurrent operations for streaming SSM inference.
|
| 3 |
+
|
| 4 |
+
Provides discretize_* functions (called once at init) and step_* functions
|
| 5 |
+
(called per timestep) for each SSM mode, enabling low-latency streaming
|
| 6 |
+
inference by maintaining state across chunks.
|
| 7 |
+
|
| 8 |
+
Gate mode is special: its discretization is input-dependent, so it keeps
|
| 9 |
+
combined recurrent_gate / recurrent_gate_single_step functions.
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
import torch
|
| 13 |
+
import torch.nn.functional as F
|
| 14 |
+
|
| 15 |
+
# ============================================================================
|
| 16 |
+
# Complex arithmetic helpers for real representation (ONNX compat)
|
| 17 |
+
# ============================================================================
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def complex_mul_real(a, b):
|
| 21 |
+
"""Multiply two complex numbers in real representation (..., 2).
|
| 22 |
+
|
| 23 |
+
Args:
|
| 24 |
+
a: Complex tensor as real representation (..., 2) where [..., 0] is real, [..., 1] is imag
|
| 25 |
+
b: Complex tensor as real representation (..., 2)
|
| 26 |
+
|
| 27 |
+
Returns:
|
| 28 |
+
Complex product as real representation (..., 2)
|
| 29 |
+
Formula: (a_r + i*a_i) * (b_r + i*b_i) = (a_r*b_r - a_i*b_i) + i*(a_r*b_i + a_i*b_r)
|
| 30 |
+
"""
|
| 31 |
+
a_real = a[..., 0]
|
| 32 |
+
a_imag = a[..., 1]
|
| 33 |
+
b_real = b[..., 0]
|
| 34 |
+
b_imag = b[..., 1]
|
| 35 |
+
|
| 36 |
+
result_real = a_real * b_real - a_imag * b_imag
|
| 37 |
+
result_imag = a_real * b_imag + a_imag * b_real
|
| 38 |
+
|
| 39 |
+
return torch.stack([result_real, result_imag], dim=-1)
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
# ============================================================================
|
| 43 |
+
# S5 mode
|
| 44 |
+
# ============================================================================
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def discretize_s5(A, B, log_dt):
|
| 48 |
+
"""Precompute discretized parameters for S5 mode.
|
| 49 |
+
|
| 50 |
+
Args:
|
| 51 |
+
A: State transition parameter of shape (N, 2) - real repr of complex
|
| 52 |
+
B: Input projection of shape (N, C_in)
|
| 53 |
+
log_dt: Time step of shape (N,)
|
| 54 |
+
|
| 55 |
+
Returns:
|
| 56 |
+
Ad: Discretized state transition of shape (N, 2)
|
| 57 |
+
B_hat: Discretized input projection of shape (N, C_in)
|
| 58 |
+
"""
|
| 59 |
+
A_real = -F.softplus(A[:, 0]) # (N,)
|
| 60 |
+
A_imag = A[:, 1] # (N,)
|
| 61 |
+
|
| 62 |
+
dt = torch.exp(log_dt) # (N,)
|
| 63 |
+
scaled_real = dt * A_real
|
| 64 |
+
scaled_imag = dt * A_imag
|
| 65 |
+
exp_scaled_real = torch.exp(scaled_real)
|
| 66 |
+
Ad = torch.stack(
|
| 67 |
+
[
|
| 68 |
+
exp_scaled_real * torch.cos(scaled_imag),
|
| 69 |
+
exp_scaled_real * torch.sin(scaled_imag),
|
| 70 |
+
],
|
| 71 |
+
dim=-1,
|
| 72 |
+
) # (N, 2)
|
| 73 |
+
|
| 74 |
+
B_hat = dt[:, None] * B # (N, C_in)
|
| 75 |
+
|
| 76 |
+
return Ad, B_hat
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def step_s5(u, Ad, B_hat, C, state):
|
| 80 |
+
"""Single timestep for S5 mode using pre-discretized parameters.
|
| 81 |
+
|
| 82 |
+
Args:
|
| 83 |
+
u: Input tensor of shape (C_in,)
|
| 84 |
+
Ad: Discretized state transition of shape (N, 2)
|
| 85 |
+
B_hat: Discretized input projection of shape (N, C_in)
|
| 86 |
+
C: Output projection of shape (D, N)
|
| 87 |
+
state: Previous state of shape (N, 2), or None for zero init
|
| 88 |
+
|
| 89 |
+
Returns:
|
| 90 |
+
y: Output tensor of shape (D,)
|
| 91 |
+
new_state: Updated state of shape (N, 2)
|
| 92 |
+
"""
|
| 93 |
+
if state is None:
|
| 94 |
+
N = Ad.shape[0]
|
| 95 |
+
state = torch.zeros((N, 2), dtype=torch.float32, device=u.device)
|
| 96 |
+
|
| 97 |
+
# State update: x = Ad * x + B_hat @ u
|
| 98 |
+
x_new = complex_mul_real(Ad, state) # (N, 2)
|
| 99 |
+
Bu = B_hat @ u # (N,)
|
| 100 |
+
x_new[..., 0] = x_new[..., 0] + Bu
|
| 101 |
+
|
| 102 |
+
# Output: y = C @ real(x)
|
| 103 |
+
y = C @ x_new[..., 0] # (D,)
|
| 104 |
+
|
| 105 |
+
return y, x_new
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
# ============================================================================
|
| 109 |
+
# DWS mode
|
| 110 |
+
# ============================================================================
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
def discretize_dws(A, E, log_dt):
|
| 114 |
+
"""Precompute discretized parameters for DWS mode.
|
| 115 |
+
|
| 116 |
+
Args:
|
| 117 |
+
A: State parameter of shape (C, N, 2) - real repr of complex
|
| 118 |
+
E: Weight matrix of shape (C, N)
|
| 119 |
+
log_dt: Time step of shape (C, N)
|
| 120 |
+
|
| 121 |
+
Returns:
|
| 122 |
+
Ad: Discretized state transition of shape (C, N, 2)
|
| 123 |
+
B_hat: Discretized input projection of shape (C, N)
|
| 124 |
+
"""
|
| 125 |
+
A_real = -F.softplus(A[..., 0]) # (C, N)
|
| 126 |
+
A_imag = A[..., 1] # (C, N)
|
| 127 |
+
|
| 128 |
+
dt = torch.exp(log_dt) # (C, N)
|
| 129 |
+
scaled_real = dt * A_real
|
| 130 |
+
scaled_imag = dt * A_imag
|
| 131 |
+
exp_scaled_real = torch.exp(scaled_real)
|
| 132 |
+
Ad = torch.stack(
|
| 133 |
+
[
|
| 134 |
+
exp_scaled_real * torch.cos(scaled_imag),
|
| 135 |
+
exp_scaled_real * torch.sin(scaled_imag),
|
| 136 |
+
],
|
| 137 |
+
dim=-1,
|
| 138 |
+
) # (C, N, 2)
|
| 139 |
+
|
| 140 |
+
B_hat = E * dt # (C, N)
|
| 141 |
+
|
| 142 |
+
return Ad, B_hat
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
def step_dws(u, Ad, B_hat, state):
|
| 146 |
+
"""Single timestep for DWS mode using pre-discretized parameters.
|
| 147 |
+
|
| 148 |
+
Args:
|
| 149 |
+
u: Input tensor of shape (C,)
|
| 150 |
+
Ad: Discretized state transition of shape (C, N, 2)
|
| 151 |
+
B_hat: Discretized input projection of shape (C, N)
|
| 152 |
+
state: Previous state of shape (C, N, 2), or None for zero init
|
| 153 |
+
|
| 154 |
+
Returns:
|
| 155 |
+
y: Output tensor of shape (C,)
|
| 156 |
+
new_state: Updated state of shape (C, N, 2)
|
| 157 |
+
"""
|
| 158 |
+
if state is None:
|
| 159 |
+
C, N = B_hat.shape
|
| 160 |
+
state = torch.zeros((C, N, 2), dtype=torch.float32, device=u.device)
|
| 161 |
+
|
| 162 |
+
# State update: x = Ad * x + B_hat * u
|
| 163 |
+
x_new = complex_mul_real(Ad, state) # (C, N, 2)
|
| 164 |
+
Bu = B_hat * u.unsqueeze(1) # (C, N)
|
| 165 |
+
x_new[..., 0] = x_new[..., 0] + Bu
|
| 166 |
+
|
| 167 |
+
# Output: y = sum(real(x), dim=1)
|
| 168 |
+
y = torch.sum(x_new[..., 0], dim=1) # (C,)
|
| 169 |
+
|
| 170 |
+
return y, x_new
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
# ============================================================================
|
| 174 |
+
# Neck mode
|
| 175 |
+
# ============================================================================
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
def discretize_neck(A, log_dt):
|
| 179 |
+
"""Precompute discretized parameters for Neck mode.
|
| 180 |
+
|
| 181 |
+
Args:
|
| 182 |
+
A: State transition parameter of shape (R, N, 2) - real repr of complex
|
| 183 |
+
log_dt: Time step of shape (R,)
|
| 184 |
+
|
| 185 |
+
Returns:
|
| 186 |
+
Ad: Discretized state transition of shape (R, N, 2)
|
| 187 |
+
dt: Discretized time step of shape (R, 1) - needed for input scaling
|
| 188 |
+
"""
|
| 189 |
+
A_real = -F.softplus(A[..., 0]) # (R, N)
|
| 190 |
+
A_imag = A[..., 1] # (R, N)
|
| 191 |
+
|
| 192 |
+
dt = torch.exp(log_dt).reshape(-1, 1) # (R, 1)
|
| 193 |
+
scaled_real = dt * A_real
|
| 194 |
+
scaled_imag = dt * A_imag
|
| 195 |
+
exp_scaled_real = torch.exp(scaled_real)
|
| 196 |
+
Ad = torch.stack(
|
| 197 |
+
[
|
| 198 |
+
exp_scaled_real * torch.cos(scaled_imag),
|
| 199 |
+
exp_scaled_real * torch.sin(scaled_imag),
|
| 200 |
+
],
|
| 201 |
+
dim=-1,
|
| 202 |
+
) # (R, N, 2)
|
| 203 |
+
|
| 204 |
+
return Ad, dt
|
| 205 |
+
|
| 206 |
+
|
| 207 |
+
def step_neck(u, Ad, dt, B, C, E, state):
|
| 208 |
+
"""Single timestep for Neck mode using pre-discretized parameters.
|
| 209 |
+
|
| 210 |
+
Args:
|
| 211 |
+
u: Input tensor of shape (C_in,)
|
| 212 |
+
Ad: Discretized state transition of shape (R, N, 2)
|
| 213 |
+
dt: Discretized time step of shape (R, 1)
|
| 214 |
+
B: Input projection of shape (R, C_in)
|
| 215 |
+
C: Output projection of shape (D, R)
|
| 216 |
+
E: State mixing matrix of shape (R, N)
|
| 217 |
+
state: Previous state of shape (R, N, 2), or None for zero init
|
| 218 |
+
|
| 219 |
+
Returns:
|
| 220 |
+
y: Output tensor of shape (D,)
|
| 221 |
+
new_state: Updated state of shape (R, N, 2)
|
| 222 |
+
"""
|
| 223 |
+
if state is None:
|
| 224 |
+
R, N = Ad.shape[0], Ad.shape[1]
|
| 225 |
+
state = torch.zeros((R, N, 2), dtype=torch.float32, device=u.device)
|
| 226 |
+
|
| 227 |
+
# Input projection: v = dt * B @ u
|
| 228 |
+
v = dt.squeeze(1) * (B @ u) # (R,)
|
| 229 |
+
|
| 230 |
+
# State update: x = Ad * x + v
|
| 231 |
+
x_new = complex_mul_real(Ad, state) # (R, N, 2)
|
| 232 |
+
x_new[..., 0] = x_new[..., 0] + v.unsqueeze(1)
|
| 233 |
+
|
| 234 |
+
# Output: z = real((x * E).sum(N)), y = C @ z
|
| 235 |
+
E_cplx = torch.stack([E, torch.zeros_like(E)], dim=-1) # (R, N, 2)
|
| 236 |
+
z = torch.sum(complex_mul_real(x_new, E_cplx)[..., 0], dim=1) # (R,)
|
| 237 |
+
y = C @ z # (D,)
|
| 238 |
+
|
| 239 |
+
return y, x_new
|
| 240 |
+
|
| 241 |
+
|
| 242 |
+
# ============================================================================
|
| 243 |
+
# Full mode
|
| 244 |
+
# ============================================================================
|
| 245 |
+
|
| 246 |
+
|
| 247 |
+
def discretize_full(A, E, log_dt):
|
| 248 |
+
"""Precompute discretized parameters for Full mode.
|
| 249 |
+
|
| 250 |
+
Args:
|
| 251 |
+
A: State parameter of shape (D, C, N, 2) - real repr of complex
|
| 252 |
+
E: Weight matrix of shape (D, C, N)
|
| 253 |
+
log_dt: Time step of shape (D, N)
|
| 254 |
+
|
| 255 |
+
Returns:
|
| 256 |
+
Ad: Discretized state transition of shape (D, C, N, 2)
|
| 257 |
+
B_hat: Discretized input projection of shape (D, C, N)
|
| 258 |
+
"""
|
| 259 |
+
A_real = -F.softplus(A[..., 0]) # (D, C, N)
|
| 260 |
+
A_imag = A[..., 1] # (D, C, N)
|
| 261 |
+
|
| 262 |
+
dt = torch.exp(log_dt) # (D, N)
|
| 263 |
+
dt_exp = dt[:, None, :] # (D, 1, N)
|
| 264 |
+
scaled_real = dt_exp * A_real
|
| 265 |
+
scaled_imag = dt_exp * A_imag
|
| 266 |
+
exp_scaled_real = torch.exp(scaled_real)
|
| 267 |
+
Ad = torch.stack(
|
| 268 |
+
[
|
| 269 |
+
exp_scaled_real * torch.cos(scaled_imag),
|
| 270 |
+
exp_scaled_real * torch.sin(scaled_imag),
|
| 271 |
+
],
|
| 272 |
+
dim=-1,
|
| 273 |
+
) # (D, C, N, 2)
|
| 274 |
+
|
| 275 |
+
B_hat = E * dt_exp # (D, C, N)
|
| 276 |
+
|
| 277 |
+
return Ad, B_hat
|
| 278 |
+
|
| 279 |
+
|
| 280 |
+
def step_full(u, Ad, B_hat, state):
|
| 281 |
+
"""Single timestep for Full mode using pre-discretized parameters.
|
| 282 |
+
|
| 283 |
+
Args:
|
| 284 |
+
u: Input tensor of shape (C,)
|
| 285 |
+
Ad: Discretized state transition of shape (D, C, N, 2)
|
| 286 |
+
B_hat: Discretized input projection of shape (D, C, N)
|
| 287 |
+
state: Previous state of shape (D, C, N, 2), or None for zero init
|
| 288 |
+
|
| 289 |
+
Returns:
|
| 290 |
+
y: Output tensor of shape (D,)
|
| 291 |
+
new_state: Updated state of shape (D, C, N, 2)
|
| 292 |
+
"""
|
| 293 |
+
if state is None:
|
| 294 |
+
D, C, N = B_hat.shape
|
| 295 |
+
state = torch.zeros((D, C, N, 2), dtype=torch.float32, device=u.device)
|
| 296 |
+
|
| 297 |
+
# State update: x = Ad * x + B_hat * u
|
| 298 |
+
x_new = complex_mul_real(Ad, state) # (D, C, N, 2)
|
| 299 |
+
u_broadcast = u.unsqueeze(0).unsqueeze(2) # (1, C, 1)
|
| 300 |
+
Bu = B_hat * u_broadcast # (D, C, N)
|
| 301 |
+
x_new[..., 0] = x_new[..., 0] + Bu
|
| 302 |
+
|
| 303 |
+
# Output: y = sum(real(x), dim=(1, 2))
|
| 304 |
+
y = torch.sum(x_new[..., 0], dim=(1, 2)) # (D,)
|
| 305 |
+
|
| 306 |
+
return y, x_new
|
| 307 |
+
|
| 308 |
+
|
| 309 |
+
# ============================================================================
|
| 310 |
+
# Gate mode (input-dependent discretization — cannot precompute)
|
| 311 |
+
# ============================================================================
|
| 312 |
+
|
| 313 |
+
|
| 314 |
+
def recurrent_gate_single_step(u, A, B_proj, C_proj, log_dt_proj, state):
|
| 315 |
+
"""
|
| 316 |
+
Gate-style SSM single timestep for ONNX export.
|
| 317 |
+
|
| 318 |
+
Processes single timestep with input-dependent parameters.
|
| 319 |
+
Unlike other modes, gate uses neural network projections for B, C, and dt.
|
| 320 |
+
|
| 321 |
+
Args:
|
| 322 |
+
u: Input tensor of shape (C_in,) - single timestep, single batch
|
| 323 |
+
A: State transition parameter of shape (N,) - in log space, represents decay rates
|
| 324 |
+
B_proj: nn.Module that projects (C_in,) -> (N,)
|
| 325 |
+
C_proj: nn.Module that projects (N,) -> (D,)
|
| 326 |
+
log_dt_proj: nn.Module that projects (C_in,) -> (N,)
|
| 327 |
+
state: Previous state of shape (N,) - real-valued state
|
| 328 |
+
|
| 329 |
+
Returns:
|
| 330 |
+
y: Output tensor of shape (D,)
|
| 331 |
+
new_state: Updated state of shape (N,)
|
| 332 |
+
|
| 333 |
+
State update formula:
|
| 334 |
+
log_dt = log_dt_proj(u)
|
| 335 |
+
dt = softplus(log_dt)
|
| 336 |
+
u_proj = B_proj(u)
|
| 337 |
+
dta = exp(-dt * exp(A)) # discretized decay
|
| 338 |
+
x_new = dta * x_old + dt * u_proj
|
| 339 |
+
y = C_proj(x_new)
|
| 340 |
+
"""
|
| 341 |
+
# Get input-dependent projections
|
| 342 |
+
u_proj = B_proj(u) # (N,)
|
| 343 |
+
log_dt = log_dt_proj(u) # (N,)
|
| 344 |
+
|
| 345 |
+
# Discretization
|
| 346 |
+
dt = F.softplus(log_dt) # (N,)
|
| 347 |
+
exp_A = torch.exp(A) # (N,) - decay rate
|
| 348 |
+
dta = torch.exp(-dt * exp_A) # (N,) - discretized decay factor
|
| 349 |
+
|
| 350 |
+
# State update: x_new = dta * x_old + dt * u_proj
|
| 351 |
+
u_dt = u_proj * dt # (N,)
|
| 352 |
+
new_state = dta * state + u_dt # (N,)
|
| 353 |
+
|
| 354 |
+
# Output projection
|
| 355 |
+
y = C_proj(new_state) # (D,)
|
| 356 |
+
|
| 357 |
+
return y, new_state
|
| 358 |
+
|
| 359 |
+
|
| 360 |
+
def recurrent_gate(u, A, B_proj, C_proj, log_dt_proj, state=None):
|
| 361 |
+
"""
|
| 362 |
+
Gate-style SSM using sequential scan for streaming inference.
|
| 363 |
+
|
| 364 |
+
Args:
|
| 365 |
+
u: Input tensor of shape (T, C_in) or (B, T, C_in)
|
| 366 |
+
A: State transition parameter of shape (N,) - in log space
|
| 367 |
+
B_proj: nn.Module or callable that projects (*, C_in) -> (*, N)
|
| 368 |
+
C_proj: nn.Module or callable that projects (*, N) -> (*, D)
|
| 369 |
+
log_dt_proj: nn.Module or callable that projects (*, C_in) -> (*, N)
|
| 370 |
+
state: Optional previous state of shape (N,) or (B, N)
|
| 371 |
+
|
| 372 |
+
Returns:
|
| 373 |
+
y: Output tensor of shape (D, T) or (B, D, T)
|
| 374 |
+
state: Updated state of shape (N,) or (B, N)
|
| 375 |
+
"""
|
| 376 |
+
# Handle batched input
|
| 377 |
+
if u.dim() == 2:
|
| 378 |
+
u = u.unsqueeze(0) # (T, C_in) -> (1, T, C_in)
|
| 379 |
+
squeeze_batch = True
|
| 380 |
+
else:
|
| 381 |
+
squeeze_batch = False
|
| 382 |
+
|
| 383 |
+
B_batch, T, C_in = u.shape
|
| 384 |
+
|
| 385 |
+
# Reshape to (B*T, C_in) for vectorized projection
|
| 386 |
+
u_flat = u.reshape(B_batch * T, C_in)
|
| 387 |
+
|
| 388 |
+
# Get projections
|
| 389 |
+
u_proj = B_proj(u_flat) # (B*T, N)
|
| 390 |
+
log_dt = log_dt_proj(u_flat) # (B*T, N)
|
| 391 |
+
|
| 392 |
+
N = u_proj.shape[1]
|
| 393 |
+
|
| 394 |
+
# Reshape back to (B, T, N)
|
| 395 |
+
u_proj = u_proj.reshape(B_batch, T, N)
|
| 396 |
+
log_dt = log_dt.reshape(B_batch, T, N)
|
| 397 |
+
|
| 398 |
+
# Discretize (vectorized across batch)
|
| 399 |
+
dt = F.softplus(log_dt).to(torch.float32) # (B, T, N)
|
| 400 |
+
exp_A = torch.exp(A).to(torch.float32) # (N,)
|
| 401 |
+
log_dta = -dt * exp_A[None, None, :] # (B, T, N)
|
| 402 |
+
dta = torch.exp(log_dta) # (B, T, N)
|
| 403 |
+
|
| 404 |
+
# Prepare scan input
|
| 405 |
+
u_dt = u_proj * dt # (B, T, N)
|
| 406 |
+
|
| 407 |
+
# Initialize state
|
| 408 |
+
if state is None:
|
| 409 |
+
x = torch.zeros((B_batch, N), dtype=torch.float32, device=u.device)
|
| 410 |
+
else:
|
| 411 |
+
if state.dim() == 1:
|
| 412 |
+
x = state.unsqueeze(0).expand(B_batch, -1)
|
| 413 |
+
else:
|
| 414 |
+
x = state
|
| 415 |
+
|
| 416 |
+
# Output accumulator
|
| 417 |
+
states = torch.zeros((B_batch, T, N), dtype=torch.float32, device=u.device)
|
| 418 |
+
|
| 419 |
+
# Sequential scan over time (vectorized over batch)
|
| 420 |
+
for t in range(T):
|
| 421 |
+
x = dta[:, t] * x + u_dt[:, t] # (B, N)
|
| 422 |
+
states[:, t] = x
|
| 423 |
+
|
| 424 |
+
# Apply C projection: (B*T, N) -> (B*T, D)
|
| 425 |
+
states_flat = states.reshape(B_batch * T, N)
|
| 426 |
+
y_flat = C_proj(states_flat) # (B*T, D)
|
| 427 |
+
D = y_flat.shape[1]
|
| 428 |
+
y = y_flat.reshape(B_batch, T, D) # (B, T, D)
|
| 429 |
+
|
| 430 |
+
# Return format
|
| 431 |
+
y = y.transpose(1, 2) # (B, D, T)
|
| 432 |
+
|
| 433 |
+
if squeeze_batch:
|
| 434 |
+
y = y.squeeze(0) # (D, T)
|
| 435 |
+
x = x.squeeze(0) # (N,)
|
| 436 |
+
|
| 437 |
+
return y, x
|
tenns_core/scan_ops.py
ADDED
|
@@ -0,0 +1,515 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Parallel scan operations for gate mode SSM.
|
| 3 |
+
|
| 4 |
+
Implements parallel prefix scan with custom autograd for training support.
|
| 5 |
+
Uses Triton kernels when available on CUDA, falls back to pure PyTorch otherwise.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
from torch import nn
|
| 10 |
+
from torch.nn import functional as F
|
| 11 |
+
|
| 12 |
+
try:
|
| 13 |
+
import triton
|
| 14 |
+
import triton.language as tl
|
| 15 |
+
|
| 16 |
+
_HAS_TRITON = hasattr(tl, 'associative_scan')
|
| 17 |
+
except ImportError:
|
| 18 |
+
_HAS_TRITON = False
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
# ----------------------------
|
| 22 |
+
# Utility
|
| 23 |
+
# ----------------------------
|
| 24 |
+
def _tp(x: torch.Tensor) -> torch.Tensor:
|
| 25 |
+
"""(B, L, N) -> (B, N, L) contiguous."""
|
| 26 |
+
return x.moveaxis(-1, -2).contiguous()
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
# ----------------------------
|
| 30 |
+
# Reference (naive) scan
|
| 31 |
+
# ----------------------------
|
| 32 |
+
def scan_naive(input, log_dt, A, state=None, dim=-1):
|
| 33 |
+
"""Naive sequential scan implementation.
|
| 34 |
+
|
| 35 |
+
Useful for testing and understanding, but slow (O(N) sequential steps).
|
| 36 |
+
|
| 37 |
+
Args:
|
| 38 |
+
input: Input tensor
|
| 39 |
+
log_dt: Log timestep parameters
|
| 40 |
+
A: State decay parameters
|
| 41 |
+
state: Optional initial state
|
| 42 |
+
dim: Dimension to scan over
|
| 43 |
+
|
| 44 |
+
Returns:
|
| 45 |
+
Scanned output tensor
|
| 46 |
+
"""
|
| 47 |
+
dt = F.softplus(log_dt)
|
| 48 |
+
log_dta = -dt * A.exp()[..., None]
|
| 49 |
+
a = log_dta.exp()
|
| 50 |
+
|
| 51 |
+
if state is None:
|
| 52 |
+
state = 0
|
| 53 |
+
output = []
|
| 54 |
+
u = input * dt
|
| 55 |
+
|
| 56 |
+
for ui, ai in zip(u.moveaxis(dim, 0), a.moveaxis(dim, 0), strict=True):
|
| 57 |
+
state = ai * state + ui
|
| 58 |
+
output.append(state)
|
| 59 |
+
|
| 60 |
+
return torch.stack(output, dim=dim)
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
# ----------------------------
|
| 64 |
+
# PyTorch parallel scan
|
| 65 |
+
# ----------------------------
|
| 66 |
+
class ParallelScan(torch.autograd.Function):
|
| 67 |
+
"""Parallel prefix scan with custom autograd.
|
| 68 |
+
|
| 69 |
+
Implements the associative scan operation:
|
| 70 |
+
state[t] = a[t] * state[t-1] + u[t]
|
| 71 |
+
|
| 72 |
+
In O(log N) parallel depth instead of O(N) sequential steps.
|
| 73 |
+
|
| 74 |
+
Note: This uses the naive sequential scan for backward pass to ensure
|
| 75 |
+
correctness. For production use with very long sequences, a parallel
|
| 76 |
+
backward scan could be implemented.
|
| 77 |
+
"""
|
| 78 |
+
|
| 79 |
+
@staticmethod
|
| 80 |
+
def forward(ctx, u, a):
|
| 81 |
+
"""Forward pass: parallel prefix scan.
|
| 82 |
+
|
| 83 |
+
Args:
|
| 84 |
+
u: Input values (batch, N, length)
|
| 85 |
+
a: Decay factors (batch, N, length)
|
| 86 |
+
|
| 87 |
+
Returns:
|
| 88 |
+
Scanned output (batch, N, length)
|
| 89 |
+
"""
|
| 90 |
+
length = u.shape[-1]
|
| 91 |
+
strides = [2**i for i in range((length - 1).bit_length())]
|
| 92 |
+
|
| 93 |
+
# Save original inputs for backward
|
| 94 |
+
u_original = u.clone()
|
| 95 |
+
a_original = a.clone()
|
| 96 |
+
|
| 97 |
+
# Clone to avoid in-place modifications
|
| 98 |
+
u = u.clone()
|
| 99 |
+
a = a.clone()
|
| 100 |
+
|
| 101 |
+
for stride in strides:
|
| 102 |
+
u[..., stride:] = u[..., stride:] + u[..., :-stride] * a[..., stride:]
|
| 103 |
+
a[..., stride:] = a[..., stride:] * a[..., :-stride]
|
| 104 |
+
|
| 105 |
+
ctx.save_for_backward(u_original, a_original, u)
|
| 106 |
+
return u
|
| 107 |
+
|
| 108 |
+
@staticmethod
|
| 109 |
+
def backward(ctx, grad_output):
|
| 110 |
+
"""Backward pass using sequential scan for correctness.
|
| 111 |
+
|
| 112 |
+
For production, this could be parallelized, but sequential is more
|
| 113 |
+
numerically stable and easier to verify.
|
| 114 |
+
"""
|
| 115 |
+
u_original, a_original, y = ctx.saved_tensors
|
| 116 |
+
|
| 117 |
+
# Compute gradients using reverse-mode automatic differentiation
|
| 118 |
+
# by recomputing forward pass while tracking dependencies
|
| 119 |
+
|
| 120 |
+
grad_u = torch.zeros_like(u_original)
|
| 121 |
+
grad_a = torch.zeros_like(a_original)
|
| 122 |
+
|
| 123 |
+
# Backward scan: process from right to left
|
| 124 |
+
length = u_original.shape[-1]
|
| 125 |
+
|
| 126 |
+
# Accumulator for gradient flowing backward through time
|
| 127 |
+
grad_state = torch.zeros_like(u_original[..., 0:1])
|
| 128 |
+
|
| 129 |
+
for t in range(length - 1, -1, -1):
|
| 130 |
+
# Gradient from output at time t
|
| 131 |
+
grad_y_t = grad_output[..., t : t + 1]
|
| 132 |
+
|
| 133 |
+
# Total gradient flowing into state[t]
|
| 134 |
+
grad_state_t = grad_y_t + grad_state
|
| 135 |
+
|
| 136 |
+
# Gradients w.r.t. inputs
|
| 137 |
+
grad_u[..., t : t + 1] = grad_state_t
|
| 138 |
+
if t > 0:
|
| 139 |
+
grad_a[..., t : t + 1] = grad_state_t * y[..., t - 1 : t]
|
| 140 |
+
|
| 141 |
+
# Propagate gradient to previous state
|
| 142 |
+
if t > 0:
|
| 143 |
+
grad_state = grad_state_t * a_original[..., t : t + 1]
|
| 144 |
+
|
| 145 |
+
return grad_u, grad_a
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
def parallel_scan_pytorch(input, log_dt, A, state=None):
|
| 149 |
+
"""Pure PyTorch parallel scan for SSM.
|
| 150 |
+
|
| 151 |
+
Args:
|
| 152 |
+
input: Input tensor (batch, length, N)
|
| 153 |
+
log_dt: Log timestep parameters (batch, length, N)
|
| 154 |
+
A: State decay parameters (N,)
|
| 155 |
+
state: Optional initial state (N,)
|
| 156 |
+
|
| 157 |
+
Returns:
|
| 158 |
+
Scanned output (batch, length, N)
|
| 159 |
+
"""
|
| 160 |
+
dt = F.softplus(log_dt)
|
| 161 |
+
log_dta = -dt * A.exp()[None, None, :]
|
| 162 |
+
a = log_dta.exp()
|
| 163 |
+
u = input * dt
|
| 164 |
+
|
| 165 |
+
# Fold initial state into first timestep
|
| 166 |
+
if state is not None:
|
| 167 |
+
u = u.clone()
|
| 168 |
+
u[:, 0, :] = u[:, 0, :] + state * a[:, 0, :]
|
| 169 |
+
|
| 170 |
+
# Transpose for scan: (batch, N, length)
|
| 171 |
+
u = u.transpose(-1, -2)
|
| 172 |
+
a = a.transpose(-1, -2)
|
| 173 |
+
|
| 174 |
+
# Apply parallel scan
|
| 175 |
+
output = ParallelScan.apply(u, a)
|
| 176 |
+
|
| 177 |
+
# Transpose back: (batch, length, N)
|
| 178 |
+
return output.transpose(-1, -2)
|
| 179 |
+
|
| 180 |
+
|
| 181 |
+
# ----------------------------
|
| 182 |
+
# Triton kernels (guarded)
|
| 183 |
+
# ----------------------------
|
| 184 |
+
if _HAS_TRITON:
|
| 185 |
+
|
| 186 |
+
@triton.jit
|
| 187 |
+
def _roll_op(x1, y1, x2, y2):
|
| 188 |
+
return x2, tl.where(y2 == float('inf'), x1, y2)
|
| 189 |
+
|
| 190 |
+
@triton.jit
|
| 191 |
+
def roll(u, length: tl.constexpr, reverse: tl.constexpr = 0):
|
| 192 |
+
if reverse:
|
| 193 |
+
_, u_rol = tl.associative_scan((u, float('inf') + u), 0, _roll_op, reverse=1)
|
| 194 |
+
u_rol = tl.where(tl.arange(0, length) < length - 1, u_rol, 0)
|
| 195 |
+
else:
|
| 196 |
+
_, u_rol = tl.associative_scan((u, float('inf') + u), 0, _roll_op)
|
| 197 |
+
u_rol = tl.where(tl.arange(0, length) > 0, u_rol, 0)
|
| 198 |
+
return u_rol
|
| 199 |
+
|
| 200 |
+
@triton.jit
|
| 201 |
+
def _scan_op(a1, x1, a2, x2):
|
| 202 |
+
return a1 * a2, a2 * x1 + x2
|
| 203 |
+
|
| 204 |
+
@triton.jit
|
| 205 |
+
def softplus_tl(x):
|
| 206 |
+
return tl.where(x < 20, tl.log(1 + tl.exp(x)), x)
|
| 207 |
+
|
| 208 |
+
@triton.jit
|
| 209 |
+
def scan_heisen_fwd_triton(
|
| 210 |
+
u_ptr,
|
| 211 |
+
log_dt_ptr,
|
| 212 |
+
A_ptr,
|
| 213 |
+
y_ptr,
|
| 214 |
+
state_ptr,
|
| 215 |
+
L,
|
| 216 |
+
N: tl.constexpr,
|
| 217 |
+
MAX_L: tl.constexpr,
|
| 218 |
+
INIT_STATE: tl.constexpr = 1,
|
| 219 |
+
):
|
| 220 |
+
id_BATCH, id_N = tl.program_id(0), tl.program_id(1)
|
| 221 |
+
id_sample = id_BATCH * N + id_N
|
| 222 |
+
|
| 223 |
+
lrange = tl.arange(0, MAX_L)
|
| 224 |
+
offsets = id_sample * L + lrange
|
| 225 |
+
mask = lrange < L
|
| 226 |
+
|
| 227 |
+
A = tl.load(A_ptr + id_N)
|
| 228 |
+
if INIT_STATE:
|
| 229 |
+
state = tl.load(state_ptr + id_N)
|
| 230 |
+
|
| 231 |
+
u = tl.load(u_ptr + offsets, mask, 0).to(tl.float32)
|
| 232 |
+
log_dt = tl.load(log_dt_ptr + offsets, mask, 0).to(tl.float32)
|
| 233 |
+
|
| 234 |
+
dt = softplus_tl(log_dt)
|
| 235 |
+
log_dta = -1.0 * dt * tl.exp(A)
|
| 236 |
+
dta = tl.exp(log_dta)
|
| 237 |
+
|
| 238 |
+
if INIT_STATE:
|
| 239 |
+
u_dt = tl.where(lrange > 0, u * dt, u * dt + state * dta)
|
| 240 |
+
else:
|
| 241 |
+
u_dt = u * dt
|
| 242 |
+
|
| 243 |
+
_, y = tl.associative_scan((dta, u_dt), 0, _scan_op)
|
| 244 |
+
tl.store(y_ptr + offsets, y, mask)
|
| 245 |
+
|
| 246 |
+
@triton.jit
|
| 247 |
+
def scan_heisen_bwd_triton(
|
| 248 |
+
u_ptr,
|
| 249 |
+
grad_x_ptr,
|
| 250 |
+
log_dt_ptr,
|
| 251 |
+
A_ptr,
|
| 252 |
+
state_ptr,
|
| 253 |
+
grad_u_ptr,
|
| 254 |
+
grad_log_dt_ptr,
|
| 255 |
+
grad_A_ptr,
|
| 256 |
+
grad_x0_ptr,
|
| 257 |
+
L,
|
| 258 |
+
N: tl.constexpr,
|
| 259 |
+
MAX_L: tl.constexpr,
|
| 260 |
+
INIT_STATE: tl.constexpr = 1,
|
| 261 |
+
):
|
| 262 |
+
id_BATCH, id_N = tl.program_id(0), tl.program_id(1)
|
| 263 |
+
id_sample = id_BATCH * N + id_N
|
| 264 |
+
|
| 265 |
+
lrange = tl.arange(0, MAX_L)
|
| 266 |
+
offsets = id_sample * L + lrange
|
| 267 |
+
mask = lrange < L
|
| 268 |
+
|
| 269 |
+
A = tl.load(A_ptr + id_N)
|
| 270 |
+
exp_A = tl.exp(A)
|
| 271 |
+
if INIT_STATE:
|
| 272 |
+
state = tl.load(state_ptr + id_N)
|
| 273 |
+
|
| 274 |
+
u = tl.load(u_ptr + offsets, mask, 0).to(tl.float32)
|
| 275 |
+
log_dt = tl.load(log_dt_ptr + offsets, mask, 0).to(tl.float32)
|
| 276 |
+
|
| 277 |
+
dt = softplus_tl(log_dt)
|
| 278 |
+
log_dta = -1.0 * dt * exp_A
|
| 279 |
+
dta = tl.exp(log_dta)
|
| 280 |
+
|
| 281 |
+
if INIT_STATE:
|
| 282 |
+
u_dt = tl.where(lrange > 0, u * dt, u * dt + state * dta)
|
| 283 |
+
else:
|
| 284 |
+
u_dt = u * dt
|
| 285 |
+
|
| 286 |
+
_, x = tl.associative_scan((dta, u_dt), 0, _scan_op)
|
| 287 |
+
x_rol = roll(x, MAX_L)
|
| 288 |
+
|
| 289 |
+
grad_x = tl.load(grad_x_ptr + offsets, mask, 0).to(tl.float32)
|
| 290 |
+
|
| 291 |
+
if INIT_STATE:
|
| 292 |
+
log_dta_star = tl.cumsum(log_dta, 0)
|
| 293 |
+
dta_star = tl.exp(log_dta_star)
|
| 294 |
+
grad_x0 = tl.sum(grad_x * dta_star, 0)
|
| 295 |
+
tl.store(grad_x0_ptr + id_sample, grad_x0)
|
| 296 |
+
x_rol = tl.where(lrange > 0, x_rol, state)
|
| 297 |
+
|
| 298 |
+
dta_rol = roll(dta, MAX_L, reverse=1)
|
| 299 |
+
_, grad_x = tl.associative_scan((dta_rol, grad_x), 0, _scan_op, reverse=1)
|
| 300 |
+
|
| 301 |
+
grad_u = grad_x * dt
|
| 302 |
+
tl.store(grad_u_ptr + offsets, grad_u, mask)
|
| 303 |
+
|
| 304 |
+
grad_dta = grad_x * x_rol
|
| 305 |
+
grad_log_dta = tl.exp(log_dta) * grad_dta
|
| 306 |
+
|
| 307 |
+
grad_log_dt = (-1.0 * grad_log_dta * exp_A + u * grad_x) * tl.sigmoid(log_dt)
|
| 308 |
+
tl.store(grad_log_dt_ptr + offsets, grad_log_dt, mask)
|
| 309 |
+
|
| 310 |
+
grad_A = tl.sum(grad_log_dta * log_dta, 0)
|
| 311 |
+
tl.store(grad_A_ptr + id_sample, grad_A)
|
| 312 |
+
|
| 313 |
+
class FusedScanTriton(torch.autograd.Function):
|
| 314 |
+
@staticmethod
|
| 315 |
+
@torch.compiler.disable
|
| 316 |
+
@torch.amp.custom_fwd(device_type='cuda')
|
| 317 |
+
def forward(ctx, u, T1, T2, logdt_bias, A, B1, B2, state=None):
|
| 318 |
+
INIT_STATE = state is not None
|
| 319 |
+
|
| 320 |
+
uh = u.half()
|
| 321 |
+
T1, T2, logdt_bias, B1 = T1.half(), T2.half(), logdt_bias.half(), B1.half()
|
| 322 |
+
if B2 is not None:
|
| 323 |
+
B2 = B2.half()
|
| 324 |
+
|
| 325 |
+
if B2 is not None:
|
| 326 |
+
u1 = F.linear(uh, B1)
|
| 327 |
+
u2_tp = _tp(F.linear(u1, B2))
|
| 328 |
+
else:
|
| 329 |
+
u2_tp = _tp(F.linear(uh, B1))
|
| 330 |
+
|
| 331 |
+
logdt_1 = F.linear(uh, T1)
|
| 332 |
+
logdt_tp = _tp(F.linear(logdt_1, T2, bias=logdt_bias))
|
| 333 |
+
|
| 334 |
+
x_tp = torch.empty_like(u2_tp, dtype=torch.float32)
|
| 335 |
+
|
| 336 |
+
BATCH, N, L = u2_tp.shape
|
| 337 |
+
grid = (BATCH, N)
|
| 338 |
+
max_L = triton.next_power_of_2(L)
|
| 339 |
+
num_warps = max(max_L // 1024, 1)
|
| 340 |
+
|
| 341 |
+
scan_heisen_fwd_triton[grid](
|
| 342 |
+
u2_tp,
|
| 343 |
+
logdt_tp,
|
| 344 |
+
A,
|
| 345 |
+
x_tp,
|
| 346 |
+
state,
|
| 347 |
+
L,
|
| 348 |
+
N,
|
| 349 |
+
max_L,
|
| 350 |
+
INIT_STATE=INIT_STATE,
|
| 351 |
+
num_warps=num_warps,
|
| 352 |
+
num_stages=3,
|
| 353 |
+
)
|
| 354 |
+
|
| 355 |
+
if B2 is not None:
|
| 356 |
+
ctx.save_for_backward(uh, state, A, T1, T2, logdt_bias, B1, B2)
|
| 357 |
+
ctx.B2_flag = True
|
| 358 |
+
else:
|
| 359 |
+
ctx.save_for_backward(uh, u2_tp, state, A, T1, T2, logdt_bias, B1)
|
| 360 |
+
ctx.B2_flag = False
|
| 361 |
+
|
| 362 |
+
return x_tp.moveaxis(-1, -2) # (B, L, N)
|
| 363 |
+
|
| 364 |
+
@staticmethod
|
| 365 |
+
@torch.compiler.disable
|
| 366 |
+
@torch.amp.custom_bwd(device_type='cuda')
|
| 367 |
+
def backward(ctx, grad_x):
|
| 368 |
+
def back_dot(x, y):
|
| 369 |
+
return torch.tensordot(x, y, dims=([0], [0]))
|
| 370 |
+
|
| 371 |
+
if ctx.B2_flag:
|
| 372 |
+
uh, state, A, T1, T2, logdt_bias, B1, B2 = ctx.saved_tensors
|
| 373 |
+
else:
|
| 374 |
+
uh, u2_tp, state, A, T1, T2, logdt_bias, B1 = ctx.saved_tensors
|
| 375 |
+
B2 = None
|
| 376 |
+
|
| 377 |
+
INIT_STATE = state is not None
|
| 378 |
+
grad_x_tp = _tp(grad_x)
|
| 379 |
+
|
| 380 |
+
if B2 is not None:
|
| 381 |
+
u1 = F.linear(uh, B1)
|
| 382 |
+
u2_tp = _tp(F.linear(u1, B2))
|
| 383 |
+
|
| 384 |
+
logdt1 = F.linear(uh, T1)
|
| 385 |
+
logdt2 = F.linear(logdt1, T2, bias=logdt_bias)
|
| 386 |
+
logdt2_tp = _tp(logdt2)
|
| 387 |
+
|
| 388 |
+
BATCH, N, L = u2_tp.shape
|
| 389 |
+
grid = (BATCH, N)
|
| 390 |
+
|
| 391 |
+
grad_u2_tp = torch.empty_like(u2_tp, dtype=torch.float32)
|
| 392 |
+
grad_logdt_tp = torch.empty_like(u2_tp, dtype=torch.float32)
|
| 393 |
+
grad_A = torch.empty(grid, dtype=torch.float32, device=A.device)
|
| 394 |
+
grad_x0 = (
|
| 395 |
+
torch.empty((BATCH, N), dtype=torch.float32, device=A.device)
|
| 396 |
+
if INIT_STATE
|
| 397 |
+
else None
|
| 398 |
+
)
|
| 399 |
+
|
| 400 |
+
max_L = triton.next_power_of_2(L)
|
| 401 |
+
num_warps = max(max_L // 1024, 1)
|
| 402 |
+
|
| 403 |
+
scan_heisen_bwd_triton[grid](
|
| 404 |
+
u2_tp,
|
| 405 |
+
grad_x_tp,
|
| 406 |
+
logdt2_tp,
|
| 407 |
+
A,
|
| 408 |
+
state,
|
| 409 |
+
grad_u2_tp,
|
| 410 |
+
grad_logdt_tp,
|
| 411 |
+
grad_A,
|
| 412 |
+
grad_x0,
|
| 413 |
+
L,
|
| 414 |
+
N,
|
| 415 |
+
max_L,
|
| 416 |
+
INIT_STATE=INIT_STATE,
|
| 417 |
+
num_warps=num_warps,
|
| 418 |
+
num_stages=3,
|
| 419 |
+
)
|
| 420 |
+
|
| 421 |
+
grad_A = grad_A.sum(0)
|
| 422 |
+
grad_init_state = grad_x0.sum(0) if INIT_STATE else None
|
| 423 |
+
|
| 424 |
+
uh2 = uh.view(-1, N)
|
| 425 |
+
grad_u2 = _tp(grad_u2_tp).view(-1, N)
|
| 426 |
+
|
| 427 |
+
if B2 is not None:
|
| 428 |
+
grad_u1 = grad_u2 @ B2
|
| 429 |
+
grad_B2 = back_dot(grad_u2, u1.view(BATCH * L, -1))
|
| 430 |
+
else:
|
| 431 |
+
grad_B2 = None
|
| 432 |
+
grad_u1 = grad_u2
|
| 433 |
+
|
| 434 |
+
grad_u = grad_u1 @ B1
|
| 435 |
+
grad_B1 = back_dot(grad_u1, uh2)
|
| 436 |
+
|
| 437 |
+
grad_logdt_bias = grad_logdt_tp.sum((0, 2))
|
| 438 |
+
|
| 439 |
+
grad_logdt = _tp(grad_logdt_tp).view(-1, N)
|
| 440 |
+
grad_logdt_1 = grad_logdt @ T2
|
| 441 |
+
grad_T2 = back_dot(grad_logdt, logdt1.view(BATCH * L, -1))
|
| 442 |
+
|
| 443 |
+
grad_u = grad_u + grad_logdt_1 @ T1
|
| 444 |
+
grad_T1 = back_dot(grad_logdt_1, uh2)
|
| 445 |
+
|
| 446 |
+
return (
|
| 447 |
+
grad_u.view(BATCH, L, N),
|
| 448 |
+
grad_T1,
|
| 449 |
+
grad_T2,
|
| 450 |
+
grad_logdt_bias,
|
| 451 |
+
grad_A,
|
| 452 |
+
grad_B1,
|
| 453 |
+
grad_B2,
|
| 454 |
+
grad_init_state,
|
| 455 |
+
)
|
| 456 |
+
|
| 457 |
+
|
| 458 |
+
# ----------------------------
|
| 459 |
+
# Unified API
|
| 460 |
+
# ----------------------------
|
| 461 |
+
def _can_use_triton(u: torch.Tensor) -> bool:
|
| 462 |
+
if not _HAS_TRITON:
|
| 463 |
+
return False
|
| 464 |
+
if not u.is_cuda:
|
| 465 |
+
return False
|
| 466 |
+
try:
|
| 467 |
+
major, _ = torch.cuda.get_device_capability(u.device)
|
| 468 |
+
if major < 7:
|
| 469 |
+
return False
|
| 470 |
+
except Exception:
|
| 471 |
+
pass
|
| 472 |
+
return True
|
| 473 |
+
|
| 474 |
+
|
| 475 |
+
def fused_scan(u, log_dt_proj, in_proj, A, state=None):
|
| 476 |
+
"""Fused scan operation for gate mode SSM.
|
| 477 |
+
|
| 478 |
+
Uses Triton kernels on CUDA when available, falls back to PyTorch parallel scan.
|
| 479 |
+
|
| 480 |
+
Args:
|
| 481 |
+
u: Input tensor (batch, length, channels)
|
| 482 |
+
log_dt_proj: Sequential module for timestep projection
|
| 483 |
+
in_proj: Sequential or single module for input projection
|
| 484 |
+
A: State decay parameters (N,)
|
| 485 |
+
state: Optional initial state (N,)
|
| 486 |
+
|
| 487 |
+
Returns:
|
| 488 |
+
Scanned output (batch, length, N)
|
| 489 |
+
"""
|
| 490 |
+
if _can_use_triton(u):
|
| 491 |
+
# Extract weights for Triton path
|
| 492 |
+
if isinstance(in_proj, nn.Linear):
|
| 493 |
+
B1, B2 = in_proj.weight, None
|
| 494 |
+
else:
|
| 495 |
+
B1, B2 = in_proj[0].weight, in_proj[1].weight
|
| 496 |
+
|
| 497 |
+
T1 = log_dt_proj[0].weight
|
| 498 |
+
T2 = log_dt_proj[1].weight
|
| 499 |
+
logdt_bias = log_dt_proj[1].bias
|
| 500 |
+
|
| 501 |
+
return FusedScanTriton.apply(
|
| 502 |
+
u.contiguous(),
|
| 503 |
+
T1.contiguous(),
|
| 504 |
+
T2.contiguous(),
|
| 505 |
+
logdt_bias.contiguous(),
|
| 506 |
+
A.contiguous(),
|
| 507 |
+
B1.contiguous(),
|
| 508 |
+
B2.contiguous() if B2 is not None else None,
|
| 509 |
+
state.contiguous() if state is not None else None,
|
| 510 |
+
)
|
| 511 |
+
|
| 512 |
+
# PyTorch fallback (CPU or CUDA without Triton)
|
| 513 |
+
u_proj = in_proj(u)
|
| 514 |
+
log_dt = log_dt_proj(u)
|
| 515 |
+
return parallel_scan_pytorch(u_proj, log_dt, A, state=state)
|
tenns_core/ssm.py
ADDED
|
@@ -0,0 +1,481 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
State Space Model (SSM) layers for sequence modeling.
|
| 3 |
+
|
| 4 |
+
This module provides SSMLayer, a flexible implementation of various SSM architectures
|
| 5 |
+
including S5, DWS, Neck, Full, and Gate modes. All implementations use pure PyTorch
|
| 6 |
+
with custom autograd functions for efficient training.
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
import math
|
| 10 |
+
|
| 11 |
+
import einops
|
| 12 |
+
import numpy as np
|
| 13 |
+
import torch
|
| 14 |
+
from torch import nn
|
| 15 |
+
from torch.nn import functional as F
|
| 16 |
+
from torch.nn.parameter import Parameter
|
| 17 |
+
|
| 18 |
+
from .activations import get_activations
|
| 19 |
+
from .fft_ops import padded_fft_conv_opt
|
| 20 |
+
from .scan_ops import fused_scan
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
# Utility functions
|
| 24 |
+
def c2r(inputs):
|
| 25 |
+
return torch.view_as_real(inputs)
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def r2c(inputs):
|
| 29 |
+
return torch.view_as_complex(inputs)
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def inv_softplus(x):
|
| 33 |
+
return x + np.log(-np.expm1(-x))
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
class Kernelizer(nn.Module):
|
| 37 |
+
"""Core module for SSM operations using FFT convolutions and parallel scans.
|
| 38 |
+
|
| 39 |
+
This is the base class that handles the actual SSM computation.
|
| 40 |
+
SSMLayer extends this with parameter initialization and training utilities.
|
| 41 |
+
"""
|
| 42 |
+
|
| 43 |
+
def __init__(self, mode='s5', transposed=False, complex_proj=False, **kwargs):
|
| 44 |
+
"""Initialize Kernelizer.
|
| 45 |
+
|
| 46 |
+
Args:
|
| 47 |
+
mode: SSM mode ('s5', 'dws', 'neck', 'full', 'gate')
|
| 48 |
+
transposed: Whether to use transposed operations (time-last vs channel-last)
|
| 49 |
+
complex_proj: Whether to use complex projections
|
| 50 |
+
"""
|
| 51 |
+
super().__init__()
|
| 52 |
+
|
| 53 |
+
self.mode = mode
|
| 54 |
+
self.transposed = transposed
|
| 55 |
+
self.complex_proj = complex_proj
|
| 56 |
+
|
| 57 |
+
@torch.compiler.disable
|
| 58 |
+
def discretize(self, A: torch.Tensor, weight: torch.Tensor, log_dt: torch.Tensor):
|
| 59 |
+
"""Discretize continuous-time SSM using zero-order-hold method.
|
| 60 |
+
|
| 61 |
+
Converts continuous-time parameters (A, B, dt) to discrete-time (A_bar, B_bar)
|
| 62 |
+
using the zero-order-hold discretization:
|
| 63 |
+
A_bar = exp(A * dt)
|
| 64 |
+
B_bar = B * dt
|
| 65 |
+
|
| 66 |
+
NOTE: Assumes diagonal state matrix A.
|
| 67 |
+
|
| 68 |
+
Args:
|
| 69 |
+
A: State matrix diagonal [real, imag] (shape varies by mode)
|
| 70 |
+
weight: Input weight matrix B or output weight E (shape varies by mode)
|
| 71 |
+
log_dt: Log of timestep parameters
|
| 72 |
+
|
| 73 |
+
Returns:
|
| 74 |
+
Tuple of (dtA_real, dtA_imag, weight_hat) discretized parameters
|
| 75 |
+
"""
|
| 76 |
+
with torch.autocast('cuda', enabled=False):
|
| 77 |
+
A_real, A_imag = -F.softplus(A[..., 0]), A[..., 1]
|
| 78 |
+
dt = log_dt.exp()
|
| 79 |
+
|
| 80 |
+
match self.mode:
|
| 81 |
+
case 'neck':
|
| 82 |
+
dt = dt.unsqueeze(-1) # (R, :) -> (R, :, 1)
|
| 83 |
+
weight_hat = weight * dt
|
| 84 |
+
case 'full':
|
| 85 |
+
dt = dt.unsqueeze(-2) # (D, :) -> (D, 1, :)
|
| 86 |
+
weight_hat = weight * dt
|
| 87 |
+
case 'dws':
|
| 88 |
+
weight_hat = weight * dt # (C, N)
|
| 89 |
+
case _: # s5, gate
|
| 90 |
+
weight_hat = weight * dt.unsqueeze(-1) # (R*N, :) -> (R*N, C)
|
| 91 |
+
|
| 92 |
+
dtA_real, dtA_imag = dt * A_real, dt * A_imag
|
| 93 |
+
|
| 94 |
+
return dtA_real, dtA_imag, weight_hat
|
| 95 |
+
|
| 96 |
+
def forward(
|
| 97 |
+
self,
|
| 98 |
+
input: torch.Tensor,
|
| 99 |
+
A: torch.Tensor,
|
| 100 |
+
B: torch.Tensor,
|
| 101 |
+
C: torch.Tensor,
|
| 102 |
+
log_dt: torch.Tensor,
|
| 103 |
+
E: torch.Tensor,
|
| 104 |
+
state=None,
|
| 105 |
+
):
|
| 106 |
+
"""Forward pass through SSM layer.
|
| 107 |
+
|
| 108 |
+
Args:
|
| 109 |
+
input: Input tensor (batch, channels, length)
|
| 110 |
+
A: State matrix diagonal parameters
|
| 111 |
+
B: Input projection matrix (for s5/neck/gate modes)
|
| 112 |
+
C: Output projection matrix (for s5/neck modes) or module (for gate)
|
| 113 |
+
log_dt: Log timestep parameters
|
| 114 |
+
E: State mixing matrix (for dws/neck/full modes)
|
| 115 |
+
state: Optional initial state (for gate mode prefix tuning)
|
| 116 |
+
|
| 117 |
+
Returns:
|
| 118 |
+
Output tensor (batch, out_channels, length)
|
| 119 |
+
"""
|
| 120 |
+
match self.mode:
|
| 121 |
+
case 's5' | 'neck':
|
| 122 |
+
dtA_real, dtA_imag, B_hat = self.discretize(A, B, log_dt)
|
| 123 |
+
return padded_fft_conv_opt(input, dtA_real, dtA_imag, B_hat, C, E)
|
| 124 |
+
|
| 125 |
+
case 'dws' | 'full':
|
| 126 |
+
dtA_real, dtA_imag, E_hat = self.discretize(A, E, log_dt)
|
| 127 |
+
return padded_fft_conv_opt(input, dtA_real, dtA_imag, None, None, E_hat)
|
| 128 |
+
|
| 129 |
+
case 'gate':
|
| 130 |
+
# Gate mode can work with both formats
|
| 131 |
+
# Transpose if needed: (B, C, T) -> (B, T, C)
|
| 132 |
+
if not self.transposed:
|
| 133 |
+
input = input.transpose(1, 2)
|
| 134 |
+
|
| 135 |
+
output = C(fused_scan(input, log_dt, B, A, state=state))
|
| 136 |
+
|
| 137 |
+
# Transpose back if needed: (B, T, D) -> (B, D, T)
|
| 138 |
+
if not self.transposed:
|
| 139 |
+
output = output.transpose(1, 2)
|
| 140 |
+
|
| 141 |
+
return output
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
class SSMLayer(Kernelizer):
|
| 145 |
+
"""State Space Model layer with multiple architecture variants.
|
| 146 |
+
|
| 147 |
+
Extends Kernelizer with parameter initialization, activation layers,
|
| 148 |
+
and training utilities. Supports multiple SSM modes:
|
| 149 |
+
|
| 150 |
+
- **s5**: Standard S5 architecture with shared state space
|
| 151 |
+
- **dws**: Depthwise separable variant (per-channel state spaces)
|
| 152 |
+
- **neck**: Bottleneck architecture with low-rank state mixing
|
| 153 |
+
- **full**: Full parameterization (per-output-channel state spaces)
|
| 154 |
+
- **gate**: Input-dependent gating (Mamba-style selective SSM)
|
| 155 |
+
|
| 156 |
+
Mode Comparison:
|
| 157 |
+
----------------
|
| 158 |
+
| Mode | Parameters | Best For | Speed |
|
| 159 |
+
|-------|------------|-----------------------------|---------|
|
| 160 |
+
| s5 | Medium | General sequence modeling | Fast |
|
| 161 |
+
| dws | Low | Efficient local processing | Fastest |
|
| 162 |
+
| neck | Low | Long sequences, low memory | Fast |
|
| 163 |
+
| full | High | Rich feature interactions | Medium |
|
| 164 |
+
| gate | High | Input-adaptive processing | Slow |
|
| 165 |
+
|
| 166 |
+
Usage Example:
|
| 167 |
+
--------------
|
| 168 |
+
>>> # S5 mode for sequence classification
|
| 169 |
+
>>> layer = SSMLayer(
|
| 170 |
+
... num_coeffs=64, # State space dimension
|
| 171 |
+
... in_channels=128, # Input features
|
| 172 |
+
... out_channels=256, # Output features
|
| 173 |
+
... mode='s5',
|
| 174 |
+
... repeat=1, # Number of parallel SSMs
|
| 175 |
+
... norm='layer',
|
| 176 |
+
... postact='gelu'
|
| 177 |
+
... )
|
| 178 |
+
>>> input = torch.randn(4, 128, 512) # (batch, channels, length)
|
| 179 |
+
>>> output = layer(input) # (4, 256, 512)
|
| 180 |
+
"""
|
| 181 |
+
|
| 182 |
+
def __init__(
|
| 183 |
+
self,
|
| 184 |
+
num_coeffs: int,
|
| 185 |
+
in_channels: int,
|
| 186 |
+
out_channels: int,
|
| 187 |
+
repeat=None,
|
| 188 |
+
norm='batch',
|
| 189 |
+
postact='relu',
|
| 190 |
+
dropout=None,
|
| 191 |
+
dropout_dim=1,
|
| 192 |
+
use_activations=False,
|
| 193 |
+
**kwargs,
|
| 194 |
+
):
|
| 195 |
+
"""Initialize SSM layer.
|
| 196 |
+
|
| 197 |
+
Args:
|
| 198 |
+
num_coeffs: Dimension of state space (N in SSM notation)
|
| 199 |
+
in_channels: Number of input channels
|
| 200 |
+
out_channels: Number of output channels
|
| 201 |
+
repeat: Number of parallel SSM blocks (default: 1)
|
| 202 |
+
norm: Normalization type ('batch', 'layer', 'rms', None)
|
| 203 |
+
postact: Activation function ('relu', 'gelu', 'silu', None)
|
| 204 |
+
dropout: Dropout probability (None for no dropout)
|
| 205 |
+
dropout_dim: Dimension for dropout (0, 1, 2, or 3)
|
| 206 |
+
use_activations: Whether to apply activations to mixer output
|
| 207 |
+
**kwargs: Additional arguments (mode, transposed, complex_proj, etc.)
|
| 208 |
+
"""
|
| 209 |
+
_VALID_MODES = {'s5', 'dws', 'neck', 'full', 'gate'}
|
| 210 |
+
_VALID_NORMS = {'batch', 'layer', 'layer-feature', 'rms', None}
|
| 211 |
+
_VALID_POSTACTS = {'relu', 'relu6', 'lelu', 'sigmoid', 'tanh', 'gelu', 'glu', 'silu', None}
|
| 212 |
+
_VALID_DROPOUT_DIMS = {0, 1, 2, 3}
|
| 213 |
+
|
| 214 |
+
mode = kwargs.get('mode', 's5')
|
| 215 |
+
if mode not in _VALID_MODES:
|
| 216 |
+
raise ValueError(f"Invalid mode '{mode}'. Must be one of {sorted(_VALID_MODES)}.")
|
| 217 |
+
if norm not in _VALID_NORMS:
|
| 218 |
+
raise ValueError(
|
| 219 |
+
f"Invalid norm '{norm}'. Must be one of {sorted(_VALID_NORMS, key=str)}."
|
| 220 |
+
)
|
| 221 |
+
if postact not in _VALID_POSTACTS:
|
| 222 |
+
raise ValueError(
|
| 223 |
+
f"Invalid postact '{postact}'. Must be one of {sorted(_VALID_POSTACTS, key=str)}."
|
| 224 |
+
)
|
| 225 |
+
if dropout_dim not in _VALID_DROPOUT_DIMS:
|
| 226 |
+
raise ValueError(
|
| 227 |
+
f'Invalid dropout_dim {dropout_dim}. Must be one of {sorted(_VALID_DROPOUT_DIMS)}.'
|
| 228 |
+
)
|
| 229 |
+
if num_coeffs < 1:
|
| 230 |
+
raise ValueError(f'num_coeffs must be >= 1, got {num_coeffs}.')
|
| 231 |
+
if in_channels < 1:
|
| 232 |
+
raise ValueError(f'in_channels must be >= 1, got {in_channels}.')
|
| 233 |
+
if out_channels < 1:
|
| 234 |
+
raise ValueError(f'out_channels must be >= 1, got {out_channels}.')
|
| 235 |
+
|
| 236 |
+
super().__init__(**kwargs)
|
| 237 |
+
self.in_channels = in_channels
|
| 238 |
+
self.out_channels = out_channels
|
| 239 |
+
|
| 240 |
+
self.repeat = 1 if repeat is None else repeat
|
| 241 |
+
|
| 242 |
+
self.norm = norm
|
| 243 |
+
self.postact = postact
|
| 244 |
+
self.dropout = dropout
|
| 245 |
+
self.dropout_dim = dropout_dim
|
| 246 |
+
|
| 247 |
+
self.bias = None
|
| 248 |
+
self.E = None
|
| 249 |
+
|
| 250 |
+
# Initialize state matrix A
|
| 251 |
+
if self.mode == 'gate':
|
| 252 |
+
# For gate mode: log-spaced initialization
|
| 253 |
+
A = np.arange(1, num_coeffs + 1)
|
| 254 |
+
A = np.log(A)
|
| 255 |
+
else:
|
| 256 |
+
# For FFT modes: complex eigenvalues
|
| 257 |
+
# Real part: decay rate, Imaginary part: frequency
|
| 258 |
+
A = np.stack([0.5 * np.ones(num_coeffs), math.pi * np.arange(num_coeffs)], -1)
|
| 259 |
+
A[..., 0] = inv_softplus(A[..., 0])
|
| 260 |
+
|
| 261 |
+
# Initialize timestep parameters
|
| 262 |
+
if self.mode in ['dws']:
|
| 263 |
+
dt = np.geomspace(1e-3, 1e-1, in_channels)
|
| 264 |
+
elif self.mode == 'full':
|
| 265 |
+
dt = np.geomspace(1e-3, 1e-1, out_channels)
|
| 266 |
+
else:
|
| 267 |
+
dt = np.geomspace(1e-3, 1e-1, self.repeat)
|
| 268 |
+
|
| 269 |
+
if self.mode == 'gate':
|
| 270 |
+
log_dt = inv_softplus(dt)
|
| 271 |
+
else:
|
| 272 |
+
log_dt = np.log(dt)
|
| 273 |
+
|
| 274 |
+
# Helper functions for parameter creation
|
| 275 |
+
def to_parameter(mat, is_complex=False, requires_grad=True):
|
| 276 |
+
if mat is None:
|
| 277 |
+
return None
|
| 278 |
+
tensor = torch.tensor(mat, dtype=torch.float)
|
| 279 |
+
if is_complex:
|
| 280 |
+
tensor = tensor.cfloat()
|
| 281 |
+
return Parameter(tensor, requires_grad=requires_grad)
|
| 282 |
+
|
| 283 |
+
def ones(shape, fan_in):
|
| 284 |
+
mat = np.ones(shape) / math.sqrt(fan_in)
|
| 285 |
+
return to_parameter(mat, is_complex=self.complex_proj)
|
| 286 |
+
|
| 287 |
+
def normal(shape, fan_in):
|
| 288 |
+
mat = np.random.randn(*shape) * math.sqrt(2 / fan_in)
|
| 289 |
+
return to_parameter(mat, is_complex=self.complex_proj)
|
| 290 |
+
|
| 291 |
+
tot_coeffs = self.repeat * num_coeffs
|
| 292 |
+
|
| 293 |
+
# Mode-specific parameter initialization
|
| 294 |
+
match self.mode:
|
| 295 |
+
case 'dws':
|
| 296 |
+
log_dt = einops.repeat(log_dt, 'c -> c n', n=num_coeffs)
|
| 297 |
+
A = einops.repeat(A, 'n i -> c n i', c=in_channels)
|
| 298 |
+
self.B = None
|
| 299 |
+
self.C = None
|
| 300 |
+
self.E = ones((in_channels, num_coeffs), num_coeffs)
|
| 301 |
+
|
| 302 |
+
case 's5':
|
| 303 |
+
log_dt = einops.repeat(log_dt, 'j -> (j n)', n=num_coeffs)
|
| 304 |
+
A = einops.repeat(A, 'n i -> (j n) i', j=self.repeat)
|
| 305 |
+
self.B = ones((tot_coeffs, in_channels), in_channels)
|
| 306 |
+
self.C = normal((out_channels, tot_coeffs), tot_coeffs)
|
| 307 |
+
self.E = None
|
| 308 |
+
|
| 309 |
+
case 'neck':
|
| 310 |
+
# Neck mode uses fewer repeated log_dt parameters
|
| 311 |
+
A = einops.repeat(A, 'n i -> r n i', r=self.repeat)
|
| 312 |
+
self.B = ones((self.repeat, in_channels), in_channels)
|
| 313 |
+
self.C = normal((out_channels, self.repeat), tot_coeffs)
|
| 314 |
+
self.E = normal((self.repeat, num_coeffs), 1)
|
| 315 |
+
|
| 316 |
+
case 'full':
|
| 317 |
+
log_dt = einops.repeat(log_dt, 'd -> d n', n=num_coeffs)
|
| 318 |
+
A = einops.repeat(A, 'n i -> d c n i', c=in_channels, d=out_channels)
|
| 319 |
+
self.B = None
|
| 320 |
+
self.C = None
|
| 321 |
+
self.E = ones((out_channels, in_channels, num_coeffs), in_channels)
|
| 322 |
+
|
| 323 |
+
case 'gate':
|
| 324 |
+
log_dt = einops.repeat(log_dt, 'j -> (j n)', n=num_coeffs)
|
| 325 |
+
|
| 326 |
+
# Timestep projection: learns input-dependent timesteps
|
| 327 |
+
self.log_dt = nn.Sequential(
|
| 328 |
+
nn.Linear(in_channels, self.repeat, bias=False),
|
| 329 |
+
nn.Linear(self.repeat, tot_coeffs, bias=True),
|
| 330 |
+
)
|
| 331 |
+
nn.init.zeros_(self.log_dt[-1].weight)
|
| 332 |
+
self.log_dt[-1].bias = to_parameter(log_dt)
|
| 333 |
+
|
| 334 |
+
# State decay parameters
|
| 335 |
+
A = einops.repeat(A, 'n -> (j n)', j=self.repeat)
|
| 336 |
+
|
| 337 |
+
# Input and output projections
|
| 338 |
+
self.B = nn.Sequential(
|
| 339 |
+
nn.Linear(in_channels, self.repeat, bias=False),
|
| 340 |
+
nn.Linear(self.repeat, tot_coeffs, bias=False),
|
| 341 |
+
)
|
| 342 |
+
self.C = nn.Linear(tot_coeffs, out_channels, bias=False)
|
| 343 |
+
|
| 344 |
+
# Register parameters
|
| 345 |
+
self.A = to_parameter(A)
|
| 346 |
+
|
| 347 |
+
if self.mode not in ['gate']:
|
| 348 |
+
self.log_dt = to_parameter(log_dt)
|
| 349 |
+
|
| 350 |
+
# Mark certain parameters as "sensitive" for optimizer
|
| 351 |
+
# (suggests using smaller learning rates for these)
|
| 352 |
+
match self.mode:
|
| 353 |
+
case 'dws' | 'full' | 'neck':
|
| 354 |
+
self._register_sensitives(self.log_dt, self.A)
|
| 355 |
+
case 'gate':
|
| 356 |
+
self._register_sensitives(self.A)
|
| 357 |
+
|
| 358 |
+
# Mixer layer: final projection and activations
|
| 359 |
+
if self.mode in ['dws']:
|
| 360 |
+
# DWS mode has explicit channel mixing
|
| 361 |
+
self.mixer = nn.Sequential(
|
| 362 |
+
self._make_activation_block(in_channels),
|
| 363 |
+
nn.Conv1d(in_channels, out_channels, 1, bias=False),
|
| 364 |
+
self._make_activation_block(out_channels) if use_activations else nn.Identity(),
|
| 365 |
+
)
|
| 366 |
+
else:
|
| 367 |
+
self.mixer = (
|
| 368 |
+
self._make_activation_block(out_channels) if use_activations else nn.Identity()
|
| 369 |
+
)
|
| 370 |
+
|
| 371 |
+
@staticmethod
|
| 372 |
+
def _register_sensitives(*args):
|
| 373 |
+
"""Mark parameters as sensitive (for optimizer to use smaller learning rates).
|
| 374 |
+
|
| 375 |
+
Args:
|
| 376 |
+
*args: Parameters or modules to mark as sensitive
|
| 377 |
+
"""
|
| 378 |
+
for arg in args:
|
| 379 |
+
if isinstance(arg, nn.Module):
|
| 380 |
+
for param in arg.parameters():
|
| 381 |
+
param.sensitive = True
|
| 382 |
+
continue
|
| 383 |
+
arg.sensitive = True
|
| 384 |
+
|
| 385 |
+
def get_param_groups(self, lr=1e-3, sensitive_lr_factor=0.1):
|
| 386 |
+
"""Get optimizer parameter groups with separate learning rates.
|
| 387 |
+
|
| 388 |
+
Sensitive parameters (A matrix, log_dt) benefit from smaller learning
|
| 389 |
+
rates. This method returns ready-made param groups for the optimizer.
|
| 390 |
+
|
| 391 |
+
Args:
|
| 392 |
+
lr: Base learning rate for regular parameters
|
| 393 |
+
sensitive_lr_factor: Multiplier for sensitive parameter learning rate
|
| 394 |
+
(default: 0.1, i.e. 10x smaller than base lr)
|
| 395 |
+
|
| 396 |
+
Returns:
|
| 397 |
+
List of dicts suitable for torch.optim optimizers
|
| 398 |
+
|
| 399 |
+
Example:
|
| 400 |
+
>>> layer = SSMLayer(64, 128, 256, mode='s5')
|
| 401 |
+
>>> optimizer = torch.optim.AdamW(layer.get_param_groups(lr=1e-3))
|
| 402 |
+
"""
|
| 403 |
+
regular, sensitive = [], []
|
| 404 |
+
for param in self.parameters():
|
| 405 |
+
if getattr(param, 'sensitive', False):
|
| 406 |
+
sensitive.append(param)
|
| 407 |
+
else:
|
| 408 |
+
regular.append(param)
|
| 409 |
+
groups = [{'params': regular, 'lr': lr}]
|
| 410 |
+
if sensitive:
|
| 411 |
+
groups.append({'params': sensitive, 'lr': lr * sensitive_lr_factor})
|
| 412 |
+
return groups
|
| 413 |
+
|
| 414 |
+
def _make_activation_block(self, num_features):
|
| 415 |
+
"""Create normalization + activation + dropout block.
|
| 416 |
+
|
| 417 |
+
Args:
|
| 418 |
+
num_features: Number of features for norm/dropout
|
| 419 |
+
|
| 420 |
+
Returns:
|
| 421 |
+
Sequential module with norm, activation, dropout
|
| 422 |
+
"""
|
| 423 |
+
return get_activations(
|
| 424 |
+
1, num_features, self.norm, self.postact, self.dropout, self.dropout_dim
|
| 425 |
+
)
|
| 426 |
+
|
| 427 |
+
def forward(self, input):
|
| 428 |
+
"""Forward pass through SSM layer.
|
| 429 |
+
|
| 430 |
+
Args:
|
| 431 |
+
input: Input tensor of shape (batch, in_channels, length)
|
| 432 |
+
|
| 433 |
+
Returns:
|
| 434 |
+
Output tensor of shape (batch, out_channels, length)
|
| 435 |
+
"""
|
| 436 |
+
output = super().forward(input, self.A, self.B, self.C, self.log_dt, E=self.E)
|
| 437 |
+
|
| 438 |
+
if self.bias is not None:
|
| 439 |
+
output = output + self.bias
|
| 440 |
+
|
| 441 |
+
return self.mixer(output)
|
| 442 |
+
|
| 443 |
+
def to_inference(self):
|
| 444 |
+
"""Convert to streaming inference mode.
|
| 445 |
+
|
| 446 |
+
Returns SSMLayerInference instance for low-latency streaming processing.
|
| 447 |
+
The inference layer maintains state across chunks for applications.
|
| 448 |
+
|
| 449 |
+
Returns:
|
| 450 |
+
SSMLayerInference: Inference layer with copied weights
|
| 451 |
+
|
| 452 |
+
Example:
|
| 453 |
+
>>> # After training
|
| 454 |
+
>>> train_layer = SSMLayer(64, 128, 256, mode='s5')
|
| 455 |
+
>>> # ... training ...
|
| 456 |
+
>>>
|
| 457 |
+
>>> # Convert for streaming
|
| 458 |
+
>>> infer_layer = train_layer.to_inference()
|
| 459 |
+
>>>
|
| 460 |
+
>>> # Process audio stream
|
| 461 |
+
>>> for chunk in audio_stream:
|
| 462 |
+
>>> output = infer_layer(chunk)
|
| 463 |
+
>>>
|
| 464 |
+
>>> # Reset between utterances
|
| 465 |
+
>>> infer_layer.reset_state()
|
| 466 |
+
|
| 467 |
+
Note:
|
| 468 |
+
The inference layer uses sequential scan which is slower than
|
| 469 |
+
FFT for full sequences but has lower latency for streaming.
|
| 470 |
+
"""
|
| 471 |
+
from .inference import SSMLayerInference
|
| 472 |
+
|
| 473 |
+
return SSMLayerInference.from_training(self)
|
| 474 |
+
|
| 475 |
+
def __repr__(self):
|
| 476 |
+
"""String representation showing parameters."""
|
| 477 |
+
param_info = []
|
| 478 |
+
for name, param in self.named_parameters():
|
| 479 |
+
if param.requires_grad:
|
| 480 |
+
param_info.append(f'{name}: {list(param.shape)}')
|
| 481 |
+
return f'{self.__class__.__name__}(\n ' + '\n '.join(param_info) + '\n)'
|
tokenizer.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
tokenizer_config.json
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"add_prefix_space": null,
|
| 3 |
+
"backend": "tokenizers",
|
| 4 |
+
"bos_token": "<s>",
|
| 5 |
+
"clean_up_tokenization_spaces": false,
|
| 6 |
+
"eos_token": "</s>",
|
| 7 |
+
"extra_special_tokens": [],
|
| 8 |
+
"is_local": false,
|
| 9 |
+
"legacy": false,
|
| 10 |
+
"model_max_length": 1000000000000000019884624838656,
|
| 11 |
+
"pad_token": "</s>",
|
| 12 |
+
"sp_model_kwargs": {},
|
| 13 |
+
"spaces_between_special_tokens": false,
|
| 14 |
+
"tokenizer_class": "TokenizersBackend",
|
| 15 |
+
"unk_token": "<unk>",
|
| 16 |
+
"use_default_system_prompt": false
|
| 17 |
+
}
|