File size: 5,671 Bytes
2f36ac4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
# Collected Results

All measurements below are real numbers from actual runs. GPU, settings, and seed noted for each.

---

## Table 1: Your Original Results (MPS, provided by author)

Config: 6 layers, chunk_size=64, B=8, T=256, 10% active, 2000 steps

| d_model | Run | Time (s) | ms/step | Val Loss |
|---------|-----|----------|---------|----------|
| 512 | dense_baseline | 74.77 | 99.70 | 5.3142 |
| 512 | sparse_full_dX | 91.04 | 121.38 | 5.4141 |
| 512 | sparse_sparse_dX | 93.33 | 124.44 | 5.5467 |
| 2048 | dense_baseline | 1035.84 | 591.91 | 6.0264 |
| 2048 | sparse_full_dX | 875.51 | 500.29 | 5.9807 |
| 2048 | sparse_sparse_dX | 847.22 | 484.13 | 6.0231 |

**Observation**: Sparse is slower at d=512 (1.22x overhead), faster at d=2048 (1.18x speedup for full_dX, 1.22x for sparse_dX). Quality comparable at d=2048, worse at d=512.

---

## Table 2: Isolated Matmul Microbenchmark (T4, per single FFN layer)

Config: B=8, T=256 (M=2048), chunk_size=64, 10% active, fp32, 100 iterations

| d_model | FFN dim | Params | Fwd (ms) | dX (ms) | dW_dense (ms) | dW_sparse (ms) | Total_dense (ms) | Total_sparse_full_dX (ms) | Speedup |
|---------|---------|--------|----------|---------|---------------|----------------|-------------------|---------------------------|---------|
| 256 | 1024 | 0.3M | 0.27 | 0.21 | 0.27 | 0.26 | 0.75 | 0.74 | 1.02x |
| 384 | 1536 | 0.6M | 0.52 | 0.69 | 0.61 | 0.18 | 1.82 | 1.39 | 1.31x |
| 512 | 2048 | 1.0M | 1.00 | 1.01 | 0.97 | 0.26 | 2.99 | 2.28 | 1.31x |
| 768 | 3072 | 2.4M | 2.16 | 2.25 | 2.05 | 0.40 | 6.46 | 4.81 | 1.34x |
| 1024 | 4096 | 4.2M | 3.69 | 3.90 | 3.35 | 0.59 | 10.95 | 8.18 | 1.34x |
| 1536 | 6144 | 9.4M | 10.33 | 9.03 | 8.14 | 1.30 | 27.50 | 20.66 | 1.33x |
| 2048 | 8192 | 16.8M | 14.76 | 15.57 | 13.19 | 1.93 | 43.51 | 32.26 | 1.35x |

Amdahl ceiling (if dW were free): ~1.42–1.48x. Crossover point: d_model ≈ 384.

---

## Table 3: Triton Kernel Correctness (T4)

| d_in | d_out | chunk_size | dW max_err | dBias max_err | dX max_err | Status |
|------|-------|------------|-----------|---------------|-----------|--------|
| 512 | 2048 | 64 | 0.000320 | 0.000023 | 0.000042 | ✓ |
| 1024 | 4096 | 64 | 0.000443 | 0.000021 | 0.000092 | ✓ |
| 256 | 1024 | 32 | 0.000275 | 0.000038 | 0.000019 | ✓ |

---

## Table 4: Triton vs PyLoop vs Dense — Isolated Backward (T4)

Config: M=2048, chunk_size=64, 10% active, full_dX mode (dW sparse, dX dense), 50 iterations after warmup

| d_model | FFN dim | Active chunks | Dense (ms) | PyLoop (ms) | Triton (ms) | Triton/Dense | Triton/PyLoop |
|---------|---------|---------------|-----------|-------------|-------------|--------------|---------------|
| 256 | 1024 | 1 | 0.39 | 0.40 | 0.46 | 0.85x | 0.88x |
| 512 | 2048 | 3 | 1.96 | 1.30 | 1.16 | 1.69x | 1.12x |
| 768 | 3072 | 4 | 4.29 | 2.52 | 2.51 | 1.70x | 1.00x |
| 1024 | 4096 | 6 | 7.29 | 4.37 | 4.30 | 1.70x | 1.02x |
| 1536 | 6144 | 9 | 17.32 | 10.04 | 9.78 | 1.77x | 1.03x |
| 2048 | 8192 | 12 | 29.14 | 17.20 | 16.89 | 1.73x | 1.02x |

Triton with both dW and dX sparse:

| d_model | Dense (ms) | Triton_all (ms) | Speedup |
|---------|-----------|-----------------|---------|
| 512 | 1.96 | 0.41 | 4.83x |
| 1024 | 7.06 | 1.07 | 6.58x |
| 2048 | 29.00 | 3.71 | 7.81x |

---

## Table 5: End-to-End Training (T4, 100 steps)

Config: 6 layers, 8 heads, B=8, T=256, chunk_size=64, 10% active, seed=42, AdamW lr=5e-4, full_dX mode

| d_model | Mode | ms/step | vs Dense | Val Loss |
|---------|------|---------|----------|----------|
| 512 | dense | 184.6 | 1.00x | 5.6954 |
| 512 | pyloop | 179.0 | 1.03x | 5.8683 |
| 512 | triton | 196.0 | 0.94x | 5.8683 |
| 1024 | dense | 451.5 | 1.00x | 5.5300 |
| 1024 | pyloop | 435.6 | 1.04x | 5.4803 |
| 1024 | triton | 441.0 | 1.02x | 5.4800 |

d=2048 does not fit on T4 (16GB). A10G results pending (job 69f3af45d2c8bd8662bd419d).

Note: Triton autotune overhead hurts at small scale. At d=512 with only 1 active chunk per layer, fused kernels lose to PyTorch's already-optimized single-kernel launches.

---

## Table 6: EMA Predictor Overlap (T4, 350 steps, seed=42)

Config: d=512, 6 layers, chunk_size=64, 10% active, measured every 25 steps after annealing (step ≥ 250)

| Step | Jaccard | Recall |
|------|---------|--------|
| 250 | 0.6000 | 0.7500 |
| 275 | 0.6552 | 0.7917 |
| 300 | 0.7778 | 0.8750 |
| 325 | 0.6000 | 0.7500 |

Single seed only. Full 3-seed results with 2000 steps pending from A10G job.

---

## Table 7: Chunk-Size vs Speed (T4, 50 steps, timing only)

Config: d=512, 6 layers, 10% active, seed=42. Loss identical across sizes (only 50 steps, all in warmup).

| Chunk Size | ms/step |
|------------|---------|
| 16 | 601.4 |
| 32 | 453.0 |
| 64 | 321.5 |
| 128 | 251.3 |
| 256 | 219.8 |

Larger chunks = fewer Python loop iterations = less overhead. This is the PyLoop backend; Triton would show a different curve.

---

## Pending Results (A10G jobs running)

| Job ID | Experiment | Status |
|--------|-----------|--------|
| 69f38371d70108f37ace1cae | Full 7-experiment suite (2000 steps, 3 seeds, all ablations) | Running |
| 69f395b3d70108f37ace1cee | Model-size scaling study (d=256→2048, 2000 steps, 2 seeds) | Running |
| 69f3af45d2c8bd8662bd419d | E2E training with Triton (d=512,1024,2048, 500 steps) | Running |

These will provide:
- Table 3 full: All 8 baselines with 3 seeds at 2000 steps (Dense, Random, EMA, EMA+sparse_dX, RigL, SET, TopK-SGD, Oracle)
- Compute-matched dense (same FLOPs) vs sparse
- Chunk-size ablation with loss numbers at 2000 steps
- Epsilon-greedy exploration sweep
- Attention sparsification results
- Sparsity level sweep (5%–100%)
- d=2048 end-to-end training with Triton