theapemachine's picture
Add sparse transformer v19 with Triton-backed KNN scheduler and various backward modes. Includes utilities for synthetic data generation and model training. Implements chunked sparse updates and integrates with existing sparse linear layers.
bc1b8eb
# Sparse Transformer v10: Metal-backed active-row Linear backward benchmark
This bundle is the first empirical optimization test rather than only a correctness prototype.
It compares dense Transformer training against sparse active-row backward variants and reports
validation loss plus wall-clock metrics (`step_ms`, `tokens_per_s`).
## Files
- `sparse_transformer_v10.py` — training + benchmark harness.
- `sparse_linear.metal` — Metal kernels for active-row `dW/db` and sparse `dX`.
- `sparse_linear_ops.mm` — PyTorch/MPS extension glue.
- `setup.py` — builds the extension and compiles the `.metallib`.
## Install the Metal extension
From this directory on macOS with PyTorch MPS available:
```bash
python3 -m pip install -e .
```
This uses `xcrun metal` and `xcrun metallib`, so Xcode command-line tools are required.
## Correctness/sanity run with PyTorch fallback
```bash
python3 sparse_transformer_v10.py \
--device mps \
--steps 2000 \
--active_fractions 0.05 0.02 \
--warmup_steps_list 5 \
--policies predicted_magnitude random \
--backward_modes sparse_dW_full_dX sparse_dW_sparse_dX \
--audit_every 0 \
--kernel_backend torch \
--benchmark_sync
```
## Metal benchmark run
```bash
python3 sparse_transformer_v10.py \
--device mps \
--steps 2000 \
--active_fractions 0.05 0.02 \
--warmup_steps_list 5 \
--policies predicted_magnitude random \
--backward_modes sparse_dW_full_dX sparse_dW_sparse_dX \
--audit_every 0 \
--kernel_backend metal \
--benchmark_sync
```
## Empirical pass/fail criteria
A genuine optimization would show:
1. `predicted_magnitude` validation loss close to the PyTorch fallback v9/v10 result.
2. `predicted_magnitude` much better than `random` at the same active fraction.
3. `--kernel_backend metal` has lower `step_ms` or higher `tokens_per_s` than `--kernel_backend torch` and ideally dense baseline.
The first Metal kernels are intentionally simple and fp32-only. They are meant to prove or disprove the acceleration path before investing in tiled half-precision kernels.