nflow-architecture-spec / docs /LOOPS_AND_RECURRENCE.md
krystv's picture
docs: LOOPS_AND_RECURRENCE.md β€” how LSTM/GRU/Mamba/scan work in nFlow (DAG-based, explicit unrolling + nn.scan pattern)
fbb55cc verified

Loops, Recurrence, and Sequential Models in nFlow

The Core Principle: DAGs Only, No Cycles

nFlow's graph IR is a directed acyclic graph (DAG). This is not a limitation β€” it is the correct model for deep learning computation graphs. PyTorch, TensorFlow, ONNX, and JAX all use DAGs. Cycles in a computation graph don't have a well-defined execution order and can't be differentiated or compiled.

What looks like a "loop" in RNNs is always one of two things:

  1. Explicit unrolling β€” you wire T time steps manually (classical RNNs)
  2. Scan/fold operation β€” a single node that maps over a sequence (modern approach)

Pattern 1: LSTM Cell (Single Step)

nn.lstm_cell β€” computes one time step of an LSTM:

  • Inputs: x [B, input_size], h [B, hidden_size], c [B, hidden_size]
  • Outputs: h_new [B, hidden_size], c_new [B, hidden_size]

To model a 3-step LSTM in nFlow:

input[0] β†’ lstm_cell β†’ h1, c1
             ↑ h0, c0
input[1] β†’ lstm_cell β†’ h2, c2
             ↑ h1, c1
input[2] β†’ lstm_cell β†’ h3, c3
             ↑ h2, c2

You wire 3 nn.lstm_cell nodes in sequence. The "loop" is explicit at the graph level. This is exactly how torch.nn.LSTMCell works β€” and it's perfectly expressible as a DAG.

When to use: Fixed-length sequences known at model-design time. Clean, simple, easy to debug. Good for teaching, small models, and architectures with specific step counts.


Pattern 2: Scan / Fold Over Sequence (Recommended)

nn.scan β€” maps a cell over a variable-length sequence:

  • Inputs: xs [B, T, input_size], h0 [B, hidden_size], c0 [B, hidden_size]
  • Outputs: hs [B, T, hidden_size], h_T [B, hidden_size], c_T [B, hidden_size]
  • Attrs: cell: str β€” uid of the cell composite to scan over (e.g. "user.lstm_cell")

This is equivalent to JAX's lax.scan, Haskell's scanl, and how modern implementations of RWKV, Mamba, and Linear RNNs are expressed.

xs [B,T,D] β†’ nn.scan(cell="user.lstm_cell") β†’ hs [B,T,H]
              ↑ h0, c0                           h_T, c_T

One node represents the entire recurrence. This is how torch.nn.LSTM and Mamba's SSM scan are compiled internally. The scan node exports to:

  • PyTorch: nn.LSTM(...) or a torch.func.scan / manual for-loop
  • Keras: keras.layers.RNN(cell, return_sequences=True)
  • ONNX: ONNX::Loop / ONNX::Scan operators

When to use: Variable-length sequences, production models, Mamba/RWKV/S4/linear attention. The scan node abstracts away the loop entirely.


Pattern 3: Custom SSM (Mamba-style)

For Mamba, S4, RWKV, and similar linear recurrences (h_t = AΒ·h_{t-1} + BΒ·x_t):

Build a UserComposite for the SSM kernel:

  1. Create a user.mamba_ssm composite in the Custom Nodes panel
  2. Wire the inner graph: dt, A, B, C, x β†’ selective_scan β†’ y
  3. The selective scan is itself a nn.scan over a custom cell

The .nfl file carries the paper_url: "https://arxiv.org/abs/2312.00752" so collaborators always know the source architecture.


Pattern 4: Bidirectional RNNs

Wire two scan nodes:

xs β†’ scan_forward  β†’ hs_fwd [B,T,H]  ──┐
xs β†’ scan_backward β†’ hs_bwd [B,T,H]  ───→ concat β†’ output

scan_backward can use the same cell with direction: "backward" attr.


Why Not Explicit Cycles?

Some graph editors allow "feedback edges" that create cycles. This is fundamentally broken for ML:

  1. No well-defined topo order β†’ can't run inference, can't compute shapes
  2. Not differentiable as written (need time-unrolling to get gradients)
  3. ONNX, PyTorch JIT, TensorRT don't support cyclic graphs β€” you'd need a special lowering pass that nFlow avoids by design

The DAG-based approach (explicit unrolling + scan) is used by every production ML framework and compiler. nFlow follows this industry standard.


Quick Reference

Architecture nFlow Pattern
LSTM (fixed T steps) T Γ— nn.lstm_cell wired in sequence
LSTM (variable length) 1 Γ— nn.scan(cell="user.lstm_cell")
GRU nn.gru_cell or nn.scan(cell="user.gru_cell")
Mamba SSM user.mamba_ssm composite with nn.scan inside
RWKV user.rwkv_block composite with nn.scan inside
Transformer (no recurrence) Standard DAG β€” MHA + FFN blocks
Bidirectional Two parallel nn.scan nodes + concat
Temporal Conv (TCN) Standard conv nodes β€” fully DAG-compatible

Implementing nn.scan in nFlow

The nn.scan primitive is registered in nflow-ops/src/prims.rs. It:

  • Declares attrs: cell: str (the composite uid to scan), axis: int = 1
  • Shape inference: reads the cell's declared ports_out, infers sequence output shape
  • Codegen: emits nn.LSTM / nn.GRU for built-in cells; torch.func.scan or manual loop for custom cells
  • Execution: the CPU evaluator calls infer_user_composite iteratively for each step

This is the recommended way to add any sequential model to nFlow.