File size: 2,069 Bytes
bc1b8eb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
# Sparse Transformer v10: Metal-backed active-row Linear backward benchmark

This bundle is the first empirical optimization test rather than only a correctness prototype.
It compares dense Transformer training against sparse active-row backward variants and reports
validation loss plus wall-clock metrics (`step_ms`, `tokens_per_s`).

## Files

- `sparse_transformer_v10.py` — training + benchmark harness.
- `sparse_linear.metal` — Metal kernels for active-row `dW/db` and sparse `dX`.
- `sparse_linear_ops.mm` — PyTorch/MPS extension glue.
- `setup.py` — builds the extension and compiles the `.metallib`.

## Install the Metal extension

From this directory on macOS with PyTorch MPS available:

```bash
python3 -m pip install -e .
```

This uses `xcrun metal` and `xcrun metallib`, so Xcode command-line tools are required.

## Correctness/sanity run with PyTorch fallback

```bash
python3 sparse_transformer_v10.py \
  --device mps \
  --steps 2000 \
  --active_fractions 0.05 0.02 \
  --warmup_steps_list 5 \
  --policies predicted_magnitude random \
  --backward_modes sparse_dW_full_dX sparse_dW_sparse_dX \
  --audit_every 0 \
  --kernel_backend torch \
  --benchmark_sync
```

## Metal benchmark run

```bash
python3 sparse_transformer_v10.py \
  --device mps \
  --steps 2000 \
  --active_fractions 0.05 0.02 \
  --warmup_steps_list 5 \
  --policies predicted_magnitude random \
  --backward_modes sparse_dW_full_dX sparse_dW_sparse_dX \
  --audit_every 0 \
  --kernel_backend metal \
  --benchmark_sync
```

## Empirical pass/fail criteria

A genuine optimization would show:

1. `predicted_magnitude` validation loss close to the PyTorch fallback v9/v10 result.
2. `predicted_magnitude` much better than `random` at the same active fraction.
3. `--kernel_backend metal` has lower `step_ms` or higher `tokens_per_s` than `--kernel_backend torch` and ideally dense baseline.

The first Metal kernels are intentionally simple and fp32-only. They are meant to prove or disprove the acceleration path before investing in tiled half-precision kernels.