Add sparse transformer v19 with Triton-backed KNN scheduler and various backward modes. Includes utilities for synthetic data generation and model training. Implements chunked sparse updates and integrates with existing sparse linear layers.
bc1b8eb Sparse Transformer v10: Metal-backed active-row Linear backward benchmark
This bundle is the first empirical optimization test rather than only a correctness prototype.
It compares dense Transformer training against sparse active-row backward variants and reports
validation loss plus wall-clock metrics (step_ms, tokens_per_s).
Files
sparse_transformer_v10.py— training + benchmark harness.sparse_linear.metal— Metal kernels for active-rowdW/dband sparsedX.sparse_linear_ops.mm— PyTorch/MPS extension glue.setup.py— builds the extension and compiles the.metallib.
Install the Metal extension
From this directory on macOS with PyTorch MPS available:
python3 -m pip install -e .
This uses xcrun metal and xcrun metallib, so Xcode command-line tools are required.
Correctness/sanity run with PyTorch fallback
python3 sparse_transformer_v10.py \
--device mps \
--steps 2000 \
--active_fractions 0.05 0.02 \
--warmup_steps_list 5 \
--policies predicted_magnitude random \
--backward_modes sparse_dW_full_dX sparse_dW_sparse_dX \
--audit_every 0 \
--kernel_backend torch \
--benchmark_sync
Metal benchmark run
python3 sparse_transformer_v10.py \
--device mps \
--steps 2000 \
--active_fractions 0.05 0.02 \
--warmup_steps_list 5 \
--policies predicted_magnitude random \
--backward_modes sparse_dW_full_dX sparse_dW_sparse_dX \
--audit_every 0 \
--kernel_backend metal \
--benchmark_sync
Empirical pass/fail criteria
A genuine optimization would show:
predicted_magnitudevalidation loss close to the PyTorch fallback v9/v10 result.predicted_magnitudemuch better thanrandomat the same active fraction.--kernel_backend metalhas lowerstep_msor highertokens_per_sthan--kernel_backend torchand ideally dense baseline.
The first Metal kernels are intentionally simple and fp32-only. They are meant to prove or disprove the acceleration path before investing in tiled half-precision kernels.