Stiles Seymens
AI & ML interests
Recent Activity
Organizations
Thank you for taking the time to reply and for clarifying the MoE/on-policy part—that helps a lot.
What you said (and how I'm reading it):
You confirmed that forcing old_log_prob = log_prob.detach() does not resolve the on-policy issue: the evaluation uses the current policy π(a|s), but the sampling distribution can still differ because expert selection at rollout time can differ from the training forward pass. Forcing the importance ratio to 1 in the PPO formula therefore doesn't fix the underlying problem—the data can still be effectively off-policy. I agree with that distinction.
You clarified the debugging sequence: root cause was unknown at first; you hypothesized inference–training inconsistency; importance sampling (rollout correction) didn't fix it; then you tried forcing the detach as a hypothesis test to see whether the MoE log-prob mismatch was the cause. That makes it clear the detach was diagnostic, not the actual fix.
You noted that expert router replay wasn't available in verl when you were debugging; you've since tested it (Stabilizing MoE RL, arxiv 2510.11370) and it was not the root cause—the root cause is attention sink. That fits the story: the real mismatch is at the attention layer (training on FA2 without sink support vs. inference on sink-aware kernels), so token-level log-probs diverge between rollout and training. That explains why rollout correction alone didn't fix things (it corrects at sequence level but not at the attention computation) and why freezing attention gave similar reward curves (attention wasn't learning correctly without the sink backward). I'm satisfied on the on-policy/MoE part.
What wasn't addressed: You addressed the MoE/on-policy part; the following from my original comment were not covered. If you have bandwidth, they would help for implementation and reproducibility:
- Attention sink backward (Section 4)—notation and derivation
The article gives the general form
∂L/∂S_h = -Σ_i P_{i,h} (∂L/∂S_{i,h} - Σ_j P_{ij} ∂L/∂S_{ij}),
then states that ∂L/∂S_{i,h} = 0 because "the sink is computed but not used in the output," yielding
∂L/∂S_h = -Σ_i P_{i,h} Σ_j P_{ij} ∂L/∂S_{ij}.
For implementers, two things would help:
Definition of ∂L/∂S_{i,h}: In the extended softmax, the logits are (S_{i,1}, …, S_{i,N_k}, S_h). The output is O_i = Σ_j P_{ij} V_j with no V_sink term, so the sink affects O_i only through the normalization Z_i = Σ_{j'} exp(S_{ij'}) + exp(S_h). The direct contribution of the sink to the output (e.g. a P_{i,h} V_sink term) is therefore zero, so ∂L/∂S_{i,h} in that sense is 0. If ∂L/∂S_{i,h} is defined differently in your derivation, a one-sentence definition in the article would remove ambiguity.
Chain rule through softmax: A short derivation showing how ∂L/∂S_h is obtained from ∂L/∂P_{ij} (and possibly ∂L/∂O_i) via ∂P_{ij}/∂S_h and ∂P_{i,h}/∂S_h would let readers verify the sign and aggregation over query positions i and match it to the FlashAttention backward API.
- Experimental setup and ablations
To reproduce and extend the results, it would be helpful to have:
Ablation: Was the FA3 + sink backward sufficient for stable training (with MoE materialization and sequence parallelism only for correctness/memory), or did you still rely on rollout correction or the on-policy detach in the final runs? A small table (e.g. FA2 vs FA3+sink, with/without rollout correction) would make the relative impact clear.
Rollout engine: The article mentions both vLLM and SGLang; Figure 4 references SGLang with Triton kernels for the sink forward. Which engine was used for rollouts in the final experiments (Figures 5–7), and was it the same across GSM8K, VerifyIf, and ReTool? This matters for training–inference consistency.
Rollout correction vs sink fix: A one-sentence summary of what "rollout correction" (sequence-level importance sampling) does in your pipeline and whether it is still enabled in the FA3+sink runs would clarify how the two mechanisms interact.
GPT-OSS-120B: You mention the sink fix also works for 120B; any learning curves or metrics (even high-level) would be valuable for scaling.
- Implementation and release
For verification and adoption: any update on timeline for releasing the sink backward implementation, whether you plan to contribute it to the main FlashAttention repo or keep it in a fork, and whether there are interim options (e.g. a reference PyTorch implementation or a patch description) would help the community reproduce the results and align training with inference.
Thanks again for the clarification; the narrative (hypothesis testing → rollout correction → attention sink as root cause) is much clearer, and the work is clearly valuable for agentic RL on GPT-OSS.
Thank you for this detailed retrospective on enabling agentic RL training for GPT-OSS. The article clearly documents important challenges and practical solutions. I have a few questions and observations that might help clarify some aspects.
Mathematical Clarification on Attention Sink Gradient
Regarding the attention sink backward pass derivation (Section 4), I'd appreciate clarification on the notation and justification. The article presents the general formula: ∂L/∂S_h = -Σ_i P_{i,h} (∂L/∂S_{i,h} - Σ_{j∈{1,...,N_k}} P_{ij} ∂L/∂S_{ij}). Then states that ∂L/∂S_{i,h} = 0 because "the sink is computed but not used in the output," leading to the simplified form: ∂L/∂S_h = -Σ_i P_{i,h} (Σ_j P_{ij} ∂L/∂S_{ij}).
I believe the final formula is correct, but the notation ∂L/∂S_{i,h} is unclear. If this refers to the gradient of the loss with respect to the sink's contribution to the output at position i (i.e., ∂L/∂(P_{i,h} V_sink)), then indeed it would be zero since there's no V_sink term. However, the sink parameter S_h itself does affect the output through softmax normalization: P_{ij} = exp(S_{ij}) / (Σ_{j'} exp(S_{ij'}) + exp(S_h)). When S_h changes, all P_{ij} change, affecting O_i = Σ_j P_{ij} V_j. The gradient ∂L/∂S_h is non-zero and flows through this normalization.
Could you clarify what ∂L/∂S_{i,h} represents in your notation? Also, could you provide the full derivation showing how you arrive at the final form from the chain rule through the softmax operation? A step-by-step derivation would help readers verify the implementation and understand how the gradient flows through the attention mechanism.
On-Policy Fix: Clarification Needed
The solution for the MoE log-probability mismatch (Section 2) sets old_log_prob = log_prob.detach() when on_policy=True. While this mathematically forces the importance sampling ratio to 1, I'm curious about the implications. If the policy parameters change between rollout collection and the training step (even slightly), this fix would still assume ratio = 1. How do you ensure true on-policy conditions? Do you collect rollouts and train in the same step, or is there a mechanism to detect policy drift?
Have you considered storing the actual routing decisions during rollout to ensure deterministic replay, or using deterministic routing during training? This might address the root cause rather than working around it. Also, the fix only applies when on_policy=True. How do you handle cases where the policy has drifted? Is there a threshold or mechanism to detect when to switch to off-policy mode?
Experimental Details
Several aspects would benefit from additional detail. Which fix had the largest impact? Are all three fixes necessary, or would fixing attention sinks alone have been sufficient? A table showing results with/without each fix would be very helpful. You mention the fixes work for GPT-OSS-120B, but no experimental results are shown. Could you share learning curves or metrics for the 120B model?
The article mentions both vLLM and SGLang. Which engine is used for rollouts in the final experiments? If different engines are used for different experiments, how does this affect the training-inference mismatch analysis? You also reference "rollout correction (sequence-level importance sampling)" and mention "this verl blog," but the details aren't explained. Could you clarify what this correction does and how it relates to the attention sink fix?
Implementation and Reproducibility
The article mentions that "The implementation will be released following the internal review process." For reproducibility and verification, it would be valuable to know an approximate timeline for release, whether the backward pass implementation will be contributed to the main FlashAttention repository, and if there are any interim workarounds or partial implementations that could be shared.
Positive Aspects
Despite these questions, the article makes valuable contributions. It clearly identifies real-world challenges with MoE + RL training, documents the engineering journey transparently, shows substantial improvements in training stability, and raises awareness of training-inference mismatch issues. The fixes appear to work well for your use case, and the results demonstrate clear improvements. My questions are aimed at understanding the details better and ensuring the solutions are robust and reproducible.
Thank you again for sharing this work with the community!