DualKV: Shared-Prompt Flash Attention for Efficient RL Training with Large Rollouts and Long Contexts
Abstract
DualKV is a FlashAttention kernel variant that eliminates redundant computations in RL post-training by processing shared prompt tokens only once, achieving significant speedups and improved efficiency in large-scale language model training.
Modern RL post-training methods such as GRPO and DAPO train on N response sequences of R tokens sampled from a shared prompt of P tokens, but standard FlashAttention replicates all P prompt tokens N times across both forward and backward passes -- duplicating compute and memory on identical hidden states. In large-rollout, long-context RL training (N{geq}16, P{geq}8K), this redundancy dominates the policy update cost. We observe that in decoder-only models, causal masking makes prompt representations invariant across sequences at every layer, so all per-token operations (norms, projections, MLP) and attention can process the prompt once -- a property not yet exploited at the kernel level for training. We propose DualKV, the first FlashAttention kernel variant that eliminates shared-prompt replication during RL training, via (1)~fused CUDA forward and backward kernels that iterate over two disjoint KV regions -- shared context and per-sequence response -- in a single kernel launch, and (2)~a data-pipeline redesign in veRL that repacks N(P{+}R) tokens into P{+}NR tokens per micro-batch, extending the token reduction from attention to the entire model by a factor ρ= N(P{+}R)/(P{+}NR). DualKV is mathematically equivalent to standard attention and introduces no approximation. On Qwen3-8B GRPO training with 8timesH100 GPUs (N{=}32, 8K-context), DualKV achieves 1.63--2.09times policy-update speedup, enables 2times larger micro-batches, and raises MFU from 36% to 76%. Similar gains hold for DAPO (2.47times speedup, 77% MFU). At 30B MoE scale on 16timesH100, DualKV achieves 3.82times policy-update and 3.38times end-to-end step speedup over FlashAttention (which requires 4-way Ulysses sequence parallelism to avoid OOM).
Community
Get this paper in your agent:
hf papers read 2605.15422 Don't have the latest CLI?
curl -LsSf https://hf.co/cli/install.sh | bash Models citing this paper 0
No model linking this paper
Datasets citing this paper 0
No dataset linking this paper
Spaces citing this paper 0
No Space linking this paper