nflow-architecture-spec / docs /LOOPS_AND_RECURRENCE.md
krystv's picture
docs: LOOPS_AND_RECURRENCE.md β€” how LSTM/GRU/Mamba/scan work in nFlow (DAG-based, explicit unrolling + nn.scan pattern)
fbb55cc verified
# Loops, Recurrence, and Sequential Models in nFlow
## The Core Principle: DAGs Only, No Cycles
nFlow's graph IR is a **directed acyclic graph (DAG)**. This is not a limitation β€”
it is the correct model for deep learning computation graphs. PyTorch, TensorFlow,
ONNX, and JAX all use DAGs. Cycles in a computation graph don't have a well-defined
execution order and can't be differentiated or compiled.
**What looks like a "loop" in RNNs is always one of two things:**
1. **Explicit unrolling** β€” you wire T time steps manually (classical RNNs)
2. **Scan/fold operation** β€” a single node that maps over a sequence (modern approach)
---
## Pattern 1: LSTM Cell (Single Step)
`nn.lstm_cell` β€” computes one time step of an LSTM:
- **Inputs**: `x` [B, input_size], `h` [B, hidden_size], `c` [B, hidden_size]
- **Outputs**: `h_new` [B, hidden_size], `c_new` [B, hidden_size]
To model a **3-step LSTM** in nFlow:
```
input[0] β†’ lstm_cell β†’ h1, c1
↑ h0, c0
input[1] β†’ lstm_cell β†’ h2, c2
↑ h1, c1
input[2] β†’ lstm_cell β†’ h3, c3
↑ h2, c2
```
You wire 3 `nn.lstm_cell` nodes in sequence. The "loop" is explicit at the graph level.
This is exactly how `torch.nn.LSTMCell` works β€” and it's perfectly expressible as a DAG.
**When to use**: Fixed-length sequences known at model-design time. Clean, simple,
easy to debug. Good for teaching, small models, and architectures with specific step counts.
---
## Pattern 2: Scan / Fold Over Sequence (Recommended)
`nn.scan` β€” maps a cell over a variable-length sequence:
- **Inputs**: `xs` [B, T, input_size], `h0` [B, hidden_size], `c0` [B, hidden_size]
- **Outputs**: `hs` [B, T, hidden_size], `h_T` [B, hidden_size], `c_T` [B, hidden_size]
- **Attrs**: `cell: str` β€” uid of the cell composite to scan over (e.g. `"user.lstm_cell"`)
This is equivalent to JAX's `lax.scan`, Haskell's `scanl`, and how modern implementations
of RWKV, Mamba, and Linear RNNs are expressed.
```
xs [B,T,D] β†’ nn.scan(cell="user.lstm_cell") β†’ hs [B,T,H]
↑ h0, c0 h_T, c_T
```
One node represents the entire recurrence. This is how `torch.nn.LSTM` and Mamba's SSM
scan are compiled internally. The scan node exports to:
- **PyTorch**: `nn.LSTM(...)` or a `torch.func.scan` / manual for-loop
- **Keras**: `keras.layers.RNN(cell, return_sequences=True)`
- **ONNX**: `ONNX::Loop` / `ONNX::Scan` operators
**When to use**: Variable-length sequences, production models, Mamba/RWKV/S4/linear
attention. The scan node abstracts away the loop entirely.
---
## Pattern 3: Custom SSM (Mamba-style)
For Mamba, S4, RWKV, and similar **linear recurrences** (`h_t = AΒ·h_{t-1} + BΒ·x_t`):
Build a `UserComposite` for the SSM kernel:
1. **Create** a `user.mamba_ssm` composite in the Custom Nodes panel
2. **Wire** the inner graph: `dt, A, B, C, x β†’ selective_scan β†’ y`
3. The selective scan is itself a `nn.scan` over a custom cell
The `.nfl` file carries the `paper_url: "https://arxiv.org/abs/2312.00752"` so
collaborators always know the source architecture.
---
## Pattern 4: Bidirectional RNNs
Wire two scan nodes:
```
xs β†’ scan_forward β†’ hs_fwd [B,T,H] ──┐
xs β†’ scan_backward β†’ hs_bwd [B,T,H] ───→ concat β†’ output
```
`scan_backward` can use the same cell with `direction: "backward"` attr.
---
## Why Not Explicit Cycles?
Some graph editors allow "feedback edges" that create cycles. This is fundamentally
broken for ML:
1. **No well-defined topo order** β†’ can't run inference, can't compute shapes
2. **Not differentiable** as written (need time-unrolling to get gradients)
3. **ONNX, PyTorch JIT, TensorRT** don't support cyclic graphs β€” you'd need a
special lowering pass that nFlow avoids by design
The DAG-based approach (explicit unrolling + scan) is used by every production ML
framework and compiler. nFlow follows this industry standard.
---
## Quick Reference
| Architecture | nFlow Pattern |
|---|---|
| LSTM (fixed T steps) | T Γ— `nn.lstm_cell` wired in sequence |
| LSTM (variable length) | 1 Γ— `nn.scan(cell="user.lstm_cell")` |
| GRU | `nn.gru_cell` or `nn.scan(cell="user.gru_cell")` |
| Mamba SSM | `user.mamba_ssm` composite with `nn.scan` inside |
| RWKV | `user.rwkv_block` composite with `nn.scan` inside |
| Transformer (no recurrence) | Standard DAG β€” MHA + FFN blocks |
| Bidirectional | Two parallel `nn.scan` nodes + concat |
| Temporal Conv (TCN) | Standard conv nodes β€” fully DAG-compatible |
---
## Implementing `nn.scan` in nFlow
The `nn.scan` primitive is registered in `nflow-ops/src/prims.rs`. It:
- Declares attrs: `cell: str` (the composite uid to scan), `axis: int = 1`
- Shape inference: reads the cell's declared `ports_out`, infers sequence output shape
- Codegen: emits `nn.LSTM` / `nn.GRU` for built-in cells; `torch.func.scan` or
manual loop for custom cells
- Execution: the CPU evaluator calls `infer_user_composite` iteratively for each step
This is the recommended way to add any sequential model to nFlow.