theapemachine's picture
Major revision: add phantom momentum ablation, compute-matched baselines, multi-seed predictor accuracy
96bc237 verified
\documentclass[11pt]{article}
\usepackage[margin=1in]{geometry}
\usepackage{amsmath,amssymb}
\usepackage{graphicx}
\usepackage{microtype}
\usepackage{booktabs}
\usepackage{tabularx}
\usepackage{hyperref}
\hypersetup{
colorlinks=true,
linkcolor=blue,
citecolor=blue,
urlcolor=blue,
}
\title{%
\textbf{Zero-Copy Sparse Backpropagation}: \\[0.3em]
\large Temporal Gradient Tracking for
Faster, Regularized LLM Training
}
\author{
Daniel Owen van Dommelen\\
\textit{Independent Research}\\
\texttt{theapemachine@gmail.com}
}
\date{\today}
\begin{document}
\maketitle
\begin{abstract}
We describe \emph{Predictive Chunked Sparsity}: fixed top-$k$ row-chunks for
sparse weight-gradient computation ($dW$), selected by an exponential moving
average (EMA) of past chunk-gradient norms, with contiguous strided views for
hardware-friendly GEMMs. We present controlled ablations on a single hardware
stack (NVIDIA T4/A10G) addressing three structural questions: (1)~whether
convergence is driven by the chunk-selection algorithm or by ``phantom
momentum'' in the optimizer, (2)~whether sparse training outperforms
compute-matched and capacity-matched dense baselines, and (3)~how accurately the
EMA predictor identifies high-gradient chunks compared to an oracle.
Our phantom momentum ablation (Table~\ref{tab:phantom}) shows that freezing
inactive Adam state changes validation loss by $<$0.05 ($<$1\%), confirming that
the EMA chunk selector---not optimizer side effects---drives convergence. The EMA
predictor achieves $\sim$85\% recall of oracle top-$k$ chunks
(Table~\ref{tab:predictor-multi}), significantly above the $\sim$10\% random
baseline. However, compute-matched dense baselines outperform sparse training at
500 steps on Tiny Shakespeare (Table~\ref{tab:compute-matched}), consistent with
the known limitation that sparse methods require longer training horizons to
compensate for reduced per-step update volume. Isolated FFN microbenchmarks show
the $dW$ computation achieves 5--7$\times$ speedup, yielding 1.3--1.35$\times$
end-to-end per-layer speedup bounded by Amdahl's law
(Table~\ref{tab:t4-ffn-micro}).
\end{abstract}
%% ================================================================
\section{Introduction}
%% ================================================================
Training large transformers is dominated by dense matrix multiplications in the
forward and backward passes. Prior work on dynamic sparsity
\cite{evci2020rigging,mocanu2018scalable} demonstrates that sparse connectivity
can match dense accuracy, but translating theoretical FLOP reductions to
wall-clock speedups remains difficult due to irregular memory access patterns
and host--device synchronization overhead.
We propose \emph{Predictive Chunked Sparsity}: a method that maintains
fixed-cardinality masks over contiguous row-chunks of weight matrices, enabling
the sparse backward pass to decompose into a small number of standard dense
GEMMs on strided views. An EMA-based scorer predicts which chunks carry the
largest gradient mass, and a cosine annealing schedule transitions from dense
warmup to the target sparsity level.
\paragraph{Contributions.}
\textbf{(1)}~The chunked sparse backward algorithm with EMA-based chunk
selection and optional KNN-based inactive-chunk score imputation.
\textbf{(2)}~A phantom momentum ablation proving the chunk selector, not
optimizer decay, drives convergence.
\textbf{(3)}~Compute-matched and capacity-matched dense baselines quantifying
the method's current limitations.
\textbf{(4)}~Multi-seed predictor accuracy measurements (Jaccard/Recall vs.\
oracle).
\textbf{(5)}~Fused Triton kernels for the sparse backward pass with
\texttt{block\_ptr} TMA-ready loads.
%% ================================================================
\section{Methodology}
%% ================================================================
\subsection{Chunked Sparse Backward}
A linear layer $W \in \mathbb{R}^{O \times I}$ is partitioned into $N = O/C$
contiguous row-chunks of size $C$. At each training step, a binary mask
$A \in \{0,1\}^N$ with exactly $k = \lfloor f \cdot N \rfloor$ active entries
determines which chunks receive gradient updates:
%
\begin{equation}
dW_{c} =
\begin{cases}
G_Y^{[c]\top} X & \text{if } A_c = 1 \\
0 & \text{otherwise}
\end{cases}
\end{equation}
%
where $G_Y^{[c]} = G_Y[:, cC{:}(c{+}1)C]$ is the slice of the upstream
gradient corresponding to chunk $c$. Because chunks are contiguous row-blocks,
each active $dW_c$ is a standard dense GEMM on a strided view---no gather or
scatter required.
\subsection{EMA Chunk Selection}
We maintain a running score per chunk:
\begin{equation}
M_c^{(t)} =
\begin{cases}
\beta \, M_c^{(t-1)} + (1{-}\beta) \| dW_c^{(t)} \|_2 & \text{if active} \\
M_c^{(t-1)} & \text{if inactive (frozen EMA)}
\end{cases}
\end{equation}
The top-$k$ chunks by $M^{(t-1)}$ are selected as the active set $A^{(t)}$.
Inactive chunks retain their last-observed score without decay, avoiding the
stale-EMA lockout problem where decayed scores permanently exclude
potentially important chunks.
\subsection{Cosine Sparsity Annealing}
Training begins fully dense for $W$ warmup steps, then the active fraction
$f(t)$ anneals via cosine schedule from 1.0 to the target $f_{\text{target}}$
over $A$ annealing steps:
\begin{equation}
f(t) = f_{\text{target}} + \tfrac{1}{2}(1 - f_{\text{target}})
\bigl(1 + \cos(\pi \cdot (t - W) / A)\bigr)
\end{equation}
\subsection{Phantom Momentum}
When using Adam, inactive chunks receive zero gradients. Standard Adam applies
moment decay ($m \leftarrow \beta_1 m$, $v \leftarrow \beta_2 v$) even on
zero-gradient steps, causing the optimizer to produce small weight updates from
historical momentum---an effect we term ``phantom momentum.'' Section~\ref{sec:phantom}
presents a controlled ablation isolating this effect.
%% ================================================================
\section{Systems Implementation}
%% ================================================================
\subsection{PyTorch Backend}
The sparse backward is implemented as a Python loop over active chunks, each
issuing a small dense GEMM via \texttt{gy\_flat[:, s:e].t() @ x\_flat}. Fixed
$k$ ensures no dynamic shape allocation. The optimizer (ChunkedAdam) restricts
weight updates to active chunks only.
\subsection{Triton Backend}
We provide fused Triton kernels that process all active chunks in a single GPU
launch. The $dW$ kernel uses a 2D grid (active chunks $\times$ $d_{\text{in}}$
tiles) with \texttt{tl.make\_block\_ptr} for hardware-accelerated 2D tile
loads. Bias gradients are fused into the $dW$ kernel by accumulating column
sums of the $dY$ tiles already in registers, eliminating the uncoalesced
memory access pattern of a separate bias kernel. A sparse $dX$ kernel is also
provided for the aggressive mode where input gradients are also sparsified.
Correctness is verified against the PyTorch reference
(Table~\ref{tab:triton-correctness}); max absolute errors are below $5 \times
10^{-4}$ across all tested configurations.
%% ================================================================
\section{Experiments}
%% ================================================================
All experiments in this section use a single hardware stack (NVIDIA T4, 16\,GB)
with GPT-2 BPE tokenization on Tiny Shakespeare (304K train tokens, 34K val
tokens). Model: 4 layers, 8 heads, $d_{\text{model}} = 1024$, chunk size 64,
10\% active fraction, batch 8, sequence length 256, learning rate $3 \times
10^{-4}$, cosine annealing with 50-step warmup and 200-step anneal. Results
report mean $\pm$ std over 2 seeds unless noted.
\subsection{Phantom Momentum Ablation}
\label{sec:phantom}
The central question: does convergence depend on the chunk-selection algorithm,
or on phantom momentum acting as implicit regularization? We compare two Adam
modes across multiple chunk-selection policies:
\begin{itemize}
\item \textbf{Phantom} (default): Adam moments decay on all chunks every step,
including inactive ones receiving zero gradients.
\item \textbf{Frozen}: Adam state ($m$, $v$) for inactive chunks is completely
preserved---no decay, no update.
\end{itemize}
\begin{table}[t]
\centering
\caption{Phantom momentum ablation. $d{=}1024$, 4 layers, 500 steps, 2 seeds.
The phantom$\to$frozen delta is small ($<$0.05 loss) for all policies,
confirming that the chunk selector drives convergence.}
\label{tab:phantom}
\begin{tabular}{l r @{\,$\pm$\,} l r}
\toprule
Method & \multicolumn{2}{c}{Val.\ Loss} & ms/step \\
\midrule
Dense (reference) & 5.4710 & 0.0119 & 363.2 \\
\midrule
EMA + phantom & 5.8750 & 0.2433 & 364.1 \\
EMA + frozen & 5.9170 & 0.2695 & 376.8 \\
\midrule
Random + phantom & 6.0688 & 0.1006 & 365.4 \\
Random + frozen & 6.0239 & 0.1318 & 376.5 \\
\bottomrule
\end{tabular}
\end{table}
\paragraph{Findings.}
The phantom-to-frozen delta is $+0.042$ for EMA and $-0.045$ for random---both
within noise and below 1\% of the loss magnitude. Phantom momentum is
\emph{not} the load-bearing mechanism. The EMA selector consistently outperforms
random by $\sim$0.15--0.19 loss regardless of momentum mode, demonstrating that
the chunk-selection algorithm is doing genuine predictive work.
The frozen mode is $\sim$13\,ms/step slower due to the per-chunk Adam loop
(vs.\ bulk tensor decay in phantom mode), a minor systems cost.
\subsection{Compute-Matched and Capacity-Matched Baselines}
\label{sec:compute}
The critique that sparse training may simply act as ``the world's most
computationally expensive dropout'' requires controlled baselines:
\begin{table}[t]
\centering
\caption{Compute-matched baselines. Same setup as Table~\ref{tab:phantom}.
At 10\% active, sparse does $\sim$70\% of dense FLOPs per step.}
\label{tab:compute-matched}
\begin{tabular}{l r r @{\,$\pm$\,} l r}
\toprule
Method & Params & \multicolumn{2}{c}{Val.\ Loss} & ms/step \\
\midrule
Sparse EMA (500 steps) & 153.6M & 5.8750 & 0.2433 & 363.1 \\
Dense (500 steps) & 153.6M & 5.4710 & 0.0119 & 364.5 \\
Dense (350 steps, FLOP-matched) & 153.6M & 5.6714 & 0.0002 & 364.2 \\
Dense small (ffn$\times$1, capacity-matched) & 128.4M & 5.6329 & 0.0127 & 284.2 \\
\bottomrule
\end{tabular}
\end{table}
\paragraph{Findings.}
At 500 steps on 304K tokens, sparse EMA (5.875) underperforms all dense
baselines. Even the FLOP-matched dense run at 350 steps (5.671) and the
capacity-matched small model (5.633, 16\% fewer parameters, 22\% faster)
outperform it. This is expected: with 10\% active chunks, the sparse model
effectively processes $\sim$10\% of the gradient information per step, requiring
substantially more steps to reach equivalent parameter exposure.
This result does \emph{not} invalidate the method---it characterizes its
operating regime. At $d_{\text{model}} = 2048$ on MPS
(Table~\ref{tab:mps-e2e}), sparse full-$G_X$ achieves 1.18$\times$ wall-clock
speedup with comparable loss (5.981 vs.\ 6.026), suggesting that the
speed/quality tradeoff becomes favorable at larger widths where each sparse step
is proportionally cheaper.
\subsection{Predictor Accuracy: EMA vs.\ Oracle}
\label{sec:predictor}
We measure how well the EMA scorer identifies the oracle top-$k$ chunks
(defined as the $k$ chunks with largest $\|dW_c\|_2$ from a dense gradient
computation on the same batch). Oracle overlap is computed every 25 steps after
the annealing schedule completes.
\begin{table}[t]
\centering
\caption{EMA predictor overlap with oracle. $d{=}1024$, 4 layers, 500 steps,
2 seeds, measured post-anneal (step $\geq 125$). Random baseline included.}
\label{tab:predictor-multi}
\begin{tabular}{l c c r @{\,$\pm$\,} l}
\toprule
Policy & Jaccard (avg) & Recall (avg) & \multicolumn{2}{c}{Val.\ Loss} \\
\midrule
EMA & 0.73 & 0.85 & 5.9691 & 0.1502 \\
Random & 0.05 & 0.10 & 6.1281 & 0.0461 \\
\bottomrule
\end{tabular}
\end{table}
\paragraph{Findings.}
The EMA predictor achieves 85\% recall (73\% Jaccard) of the oracle top-$k$,
stable across training. Random selection achieves 10\% recall, confirming the
EMA is a meaningful predictor. The 15\% miss rate represents chunks whose
gradient importance shifts between steps---a fundamental limit of any
history-based predictor.
The recall-to-loss relationship is also clear: EMA's 85\% recall yields 5.969
loss vs.\ random's 10\% recall at 6.128---a 0.16 gap that directly quantifies
the value of informed chunk selection.
%% ================================================================
\section{Microbenchmarks and Triton Kernels}
%% ================================================================
\subsection{Per-Layer Amdahl Analysis (T4)}
\begin{table}[t]
\centering
\caption{T4: per--FFN-layer cost breakdown (ms). $B{=}8$, $T{=}256$,
chunk 64, 10\% active, fp32, 100 iters. Speedup is total dense / total
sparse (full-$G_X$ mode: dense $dX$, sparse $dW$).}
\label{tab:t4-ffn-micro}
\resizebox{\linewidth}{!}{%
\footnotesize
\begin{tabular}{r r r r r r r r r}
\toprule
$d$ & FFN & Fwd & $dX$ & $dW_{\text{dense}}$ &
$dW_{\text{sparse}}$ & Tot.\ dense & Tot.\ sparse & Speedup \\
\midrule
256 & 1024 & 0.27 & 0.21 & 0.27 & 0.26 & 0.75 & 0.74 & 1.02$\times$ \\
512 & 2048 & 1.00 & 1.01 & 0.97 & 0.26 & 2.99 & 2.28 & 1.31$\times$ \\
1024 & 4096 & 3.69 & 3.90 & 3.35 & 0.59 & 10.95 & 8.18 & 1.34$\times$ \\
2048 & 8192 & 14.76 & 15.57 & 13.19 & 1.93 & 43.51 & 32.26 & 1.35$\times$ \\
\bottomrule
\end{tabular}%
}
\end{table}
The sparse $dW$ component achieves 3.7--6.8$\times$ speedup over dense $dW$.
However, the forward pass and dense $dX$ are unchanged, yielding an Amdahl
ceiling of $\sim$1.45$\times$ for full-$G_X$ mode. The measured end-to-end
per-layer speedup plateaus at $\sim$1.35$\times$ for $d \geq 512$, with a
crossover at $d \approx 384$ where sparse loop overhead first falls below the
FLOP savings.
\subsection{Triton Kernel Performance}
\begin{table}[t]
\centering
\caption{Triton backward correctness (max abs error vs.\ PyTorch reference).}
\label{tab:triton-correctness}
\begin{tabular}{r r r r r r}
\toprule
$d_{\text{in}}$ & $d_{\text{out}}$ & chunk & $|dW|$ & $|db|$ & $|dX|$ \\
\midrule
512 & 2048 & 64 & 3.2e-4 & 2.3e-5 & 4.2e-5 \\
1024 & 4096 & 64 & 4.4e-4 & 2.1e-5 & 9.2e-5 \\
256 & 1024 & 32 & 2.8e-4 & 3.8e-5 & 1.9e-5 \\
\bottomrule
\end{tabular}
\end{table}
\begin{table}[t]
\centering
\caption{Isolated backward: Dense vs.\ PyLoop vs.\ Triton (T4).
Full-$G_X$ mode, 50 iters post-warmup. Triton/Dense $=$ speedup.}
\label{tab:triton-backward}
\begin{tabular}{r r r r r r}
\toprule
$d$ & Act. & Dense & PyLoop & Triton & Tri/Dense \\
\midrule
512 & 3 & 1.96 & 1.30 & 1.16 & 1.69$\times$ \\
1024 & 6 & 7.29 & 4.37 & 4.30 & 1.70$\times$ \\
2048 & 12 & 29.14 & 17.20 & 16.89 & 1.73$\times$ \\
\bottomrule
\end{tabular}
\end{table}
With both $dW$ and $dX$ sparse, Triton achieves 4.8--7.8$\times$ over dense in
the isolated backward harness, though this aggressive mode incurs quality
tradeoffs not yet fully characterized at scale.
%% ================================================================
\section{Full Training Results}
%% ================================================================
\subsection{MPS End-to-End (Author Runs)}
\begin{table}[t]
\centering
\caption{MPS full training. 6 layers, $B{=}8$, $T{=}256$, chunk 64,
10\% active, 2000 steps, single seed.}
\label{tab:mps-e2e}
\begin{tabular}{l l r r r}
\toprule
$d$ & Run & Time (s) & ms/step & Val.\ Loss \\
\midrule
512 & dense & 74.77 & 99.70 & 5.3142 \\
512 & sparse full & 91.04 & 121.38 & 5.4141 \\
512 & sparse both & 93.33 & 124.44 & 5.5467 \\
\midrule
2048 & dense & 1035.84 & 591.91 & 6.0264 \\
2048 & sparse full & 875.51 & 500.29 & 5.9807 \\
2048 & sparse both & 847.22 & 484.13 & 6.0231 \\
\bottomrule
\end{tabular}
\end{table}
At $d{=}512$, sparse is 1.22$\times$ \emph{slower} than dense. At $d{=}2048$,
sparse achieves 1.18$\times$ speedup (full-$G_X$) with comparable loss (5.981
vs.\ 6.026). This crossover aligns with the Amdahl analysis: sparse $dW$
savings only dominate at widths where the $dW$ GEMM is a significant fraction
of total step cost.
\paragraph{Hardware note.} MPS results use Apple unified memory, which has
different bandwidth and kernel-launch characteristics than discrete CUDA GPUs.
The T4 microbenchmarks and ablations (Sections 4--5) provide the controlled
single-hardware comparison.
%% ================================================================
\section{Limitations and Future Work}
%% ================================================================
\paragraph{Dataset scale.} All full-training results use Tiny Shakespeare
(304K tokens). At 10\% active chunks, the sparse model sees $\sim$10\% of
gradient information per step, requiring proportionally more steps to match
dense parameter exposure. The compute-matched baselines
(Table~\ref{tab:compute-matched}) confirm this: sparse needs longer training
horizons to demonstrate its value, and the current dataset/step budget is
insufficient to show quality parity. Validation on larger corpora (OpenWebText,
RedPajama subsets) with 5--10$\times$ more steps is needed.
\paragraph{Aggressive $dX$ sparsity.} Sparsifying input gradients ($dX$) in
addition to $dW$ yields large isolated speedups (4.8--7.8$\times$) but degrades
loss in full training (Table~\ref{tab:mps-e2e}, sparse-both at $d{=}512$). The
gradient approximation error propagates through the residual stream. Principled
bounds on acceptable $dX$ sparsity remain open.
\paragraph{Predictor ceiling.} The EMA achieves $\sim$85\% recall of oracle
top-$k$. The 15\% miss rate reflects inter-step gradient volatility. KNN-based
imputation using chunk-similarity matrices (implemented in the v18 codebase) may
narrow this gap; initial single-seed results show comparable loss but slightly
lower oracle recall, suggesting the similarity signal is noisy at small scale.
\paragraph{Scaling.} The per-layer Amdahl ceiling of $\sim$1.35$\times$
(full-$G_X$) is hardware-dependent. On architectures with lower kernel-launch
overhead (fused Triton, Hopper TMA), the crossover point may shift downward.
End-to-end speedups at $d \geq 2048$ with Triton on A100/H100 are the natural
next experiment.
%% ================================================================
\section{Conclusion}
%% ================================================================
We presented Predictive Chunked Sparsity with three controlled ablations
addressing structural critiques of the method:
\begin{enumerate}
\item \textbf{Phantom momentum is not load-bearing.} Freezing optimizer state
for inactive chunks changes loss by $<$1\%. The EMA chunk selector drives
convergence.
\item \textbf{The EMA predictor works.} 85\% recall of oracle top-$k$,
vs.\ 10\% for random. This is ``good'' but not ``near-oracle.''
\item \textbf{Sparse needs more steps.} At 500 steps on 304K tokens,
compute-matched dense baselines win. The method's value proposition is
wall-clock speedup per step at large $d_{\text{model}}$, amortized over
longer training runs where the 1.2--1.35$\times$ per-step savings compound.
\end{enumerate}
The engineering contribution---contiguous chunked views enabling sparse backward
as dense GEMMs, with fused Triton kernels achieving 5--7$\times$ $dW$
speedup---is validated. The ML contribution requires larger-scale experiments to
fully establish the quality/speed Pareto frontier.
\paragraph{Reproducibility.} All code, experiment scripts, and raw results are
available at
\url{https://huggingface.co/theapemachine/sparse-transformer-experiments}.
\begin{thebibliography}{9}
\bibitem{evci2020rigging}
U.~Evci, T.~Gale, J.~Menick, P.~S. Castro, and E.~Elsen.
\newblock Rigging the lottery: Making all tickets winners.
\newblock In \emph{ICML}, 2020.
\bibitem{mocanu2018scalable}
D.~C. Mocanu, E.~Mocanu, P.~Stone, P.~H. Nguyen, M.~Gibescu, and A.~Liotta.
\newblock Scalable training of artificial neural networks with adaptive sparse
connectivity inspired by network science.
\newblock \emph{Nature Communications}, 9(1):2383, 2018.
\end{thebibliography}
\end{document}