new

Get trending papers in your email inbox!

Subscribe

Daily Papers

byAK and the research community

Jun 3

DualKV: Shared-Prompt Flash Attention for Efficient RL Training with Large Rollouts and Long Contexts

Modern RL post-training methods such as GRPO and DAPO train on N response sequences of R tokens sampled from a shared prompt of P tokens, but standard FlashAttention replicates all P prompt tokens N times across both forward and backward passes -- duplicating compute and memory on identical hidden states. In large-rollout, long-context RL training (N{geq}16, P{geq}8K), this redundancy dominates the policy update cost. We observe that in decoder-only models, causal masking makes prompt representations invariant across sequences at every layer, so all per-token operations (norms, projections, MLP) and attention can process the prompt once -- a property not yet exploited at the kernel level for training. We propose DualKV, the first FlashAttention kernel variant that eliminates shared-prompt replication during RL training, via (1)~fused CUDA forward and backward kernels that iterate over two disjoint KV regions -- shared context and per-sequence response -- in a single kernel launch, and (2)~a data-pipeline redesign in veRL that repacks N(P{+}R) tokens into P{+}NR tokens per micro-batch, extending the token reduction from attention to the entire model by a factor ρ= N(P{+}R)/(P{+}NR). DualKV is mathematically equivalent to standard attention and introduces no approximation. On Qwen3-8B GRPO training with 8timesH100 GPUs (N{=}32, 8K-context), DualKV achieves 1.63--2.09times policy-update speedup, enables 2times larger micro-batches, and raises MFU from 36% to 76%. Similar gains hold for DAPO (2.47times speedup, 77% MFU). At 30B MoE scale on 16timesH100, DualKV achieves 3.82times policy-update and 3.38times end-to-end step speedup over FlashAttention (which requires 4-way Ulysses sequence parallelism to avoid OOM).

  • 5 authors
·
May 26 1