| # Loops, Recurrence, and Sequential Models in nFlow |
|
|
| ## The Core Principle: DAGs Only, No Cycles |
|
|
| nFlow's graph IR is a **directed acyclic graph (DAG)**. This is not a limitation β |
| it is the correct model for deep learning computation graphs. PyTorch, TensorFlow, |
| ONNX, and JAX all use DAGs. Cycles in a computation graph don't have a well-defined |
| execution order and can't be differentiated or compiled. |
|
|
| **What looks like a "loop" in RNNs is always one of two things:** |
| 1. **Explicit unrolling** β you wire T time steps manually (classical RNNs) |
| 2. **Scan/fold operation** β a single node that maps over a sequence (modern approach) |
|
|
| --- |
|
|
| ## Pattern 1: LSTM Cell (Single Step) |
|
|
| `nn.lstm_cell` β computes one time step of an LSTM: |
| - **Inputs**: `x` [B, input_size], `h` [B, hidden_size], `c` [B, hidden_size] |
| - **Outputs**: `h_new` [B, hidden_size], `c_new` [B, hidden_size] |
|
|
| To model a **3-step LSTM** in nFlow: |
| ``` |
| input[0] β lstm_cell β h1, c1 |
| β h0, c0 |
| input[1] β lstm_cell β h2, c2 |
| β h1, c1 |
| input[2] β lstm_cell β h3, c3 |
| β h2, c2 |
| ``` |
| You wire 3 `nn.lstm_cell` nodes in sequence. The "loop" is explicit at the graph level. |
| This is exactly how `torch.nn.LSTMCell` works β and it's perfectly expressible as a DAG. |
|
|
| **When to use**: Fixed-length sequences known at model-design time. Clean, simple, |
| easy to debug. Good for teaching, small models, and architectures with specific step counts. |
|
|
| --- |
|
|
| ## Pattern 2: Scan / Fold Over Sequence (Recommended) |
|
|
| `nn.scan` β maps a cell over a variable-length sequence: |
| - **Inputs**: `xs` [B, T, input_size], `h0` [B, hidden_size], `c0` [B, hidden_size] |
| - **Outputs**: `hs` [B, T, hidden_size], `h_T` [B, hidden_size], `c_T` [B, hidden_size] |
| - **Attrs**: `cell: str` β uid of the cell composite to scan over (e.g. `"user.lstm_cell"`) |
|
|
| This is equivalent to JAX's `lax.scan`, Haskell's `scanl`, and how modern implementations |
| of RWKV, Mamba, and Linear RNNs are expressed. |
|
|
| ``` |
| xs [B,T,D] β nn.scan(cell="user.lstm_cell") β hs [B,T,H] |
| β h0, c0 h_T, c_T |
| ``` |
|
|
| One node represents the entire recurrence. This is how `torch.nn.LSTM` and Mamba's SSM |
| scan are compiled internally. The scan node exports to: |
| - **PyTorch**: `nn.LSTM(...)` or a `torch.func.scan` / manual for-loop |
| - **Keras**: `keras.layers.RNN(cell, return_sequences=True)` |
| - **ONNX**: `ONNX::Loop` / `ONNX::Scan` operators |
|
|
| **When to use**: Variable-length sequences, production models, Mamba/RWKV/S4/linear |
| attention. The scan node abstracts away the loop entirely. |
|
|
| --- |
|
|
| ## Pattern 3: Custom SSM (Mamba-style) |
|
|
| For Mamba, S4, RWKV, and similar **linear recurrences** (`h_t = AΒ·h_{t-1} + BΒ·x_t`): |
|
|
| Build a `UserComposite` for the SSM kernel: |
| 1. **Create** a `user.mamba_ssm` composite in the Custom Nodes panel |
| 2. **Wire** the inner graph: `dt, A, B, C, x β selective_scan β y` |
| 3. The selective scan is itself a `nn.scan` over a custom cell |
|
|
| The `.nfl` file carries the `paper_url: "https://arxiv.org/abs/2312.00752"` so |
| collaborators always know the source architecture. |
|
|
| --- |
|
|
| ## Pattern 4: Bidirectional RNNs |
|
|
| Wire two scan nodes: |
| ``` |
| xs β scan_forward β hs_fwd [B,T,H] βββ |
| xs β scan_backward β hs_bwd [B,T,H] βββ€β concat β output |
| ``` |
| `scan_backward` can use the same cell with `direction: "backward"` attr. |
|
|
| --- |
|
|
| ## Why Not Explicit Cycles? |
|
|
| Some graph editors allow "feedback edges" that create cycles. This is fundamentally |
| broken for ML: |
|
|
| 1. **No well-defined topo order** β can't run inference, can't compute shapes |
| 2. **Not differentiable** as written (need time-unrolling to get gradients) |
| 3. **ONNX, PyTorch JIT, TensorRT** don't support cyclic graphs β you'd need a |
| special lowering pass that nFlow avoids by design |
|
|
| The DAG-based approach (explicit unrolling + scan) is used by every production ML |
| framework and compiler. nFlow follows this industry standard. |
|
|
| --- |
|
|
| ## Quick Reference |
|
|
| | Architecture | nFlow Pattern | |
| |---|---| |
| | LSTM (fixed T steps) | T Γ `nn.lstm_cell` wired in sequence | |
| | LSTM (variable length) | 1 Γ `nn.scan(cell="user.lstm_cell")` | |
| | GRU | `nn.gru_cell` or `nn.scan(cell="user.gru_cell")` | |
| | Mamba SSM | `user.mamba_ssm` composite with `nn.scan` inside | |
| | RWKV | `user.rwkv_block` composite with `nn.scan` inside | |
| | Transformer (no recurrence) | Standard DAG β MHA + FFN blocks | |
| | Bidirectional | Two parallel `nn.scan` nodes + concat | |
| | Temporal Conv (TCN) | Standard conv nodes β fully DAG-compatible | |
|
|
| --- |
|
|
| ## Implementing `nn.scan` in nFlow |
|
|
| The `nn.scan` primitive is registered in `nflow-ops/src/prims.rs`. It: |
| - Declares attrs: `cell: str` (the composite uid to scan), `axis: int = 1` |
| - Shape inference: reads the cell's declared `ports_out`, infers sequence output shape |
| - Codegen: emits `nn.LSTM` / `nn.GRU` for built-in cells; `torch.func.scan` or |
| manual loop for custom cells |
| - Execution: the CPU evaluator calls `infer_user_composite` iteratively for each step |
|
|
| This is the recommended way to add any sequential model to nFlow. |
|
|