Buckets:
Title: SageBwd: A Trainable Low-bit Attention
URL Source: https://arxiv.org/html/2603.02170
Markdown Content: Jintao Zhang, Marco Chen 1 1 footnotemark: 1, Haoxu Wang 1 1 footnotemark: 1, Kai Jiang, Ion Stoica, Joseph E. Gonzalez,
Jianfei Chen,Jun Zhu
Tsinghua University, UC Berkeley
{zhang-jt24@mails., jianfeic@, dcszj@}tsinghua.edu.cn
Abstract
Low-bit attention, such as SageAttention, has emerged as an effective approach for accelerating model inference, but its applicability to training remains poorly understood. In prior work, we introduced SageBwd, a trainable INT8 attention that quantizes six of seven attention matrix multiplications while preserving fine-tuning performance. However, SageBwd exhibited a persistent performance gap to full-precision attention (FPA) during pre-training. In this work, we investigate why this gap occurs and demonstrate that SageBwd matches full-precision attention during pretraining. Through experiments and theoretical analysis, we reach a few important insights and conclusions: (i) QK-norm is necessary for stable training at large tokens per step, (ii) quantization errors primarily arise from the backward-pass score gradient ππ\mathbf{dS}, (iii) reducing tokens per step enables SageBwd to match FPA performance in pre-training, and (iv) K-smoothing remains essential for training stability, while Q-smoothing provides limited benefit during pre-training.
1 Introduction
Motivation.
The efficiency of attention(Vaswani, 2017) is critical for modern generative models, particularly as context lengths continue to grow and the quadratic complexity of scaled dot-product attention becomes a bottleneck(Vaswani, 2017; Jiang et al., 2024). Low-bit quantization offers a promising approach to reducing this cost by enabling the use of low-precision Tensor Cores on GPUs(Chen et al., 2020). Recent methods such as SageAttention(Zhang et al., 2025d; a; f) and FlashAttention3(Shah et al., 2024) have shown that low-bit attention can be highly effective for inference; however, its applicability to training, particularly large-scale pre-training, remains less well understood.
Challenge.
Designing a trainable low-bit attention mechanism is challenging because the backward pass is substantially more sensitive to quantization error than the forward pass. In particular, computing gradients involves products of small-magnitude tensors and repeated error propagation through the chain rule, which can amplify quantization errors. Moreover, quantization error in the forward output π\mathbf{O} propagates directly through the backward computation, inducing deviations even when the backward pass matrix multiplications (MatMuls) themselves are executed in higher precision.
Contributions and insights.
In prior work(Zhang et al., 2025c), we introduced SageBwd, a trainable low-bit attention mechanism that quantizes six of the seven attention matrix multiplications to INT8 while preserving fine-tuning performance. However, during pre-training, SageBwd exhibited a persistent performance gap relative to full-precision attention (FPA). In this work, we provide theoretical analyses and make empirical observations regarding the sources of this gap and identify conditions under which SageBwd recovers FPA-level pre-training performance.
Key findings. First, we identify the dominant source of training deviation as the ππ\mathbf{dS} tensor in the backward pass, whose small magnitude makes it particularly vulnerable to upstream quantization error. Second, we show that QK-norm stabilizes pre-training by constraining queryβkey outliers. Third, we find that reducing the number of tokens per optimization step allows SageBwd to match FPA pre-training performance, suggesting that increased gradient noise can mitigate the impact of quantization error. Finally, through targeted ablations, we show that K-smoothing remains necessary for stable training, while Q-smoothing provides limited benefit in the pre-training setting.
2 Related Work
Hardware-efficient attention.
A line of recent work accelerates attention by optimizing GPU kernel implementations. FlashAttention(Dao et al., 2022) reduces memory I/O by tiling attention computation to on-chip SRAM, achieving significant speedups over standard attention. FlashAttention2(Dao, 2024) further improves parallelism and warp partitioning, while FlashAttention3(Shah et al., 2024) targets kernel-level optimizations on Hopper GPUs. Similarly, xFormers(Lefaudeux et al., 2022) provides a collection of custom CUDA kernels for efficient attention variants.
Low-bit and quantized attention.
Another line of work accelerates attention by leveraging low-precision tensor cores. SageAttention(Zhang et al., 2025d), SageAttention2(Zhang et al., 2025a), and SageAttention2++(Zhang et al., 2025f) combine INT8 quantization with outlier-smoothing techniques to enable efficient attention computation. FlashAttention3(Shah et al., 2024) proposes an FP8 attention variant; however, it is not directly applicable to large generative models such as video diffusion in a plug-and-play manner(Zhang et al., 2025a). More broadly, these low-bit attention methods are primarily designed for inference and do not support training, limiting their applicability in pre-training and fine-tuning settings.
Trainable low-bit attention.
SageAttention3(Zhang et al., 2025c) introduces two complementary advances: (i) an extension of SageAttention2++ that improves inference-side low-bit attention, and (ii) SageBwd, a trainable low-bit attention mechanism that quantizes most attention matrix multiplications while preserving fine-tuning performance. This work builds on the SageBwd component of SageAttention3 by analyzing the sources of training instability in low-bit attention and characterizing the conditions under which full-precision attention performance can be recovered during pre-training.
3 Preliminaries
FlashAttention.
Scaled dot-product attention computes π=ππβ€,π=softmaxβ‘(π),π=ππ\mathbf{S}=\mathbf{Q}\mathbf{K}^{\top},\mathbf{P}=\operatorname{softmax}(\mathbf{S}),\mathbf{O}=\mathbf{P}\mathbf{V} where π,π,πββ NΓD\mathbf{Q},\mathbf{K},\mathbf{V}\in\mathbb{R}^{N\times D} and π,πββ NΓN\mathbf{P},\mathbf{S}\in\mathbb{R}^{N\times N}, with N N denoting the sequence length and D D the head dimension. FlashAttention tiles the sequence dimension by chunking π,π,π\mathbf{Q},\mathbf{K},\mathbf{V} into blocks {π i},{π j},{π j}{\mathbf{Q}{i}},{\mathbf{K}{j}},{\mathbf{V}{j}}, where π iββ B qΓD\mathbf{Q}{i}\in\mathbb{R}^{B_{q}\times D} and π j,π jββ B kβvΓD\mathbf{K}{j},\mathbf{V}{j}\in\mathbb{R}^{B_{kv}\times D}. It then avoids the quadratic IO overhead of materializing π\mathbf{S} and π\mathbf{P} in global memory by using an online softmax and fusing all operations into a single kernel: π iβj=π iβπ jβ€,π iβj=OnlineSoftmaxβ‘(π iβj),π i=βj π iβjβπ j\mathbf{S}{ij}=\mathbf{Q}{i}\mathbf{K}{j}^{\top},\mathbf{P}{ij}=\operatorname{OnlineSoftmax}(\mathbf{S}{ij}),\mathbf{O}{i}=\sum_{j}\mathbf{P}{ij}\mathbf{V}{j}.
Quantization.
Quantization accelerates matrix multiplication by representing high-precision matrices with low-bit numeric formats and floating-point scale factors. Given a high-precision matrix πββ mΓn\mathbf{X}\in\mathbb{R}^{m\times n}, its INT8 quantization π^\hat{\mathbf{X}} is defined as π^:=roundβ‘(π/Ξ΄ π),\hat{\mathbf{X}}:=\operatorname{round}(\mathbf{X}/\delta_{\mathbf{X}}), where Ξ΄ π>0\delta_{\mathbf{X}}>0 is a scale factor, typically computed as Ξ΄ π=maxβ‘(|π|)/127\delta_{\mathbf{X}}=\max(|\mathbf{X}|)/127, and stored in FP32.
Subsequently, given two FP16 matrices π\mathbf{A} and π\mathbf{B}, their approximate matrix product under INT8 quantization is computed as
ππβΞ΄ πβΞ΄ πβ π^βπ^,\mathbf{A}\mathbf{B}\approx\delta_{\mathbf{A}}\delta_{\mathbf{B}}\cdot\hat{\mathbf{A}}\hat{\mathbf{B}},
where the integer matrix multiplication π^βπ^\hat{\mathbf{A}}\hat{\mathbf{B}} can be accelerated using INT8 tensor cores.
The granularity of quantization refers to the scope over which the scale factor Ξ΄\delta is computed. Common choices include per-tensor, per-channel, and per-block quantization. In per-block quantization, a single scale factor is shared by all elements within a block, e.g., a FlashAttention tile.
Q and K Smoothing in SageAttention.
SageAttention(Zhang et al., 2025d; a; c) extends FlashAttention by quantizing π\mathbf{Q} and π\mathbf{K} to low precision for efficient inference. To mitigate the effect of channel-wise outliers prior to quantization, SageAttention introduced the preprocessing techniques of Q- and K-smoothing. Given query and key blocks π iββ BΓd\mathbf{Q}{i}\in\mathbb{R}^{B\times d} and π jββ BΓd\mathbf{K}{j}\in\mathbb{R}^{B\times d}, SageAttention computes a block-wise mean for queries and a global mean for keys:
π Q i=mean rowβ‘(π i),π K=mean rowβ‘(π),\boldsymbol{\mu}{Q{i}}=\operatorname{mean}{\text{row}}(\mathbf{Q}{i}),\qquad\boldsymbol{\mu}{K}=\operatorname{mean}{\text{row}}(\mathbf{K}),
where π Q i,π Kββ 1Γd\boldsymbol{\mu}{Q{i}},\boldsymbol{\mu}_{K}\in\mathbb{R}^{1\times d} and π\mathbf{K} denotes the full key tensor. The smoothed tensors are defined as
π i sm=π iβπ Q i,π j sm=π jβπ K.\mathbf{Q}{i}^{\mathrm{sm}}=\mathbf{Q}{i}-\boldsymbol{\mu}{Q{i}},\qquad\mathbf{K}{j}^{\mathrm{sm}}=\mathbf{K}{j}-\boldsymbol{\mu}_{K}.
The attention logits admit the decomposition
π iβπ jβ€=π i smβπ j smβ€+π Q iβπ j smβ€+π i smβπ Kβ€+π Q iβπ Kβ€.\mathbf{Q}{i}\mathbf{K}{j}^{\top}=\mathbf{Q}{i}^{\mathrm{sm}}\mathbf{K}{j}^{\mathrm{sm},\top}+\boldsymbol{\mu}{Q{i}}\mathbf{K}{j}^{\mathrm{sm},\top}+\mathbf{Q}{i}^{\mathrm{sm}}\boldsymbol{\mu}{K}^{\top}+\boldsymbol{\mu}{Q_{i}}\boldsymbol{\mu}_{K}^{\top}.
Since the softmax operation is invariant to adding a constant to each row, SageAttention applies low-bit quantization to the smoothed tensors π i sm\mathbf{Q}{i}^{\mathrm{sm}} and π j sm\mathbf{K}{j}^{\mathrm{sm}}, computes the dominant term π i smβπ j smβ€\mathbf{Q}{i}^{\mathrm{sm}}\mathbf{K}{j}^{\mathrm{sm},\top} using low-bit tensor cores, and adds back the remaining low-rank bias term to recover the logits. When only K-smoothing is applied, the additive bias term vanishes.
SageBwd.
SageAttention3(Zhang et al., 2025c) proposes SageBwd, a trainable low-bit attention mechanism. In the forward pass, SageBwd applies K-smoothing prior to per-block INT8 quantization of ππβ€\mathbf{Q}\mathbf{K}^{\top}, and uses a mixed per-token or per-block quantization scheme for the π~βπ\tilde{\mathbf{P}}\mathbf{V} product.
In the backward pass, attention gradients involve the following matrix multiplications:
π=ππβ€,ππ=πβ€βππ,ππ=πππβ€,ππ=πππ,ππ=ππβ€βπ,\mathbf{S}=\mathbf{Q}\mathbf{K}^{\top},\quad\mathbf{dV}=\mathbf{P}^{\top}\mathbf{dO},\quad\mathbf{dP}=\mathbf{dO}\mathbf{V}^{\top},\quad\mathbf{dQ}=\mathbf{dS}\mathbf{K},\quad\mathbf{dK}=\mathbf{dS}^{\top}\mathbf{Q},
where
ππ=πβ(ππβπΉβπβ€),πΉ=rowsumβ‘(ππβπ).\mathbf{dS}=\mathbf{P}\circ(\mathbf{dP}-\boldsymbol{\delta}\mathbf{1}^{\top}),\qquad\boldsymbol{\delta}=\operatorname{rowsum}(\mathbf{dO}\circ\mathbf{O}).
SageBwd retains ππ=πππβ€\mathbf{dP}=\mathbf{dO}\mathbf{V}^{\top} in FP16 precision, while quantizing the remaining four matrix multiplications using per-block INT8. This design choice avoids error amplification through ππ\mathbf{dS}, ππ\mathbf{dQ}, and ππ\mathbf{dK} that arises when ππ\mathbf{dP} is quantized. SageAttention3 demonstrates that this formulation preserves fine-tuning performance, while its behavior during full pre-training remains less well understood. A pseudocode description is provided in Appendix A.
4 Analysis of SageBwd in Pretraining
In this section, we analyze which design choices in SageBwd are necessary to match full-precision attention (FPA) performance during pre-training. Our analysis focuses on four central aspects: (i) controlling queryβkey outliers via QK-norm, (ii) identifying the most sensitive tensor in the INT8 backward pass, (iii) understanding how tokens-per-step interacts with quantization noise, and (iv) characterizing the effect of the activation scale through controlled QK standard deviation experiments. Together, these analyses provide a mechanistic explanation for the observed behavior of SageBwd. In Section 5, we empirically validate the resulting conclusions via pre-training experiments.
4.1 Stabilizing Outliers with QK-Norm
QK-norm for logit stabilization.
In scaled dot-product attention, the logits π=ππβ€/d\mathbf{S}=\mathbf{Q}\mathbf{K}^{\top}/\sqrt{d} scale with the norms of π\mathbf{Q} and π\mathbf{K}. During pre-training, these norms tend to increase, leading to large logits that can saturate the softmax or trigger numerical instabilities, particularly under low-precision arithmetic(Anson and Aitchison, 2025; Dehghani et al., 2023). QK-norm(Henry et al., 2020) addresses this issue by applying RMS normalization to each token in π\mathbf{Q} and π\mathbf{K}, explicitly controlling their scale and keeping the logits within a numerically stable range throughout training.
QK-norm for quantization.
Beyond stabilizing the softmax, QK-norm is also useful in improving the robustness of low-bit attention. Prior work, such as SageAttention, combines channel-wise smoothing of π\mathbf{Q} and π\mathbf{K} with fine-grained quantization to mitigate the effect of extreme outliers(Zhang et al., 2025d; a). QK-norm complements these techniques: by compressing the dynamic range of π\mathbf{Q} and π\mathbf{K}, it reduces the effective quantization step size under uniform INT8 quantization, pulling outliers closer to the rest of the distribution and improving quantization accuracy. As shown in Subsection 5.3, this effect is particularly important during pre-training with SageBwd.
4.2 Sensitivity of ππ\mathbf{dS} in the Backward Pass
A central challenge in training low-bit attention is the accurate computation of the softmax-gradient tensor ππ\mathbf{dS}. Empirically, Table 2 shows that the discrepancy between SageBwd and FPA peaks at ππ\mathbf{dS}, with errors further propagating to ππ\mathbf{dQ} and ππ\mathbf{dK}. This behavior indicates that ππ\mathbf{dS} likely constitutes the primary numerical bottleneck in the INT8 backward pass.
Why ππ\mathbf{dS} is intrinsically fragile.
The sensitivity of ππ\mathbf{dS} stems from its systematically small magnitude. Recall that
ππ=πβ(ππβπΉ),πΉ=rowsumβ‘(ππβπ),\mathbf{dS}=\mathbf{P}\circ(\mathbf{dP}-\boldsymbol{\delta}),\qquad\boldsymbol{\delta}=\operatorname{rowsum}(\mathbf{dO}\circ\mathbf{O}),
where π\mathbf{P} is the softmax output. As shown in Appendix B, the RMS of ππ\mathbf{dS} admits the upper bound
RMSβ‘(ππ)β€1 Nβmax iβ‘βππ iβπΉ iβπββ,\operatorname{RMS}(\mathbf{dS})\leq\frac{1}{\sqrt{N}}\max_{i}\left|\mathbf{dP}{i}-\boldsymbol{\delta}{i}\mathbf{1}\right|_{\infty},
where N N is the sequence length. This 1/N 1/\sqrt{N} scaling implies that ππ\mathbf{dS} becomes increasingly small for long sequences, even when upstream gradients are well behaved.
Implications for INT8 quantization.
INT8 quantization introduces approximately fixed absolute noise determined by the quantization step size(Jacob et al., 2018). For tensors with large magnitude, this noise is often tolerable; however, when the signal itself is small, the same absolute error translates into a large relative error. As a result, ππ\mathbf{dS} exhibits a much poorer effective signal-to-noise ratio under INT8 quantization than other intermediate tensors. This issue is exacerbated by the multiplicative structure of ππ=πβ(ππβπΉ)\mathbf{dS}=\mathbf{P}\circ(\mathbf{dP}-\boldsymbol{\delta}), which combines quantization noise from both forward-pass tensors (π\mathbf{P}, π\mathbf{O}) and backward-pass tensors (ππ\mathbf{dP}, ππ\mathbf{dO}).
Empirical scale of ππ\mathbf{dS}.
We empirically verify this analysis by measuring the RMS values of π\mathbf{P}, ππ\mathbf{dP}, and ππ\mathbf{dS} from a representative layer and head of a QK-normed SageBwd checkpoint that was trained over 78B tokens with 2.1M tokens per step and a sequence length of N=4096 N=4096:
RMSβ‘(π)β5Γ10β3,RMSβ‘(ππ)β5Γ10β5,RMSβ‘(ππ)β1Γ10β7.\operatorname{RMS}(\mathbf{P})\approx 5\times 10^{-3},\quad\operatorname{RMS}(\mathbf{dP})\approx 5\times 10^{-5},\quad\operatorname{RMS}(\mathbf{dS})\approx 1\times 10^{-7}.
While the theoretical bound suggests that ππ\mathbf{dS} should be at most 1/4096β1/64 1/\sqrt{4096}\approx 1/64 times the scale of ππ\mathbf{dP}, we observe a ratio closer to 500 500 in practice. Although the bound is loose, this discrepancy only further highlights how tightly constrained the magnitude of ππ\mathbf{dS} is in realistic training settings.
Propagation to ππ\mathbf{dQ} and ππ\mathbf{dK}.
Finally, since ππ\mathbf{dQ} and ππ\mathbf{dK} are obtained through matrix multiplications with ππ\mathbf{dS}, quantization errors in ππ\mathbf{dS} propagate directly and are amplified by the norms of π\mathbf{Q} and π\mathbf{K}. This amplification becomes more pronounced at longer sequence lengths, consistent with observations in prior work(Zhang et al., 2025c).
4.3 Effect of Tokens-per-Step
We define tokens-per-step (TPS) as the total number of tokens processed in a single optimizer update. In our experiments (Section 5), we fix the sequence length and vary the global batch size, making TPS directly proportional to batch size. We observe that TPS has a significant impact on pre-training behavior: at a large TPS of 2.1M, SageBwd consistently underperforms FPA, whereas at a smaller TPS of 260K, SageBwd matches FPA within noise.
Gradient noise and quantization error.
Under a fixed token budget, increasing TPS results in fewer, more deterministic gradient updates, while decreasing TPS yields more frequent updates with higher stochastic gradient noise(Smith et al., 2018). Prior work has shown that large-batch training reduces gradient noise and can alter optimization dynamics(Keskar et al., 2017). In the context of low-bit attention, we hypothesize that, at large TPS, the reduced gradient noise makes systematic quantization error in the backward passβparticularly along the sensitive ππ\mathbf{dS} path (Subsection 4.2)βmore salient to the optimizer. This persistent, biased error may then influence the optimization trajectory, leading to convergence toward a stable but suboptimal solution.
At smaller TPS, the inherent stochasticity of gradient updates is higher. In this regime, INT8 quantization error likely acts as a small perturbation relative to gradient noise and therefore may not significantly alter the training trajectory, allowing SageBwd to recover FPA-level pre-training performance. While this explanation provides a plausible mechanism linking TPS and quantization error, we do not rule out the influence of other batch-sizeβdependent optimization effects.
Sequence length as a potential factor.
In this work, we vary TPS only through batch size while holding the sequence length fixed. However, since ππ\mathbf{dQ} and ππ\mathbf{dK} involve matrix multiplications over the sequence dimension N N, longer sequences aggregate contributions from a larger number of ππ\mathbf{dS} entries. As a result, upstream errors may be further amplified at larger sequence lengths. A systematic study of how sequence length interacts with TPS and quantization error is left to future work.
4.4 Effect of QK Standard Deviation on Quantization Error
To isolate the effect of activation scale on quantization error, we evaluate SageBwd using synthetic Gaussian attention inputs in which the standard deviations of π\mathbf{Q} and π\mathbf{K} (Ο Q\sigma_{Q}, Ο K\sigma_{K}) are varied, while holding Ο V\sigma_{V} and Ο dβO\sigma_{dO} fixed at 1. This controlled setting removes optimizer dynamics and directly probes the sensitivity of quantized attention to activation scale, simulating the typical growth of query and key norms observed during pre-training.
As shown in Table1, the accuracy of SageBwd degrades sharply as Ο Q\sigma_{Q} and Ο K\sigma_{K} increase. While the output π\mathbf{O} and gradient ππ\mathbf{dV} remain relatively accurate, the gradients ππ\mathbf{dQ} and ππ\mathbf{dK} exhibit severe error, with cosine similarity dropping below 0.79 and relative β 2\ell_{2} error exceeding 0.66 at Ο Q,K=10\sigma_{Q,K}=10.
Table 1: Sage vs. FPA across random QKV with varying Ο Q\sigma_{Q} and Ο K\sigma_{K}
Ο Q,Ο K\sigma_{Q},\sigma_{K}Output dQ dK dV CosSim Rel-β 2\ell_{2}CosSim Rel-β 2\ell_{2}CosSim Rel-β 2\ell_{2}CosSim Rel-β 2\ell_{2} 1 0.9999 0.0160 0.9998 0.0184 0.9998 0.0220 0.9999 0.0159 3 0.9992 0.0389 0.9971 0.0758 0.9970 0.0777 0.9992 0.0387 5 0.9982 0.0603 0.9798 0.2014 0.9799 0.2007 0.9982 0.0605 8 0.9953 0.0972 0.8900 0.4666 0.8886 0.4699 0.9953 0.0973 10 0.9933 0.1161 0.7823 0.6648 0.7820 0.6684 0.9933 0.1157
Intuitively, increasing Ο Q\sigma_{Q} and Ο K\sigma_{K} inflates the dynamic range of tensors involved in quantized matrix multiplications, increasing the quantization step size under uniform INT8 quantization and thereby amplifying absolute quantization noise(Jacob et al., 2018). These errors are especially harmful when they feed into the softmax gradient computation, since ππ\mathbf{dS} has comparatively small magnitude (Subsection 4.2), resulting in a poor effective signal-to-noise ratio. Consequently, even moderate upstream absolute errors can translate into large relative errors in ππ\mathbf{dS} and propagate to ππ\mathbf{dQ} and ππ\mathbf{dK}.
This analysis also further clarifies the role of QK-norm in stabilizing pre-training, as mentioned in Subsection 4.1. Evidently, by normalizing π\mathbf{Q} and π\mathbf{K}, QK-norm bounds their effective scale and reduces the dynamic range seen by INT8 quantization, yielding much higher quantization accuracy. However, since QK-norm includes a learned RMSNorm scale vector πΈ\boldsymbol{\gamma}, which tends to increase gradually during pre-training(Xiao et al., 2023), the effective Ο Q\sigma_{Q} and Ο K\sigma_{K} may still grow over time. Once this growth surpasses a critical error threshold, the quantization noise along the ππ\mathbf{dS} path can become dominant, providing an explanation for the pre-training instability observed at large tokens-per-step in Section 5, even when QK-norm is applied.
5 Experiments
Core results.
At 260K tokens-per-step (TPS), SageBwd pretrains with performance on par with full-precision attention (FPA), regardless of QK-norm. However, at 2.1M TPS, QK-norm is necessary to avoid loss explosion. In general, quantization error tends to spike at the intermediate tensor ππ\mathbf{dS}.
5.1 Setup
We implement SageBwd using OpenAI Triton (Tillet et al., 2019) and conduct pretraining experiments with a 325M Llama model (Dubey et al., 2024)) over 78B tokens of the OpenWebText dataset (Gokaslan et al., 2019). All runs use BF16 mixed precision and are trained on a single B200 or RTX4090 GPU node. Across all experiments, we use cosine learning rate scheduling, a context length of 4096 4096, a hidden dimension of 3072 3072, the GPT2 tokenizer, a norm epsilon of 1e-6, and a learning rate of 3e-5. By default, all experiments in this section apply K-smoothing but not Q-smoothing.
(a) 2.1M Tokens/Step
(b) 260K Tokens/Step
Figure 1: Pretraining loss over 78B tokens under a different number of tokens/step
Table 2: Cosine similarity and relative β 2\ell_{2} error for intermediate tensors in SageBwd (vs. FPA).
Metric πΉ\boldsymbol{\delta}π\mathbf{P}ππ\mathbf{dP}ππ\mathbf{dS}π\mathbf{O}ππ\mathbf{dQ}ππ\mathbf{dK}ππ\mathbf{dV} CosSim 0.9973 0.9917 1.0000 0.9789 0.9969 0.9664 0.9537 0.9985 Rel-L2 0.0736 0.1293 0.0000 0.2045 0.0793 0.2579 0.3074 0.0540
5.2 Effect of TPS on SageBwd vs. full-precision attention pretraining
1(a) and 1(b) compare the pre-training performance of SageBwd and FPA at 2.1M and 260K TPS, respectively. At the larger TPS of 2.1M, SageBwd exhibits a clear gap relative to FPA: after 37.5k training steps with a global batch size of 512 (including 1k warmup steps), SageBwd reaches a loss of 2.640, whereas FPA attains 2.586. In contrast, at the smaller TPS of 260K, where training is performed for 300k steps with a global batch size of 64 (including 7.5k warmup steps), SageBwd matches FPA within noise, achieving a loss of 2.561 compared to 2.563 for FPA.
5.3 QK-norm Is Necessary at High TPS
As shown in 1(a), at a large TPS of 2.1M, removing QK-norm leads to training instability and eventual divergence. This behavior is consistent with increased quantization error arising from unconstrained query and key magnitudes. In contrast, for the smaller-TPS runs in 1(b), SageBwd matches FPA even without QK-norm.
Despite this apparent robustness at low TPS, intermediate-tensor analysis reveals a different picture. As reported in Appendix C, the non-normed runs exhibit notably larger relative β 2\ell_{2} error and lower cosine similarity than their QK-normed counterparts, even at 260K TPS. This observation aligns with our hypothesis in Subsection 4.3 that increased gradient noise at lower TPS can mask moderate quantization errors without eliminating them.
Combined with the controlled activation-scale analysis in Subsection 4.4, these results indicate that QK-norm is a critical component for robust low-bit attention training at scale.
5.4 Tracing intermediate tensor error
In FlashAttention-style kernels, intermediate attention tensors such as π\mathbf{P}, π\mathbf{S}, ππ\mathbf{dP}, and ππ\mathbf{dS} are not explicitly materialized, making direct accuracy inspection difficult. To isolate quantization-induced error, we construct a pseudo-quantized FPA baseline: we extract full-precision π\mathbf{Q}, π\mathbf{K}, π\mathbf{V}, and ππ\mathbf{dO} from layer 11 of SageBwd with QK-norm in the 2.1M TPS run (the most error-prone layer identified in 5(a) and Appendix C). We then apply the SageBwd INT8 quantizeβdequantize scheme before each relevant matrix multiplication in a PyTorch attention implementation and compare all intermediate tensors against full-precision FPA using cosine similarity and relative β 2\ell_{2} error.
As shown in Table 2, most intermediates, including π\mathbf{O} and ππ\mathbf{dV}, remain very close to FPA. In contrast, ππ\mathbf{dS} and its subsequent downstream gradients ππ\mathbf{dQ} and ππ\mathbf{dK} exhibit substantially larger deviations. This provides direct evidence that the ππ\mathbf{dS} computation constitutes the primary quantization bottleneck in the backward pass of SageBwd. In this analysis, the upstream gradient ππ\mathbf{dO} is treated as error-free; hence, ππ\mathbf{dP} appears perfectly accurate.
5.5 Kernel Performance
Figure 2 and Figure 3 report the end-to-end forward and backward kernel throughput of SageBwd compared to baseline attention implementations on an RTX4090. Across head dimensions D=64 D=64 and D=128 D=128, SageBwd consistently outperforms FlashAttention2, achieving up to a 1.67Γ\times speedup, and exceeds the performance of Triton- and xFormers-based FlashAttention2 implementations, too.
We note that our current implementation prioritizes correctness and stability over aggressive kernel fusion, and further speed improvements are likely achievable with additional optimizations.
Figure 2: Speed comparison between SageBwd and Baselines (RTX4090, headim=128).
Figure 3: Speed comparison between SageBwd and Baselines (RTX4090, headim=64).
6 Ablation Study
In the main experiments (Section 5), we apply K-smoothing by default. However, prior SageAttention work also employs Q-smoothing as a key technique for improving quantization accuracy(Zhang et al., 2025d; a). In this section, we ablate the effects of Q- and K-smoothing in SageBwd to clarify their respective roles during pre-training.
In Figure 4, we compare full-precision attention (FPA) with SageBwd under three settings: no smoothing, K-smoothing, and QK-smoothing, at both 2.1M and 260K tokens per step (TPS). Due to computational constraints, we do not evaluate Q-smoothing in isolation. All runs use QK-norm and identical training hyperparameters to those in Section 5.
(a) 2.1M Tokens/Step
(b) 260K Tokens/Step
Figure 4: Ablation of Q-smoothing and K-smoothing pretraining loss over 78B tokens under a different number of tokens/step
Contrary to expectations from prior work, we find that while K-smoothing remains essential for stable pre-training, Q-smoothing provides no consistent benefit and can slightly degrade gradient accuracy.
K-smoothing is necessary for stable training.
Consistent with prior work, Figure 4 shows that K-smoothing, a technique where the token-wise mean of π\mathbf{K} is subtracted prior to quantization, is critical for maintaining pre-training stability. Even in the more noise-tolerant 260K TPS regime, K-smoothing is required to achieve FPA-level performance.
From an implementation perspective, K-smoothing requires no modification to the backward pass. Smoothing can occur at kernel entry and the smoothed π\mathbf{K} can be used without additional bias terms.
The gradient computation for ππ=πππ\mathbf{dQ}=\mathbf{dS}\mathbf{K} remains valid even after π\mathbf{K} is smoothed because each row of ππ\mathbf{dS} sums to 0, so ππ(π mean row(π)β€)=0\mathbf{dS}(\mathbf{1}\operatorname{mean}{\text{row}}(\mathbf{K})^{\top})=0 and ππ=πππ=ππ(πβπ mean row(π)β€)=πππ sm\mathbf{dQ}=\mathbf{dS}\mathbf{K}=\mathbf{dS}(\mathbf{K}-\mathbf{1}\operatorname{mean}{\text{row}}(\mathbf{K})^{\top})=\mathbf{dS}\mathbf{K}^{\mathrm{sm}} where π sm\mathbf{K}^{\mathrm{sm}} denotes the key matrix π\mathbf{K} after smoothing.
Q-smoothing shows limited benefit.
In contrast, we observe no consistent improvement in either pre-training loss or intermediate-tensor accuracy from applying Q-smoothing. In some cases, Q-smoothing slightly degrades gradient fidelity. As shown in our error analysis (Appendix C), ππ\mathbf{dQ} and ππ\mathbf{dK} exhibit marginally larger deviation from the FPA baseline when Q-smoothing is enabled.
One contributing factor is the gradient correction required by Q-smoothing. Rewriting the logits as
π=(πβπβΞΌ Qβ€)βπβ€+πβ(ΞΌ Qβπβ€)with ΞΌ Q=meanβ‘(π).\mathbf{S}=(\mathbf{Q}-\mathbf{1}\mu_{Q}^{\top})\mathbf{K}^{\top}+\mathbf{1}(\mu_{Q}\mathbf{K}^{\top})\quad\text{with}\quad\mu_{Q}=\operatorname{mean}(\mathbf{Q}).
preserves forward equivalence with ππβ€\mathbf{Q}\mathbf{K}^{\top}. Therefore, the total gradient remains ππ=ππβ€βπ.\mathbf{dK}=\mathbf{dS}^{\top}\mathbf{Q}. Consequently, ππ\mathbf{dK} cannot be computed from the centered branch only, i.e., ππβ ππ center=ππβ€βπ sm\mathbf{dK}\neq\mathbf{dK}{\text{center}}=\mathbf{dS}^{\top}\mathbf{Q}^{\mathrm{sm}} with π sm=smoothβ‘(π)=πβπβΞΌ Qβ€\mathbf{Q}^{\mathrm{sm}}=\operatorname{smooth}(\mathbf{Q})=\mathbf{Q}-\mathbf{1}\mu{Q}^{\top}. We need to add an additional bias branch term, ππ bias=(ππβ€βπ)βΞΌ Qβ€\mathbf{dK}{\text{bias}}=(\mathbf{dS}^{\top}\mathbf{1}),\mu{Q}^{\top}, to recover the correct gradient:
ππ=ππβ€βπ=ππβ€β(π sm+πβΞΌ Qβ€)=ππβ€βπ sm+ππβ€βπβΞΌ Qβ€=ππ center+ππ bias.\mathbf{dK}=\mathbf{dS}^{\top}\mathbf{Q}=\mathbf{dS^{\top}}(\mathbf{Q}^{\mathrm{sm}}+\mathbf{1}\mu_{Q}^{\top})=\mathbf{dS}^{\top}\mathbf{Q}^{\mathrm{sm}}+\mathbf{dS}^{\top}\mathbf{1}\mu_{Q}^{\top}=\mathbf{dK}{\text{center}}+\mathbf{dK}{\text{bias}}.
This additional correction introduces another pathway for quantization noise, which may partially offset the benefits of reduced activation range. We leave a deeper investigation of when Q-smoothing benefits training-time quantization to future work.
7 Conclusion and Future Work
Conclusion.
In this paper, we extend SageBwd, a trainable low-bit attention mechanism. We analyze when SageBwd can match full-precision scaled dot-product attention during pre-training and find two key factors: (i) controlling outliers in π\mathbf{Q} and π\mathbf{K} via QK-norm is necessary for stability at large tokens-per-step, and (ii) the dominant accuracy bottleneck is the low-magnitude softmax gradient ππ\mathbf{dS}, which affects ππ\mathbf{dQ} and ππ\mathbf{dK}. Empirically, smaller tokens-per-step make training more tolerant to this noise, while larger tokens-per-step expose it and yield a stable but suboptimal gap.
Limitations and future work.
While SageBwd achieves FPA-level performance under moderate tokens-per-step, its training stability degrades at very large batch sizes. A key direction for future work is therefore to develop methods that mitigate backward-pass quantization error, particularly along the ππ\mathbf{dS} path, without relying on reduced batch size or increased gradient noise. In addition, although SageBwd already delivers considerable speedups over existing baselines, further kernel-level optimizations remain an important avenue for future research.
References
- B. Anson and L. Aitchison (2025)Controlling changes to attention logits. External Links: 2511.21377, LinkCited by: Β§4.1.
- J. Chen, Y. Gai, Z. Yao, M. W. Mahoney, and J. E. Gonzalez (2020)A statistical framework for low-bitwidth training of deep neural networks. Advances in neural information processing systems 33, pp.883β894. Cited by: Β§1.
- T. Dao, D. Fu, S. Ermon, A. Rudra, and C. RΓ© (2022)Flashattention: fast and memory-efficient exact attention with io-awareness. Advances in Neural Information Processing Systems 35, pp.16344β16359. Cited by: Β§2.
- T. Dao (2024)FlashAttention-2: faster attention with better parallelism and work partitioning. In The Twelfth International Conference on Learning Representations, Cited by: Β§2.
- M. Dehghani, J. Djolonga, B. Mustafa, P. Padlewski, J. Heek, J. Gilmer, A. P. Steiner, M. Caron, R. Geirhos, I. Alabdulmohsin, et al. (2023)Scaling vision transformers to 22 billion parameters. In International conference on machine learning, pp.7480β7512. Cited by: Β§4.1.
- A. Dubey, A. Jauhri, A. Pandey, A. Kadian, A. Al-Dahle, A. Letman, A. Mathur, A. Schelten, A. Yang, A. Fan, et al. (2024)The llama 3 herd of models. arXiv preprint arXiv:2407.21783. Cited by: Β§5.1.
- A. Gokaslan, V. Cohen, E. Pavlick, and S. Tellex (2019)OpenWebText corpus. Note: http://Skylion007.github.io/OpenWebTextCorpusCited by: Β§5.1.
- A. Henry, P. R. Dachapally, S. Pawar, and Y. Chen (2020)Query-key normalization for transformers. External Links: 2010.04245, LinkCited by: Β§4.1.
- Y. Hu, W. Huang, Z. Liang, C. Chen, J. Zhang, J. Zhu, and J. Chen (2025)Identifying sensitive weights via post-quantization integral. arXiv preprint arXiv:2503.01901. Cited by: Appendix C.
- Y. Hu, H. Singh, M. Maheswaran, H. Xi, C. Hooper, J. Zhang, A. Tomar, M. W. Mahoney, S. Min, M. Farajtabar, et al. (2026)Residual context diffusion language models. arXiv preprint arXiv:2601.22954. Cited by: Appendix C.
- B. Jacob, S. Kligys, B. Chen, M. Zhu, M. Tang, A. Howard, H. Adam, and D. Kalenichenko (2018)Quantization and training of neural networks for efficient integer-arithmetic-only inference. In Proceedings of the IEEE conference on computer vision and pattern recognition, pp.2704β2713. Cited by: Β§4.2, Β§4.4.
- H. Jiang, Y. LI, C. Zhang, Q. Wu, X. Luo, S. Ahn, Z. Han, A. H. Abdi, D. Li, C. Lin, Y. Yang, and L. Qiu (2024)MInference 1.0: accelerating pre-filling for long-context LLMs via dynamic sparse attention. In The Thirty-eighth Annual Conference on Neural Information Processing Systems, Cited by: Β§1.
- Y. Jiang, F. Fu, W. Zhao, S. Rabanser, N. D. Lane, and B. Yuan (2025a)Cascadia: a cascade serving system for large language models. arXiv preprint arXiv:2506.04203. Cited by: Appendix C.
- Y. Jiang, W. Li, Y. Peng, J. Zhang, R. Yan, J. Chen, X. Han, F. Fu, and B. Yuan (2025b)HexGen-3: a fully disaggregated llm serving framework with fine-grained heterogeneous resource autoscaling. Cited by: Appendix C.
- N. S. Keskar, D. Mudigere, J. Nocedal, M. Smelyanskiy, and P. T. P. Tang (2017)On large-batch training for deep learning: generalization gap and sharp minima. External Links: 1609.04836, LinkCited by: Β§4.3.
- B. Lefaudeux, F. Massa, D. Liskovich, W. Xiong, V. Caggiano, S. Naren, M. Xu, J. Hu, M. Tintore, S. Zhang, et al. (2022)Xformers: a modular and hackable transformer modelling library. Cited by: Β§2.
- J. Shah, G. Bikshandi, Y. Zhang, V. Thakkar, P. Ramani, and T. Dao (2024)FlashAttention-3: fast and accurate attention with asynchrony and low-precision. In The Thirty-eighth Annual Conference on Neural Information Processing Systems, Cited by: Β§1, Β§2, Β§2.
- S. L. Smith, P. Kindermans, C. Ying, and Q. V. Le (2018)Donβt decay the learning rate, increase the batch size. External Links: 1711.00489, LinkCited by: Β§4.3.
- P. Tillet, H. Kung, and D. Cox (2019)Triton: an intermediate language and compiler for tiled neural network computations. In Proceedings of the 3rd ACM SIGPLAN International Workshop on Machine Learning and Programming Languages, pp.10β19. Cited by: Β§5.1.
- A. Vaswani (2017)Attention is all you need. Advances in Neural Information Processing Systems. Cited by: Β§1.
- H. Xi, S. Yang, Y. Zhao, M. Li, H. Cai, X. Li, Y. Lin, Z. Zhang, J. Zhang, X. Li, et al. (2026)Quant videogen: auto-regressive long video generation via 2-bit kv-cache quantization. arXiv preprint arXiv:2602.02958. Cited by: Appendix C.
- C. Xiang, J. Liu, J. Zhang, X. Yang, Z. Fang, S. Wang, Z. Wang, Y. Zou, H. Su, and J. Zhu (2026)Geometry-aware rotary position embedding for consistent video world model. arXiv preprint arXiv:2602.07854. Cited by: Appendix C.
- G. Xiao, J. Lin, M. Seznec, H. Wu, J. Demouth, and S. Han (2023)Smoothquant: accurate and efficient post-training quantization for large language models. In International Conference on Machine Learning, pp.38087β38099. Cited by: Β§4.4.
- S. Yang, H. Xi, Y. Zhao, M. Li, J. Zhang, H. Cai, Y. Lin, X. Li, C. Xu, K. Peng, et al. (2025)Sparse videogen2: accelerate video generation with sparse attention via semantic-aware permutation. Advances in Neural Information Processing Systems (NeurIPS 2025). Cited by: Appendix C.
- J. Zhang, H. Huang, P. Zhang, J. Wei, J. Zhu, and J. Chen (2025a)Sageattention2: efficient attention with thorough outlier smoothing and per-thread int4 quantization. In International Conference on Machine Learning (ICML), Cited by: Β§1, Β§2, Β§3, Β§4.1, Β§6.
- J. Zhang, K. Jiang, C. Xiang, W. Feng, Y. Hu, H. Xi, J. Chen, and J. Zhu (2026a)SpargeAttention2: Trainable Sparse Attention via Hybrid Top-k+Top-p Masking and Distillation Fine-Tuning. arXiv preprint arXiv:2602.13515. Cited by: Appendix C.
- [27]J. Zhang, R. Su, C. Liu, J. Wei, Z. Wang, H. Wang, P. Zhang, H. Jiang, H. Huang, C. Xiang, et al.Efficient attention methods: hardware-efficient, sparse, compact, and linear attention. Cited by: Appendix C.
- J. Zhang, H. Wang, K. Jiang, S. Yang, K. Zheng, H. Xi, Z. Wang, H. Zhu, M. Zhao, I. Stoica, et al. (2025b)SLA: beyond sparsity in diffusion transformers via fine-tunable sparse-linear attention. arXiv preprint arXiv:2509.24006. Cited by: Appendix C.
- J. Zhang, H. Wang, K. Jiang, K. Zheng, Y. Jiang, I. Stoica, J. Chen, J. Zhu, and J. E. Gonzalez (2026b)SLA2: Sparse-Linear Attention with Learnable Routing and QAT. arXiv preprint arXiv:2602.12675. Cited by: Appendix C.
- J. Zhang, J. Wei, P. Zhang, X. Xu, H. Huang, H. Wang, K. Jiang, J. Zhu, and J. Chen (2025c)SageAttention3: microscaling fp4 attention for inference and an exploration of 8-bit training. arXiv preprint arXiv:2505.11594. Cited by: Β§1, Β§2, Β§3, Β§3, Β§4.2.
- J. Zhang, J. Wei, P. Zhang, J. Zhu, and J. Chen (2025d)SageAttention: accurate 8-bit attention for plug-and-play inference acceleration. In International Conference on Learning Representations (ICLR), Cited by: Β§1, Β§2, Β§3, Β§4.1, Β§6.
- J. Zhang, C. Xiang, H. Huang, H. Xi, J. Zhu, J. Chen, et al. (2025e)SpargeAttention: accurate and training-free sparse attention accelerating any model inference. In Forty-second International Conference on Machine Learning, Cited by: Appendix C.
- J. Zhang, X. Xu, J. Wei, H. Huang, P. Zhang, C. Xiang, J. Zhu, and J. Chen (2025f)Sageattention2++: a more efficient implementation of sageattention2. arXiv preprint arXiv:2505.21136. Cited by: Β§1, Β§2.
- J. Zhang, K. Zheng, K. Jiang, H. Wang, I. Stoica, J. E. Gonzalez, J. Chen, and J. Zhu (2025g)TurboDiffusion: accelerating video diffusion models by 100-200 times. arXiv preprint arXiv:2512.16093. Cited by: Appendix C.
- P. Zhang, J. Wei, J. Zhang, J. Zhu, and J. Chen (2025h)Accurate int8 training through dynamic block-level fallback. arXiv preprint arXiv:2503.08040. Cited by: Appendix C.
Appendix A SageBwd Algorithm
A.1 Forward Pass
1:Input:FP16 matrices
Q,K,Vββ NΓd Q,K,V\in\mathbb{R}^{N\times d} , and block size
B q,B kβv B_{q},B_{kv} .
2: Divide
Q Q to
T m=N/B q T_{m}={N}/{B_{q}} blocks
{π i}{\mathbf{Q}_{i}} ; divide
K K , and
V V to
T n=N/B kβv T_{n}={N}/{B_{kv}} blocks
{π i}{\mathbf{K}_{i}} ,
{π i}{\mathbf{V}_{i}} ;
3:Quantization:
{π¬ π,π^i}={Οβ(π i)}{\mathbf{s_{Q}},\hat{\mathbf{Q}}{i}}={\psi(\mathbf{Q}{i})} , {π¬ π,π^i}={Οβ(π iβ€)}{\mathbf{s_{K}},\hat{\mathbf{K}}{i}}={\psi(\mathbf{K}{i}^{\top})}, {π¬ π,π^i}={Οβ(π i)}{\mathbf{s_{V}},\hat{\mathbf{V}}{i}}={\psi(\mathbf{V}{i})} ; // Per-block.
4:for
i=1 i=1 to
T m T_{m} do
5:
π iββ B qΓD=(0)\mathbf{O}{i}\in\mathbb{R}^{B{q}\times D}=(0) ,
π iββ B q=(0),m iββ B kβv=(0)\mathbf{L}{i}\in\mathbb{R}^{B{q}}=(0),~~m_{i}\in\mathbb{R}^{B_{kv}}=(0) ;
6:for
j j in [1,
T n T_{n} ] do
7:
π iβj=MMβ(π^i,π^j)Γπ¬ πΓπ¬ π\mathbf{S}{ij}=\texttt{MM}(\hat{\mathbf{Q}}{i},\hat{\mathbf{K}}{j})\times\mathbf{s{Q}}\times\mathbf{s_{K}} ;
8:
m iβj=maxβ(m i,jβ1,rowmaxβ(π iβj))m_{ij}=\mathrm{max}(m_{i,j-1},\mathrm{rowmax}(\mathbf{S}_{ij})) ,
π~iβj=expβ(π iβjβm iβj)\widetilde{\mathbf{P}}{ij}=\mathrm{exp}(\mathbf{S}{ij}-m_{ij}) ,
l iβj=e m i,jβ1βm iβj+rowsumβ(π~iβj)l_{ij}=e^{m_{i,j-1}-m_{ij}}+\mathrm{rowsum}(\widetilde{\mathbf{P}}_{ij}) ;
9:
π¬ π=expβ(rowmaxβ(π iβj)βm iβj)/127\mathbf{s_{P}}=\mathrm{exp}(\mathrm{rowmax}(\mathbf{S}{ij})-m{ij})/127 , π^iβj=π~iβj/π¬ π\mathbf{\hat{\mathbf{P}}}{ij}=\widetilde{\mathbf{P}}{ij}/\mathbf{s_{P}} ; // Per-token quantization.
10:
π iβj=diagβ(e m i,jβ1βm iβj)β1βπ i,jβ1+MMβ(π^iβj,π^j)Γπ¬ πΓπ¬ π\mathbf{O}{ij}=\mathrm{diag}(e^{m{i,j-1}-m_{ij}})^{-1}\mathbf{O}{i,j-1}+{\color[rgb]{0,0,1}\definecolor[named]{pgfstrokecolor}{rgb}{0,0,1}{\texttt{MM}(\hat{\mathbf{P}}{ij},\hat{\mathbf{V}}{j})\times\mathbf{s{P}}\times\mathbf{s_{V}}}}
11:end for
12:
π i=diagβ(l i,T n)β1βπ i,T n\mathbf{O}{i}=\mathrm{diag}(l{i,T_{n}})^{-1}\mathbf{O}{i,T{n}} ;
13:
π i=m i,T n+logβ(l i,T n)\mathbf{L}{i}=m{i,T_{n}}+\mathrm{log}(l_{i,T_{n}}) ;
14:end for
15:return
O={π i}O={\mathbf{O}_{i}} ,
L={π i}L={\mathbf{L}_{i}} ;
Algorithm 1 Foward pass of the 8-bit attention.
A.2 Backward Pass
1:Input:
{π¬ π,π^i},{π¬ π,π^i}{\mathbf{s_{Q}},\hat{\mathbf{Q}}{i}},{\mathbf{s{K}},\hat{\mathbf{K}}{i}} , {π¬ π,π^i}{\mathbf{s{V}},\hat{\mathbf{V}}_{i}},
O O ,
{π i}{\mathbf{L}_{i}} from the forward,
dβOββ NΓd dO\in\mathbb{R}^{N\times d} , and block size
B q,B kβv B_{q},B_{kv} ;
2:
D=rowsumβ(dβOβO)D=\mathrm{rowsum}(dO\circ O) , divide
D D to
T m=N/B q T_{m}={N}/{B_{q}} blocks
{π i}{\mathbf{D}_{i}} ;
3:for
j=1 j=1 to
T n T_{n} do
4:for
i i in [1,
T m T_{m} ] do
5:
π iβj=MMβ(π^i,π^j)Γπ¬ πΓπ¬ π\mathbf{S}{ij}=\texttt{MM}(\hat{\mathbf{Q}}{i},\hat{\mathbf{K}}{j})\times\mathbf{s{Q}}\times\mathbf{s_{K}} ;
π iβj=expβ(π iβjβπ i)\mathbf{P}{ij}=\mathrm{exp}(\mathbf{S}{ij}-\mathbf{L}_{i}) ;
6:
π¬ π,π^iβj=Οβ(π iβj)\mathbf{s_{P}},\hat{\mathbf{P}}{ij}=\psi(\mathbf{P}{ij}) , π¬ ππ,ππ^i=Οβ(ππ i)\mathbf{s_{dO}},\hat{\mathbf{dO}}{i}=\psi(\mathbf{dO}{i}) ; // INT8 per-block quantization.
7:
ππ jβππ j+MMβ(π^iβjβ€,ππ^i)Γπ¬ πΓπ¬ ππ\mathbf{dV}{j}\leftarrow\mathbf{dV}{j}+{\color[rgb]{0,0,1}\definecolor[named]{pgfstrokecolor}{rgb}{0,0,1}{\texttt{MM}(\hat{\mathbf{P}}{ij}^{\top},\hat{\mathbf{dO}}{i})\times\mathbf{s_{P}}\times\mathbf{s_{dO}}}} ;
8:
ππ iβj=MMβ(ππ,π jβ€)\mathbf{dP}{ij}=\texttt{MM}(\mathbf{dO},\mathbf{V}{j}^{\top}) ; // Keep in FP16.
9:
ππ iβj=π iβjβ(ππ iβjβπ i)\mathbf{dS}{ij}=\mathbf{P}{ij}\circ(\mathbf{dP}{ij}-\mathbf{D}{i}) ;
π¬ ππ,ππ^iβj=Οβ(ππ iβj)\mathbf{s_{dS}},\hat{\mathbf{dS}}{ij}=\psi(\mathbf{dS}{ij}) ; // INT8 per-block quantization.
10:
ππ iβππ i+MMβ(ππ^iβj,π^j)Γπ¬ ππΓπ¬ π\mathbf{dQ}{i}\leftarrow\mathbf{dQ}{i}+{\color[rgb]{0,0,1}\definecolor[named]{pgfstrokecolor}{rgb}{0,0,1}{\texttt{MM}(\hat{\mathbf{dS}}{ij},\hat{\mathbf{K}}{j})\times\mathbf{s_{dS}}\times\mathbf{s_{K}}}} ;
11:
ππ jβππ j+MMβ(ππ^iβjβ€,π^i)Γπ¬ ππΓπ¬ π\mathbf{dK}{j}\leftarrow\mathbf{dK}{j}+{\color[rgb]{0,0,1}\definecolor[named]{pgfstrokecolor}{rgb}{0,0,1}{\texttt{MM}(\hat{\mathbf{dS}}^{\top}{ij},\hat{\mathbf{Q}}{i})\times\mathbf{s_{dS}}\times\mathbf{s_{Q}}}} ;
12:end for
13:end for
14:return
dβQ,dβK,dβV dQ,dK,dV ;
Algorithm 2 Backward pass of the 8-bit attention.
Appendix B dS Magnitude
In this section, we prove a simple upper bound on the RMS of ππ\mathbf{dS}.
Proof.
Recall that
ππ=πβ(ππβπΉβπβ€)ββ NΓN,πΉ=rowsumβ‘(ππβπ)ββ N,\mathbf{dS}=\mathbf{P}\circ(\mathbf{dP}-\boldsymbol{\delta}\mathbf{1}^{\top})\in\mathbb{R}^{N\times N},\qquad\boldsymbol{\delta}=\operatorname{rowsum}(\mathbf{dO}\circ\mathbf{O})\in\mathbb{R}^{N},
where β\circ denotes elementwise multiplication and π\mathbf{1} is the all-ones vector.
For row i i of ππ\mathbf{dS}, we can write
ππ i=π iβ(ππ iβπΉ iβπ),\mathbf{dS}{i}=\mathbf{P}{i}\circ(\mathbf{dP}{i}-\boldsymbol{\delta}{i}\mathbf{1}),
where π i,ππ iββ N\mathbf{P}{i},\mathbf{dP}{i}\in\mathbb{R}^{N} and πΉ i\boldsymbol{\delta}_{i} is the i i-th entry of πΉ\boldsymbol{\delta}. Its root-mean-square (RMS) value is
RMSβ‘(ππ i)=1 Nββj=1 N π i,j 2β(ππ i,jβπΉ i)2.\operatorname{RMS}(\mathbf{dS}{i})=\sqrt{\frac{1}{N}\sum{j=1}^{N}\mathbf{P}{i,j}^{2}(\mathbf{dP}{i,j}-\boldsymbol{\delta}_{i})^{2}}.
Using the infinity norm (βπ±ββ=max jβ‘|π± j||\mathbf{x}|{\infty}=\max{j}|\mathbf{x}{j}|), we have |ππ i,jβπΉ i|β€βππ iβπΉ iβπββ|\mathbf{dP}{i,j}-\boldsymbol{\delta}{i}|;\leq;|\mathbf{dP}{i}-\boldsymbol{\delta}{i}\mathbf{1}|{\infty}. Therefore,
RMS(ππ i)2\displaystyle\operatorname{RMS}(\mathbf{dS}{i})^{2}β€βππ iβπΉ iβπββ2β 1 Nββj=1 N π i,j 2\displaystyle\leq|\mathbf{dP}{i}-\boldsymbol{\delta}{i}\mathbf{1}|{\infty}^{2}\cdot\frac{1}{N}\sum_{j=1}^{N}\mathbf{P}{i,j}^{2}(1) =β₯ππ iβπΉ i πβ₯β2 RMS(π i)2,\displaystyle=|\mathbf{dP}{i}-\boldsymbol{\delta}{i}\mathbf{1}|{\infty}^{2},\operatorname{RMS}(\mathbf{P}_{i})^{2},(2)
and hence
RMSβ‘(ππ i)β€RMSβ‘(π i)ββππ iβπΉ iβπββ.\operatorname{RMS}(\mathbf{dS}{i})\leq\operatorname{RMS}(\mathbf{P}{i}),|\mathbf{dP}{i}-\boldsymbol{\delta}{i}\mathbf{1}|_{\infty}.(3)
Because π\mathbf{P} is the output of a softmax operation, each row π i\mathbf{P}_{i} is a probability vector that sums to 1 1 and only has entries in the range [0,1][0,1]. Thus
RMSβ‘(π i)=1 Nββj=1 N π i,j 2β€1 Nβmax jβ‘π i,jββj=1 N π i,jβ€1 N.\operatorname{RMS}(\mathbf{P}{i})=\sqrt{\frac{1}{N}\sum{j=1}^{N}\mathbf{P}{i,j}^{2}};\leq;\sqrt{\frac{1}{N}\max{j}\mathbf{P}{i,j}\sum{j=1}^{N}\mathbf{P}_{i,j}};\leq;\frac{1}{\sqrt{N}}.(4)
Combining (3) and (4) yields, for each row i i,
RMSβ‘(ππ i)β€1 Nββππ iβπΉ iβπββ.\operatorname{RMS}(\mathbf{dS}{i})\leq\frac{1}{\sqrt{N}},|\mathbf{dP}{i}-\boldsymbol{\delta}{i}\mathbf{1}|{\infty}.
Finally, the global RMS of ππ\mathbf{dS} satisfies
RMSβ‘(ππ)β€1 Nβmax iβ‘βππ iβπΉ iβπββ.\operatorname{RMS}(\mathbf{dS})\leq\frac{1}{\sqrt{N}},\max_{i}|\mathbf{dP}{i}-\boldsymbol{\delta}{i}\mathbf{1}|_{\infty}.(5)
This upper bound can be interpreted as follows: the average magnitude of ππ\mathbf{dS} is at most the largest per-row gradient magnitude in ππ\mathbf{dP}, scaled by a factor of 1/N 1/\sqrt{N}.
Appendix C Cosine Similarity and Rel-L2 Error
Figure 5 and Figure 6 show the cosine similarity and relative β 2\ell_{2}-error between SageBwd and FPA on inputs and gradients extracted from a single forward-backward pass of the pretrained 325M Llama model under various TPS and architectural settings.
(a)
(b)
(c)
(d)
Figure 5: Cosine similarity between SageBwd and SDPA over layers on different settings
(a)
(b)
(c)
(d)
Figure 6: Relative L2-Error between SageBwd and SDPA over layers on different settings
Xet Storage Details
- Size:
- 70 kB
- Xet hash:
- 9abe99fd848dc2104a09310236a6698c8174efc67dfc44327d304d57a68e5ad2
Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.













