# Loops, Recurrence, and Sequential Models in nFlow ## The Core Principle: DAGs Only, No Cycles nFlow's graph IR is a **directed acyclic graph (DAG)**. This is not a limitation — it is the correct model for deep learning computation graphs. PyTorch, TensorFlow, ONNX, and JAX all use DAGs. Cycles in a computation graph don't have a well-defined execution order and can't be differentiated or compiled. **What looks like a "loop" in RNNs is always one of two things:** 1. **Explicit unrolling** — you wire T time steps manually (classical RNNs) 2. **Scan/fold operation** — a single node that maps over a sequence (modern approach) --- ## Pattern 1: LSTM Cell (Single Step) `nn.lstm_cell` — computes one time step of an LSTM: - **Inputs**: `x` [B, input_size], `h` [B, hidden_size], `c` [B, hidden_size] - **Outputs**: `h_new` [B, hidden_size], `c_new` [B, hidden_size] To model a **3-step LSTM** in nFlow: ``` input[0] → lstm_cell → h1, c1 ↑ h0, c0 input[1] → lstm_cell → h2, c2 ↑ h1, c1 input[2] → lstm_cell → h3, c3 ↑ h2, c2 ``` You wire 3 `nn.lstm_cell` nodes in sequence. The "loop" is explicit at the graph level. This is exactly how `torch.nn.LSTMCell` works — and it's perfectly expressible as a DAG. **When to use**: Fixed-length sequences known at model-design time. Clean, simple, easy to debug. Good for teaching, small models, and architectures with specific step counts. --- ## Pattern 2: Scan / Fold Over Sequence (Recommended) `nn.scan` — maps a cell over a variable-length sequence: - **Inputs**: `xs` [B, T, input_size], `h0` [B, hidden_size], `c0` [B, hidden_size] - **Outputs**: `hs` [B, T, hidden_size], `h_T` [B, hidden_size], `c_T` [B, hidden_size] - **Attrs**: `cell: str` — uid of the cell composite to scan over (e.g. `"user.lstm_cell"`) This is equivalent to JAX's `lax.scan`, Haskell's `scanl`, and how modern implementations of RWKV, Mamba, and Linear RNNs are expressed. ``` xs [B,T,D] → nn.scan(cell="user.lstm_cell") → hs [B,T,H] ↑ h0, c0 h_T, c_T ``` One node represents the entire recurrence. This is how `torch.nn.LSTM` and Mamba's SSM scan are compiled internally. The scan node exports to: - **PyTorch**: `nn.LSTM(...)` or a `torch.func.scan` / manual for-loop - **Keras**: `keras.layers.RNN(cell, return_sequences=True)` - **ONNX**: `ONNX::Loop` / `ONNX::Scan` operators **When to use**: Variable-length sequences, production models, Mamba/RWKV/S4/linear attention. The scan node abstracts away the loop entirely. --- ## Pattern 3: Custom SSM (Mamba-style) For Mamba, S4, RWKV, and similar **linear recurrences** (`h_t = A·h_{t-1} + B·x_t`): Build a `UserComposite` for the SSM kernel: 1. **Create** a `user.mamba_ssm` composite in the Custom Nodes panel 2. **Wire** the inner graph: `dt, A, B, C, x → selective_scan → y` 3. The selective scan is itself a `nn.scan` over a custom cell The `.nfl` file carries the `paper_url: "https://arxiv.org/abs/2312.00752"` so collaborators always know the source architecture. --- ## Pattern 4: Bidirectional RNNs Wire two scan nodes: ``` xs → scan_forward → hs_fwd [B,T,H] ──┐ xs → scan_backward → hs_bwd [B,T,H] ──┤→ concat → output ``` `scan_backward` can use the same cell with `direction: "backward"` attr. --- ## Why Not Explicit Cycles? Some graph editors allow "feedback edges" that create cycles. This is fundamentally broken for ML: 1. **No well-defined topo order** → can't run inference, can't compute shapes 2. **Not differentiable** as written (need time-unrolling to get gradients) 3. **ONNX, PyTorch JIT, TensorRT** don't support cyclic graphs — you'd need a special lowering pass that nFlow avoids by design The DAG-based approach (explicit unrolling + scan) is used by every production ML framework and compiler. nFlow follows this industry standard. --- ## Quick Reference | Architecture | nFlow Pattern | |---|---| | LSTM (fixed T steps) | T × `nn.lstm_cell` wired in sequence | | LSTM (variable length) | 1 × `nn.scan(cell="user.lstm_cell")` | | GRU | `nn.gru_cell` or `nn.scan(cell="user.gru_cell")` | | Mamba SSM | `user.mamba_ssm` composite with `nn.scan` inside | | RWKV | `user.rwkv_block` composite with `nn.scan` inside | | Transformer (no recurrence) | Standard DAG — MHA + FFN blocks | | Bidirectional | Two parallel `nn.scan` nodes + concat | | Temporal Conv (TCN) | Standard conv nodes — fully DAG-compatible | --- ## Implementing `nn.scan` in nFlow The `nn.scan` primitive is registered in `nflow-ops/src/prims.rs`. It: - Declares attrs: `cell: str` (the composite uid to scan), `axis: int = 1` - Shape inference: reads the cell's declared `ports_out`, infers sequence output shape - Codegen: emits `nn.LSTM` / `nn.GRU` for built-in cells; `torch.func.scan` or manual loop for custom cells - Execution: the CPU evaluator calls `infer_user_composite` iteratively for each step This is the recommended way to add any sequential model to nFlow.