\documentclass[11pt]{article} \usepackage[margin=1in]{geometry} \usepackage{amsmath,amssymb} \usepackage{graphicx} \usepackage{microtype} \usepackage{booktabs} \usepackage{tabularx} \usepackage{hyperref} \hypersetup{ colorlinks=true, linkcolor=blue, citecolor=blue, urlcolor=blue, } \title{% \textbf{Zero-Copy Sparse Backpropagation}: \\[0.3em] \large Temporal Gradient Tracking for Faster, Regularized LLM Training } \author{ Daniel Owen van Dommelen\\ \textit{Independent Research}\\ \texttt{theapemachine@gmail.com} } \date{\today} \begin{document} \maketitle \begin{abstract} We describe \emph{Predictive Chunked Sparsity}: fixed top-$k$ row-chunks for sparse weight-gradient computation ($dW$), selected by an exponential moving average (EMA) of past chunk-gradient norms, with contiguous strided views for hardware-friendly GEMMs. We present controlled ablations on a single hardware stack (NVIDIA T4/A10G) addressing three structural questions: (1)~whether convergence is driven by the chunk-selection algorithm or by ``phantom momentum'' in the optimizer, (2)~whether sparse training outperforms compute-matched and capacity-matched dense baselines, and (3)~how accurately the EMA predictor identifies high-gradient chunks compared to an oracle. Our phantom momentum ablation (Table~\ref{tab:phantom}) shows that freezing inactive Adam state changes validation loss by $<$0.05 ($<$1\%), confirming that the EMA chunk selector---not optimizer side effects---drives convergence. The EMA predictor achieves $\sim$85\% recall of oracle top-$k$ chunks (Table~\ref{tab:predictor-multi}), significantly above the $\sim$10\% random baseline. However, compute-matched dense baselines outperform sparse training at 500 steps on Tiny Shakespeare (Table~\ref{tab:compute-matched}), consistent with the known limitation that sparse methods require longer training horizons to compensate for reduced per-step update volume. Isolated FFN microbenchmarks show the $dW$ computation achieves 5--7$\times$ speedup, yielding 1.3--1.35$\times$ end-to-end per-layer speedup bounded by Amdahl's law (Table~\ref{tab:t4-ffn-micro}). \end{abstract} %% ================================================================ \section{Introduction} %% ================================================================ Training large transformers is dominated by dense matrix multiplications in the forward and backward passes. Prior work on dynamic sparsity \cite{evci2020rigging,mocanu2018scalable} demonstrates that sparse connectivity can match dense accuracy, but translating theoretical FLOP reductions to wall-clock speedups remains difficult due to irregular memory access patterns and host--device synchronization overhead. We propose \emph{Predictive Chunked Sparsity}: a method that maintains fixed-cardinality masks over contiguous row-chunks of weight matrices, enabling the sparse backward pass to decompose into a small number of standard dense GEMMs on strided views. An EMA-based scorer predicts which chunks carry the largest gradient mass, and a cosine annealing schedule transitions from dense warmup to the target sparsity level. \paragraph{Contributions.} \textbf{(1)}~The chunked sparse backward algorithm with EMA-based chunk selection and optional KNN-based inactive-chunk score imputation. \textbf{(2)}~A phantom momentum ablation proving the chunk selector, not optimizer decay, drives convergence. \textbf{(3)}~Compute-matched and capacity-matched dense baselines quantifying the method's current limitations. \textbf{(4)}~Multi-seed predictor accuracy measurements (Jaccard/Recall vs.\ oracle). \textbf{(5)}~Fused Triton kernels for the sparse backward pass with \texttt{block\_ptr} TMA-ready loads. %% ================================================================ \section{Methodology} %% ================================================================ \subsection{Chunked Sparse Backward} A linear layer $W \in \mathbb{R}^{O \times I}$ is partitioned into $N = O/C$ contiguous row-chunks of size $C$. At each training step, a binary mask $A \in \{0,1\}^N$ with exactly $k = \lfloor f \cdot N \rfloor$ active entries determines which chunks receive gradient updates: % \begin{equation} dW_{c} = \begin{cases} G_Y^{[c]\top} X & \text{if } A_c = 1 \\ 0 & \text{otherwise} \end{cases} \end{equation} % where $G_Y^{[c]} = G_Y[:, cC{:}(c{+}1)C]$ is the slice of the upstream gradient corresponding to chunk $c$. Because chunks are contiguous row-blocks, each active $dW_c$ is a standard dense GEMM on a strided view---no gather or scatter required. \subsection{EMA Chunk Selection} We maintain a running score per chunk: \begin{equation} M_c^{(t)} = \begin{cases} \beta \, M_c^{(t-1)} + (1{-}\beta) \| dW_c^{(t)} \|_2 & \text{if active} \\ M_c^{(t-1)} & \text{if inactive (frozen EMA)} \end{cases} \end{equation} The top-$k$ chunks by $M^{(t-1)}$ are selected as the active set $A^{(t)}$. Inactive chunks retain their last-observed score without decay, avoiding the stale-EMA lockout problem where decayed scores permanently exclude potentially important chunks. \subsection{Cosine Sparsity Annealing} Training begins fully dense for $W$ warmup steps, then the active fraction $f(t)$ anneals via cosine schedule from 1.0 to the target $f_{\text{target}}$ over $A$ annealing steps: \begin{equation} f(t) = f_{\text{target}} + \tfrac{1}{2}(1 - f_{\text{target}}) \bigl(1 + \cos(\pi \cdot (t - W) / A)\bigr) \end{equation} \subsection{Phantom Momentum} When using Adam, inactive chunks receive zero gradients. Standard Adam applies moment decay ($m \leftarrow \beta_1 m$, $v \leftarrow \beta_2 v$) even on zero-gradient steps, causing the optimizer to produce small weight updates from historical momentum---an effect we term ``phantom momentum.'' Section~\ref{sec:phantom} presents a controlled ablation isolating this effect. %% ================================================================ \section{Systems Implementation} %% ================================================================ \subsection{PyTorch Backend} The sparse backward is implemented as a Python loop over active chunks, each issuing a small dense GEMM via \texttt{gy\_flat[:, s:e].t() @ x\_flat}. Fixed $k$ ensures no dynamic shape allocation. The optimizer (ChunkedAdam) restricts weight updates to active chunks only. \subsection{Triton Backend} We provide fused Triton kernels that process all active chunks in a single GPU launch. The $dW$ kernel uses a 2D grid (active chunks $\times$ $d_{\text{in}}$ tiles) with \texttt{tl.make\_block\_ptr} for hardware-accelerated 2D tile loads. Bias gradients are fused into the $dW$ kernel by accumulating column sums of the $dY$ tiles already in registers, eliminating the uncoalesced memory access pattern of a separate bias kernel. A sparse $dX$ kernel is also provided for the aggressive mode where input gradients are also sparsified. Correctness is verified against the PyTorch reference (Table~\ref{tab:triton-correctness}); max absolute errors are below $5 \times 10^{-4}$ across all tested configurations. %% ================================================================ \section{Experiments} %% ================================================================ All experiments in this section use a single hardware stack (NVIDIA T4, 16\,GB) with GPT-2 BPE tokenization on Tiny Shakespeare (304K train tokens, 34K val tokens). Model: 4 layers, 8 heads, $d_{\text{model}} = 1024$, chunk size 64, 10\% active fraction, batch 8, sequence length 256, learning rate $3 \times 10^{-4}$, cosine annealing with 50-step warmup and 200-step anneal. Results report mean $\pm$ std over 2 seeds unless noted. \subsection{Phantom Momentum Ablation} \label{sec:phantom} The central question: does convergence depend on the chunk-selection algorithm, or on phantom momentum acting as implicit regularization? We compare two Adam modes across multiple chunk-selection policies: \begin{itemize} \item \textbf{Phantom} (default): Adam moments decay on all chunks every step, including inactive ones receiving zero gradients. \item \textbf{Frozen}: Adam state ($m$, $v$) for inactive chunks is completely preserved---no decay, no update. \end{itemize} \begin{table}[t] \centering \caption{Phantom momentum ablation. $d{=}1024$, 4 layers, 500 steps, 2 seeds. The phantom$\to$frozen delta is small ($<$0.05 loss) for all policies, confirming that the chunk selector drives convergence.} \label{tab:phantom} \begin{tabular}{l r @{\,$\pm$\,} l r} \toprule Method & \multicolumn{2}{c}{Val.\ Loss} & ms/step \\ \midrule Dense (reference) & 5.4710 & 0.0119 & 363.2 \\ \midrule EMA + phantom & 5.8750 & 0.2433 & 364.1 \\ EMA + frozen & 5.9170 & 0.2695 & 376.8 \\ \midrule Random + phantom & 6.0688 & 0.1006 & 365.4 \\ Random + frozen & 6.0239 & 0.1318 & 376.5 \\ \bottomrule \end{tabular} \end{table} \paragraph{Findings.} The phantom-to-frozen delta is $+0.042$ for EMA and $-0.045$ for random---both within noise and below 1\% of the loss magnitude. Phantom momentum is \emph{not} the load-bearing mechanism. The EMA selector consistently outperforms random by $\sim$0.15--0.19 loss regardless of momentum mode, demonstrating that the chunk-selection algorithm is doing genuine predictive work. The frozen mode is $\sim$13\,ms/step slower due to the per-chunk Adam loop (vs.\ bulk tensor decay in phantom mode), a minor systems cost. \subsection{Compute-Matched and Capacity-Matched Baselines} \label{sec:compute} The critique that sparse training may simply act as ``the world's most computationally expensive dropout'' requires controlled baselines: \begin{table}[t] \centering \caption{Compute-matched baselines. Same setup as Table~\ref{tab:phantom}. At 10\% active, sparse does $\sim$70\% of dense FLOPs per step.} \label{tab:compute-matched} \begin{tabular}{l r r @{\,$\pm$\,} l r} \toprule Method & Params & \multicolumn{2}{c}{Val.\ Loss} & ms/step \\ \midrule Sparse EMA (500 steps) & 153.6M & 5.8750 & 0.2433 & 363.1 \\ Dense (500 steps) & 153.6M & 5.4710 & 0.0119 & 364.5 \\ Dense (350 steps, FLOP-matched) & 153.6M & 5.6714 & 0.0002 & 364.2 \\ Dense small (ffn$\times$1, capacity-matched) & 128.4M & 5.6329 & 0.0127 & 284.2 \\ \bottomrule \end{tabular} \end{table} \paragraph{Findings.} At 500 steps on 304K tokens, sparse EMA (5.875) underperforms all dense baselines. Even the FLOP-matched dense run at 350 steps (5.671) and the capacity-matched small model (5.633, 16\% fewer parameters, 22\% faster) outperform it. This is expected: with 10\% active chunks, the sparse model effectively processes $\sim$10\% of the gradient information per step, requiring substantially more steps to reach equivalent parameter exposure. This result does \emph{not} invalidate the method---it characterizes its operating regime. At $d_{\text{model}} = 2048$ on MPS (Table~\ref{tab:mps-e2e}), sparse full-$G_X$ achieves 1.18$\times$ wall-clock speedup with comparable loss (5.981 vs.\ 6.026), suggesting that the speed/quality tradeoff becomes favorable at larger widths where each sparse step is proportionally cheaper. \subsection{Predictor Accuracy: EMA vs.\ Oracle} \label{sec:predictor} We measure how well the EMA scorer identifies the oracle top-$k$ chunks (defined as the $k$ chunks with largest $\|dW_c\|_2$ from a dense gradient computation on the same batch). Oracle overlap is computed every 25 steps after the annealing schedule completes. \begin{table}[t] \centering \caption{EMA predictor overlap with oracle. $d{=}1024$, 4 layers, 500 steps, 2 seeds, measured post-anneal (step $\geq 125$). Random baseline included.} \label{tab:predictor-multi} \begin{tabular}{l c c r @{\,$\pm$\,} l} \toprule Policy & Jaccard (avg) & Recall (avg) & \multicolumn{2}{c}{Val.\ Loss} \\ \midrule EMA & 0.73 & 0.85 & 5.9691 & 0.1502 \\ Random & 0.05 & 0.10 & 6.1281 & 0.0461 \\ \bottomrule \end{tabular} \end{table} \paragraph{Findings.} The EMA predictor achieves 85\% recall (73\% Jaccard) of the oracle top-$k$, stable across training. Random selection achieves 10\% recall, confirming the EMA is a meaningful predictor. The 15\% miss rate represents chunks whose gradient importance shifts between steps---a fundamental limit of any history-based predictor. The recall-to-loss relationship is also clear: EMA's 85\% recall yields 5.969 loss vs.\ random's 10\% recall at 6.128---a 0.16 gap that directly quantifies the value of informed chunk selection. %% ================================================================ \section{Microbenchmarks and Triton Kernels} %% ================================================================ \subsection{Per-Layer Amdahl Analysis (T4)} \begin{table}[t] \centering \caption{T4: per--FFN-layer cost breakdown (ms). $B{=}8$, $T{=}256$, chunk 64, 10\% active, fp32, 100 iters. Speedup is total dense / total sparse (full-$G_X$ mode: dense $dX$, sparse $dW$).} \label{tab:t4-ffn-micro} \resizebox{\linewidth}{!}{% \footnotesize \begin{tabular}{r r r r r r r r r} \toprule $d$ & FFN & Fwd & $dX$ & $dW_{\text{dense}}$ & $dW_{\text{sparse}}$ & Tot.\ dense & Tot.\ sparse & Speedup \\ \midrule 256 & 1024 & 0.27 & 0.21 & 0.27 & 0.26 & 0.75 & 0.74 & 1.02$\times$ \\ 512 & 2048 & 1.00 & 1.01 & 0.97 & 0.26 & 2.99 & 2.28 & 1.31$\times$ \\ 1024 & 4096 & 3.69 & 3.90 & 3.35 & 0.59 & 10.95 & 8.18 & 1.34$\times$ \\ 2048 & 8192 & 14.76 & 15.57 & 13.19 & 1.93 & 43.51 & 32.26 & 1.35$\times$ \\ \bottomrule \end{tabular}% } \end{table} The sparse $dW$ component achieves 3.7--6.8$\times$ speedup over dense $dW$. However, the forward pass and dense $dX$ are unchanged, yielding an Amdahl ceiling of $\sim$1.45$\times$ for full-$G_X$ mode. The measured end-to-end per-layer speedup plateaus at $\sim$1.35$\times$ for $d \geq 512$, with a crossover at $d \approx 384$ where sparse loop overhead first falls below the FLOP savings. \subsection{Triton Kernel Performance} \begin{table}[t] \centering \caption{Triton backward correctness (max abs error vs.\ PyTorch reference).} \label{tab:triton-correctness} \begin{tabular}{r r r r r r} \toprule $d_{\text{in}}$ & $d_{\text{out}}$ & chunk & $|dW|$ & $|db|$ & $|dX|$ \\ \midrule 512 & 2048 & 64 & 3.2e-4 & 2.3e-5 & 4.2e-5 \\ 1024 & 4096 & 64 & 4.4e-4 & 2.1e-5 & 9.2e-5 \\ 256 & 1024 & 32 & 2.8e-4 & 3.8e-5 & 1.9e-5 \\ \bottomrule \end{tabular} \end{table} \begin{table}[t] \centering \caption{Isolated backward: Dense vs.\ PyLoop vs.\ Triton (T4). Full-$G_X$ mode, 50 iters post-warmup. Triton/Dense $=$ speedup.} \label{tab:triton-backward} \begin{tabular}{r r r r r r} \toprule $d$ & Act. & Dense & PyLoop & Triton & Tri/Dense \\ \midrule 512 & 3 & 1.96 & 1.30 & 1.16 & 1.69$\times$ \\ 1024 & 6 & 7.29 & 4.37 & 4.30 & 1.70$\times$ \\ 2048 & 12 & 29.14 & 17.20 & 16.89 & 1.73$\times$ \\ \bottomrule \end{tabular} \end{table} With both $dW$ and $dX$ sparse, Triton achieves 4.8--7.8$\times$ over dense in the isolated backward harness, though this aggressive mode incurs quality tradeoffs not yet fully characterized at scale. %% ================================================================ \section{Full Training Results} %% ================================================================ \subsection{MPS End-to-End (Author Runs)} \begin{table}[t] \centering \caption{MPS full training. 6 layers, $B{=}8$, $T{=}256$, chunk 64, 10\% active, 2000 steps, single seed.} \label{tab:mps-e2e} \begin{tabular}{l l r r r} \toprule $d$ & Run & Time (s) & ms/step & Val.\ Loss \\ \midrule 512 & dense & 74.77 & 99.70 & 5.3142 \\ 512 & sparse full & 91.04 & 121.38 & 5.4141 \\ 512 & sparse both & 93.33 & 124.44 & 5.5467 \\ \midrule 2048 & dense & 1035.84 & 591.91 & 6.0264 \\ 2048 & sparse full & 875.51 & 500.29 & 5.9807 \\ 2048 & sparse both & 847.22 & 484.13 & 6.0231 \\ \bottomrule \end{tabular} \end{table} At $d{=}512$, sparse is 1.22$\times$ \emph{slower} than dense. At $d{=}2048$, sparse achieves 1.18$\times$ speedup (full-$G_X$) with comparable loss (5.981 vs.\ 6.026). This crossover aligns with the Amdahl analysis: sparse $dW$ savings only dominate at widths where the $dW$ GEMM is a significant fraction of total step cost. \paragraph{Hardware note.} MPS results use Apple unified memory, which has different bandwidth and kernel-launch characteristics than discrete CUDA GPUs. The T4 microbenchmarks and ablations (Sections 4--5) provide the controlled single-hardware comparison. %% ================================================================ \section{Limitations and Future Work} %% ================================================================ \paragraph{Dataset scale.} All full-training results use Tiny Shakespeare (304K tokens). At 10\% active chunks, the sparse model sees $\sim$10\% of gradient information per step, requiring proportionally more steps to match dense parameter exposure. The compute-matched baselines (Table~\ref{tab:compute-matched}) confirm this: sparse needs longer training horizons to demonstrate its value, and the current dataset/step budget is insufficient to show quality parity. Validation on larger corpora (OpenWebText, RedPajama subsets) with 5--10$\times$ more steps is needed. \paragraph{Aggressive $dX$ sparsity.} Sparsifying input gradients ($dX$) in addition to $dW$ yields large isolated speedups (4.8--7.8$\times$) but degrades loss in full training (Table~\ref{tab:mps-e2e}, sparse-both at $d{=}512$). The gradient approximation error propagates through the residual stream. Principled bounds on acceptable $dX$ sparsity remain open. \paragraph{Predictor ceiling.} The EMA achieves $\sim$85\% recall of oracle top-$k$. The 15\% miss rate reflects inter-step gradient volatility. KNN-based imputation using chunk-similarity matrices (implemented in the v18 codebase) may narrow this gap; initial single-seed results show comparable loss but slightly lower oracle recall, suggesting the similarity signal is noisy at small scale. \paragraph{Scaling.} The per-layer Amdahl ceiling of $\sim$1.35$\times$ (full-$G_X$) is hardware-dependent. On architectures with lower kernel-launch overhead (fused Triton, Hopper TMA), the crossover point may shift downward. End-to-end speedups at $d \geq 2048$ with Triton on A100/H100 are the natural next experiment. %% ================================================================ \section{Conclusion} %% ================================================================ We presented Predictive Chunked Sparsity with three controlled ablations addressing structural critiques of the method: \begin{enumerate} \item \textbf{Phantom momentum is not load-bearing.} Freezing optimizer state for inactive chunks changes loss by $<$1\%. The EMA chunk selector drives convergence. \item \textbf{The EMA predictor works.} 85\% recall of oracle top-$k$, vs.\ 10\% for random. This is ``good'' but not ``near-oracle.'' \item \textbf{Sparse needs more steps.} At 500 steps on 304K tokens, compute-matched dense baselines win. The method's value proposition is wall-clock speedup per step at large $d_{\text{model}}$, amortized over longer training runs where the 1.2--1.35$\times$ per-step savings compound. \end{enumerate} The engineering contribution---contiguous chunked views enabling sparse backward as dense GEMMs, with fused Triton kernels achieving 5--7$\times$ $dW$ speedup---is validated. The ML contribution requires larger-scale experiments to fully establish the quality/speed Pareto frontier. \paragraph{Reproducibility.} All code, experiment scripts, and raw results are available at \url{https://huggingface.co/theapemachine/sparse-transformer-experiments}. \begin{thebibliography}{9} \bibitem{evci2020rigging} U.~Evci, T.~Gale, J.~Menick, P.~S. Castro, and E.~Elsen. \newblock Rigging the lottery: Making all tickets winners. \newblock In \emph{ICML}, 2020. \bibitem{mocanu2018scalable} D.~C. Mocanu, E.~Mocanu, P.~Stone, P.~H. Nguyen, M.~Gibescu, and A.~Liotta. \newblock Scalable training of artificial neural networks with adaptive sparse connectivity inspired by network science. \newblock \emph{Nature Communications}, 9(1):2383, 2018. \end{thebibliography} \end{document}