theapemachine's picture
Add sparse transformer v19 with Triton-backed KNN scheduler and various backward modes. Includes utilities for synthetic data generation and model training. Implements chunked sparse updates and integrates with existing sparse linear layers.
bc1b8eb

Sparse Transformer v10: Metal-backed active-row Linear backward benchmark

This bundle is the first empirical optimization test rather than only a correctness prototype. It compares dense Transformer training against sparse active-row backward variants and reports validation loss plus wall-clock metrics (step_ms, tokens_per_s).

Files

  • sparse_transformer_v10.py — training + benchmark harness.
  • sparse_linear.metal — Metal kernels for active-row dW/db and sparse dX.
  • sparse_linear_ops.mm — PyTorch/MPS extension glue.
  • setup.py — builds the extension and compiles the .metallib.

Install the Metal extension

From this directory on macOS with PyTorch MPS available:

python3 -m pip install -e .

This uses xcrun metal and xcrun metallib, so Xcode command-line tools are required.

Correctness/sanity run with PyTorch fallback

python3 sparse_transformer_v10.py \
  --device mps \
  --steps 2000 \
  --active_fractions 0.05 0.02 \
  --warmup_steps_list 5 \
  --policies predicted_magnitude random \
  --backward_modes sparse_dW_full_dX sparse_dW_sparse_dX \
  --audit_every 0 \
  --kernel_backend torch \
  --benchmark_sync

Metal benchmark run

python3 sparse_transformer_v10.py \
  --device mps \
  --steps 2000 \
  --active_fractions 0.05 0.02 \
  --warmup_steps_list 5 \
  --policies predicted_magnitude random \
  --backward_modes sparse_dW_full_dX sparse_dW_sparse_dX \
  --audit_every 0 \
  --kernel_backend metal \
  --benchmark_sync

Empirical pass/fail criteria

A genuine optimization would show:

  1. predicted_magnitude validation loss close to the PyTorch fallback v9/v10 result.
  2. predicted_magnitude much better than random at the same active fraction.
  3. --kernel_backend metal has lower step_ms or higher tokens_per_s than --kernel_backend torch and ideally dense baseline.

The first Metal kernels are intentionally simple and fp32-only. They are meant to prove or disprove the acceleration path before investing in tiled half-precision kernels.