Add sparse transformer v19 with Triton-backed KNN scheduler and various backward modes. Includes utilities for synthetic data generation and model training. Implements chunked sparse updates and integrates with existing sparse linear layers.
bc1b8eb | # Sparse Transformer v10: Metal-backed active-row Linear backward benchmark | |
| This bundle is the first empirical optimization test rather than only a correctness prototype. | |
| It compares dense Transformer training against sparse active-row backward variants and reports | |
| validation loss plus wall-clock metrics (`step_ms`, `tokens_per_s`). | |
| ## Files | |
| - `sparse_transformer_v10.py` — training + benchmark harness. | |
| - `sparse_linear.metal` — Metal kernels for active-row `dW/db` and sparse `dX`. | |
| - `sparse_linear_ops.mm` — PyTorch/MPS extension glue. | |
| - `setup.py` — builds the extension and compiles the `.metallib`. | |
| ## Install the Metal extension | |
| From this directory on macOS with PyTorch MPS available: | |
| ```bash | |
| python3 -m pip install -e . | |
| ``` | |
| This uses `xcrun metal` and `xcrun metallib`, so Xcode command-line tools are required. | |
| ## Correctness/sanity run with PyTorch fallback | |
| ```bash | |
| python3 sparse_transformer_v10.py \ | |
| --device mps \ | |
| --steps 2000 \ | |
| --active_fractions 0.05 0.02 \ | |
| --warmup_steps_list 5 \ | |
| --policies predicted_magnitude random \ | |
| --backward_modes sparse_dW_full_dX sparse_dW_sparse_dX \ | |
| --audit_every 0 \ | |
| --kernel_backend torch \ | |
| --benchmark_sync | |
| ``` | |
| ## Metal benchmark run | |
| ```bash | |
| python3 sparse_transformer_v10.py \ | |
| --device mps \ | |
| --steps 2000 \ | |
| --active_fractions 0.05 0.02 \ | |
| --warmup_steps_list 5 \ | |
| --policies predicted_magnitude random \ | |
| --backward_modes sparse_dW_full_dX sparse_dW_sparse_dX \ | |
| --audit_every 0 \ | |
| --kernel_backend metal \ | |
| --benchmark_sync | |
| ``` | |
| ## Empirical pass/fail criteria | |
| A genuine optimization would show: | |
| 1. `predicted_magnitude` validation loss close to the PyTorch fallback v9/v10 result. | |
| 2. `predicted_magnitude` much better than `random` at the same active fraction. | |
| 3. `--kernel_backend metal` has lower `step_ms` or higher `tokens_per_s` than `--kernel_backend torch` and ideally dense baseline. | |
| The first Metal kernels are intentionally simple and fp32-only. They are meant to prove or disprove the acceleration path before investing in tiled half-precision kernels. | |