This view is limited to 50 files because it contains too many changes. See the raw diff here.
Files changed (50) hide show
  1. .gitattributes +35 -0
  2. .gitignore +0 -21
  3. BENCHMARK_REPORT.md +0 -439
  4. HACKATHON_SUMMARY.md +0 -179
  5. README.md +12 -129
  6. backend/main.py +0 -177
  7. backend/requirements.txt +0 -3
  8. backend/terminal_manager.py +0 -342
  9. client.py +0 -77
  10. frontend/.gitignore +0 -24
  11. frontend/README.md +0 -16
  12. frontend/eslint.config.js +0 -29
  13. frontend/index.html +0 -13
  14. frontend/package-lock.json +0 -0
  15. frontend/package.json +0 -27
  16. frontend/public/vite.svg +0 -1
  17. frontend/src/App.jsx +0 -195
  18. frontend/src/api/terminal.js +0 -57
  19. frontend/src/assets/react.svg +0 -1
  20. frontend/src/components/TerminalPane.jsx +0 -147
  21. frontend/src/hooks/useTerminalSession.js +0 -166
  22. frontend/src/index.css +0 -486
  23. frontend/src/main.jsx +0 -10
  24. frontend/vite.config.js +0 -7
  25. index.html +317 -0
  26. models.py +0 -27
  27. openenv.yaml +0 -15
  28. requirements.txt +0 -8
  29. scripts/__init__.py +0 -1
  30. scripts/benchmark_qwen_05b_runtime.py +0 -158
  31. scripts/benchmark_runtime_references.py +0 -222
  32. scripts/benchmark_surrogate_vs_triton_autotune.py +0 -361
  33. scripts/build_benchmark_splits.py +0 -81
  34. scripts/build_qwen_05b_splits.py +0 -56
  35. scripts/check_torch_cuda.py +0 -18
  36. scripts/collect_measurements.py +0 -398
  37. scripts/collect_multifamily_measurements.py +0 -659
  38. scripts/collect_qwen_05b_measurements.py +0 -449
  39. scripts/demo_compare.py +0 -200
  40. scripts/eval_generalization.py +0 -147
  41. scripts/qwen_05b_spec.py +0 -115
  42. scripts/report_task_hardness.py +0 -42
  43. scripts/run_full_pipeline.sh +0 -84
  44. scripts/run_qwen_05b_pipeline.sh +0 -90
  45. scripts/run_random_baseline.py +0 -187
  46. scripts/run_surrogate_baseline.py +0 -217
  47. scripts/smoke_test_client.py +0 -28
  48. server/__init__.py +0 -1
  49. server/app.py +0 -72
  50. server/softmax_surrogate_environment.py +0 -570
.gitattributes ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
.gitignore DELETED
@@ -1,21 +0,0 @@
1
- __pycache__/
2
- *.pyc
3
- *.pyo
4
- *.log
5
-
6
-
7
- outputs/
8
- data/softmax_measurements.csv
9
- backend/build/
10
- backend/dist/
11
- backend/venv/
12
- backend/__pycache__/
13
- backend/*.pyc
14
- backend/*.pyo
15
- backend/*.log
16
- backend/outputs/
17
- backend/data/
18
- backend/data/softmax_measurements.csv
19
- backend/backend/build/
20
- backend/backend/dist/
21
- backend/backend/__pycache__/
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
BENCHMARK_REPORT.md DELETED
@@ -1,439 +0,0 @@
1
- # GPU Autotuning Benchmark Report
2
-
3
- ## Project Summary
4
-
5
- This project implements a discrete-action, measured-oracle autotuning benchmark for GPU kernels on an H200-class machine. The optimizer is a surrogate-guided search policy that selects kernel configurations under a short evaluation budget. The benchmark is intentionally structured to answer two different questions:
6
-
7
- 1. How much more search-efficient is a surrogate-guided policy than a random baseline?
8
- 2. Do the kernel configurations found by the search process translate into materially better runtime than strong PyTorch baselines such as eager execution and `torch.compile`?
9
-
10
- The current benchmark answers both questions with real measurements.
11
-
12
- ## What Was Built
13
-
14
- The current system includes:
15
-
16
- - A shared multi-family measurement cache:
17
- - `softmax`
18
- - `layernorm`
19
- - `grouped_gemm`
20
- - A discrete-action environment with:
21
- - hidden oracle table
22
- - short-budget search episodes
23
- - regret-based metrics
24
- - train-task priors for cross-task generalization
25
- - Split-based evaluation:
26
- - held-out shapes
27
- - held-out family
28
- - Runtime reference benchmarking against:
29
- - eager PyTorch
30
- - `torch.compile`
31
- - live reruns of the best Triton configs found by `random`
32
- - live reruns of the best Triton configs found by `surrogate`
33
-
34
- The key design choice is that search evaluation and runtime evaluation are separated:
35
-
36
- - Search benchmark: measures how quickly each method finds good configs
37
- - Runtime benchmark: measures how fast the final chosen kernels actually run
38
-
39
- ## Benchmark Structure
40
-
41
- ### Families
42
-
43
- - Family A:
44
- - `softmax`
45
- - `layernorm`
46
- - Family B:
47
- - `grouped_gemm`
48
-
49
- ### Shapes
50
-
51
- Current task shapes use:
52
-
53
- - fixed `M = 4096`
54
- - `N ∈ {256, 512, 1024, 1536, 2048, 3072, 4096, 6144, 8192}`
55
-
56
- This yields:
57
-
58
- - `9` softmax tasks
59
- - `9` layernorm tasks
60
- - `9` grouped GEMM tasks
61
- - `27` tasks total
62
-
63
- ### Search Space
64
-
65
- Current kernel config axes:
66
-
67
- - `block_size`
68
- - `num_warps`
69
- - `num_stages`
70
-
71
- Search is over discrete `config_id`s, not projected continuous actions. That makes the random baseline a true uniform baseline over legal configs.
72
-
73
- ### Evaluation Splits
74
-
75
- The benchmark currently uses:
76
-
77
- - `shape_generalization`
78
- - train on lower/intermediate shapes from each family
79
- - test on held-out larger shapes within the same families
80
- - `family_holdout`
81
- - train on `softmax` + `layernorm`
82
- - test on all `grouped_gemm` tasks
83
-
84
- At the time of the main run:
85
-
86
- - shape-generalization test tasks: `6`
87
- - family-holdout test tasks: `9`
88
- - unique held-out runtime-reference tasks: `13`
89
-
90
- ## Metrics
91
-
92
- ### Search Metrics
93
-
94
- The main optimization metrics are:
95
-
96
- - `regret@k`
97
- - best-so-far latency relative to oracle best after `k` search steps
98
- - best-so-far AUC
99
- - average regret across the whole search trajectory
100
- - lower is better
101
- - final oracle hit rate
102
- - how often the optimizer reaches the oracle-best config by the end of the episode
103
-
104
- These metrics intentionally emphasize search efficiency, not just the final endpoint.
105
-
106
- ### Runtime Metrics
107
-
108
- The runtime-reference benchmark records:
109
-
110
- - eager PyTorch steady-state latency
111
- - `torch.compile` compile-plus-first-call time
112
- - `torch.compile` steady-state latency
113
- - live rerun latency of best Triton config found by `random`
114
- - live rerun latency of best Triton config found by `surrogate`
115
- - Triton speedups vs eager and vs compiled
116
-
117
- ## Main Search Results
118
-
119
- The main search results came from the held-out-shape / held-out-family evaluation in `run13.log`.
120
-
121
- ### Held-out Shapes
122
-
123
- Random:
124
-
125
- - `mean_regret_at`:
126
- - `@1 = 0.31341859698295593`
127
- - `@3 = 0.13305269181728363`
128
- - `@5 = 0.1070360466837883`
129
- - `@6 = 0.06344400346279144`
130
- - `mean_best_so_far_auc = 0.1483089178800583`
131
- - `mean_oracle_hit_rate_final = 0.6749999523162842`
132
-
133
- Surrogate:
134
-
135
- - `mean_regret_at`:
136
- - `@1 = 0.002321675419807434`
137
- - `@3 = 0.002293013734742999`
138
- - `@5 = 0.002293013734742999`
139
- - `@6 = 0.002293013734742999`
140
- - `mean_best_so_far_auc = 0.0023013732861727476`
141
- - `mean_oracle_hit_rate_final = 0.8333333134651184`
142
-
143
- Interpretation:
144
-
145
- - Surrogate reduced best-so-far AUC by about `98.4%` versus random.
146
- - Surrogate reduced final `regret@6` by about `96.4%`.
147
- - Surrogate improved final oracle-hit rate by about `15.8` percentage points.
148
-
149
- This is a very strong within-family / held-out-shape result.
150
-
151
- ### Held-out Family (`grouped_gemm`)
152
-
153
- Random:
154
-
155
- - `mean_regret_at`:
156
- - `@1 = 2.341181755065918`
157
- - `@3 = 0.8532703518867493`
158
- - `@5 = 0.3116174638271332`
159
- - `@6 = 0.21012252569198608`
160
- - `mean_best_so_far_auc = 0.9102223515510559`
161
- - `mean_oracle_hit_rate_final = 0.17777778208255768`
162
-
163
- Surrogate:
164
-
165
- - `mean_regret_at`:
166
- - `@1 = 0.4722703695297241`
167
- - `@3 = 0.29785311222076416`
168
- - `@5 = 0.20862582325935364`
169
- - `@6 = 0.17804712057113647`
170
- - `mean_best_so_far_auc = 0.3014116585254669`
171
- - `mean_oracle_hit_rate_final = 0.3444444239139557`
172
-
173
- Interpretation:
174
-
175
- - Surrogate reduced best-so-far AUC by about `66.9%` versus random.
176
- - Surrogate reduced final `regret@6` by about `15.3%`.
177
- - Surrogate improved final oracle-hit rate by about `16.7` percentage points.
178
-
179
- This is a good cross-family transfer result. The gap is smaller than in the held-out-shape setting, which is expected.
180
-
181
- ## Main Runtime Results
182
-
183
- The runtime-reference benchmark compares eager PyTorch, `torch.compile`, and the live reruns of the best Triton configs found by `random` and `surrogate`.
184
-
185
- ### Summary by Family
186
-
187
- #### Softmax
188
-
189
- Average eager latency:
190
-
191
- - `0.1103919968008995 ms`
192
-
193
- Average `torch.compile` steady-state latency:
194
-
195
- - `0.11152799427509308 ms`
196
-
197
- Average compile-plus-first-call time:
198
-
199
- - `529.9687260048813 ms`
200
-
201
- Average Triton speedup vs eager:
202
-
203
- - random-best: `3.362561387683493x`
204
- - surrogate-best: `3.286588301595338x`
205
-
206
- Average Triton speedup vs compiled:
207
-
208
- - random-best: `3.3985671575178635x`
209
- - surrogate-best: `3.321742054891467x`
210
-
211
- Interpretation:
212
-
213
- - `torch.compile` is effectively flat vs eager on this softmax set.
214
- - Tuned Triton is substantially faster than both.
215
- - Surrogate-best and random-best final kernels are very close in absolute runtime.
216
-
217
- #### LayerNorm
218
-
219
- Average eager latency:
220
-
221
- - `0.05939200147986412 ms`
222
-
223
- Average `torch.compile` steady-state latency:
224
-
225
- - `0.09503999352455139 ms`
226
-
227
- Average compile-plus-first-call time:
228
-
229
- - `440.1235789991915 ms`
230
-
231
- Average Triton speedup vs eager:
232
-
233
- - random-best: `1.8776593781360051x`
234
- - surrogate-best: `1.8364378273209185x`
235
-
236
- Average Triton speedup vs compiled:
237
-
238
- - random-best: `2.927484944635789x`
239
- - surrogate-best: `2.862647103483093x`
240
-
241
- Interpretation:
242
-
243
- - `torch.compile` is slower than eager on this LayerNorm set.
244
- - Tuned Triton is materially faster than both eager and compiled.
245
- - Again, surrogate-best and random-best final kernels are close in endpoint runtime.
246
-
247
- #### Grouped GEMM
248
-
249
- Average eager latency:
250
-
251
- - `0.05589688859052128 ms`
252
-
253
- Average `torch.compile` steady-state latency:
254
-
255
- - `0.101806221736802 ms`
256
-
257
- Average compile-plus-first-call time:
258
-
259
- - `102.45987688863858 ms`
260
-
261
- Average Triton speedup vs eager:
262
-
263
- - random-best: `1.2771213149737215x`
264
- - surrogate-best: `1.2644549628354071x`
265
-
266
- Average Triton speedup vs compiled:
267
-
268
- - random-best: `2.4414293463407355x`
269
- - surrogate-best: `2.4156697207038382x`
270
-
271
- Interpretation:
272
-
273
- - `torch.compile` is materially slower than eager on this grouped-GEMM set.
274
- - Tuned Triton is faster than both eager and compiled.
275
- - Endpoint difference between surrogate-best and random-best remains small.
276
-
277
- ## Triton Autotune Comparison
278
-
279
- We also compared the surrogate search workflow directly against `triton.autotune` on three large softmax tasks:
280
-
281
- - `softmax_m4096_n4096`
282
- - `softmax_m4096_n6144`
283
- - `softmax_m4096_n8192`
284
-
285
- This comparison measures two things:
286
-
287
- - search plus compile cost to the first usable tuned kernel
288
- - steady-state runtime of the resulting tuned kernel
289
-
290
- ### Per-task Results
291
-
292
- #### `softmax_m4096_n4096`
293
-
294
- - oracle best: `0.02127999998629093 ms`
295
- - surrogate:
296
- - decision time: `33.06370500649791 ms`
297
- - compile plus first call: `294.734695009538 ms`
298
- - steady-state: `0.02127999998629093 ms`
299
- - regret vs oracle: `0.0`
300
- - `triton.autotune`:
301
- - first call: `8970.702438004082 ms`
302
- - steady-state: `0.021856000646948814 ms`
303
- - regret vs oracle: `0.0270677002363231`
304
-
305
- #### `softmax_m4096_n6144`
306
-
307
- - oracle best: `0.030719999223947525 ms`
308
- - surrogate:
309
- - decision time: `15.47088599181734 ms`
310
- - compile plus first call: `0.9627069957787171 ms`
311
- - steady-state: `0.031007999554276466 ms`
312
- - regret vs oracle: `0.009375010989727928`
313
- - `triton.autotune`:
314
- - first call: `1990.3547260037158 ms`
315
- - steady-state: `0.031039999797940254 ms`
316
- - regret vs oracle: `0.010416685614473398`
317
-
318
- #### `softmax_m4096_n8192`
319
-
320
- - oracle best: `0.03747199848294258 ms`
321
- - surrogate:
322
- - decision time: `15.47144899086561 ms`
323
- - compile plus first call: `0.09679699724074453 ms`
324
- - steady-state: `0.03798399865627289 ms`
325
- - regret vs oracle: `0.013663540618560122`
326
- - `triton.autotune`:
327
- - first call: `1988.913670007605 ms`
328
- - steady-state: `0.03747199848294258 ms`
329
- - regret vs oracle: `0.0`
330
-
331
- ### Summary
332
-
333
- - mean surrogate compile plus first call: `98.59806633418582 ms`
334
- - mean surrogate steady-state latency: `0.03009066606561343 ms`
335
- - mean `triton.autotune` first call: `4316.656944671801 ms`
336
- - mean `triton.autotune` steady-state latency: `0.030122666309277218 ms`
337
-
338
- Interpretation:
339
-
340
- - The surrogate reaches a competitive tuned kernel far faster than `triton.autotune` on these tasks.
341
- - Steady-state performance is effectively the same on average:
342
- - surrogate mean steady-state: `0.0300907 ms`
343
- - `triton.autotune` mean steady-state: `0.0301227 ms`
344
- - On one task the surrogate exactly matched the oracle best.
345
- - On the other two tasks the surrogate was slightly off the oracle, but still close.
346
- - `triton.autotune` won one task in steady-state quality, but paid a much larger first-call search cost.
347
-
348
- This is an important result because it shows the surrogate is not only better than a random search baseline. It is also competitive with Triton's built-in autotuning in final kernel quality while being dramatically cheaper in tuning-time-to-first-good-kernel on these tested shapes.
349
-
350
- ## What The Results Mean
351
-
352
- The results support the following conclusions:
353
-
354
- 1. The surrogate optimizer is genuinely useful as a search policy.
355
- - It is substantially more sample-efficient than random.
356
- - It reaches good kernels much earlier in the budget.
357
- - This effect is very strong on held-out shapes and still meaningful on held-out family transfer.
358
-
359
- 2. The resulting Triton kernels are genuinely useful as runtime implementations.
360
- - They are faster than eager PyTorch.
361
- - They are faster than `torch.compile`.
362
- - The advantage is strongest on `softmax`, then `layernorm`, then `grouped_gemm`.
363
-
364
- 3. The surrogate is also competitive with `triton.autotune` on final steady-state runtime while being much cheaper in first-call tuning cost on the tested softmax shapes.
365
-
366
- 4. The main value of the surrogate is search efficiency, not necessarily a dramatically better final endpoint than a lucky random search.
367
- - By the end of the short search budget, random and surrogate can still land on very similar endpoint kernels.
368
- - This is visible in the runtime benchmark, where random-best and surrogate-best final kernels are often close in ms.
369
- - The surrogate still wins decisively on regret and AUC.
370
-
371
- This is a coherent and valuable result. The optimizer is improving how quickly good kernels are found, and the kernels it finds are fast in absolute runtime.
372
-
373
- ## Caveats and Professional Notes
374
-
375
- This benchmark is strong for a hackathon project, but it should not be oversold.
376
-
377
- Important caveats:
378
-
379
- - The benchmark is not publication-grade yet.
380
- - Only one `M` value is used.
381
- - Family/task distributions are still narrow.
382
- - Grouped GEMM was added recently and should be validated more deeply.
383
- - `torch.compile` compile time measurement should be interpreted carefully.
384
- - The measured compile-plus-first-call times vary across tasks.
385
- - A more publication-grade measurement would reset compiler state more aggressively and isolate cold-start behavior more carefully.
386
- - The runtime benchmark uses live reruns.
387
- - That means small run-to-run variation is expected.
388
- - The endpoint runtime gap between surrogate-best and random-best is small.
389
- - This is not a failure of the benchmark.
390
- - It means the benchmark currently demonstrates search-efficiency gains more strongly than final-endpoint gains.
391
-
392
- These caveats do not invalidate the results. They define the proper scope of the claims.
393
-
394
- ## Recommended Project Framing
395
-
396
- The most honest and compelling project framing is:
397
-
398
- > We built a measured-oracle GPU autotuning benchmark with held-out-shape and held-out-family evaluation. Our surrogate-guided optimizer substantially outperforms a random baseline on short-budget search efficiency, and the Triton kernels it finds are materially faster than both eager PyTorch and `torch.compile`.
399
-
400
- That statement matches the data.
401
-
402
- ## Recommended Next Steps
403
-
404
- Highest-value next steps:
405
-
406
- 1. Expand the runtime benchmark:
407
- - more held-out tasks
408
- - more end-to-end summaries
409
-
410
- 2. Improve search without changing the measurement cache:
411
- - stronger acquisition strategies
412
- - family-aware priors
413
- - feature engineering before simply increasing parameter count
414
-
415
- 3. Make the collector more production-friendly:
416
- - incremental writes
417
- - progress logging
418
- - resumable measurement collection
419
-
420
- 4. If needed, increase benchmark rigor:
421
- - broader shape sets
422
- - more families
423
- - more careful cold-start `torch.compile` accounting
424
-
425
- ## Deliverable Status
426
-
427
- At the current stage, this project is:
428
-
429
- - hackathon-ready
430
- - technically credible
431
- - professionally explainable
432
- - useful as a platform for next-stage kernel autotuning work
433
-
434
- It is not yet:
435
-
436
- - a finished research benchmark
437
- - a final systems paper artifact
438
-
439
- That is the correct level of rigor for the current results.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
HACKATHON_SUMMARY.md DELETED
@@ -1,179 +0,0 @@
1
- # GPU Kernel Autotuning: Hackathon Summary
2
-
3
- ## What We Built
4
-
5
- We built a measured-oracle GPU autotuning benchmark for Triton kernels on an H200-class machine.
6
-
7
- The system has two layers:
8
-
9
- - Search benchmark
10
- - compares `surrogate` vs `random`
11
- - measures how quickly each method finds good kernel configs
12
- - Runtime benchmark
13
- - compares the selected Triton kernels against eager PyTorch and `torch.compile`
14
- - measures actual execution latency in milliseconds
15
-
16
- The benchmark is discrete-action, uses real measured kernel runtimes, and supports held-out-shape and held-out-family evaluation.
17
-
18
- ## Kernel Families
19
-
20
- Current benchmark families:
21
-
22
- - Family A:
23
- - `softmax`
24
- - `layernorm`
25
- - Family B:
26
- - `grouped_gemm`
27
-
28
- Current shape set:
29
-
30
- - `M = 4096`
31
- - `N ∈ {256, 512, 1024, 1536, 2048, 3072, 4096, 6144, 8192}`
32
-
33
- Total tasks:
34
-
35
- - `27` tasks
36
-
37
- ## Why This Matters
38
-
39
- There are two distinct questions in GPU autotuning:
40
-
41
- 1. Can an optimizer find good kernels quickly under a short tuning budget?
42
- 2. Are the kernels it finds actually fast in real runtime?
43
-
44
- This project answers both.
45
-
46
- ## Search Results
47
-
48
- ### Held-out Shapes
49
-
50
- Against the `random` baseline, the `surrogate` optimizer achieved:
51
-
52
- - `98.4%` lower best-so-far regret AUC
53
- - `96.4%` lower `regret@6`
54
- - `+15.8` percentage points higher final oracle-hit rate
55
-
56
- Interpretation:
57
-
58
- - On held-out shapes, the surrogate is dramatically more sample-efficient than random.
59
- - It finds near-optimal kernels almost immediately.
60
-
61
- ### Held-out Family (`grouped_gemm`)
62
-
63
- Against the `random` baseline, the `surrogate` optimizer achieved:
64
-
65
- - `66.9%` lower best-so-far regret AUC
66
- - `15.3%` lower `regret@6`
67
- - `+16.7` percentage points higher final oracle-hit rate
68
-
69
- Interpretation:
70
-
71
- - The surrogate also transfers across kernel families.
72
- - The cross-family problem is harder, but the optimizer still wins clearly.
73
-
74
- ## Runtime Results
75
-
76
- We then reran the selected Triton kernels live and compared them to:
77
-
78
- - eager PyTorch
79
- - `torch.compile`
80
- - `triton.autotune` on selected large softmax tasks
81
-
82
- ### Softmax
83
-
84
- Average surrogate-best Triton runtime:
85
-
86
- - about `0.0336 ms`
87
-
88
- Average `torch.compile` runtime:
89
-
90
- - about `0.1115 ms`
91
-
92
- Result:
93
-
94
- - surrogate-selected Triton is about `3.32x` faster than `torch.compile`
95
-
96
- ### LayerNorm
97
-
98
- Average surrogate-best Triton runtime:
99
-
100
- - about `0.0323 ms`
101
-
102
- Average `torch.compile` runtime:
103
-
104
- - about `0.0950 ms`
105
-
106
- Result:
107
-
108
- - surrogate-selected Triton is about `2.86x` faster than `torch.compile`
109
-
110
- ### Grouped GEMM
111
-
112
- Average surrogate-best Triton runtime:
113
-
114
- - about `0.0442 ms`
115
-
116
- Average `torch.compile` runtime:
117
-
118
- - about `0.1018 ms`
119
-
120
- Result:
121
-
122
- - surrogate-selected Triton is about `2.42x` faster than `torch.compile`
123
-
124
- ### Triton Autotune on Large Softmax
125
-
126
- We also compared the surrogate directly against `triton.autotune` on:
127
-
128
- - `softmax_m4096_n4096`
129
- - `softmax_m4096_n6144`
130
- - `softmax_m4096_n8192`
131
-
132
- Result:
133
-
134
- - mean surrogate compile plus first call: `98.6 ms`
135
- - mean `triton.autotune` first call: `4316.7 ms`
136
- - mean steady-state latency was effectively the same:
137
- - surrogate: `0.03009 ms`
138
- - `triton.autotune`: `0.03012 ms`
139
-
140
- Interpretation:
141
-
142
- - The surrogate reaches a strong tuned kernel far faster than `triton.autotune` on these tested softmax shapes.
143
- - Final steady-state kernel quality is essentially matched on average.
144
-
145
- ## What The Results Mean
146
-
147
- The current system shows two real advantages:
148
-
149
- - The surrogate is much better than random at short-budget tuning.
150
- - The resulting Triton kernels are materially faster than `torch.compile`.
151
- - On selected large softmax tasks, the surrogate also reaches competitive tuned kernels far faster than `triton.autotune`.
152
-
153
- The subtle point is that the surrogate's biggest gain is in search efficiency, not necessarily in a huge final-endpoint gap over a lucky random search. That is why regret/AUC is the right optimization metric here.
154
-
155
- ## What Is Strong
156
-
157
- - Real GPU measurements, not synthetic rewards
158
- - Hidden-oracle benchmark protocol
159
- - Discrete action space
160
- - Held-out-shape evaluation
161
- - Held-out-family evaluation
162
- - Runtime comparison against eager PyTorch and `torch.compile`
163
-
164
- ## What Is Not Final Yet
165
-
166
- - This is not publication-grade yet
167
- - Only one `M` dimension is used
168
- - Grouped GEMM was added recently and should be stress-validated further
169
- - `torch.compile` cold-start accounting could be measured even more rigorously
170
-
171
- ## Bottom Line
172
-
173
- This project demonstrates a credible autotuning workflow:
174
-
175
- - surrogate-guided search beats random on search quality
176
- - tuned Triton kernels beat `torch.compile` on runtime
177
- - the benchmark already shows cross-family transfer
178
-
179
- That is a strong hackathon result with professional-grade measurement discipline.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
README.md CHANGED
@@ -1,129 +1,12 @@
1
- # Multi-family GPU autotuning benchmark
2
-
3
- ## TLDR
4
-
5
- ```bash
6
- python3 scripts/check_torch_cuda.py
7
- ./scripts/run_full_pipeline.sh
8
- ```
9
-
10
- The pipeline now:
11
- - collects a shared measurement cache for `softmax`, `layernorm`, and `grouped_gemm`
12
- - builds split manifests for shape holdout and family holdout
13
- - smoke-tests the local OpenEnv-style environment
14
- - evaluates `random` vs `surrogate` using `regret@k` and best-so-far AUC
15
- - benchmarks eager PyTorch and `torch.compile` against best Triton configs
16
-
17
- ## Measurement cache
18
-
19
- Main collector:
20
-
21
- ```bash
22
- python3 scripts/collect_multifamily_measurements.py \
23
- --output data/autotune_measurements.csv \
24
- --families softmax layernorm grouped_gemm \
25
- --n-cols 256 512 1024 1536 2048 3072 4096 6144 8192 \
26
- --m 4096 \
27
- --repeats 200 \
28
- --warmup 25 \
29
- --seed 0 \
30
- --append
31
- ```
32
-
33
- Current implemented families:
34
- - `softmax`
35
- - `layernorm`
36
- - `grouped_gemm`
37
-
38
- The shared CSV schema includes:
39
- - `family_group`
40
- - `family`
41
- - `task_id`
42
- - `m`, `n`
43
- - `config_id`
44
- - `block_size`, `num_warps`, `num_stages`
45
- - `shape_json`, `config_json`
46
- - `median_ms`, `effective_gbps`, `score`, `validation_error`
47
-
48
- ## Splits and eval
49
-
50
- Build split manifests:
51
-
52
- ```bash
53
- python3 scripts/build_benchmark_splits.py \
54
- --measurement-path data/autotune_measurements.csv \
55
- --output data/benchmark_splits.json \
56
- --heldout-family grouped_gemm
57
- ```
58
-
59
- Run split-based evaluation:
60
-
61
- ```bash
62
- python3 scripts/eval_generalization.py \
63
- --measurement-path data/autotune_measurements.csv \
64
- --splits data/benchmark_splits.json \
65
- --episodes 20 \
66
- --budget 6 \
67
- --seed 2 \
68
- --acquisition ucb \
69
- --beta 2.0
70
- ```
71
-
72
- Benchmark absolute runtime references:
73
-
74
- ```bash
75
- python3 scripts/benchmark_runtime_references.py \
76
- --generalization-results outputs/generalization_eval.json \
77
- --repeats 100 \
78
- --warmup 10 \
79
- --seed 123
80
- ```
81
-
82
- Metrics:
83
- - `mean_regret_at`
84
- - `median_regret_at`
85
- - `mean_best_so_far_auc`
86
- - `mean_oracle_hit_rate_final`
87
- - `eager_latency_ms`
88
- - `compile_plus_first_call_ms`
89
- - `compiled_latency_ms`
90
- - Triton speedups vs eager / compiled
91
-
92
- ## Environment
93
-
94
- OpenEnv metadata is in:
95
- - `openenv.yaml`
96
-
97
- The environment server still uses:
98
- - `server/app.py`
99
- - `server/softmax_surrogate_environment.py`
100
-
101
- Despite the filename, the env is now multi-family and supports a train-task prior for held-out-shape / held-out-family evaluation.
102
-
103
- ## Qwen2.5-0.5B exact-kernel pipeline
104
-
105
- This repo now also includes a model-specific benchmark pipeline for the exact inference kernel roles needed by `Qwen/Qwen2.5-0.5B`.
106
-
107
- Kernel roles covered:
108
- - `rmsnorm`
109
- - attention `softmax`
110
- - `q_proj`, `k_proj`, `v_proj`, `o_proj`
111
- - `gate_proj`, `up_proj`, `down_proj`
112
-
113
- Run it with:
114
-
115
- ```bash
116
- ./scripts/run_qwen_05b_pipeline.sh
117
- ```
118
-
119
- Key files:
120
- - `scripts/qwen_05b_spec.py`
121
- - `scripts/collect_qwen_05b_measurements.py`
122
- - `scripts/build_qwen_05b_splits.py`
123
- - `scripts/benchmark_qwen_05b_runtime.py`
124
-
125
- Outputs:
126
- - `data/qwen_05b_measurements.csv`
127
- - `data/qwen_05b_splits.json`
128
- - `outputs/qwen_05b_generalization_eval.json`
129
- - `outputs/qwen_05b_runtime_references.json`
 
1
+ ---
2
+ title: RL Surrogate ENV
3
+ emoji: 👀
4
+ colorFrom: gray
5
+ colorTo: blue
6
+ sdk: static
7
+ pinned: false
8
+ license: apache-2.0
9
+ short_description: This is a surrogate discovery environment to improve perform
10
+ ---
11
+
12
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
backend/main.py DELETED
@@ -1,177 +0,0 @@
1
- from __future__ import annotations
2
-
3
- import sys
4
- from pathlib import Path
5
- from typing import Any, Dict, List, Optional
6
-
7
- ROOT = Path(__file__).resolve().parents[1]
8
- if str(ROOT) not in sys.path:
9
- sys.path.insert(0, str(ROOT))
10
-
11
- from fastapi import FastAPI, HTTPException, WebSocket, WebSocketDisconnect
12
- from fastapi.middleware.cors import CORSMiddleware
13
- from pydantic import BaseModel
14
-
15
- from backend.terminal_manager import ALLOWED_JOBS, TerminalManager
16
-
17
- app = FastAPI(
18
- title="RL Autotuning Backend",
19
- description="Backend API for the multi-family GPU autotuning benchmark",
20
- version="0.1.0",
21
- )
22
-
23
- app.add_middleware(
24
- CORSMiddleware,
25
- allow_origins=[
26
- "http://localhost:5173",
27
- "http://127.0.0.1:5173",
28
- "http://localhost:4173",
29
- "http://127.0.0.1:4173",
30
- ],
31
- allow_credentials=False,
32
- allow_methods=["*"],
33
- allow_headers=["*"],
34
- )
35
-
36
- env = None
37
- terminal_manager = TerminalManager()
38
-
39
-
40
- def _get_env():
41
- global env
42
- if env is None:
43
- try:
44
- from server.softmax_surrogate_environment import SoftmaxSurrogateEnvironment
45
- env = SoftmaxSurrogateEnvironment()
46
- except ImportError as exc:
47
- raise HTTPException(
48
- status_code=503,
49
- detail=f"Environment unavailable – missing dependency: {exc.name}",
50
- )
51
- return env
52
-
53
-
54
- class ResetRequest(BaseModel):
55
- task: Optional[str] = None
56
- seed: Optional[int] = None
57
-
58
-
59
- class StepRequest(BaseModel):
60
- config_id: Optional[int] = None
61
- x: Optional[List[float]] = None
62
-
63
-
64
- class SessionRequest(BaseModel):
65
- job_id: str
66
- restart: bool = False
67
-
68
-
69
- class SessionInputRequest(BaseModel):
70
- data: str
71
- append_newline: bool = True
72
-
73
-
74
- class SessionResizeRequest(BaseModel):
75
- cols: int
76
- rows: int
77
-
78
-
79
- @app.get("/health")
80
- def health() -> Dict[str, str]:
81
- return {"status": "ok"}
82
-
83
-
84
- @app.post("/reset")
85
- def reset(payload: ResetRequest) -> Dict[str, Any]:
86
- return _get_env().reset(task=payload.task, seed=payload.seed)
87
-
88
-
89
- @app.post("/step")
90
- def step(payload: StepRequest) -> Dict[str, Any]:
91
- e = _get_env()
92
- if payload.config_id is not None:
93
- return e.step({"config_id": payload.config_id})
94
- if payload.x is not None:
95
- return e.step({"x": payload.x})
96
- raise HTTPException(status_code=400, detail="Provide config_id or x.")
97
-
98
-
99
- @app.get("/state")
100
- def state() -> Dict[str, Any]:
101
- return _get_env().state()
102
-
103
-
104
- @app.get("/terminal/jobs")
105
- def terminal_jobs() -> Dict[str, Any]:
106
- return {"jobs": terminal_manager.list_jobs()}
107
-
108
-
109
- @app.post("/terminal/sessions")
110
- async def create_terminal_session(payload: SessionRequest) -> Dict[str, Any]:
111
- if payload.job_id not in ALLOWED_JOBS:
112
- raise HTTPException(status_code=404, detail=f"Unknown job_id: {payload.job_id}")
113
- session = await terminal_manager.ensure_session(payload.job_id, restart=payload.restart)
114
- return session.snapshot()
115
-
116
-
117
- @app.get("/terminal/sessions/{session_id}")
118
- def terminal_session_snapshot(session_id: str) -> Dict[str, Any]:
119
- session = terminal_manager.get_session(session_id)
120
- if session is None:
121
- raise HTTPException(status_code=404, detail="Session not found")
122
- return session.snapshot()
123
-
124
-
125
- @app.post("/terminal/sessions/{session_id}/input")
126
- def terminal_session_input(session_id: str, payload: SessionInputRequest) -> Dict[str, Any]:
127
- session = terminal_manager.get_session(session_id)
128
- if session is None:
129
- raise HTTPException(status_code=404, detail="Session not found")
130
- if not session.is_active:
131
- raise HTTPException(status_code=409, detail="Session is not running")
132
- session.write(payload.data, append_newline=payload.append_newline)
133
- return {"ok": True}
134
-
135
-
136
- @app.post("/terminal/sessions/{session_id}/resize")
137
- def terminal_session_resize(session_id: str, payload: SessionResizeRequest) -> Dict[str, Any]:
138
- session = terminal_manager.get_session(session_id)
139
- if session is None:
140
- raise HTTPException(status_code=404, detail="Session not found")
141
- session.resize(payload.cols, payload.rows)
142
- return {"ok": True}
143
-
144
-
145
- @app.post("/terminal/sessions/{session_id}/stop")
146
- def terminal_session_stop(session_id: str) -> Dict[str, Any]:
147
- session = terminal_manager.get_session(session_id)
148
- if session is None:
149
- raise HTTPException(status_code=404, detail="Session not found")
150
- session.interrupt()
151
- return {"ok": True}
152
-
153
-
154
- @app.websocket("/terminal/sessions/{session_id}/stream")
155
- async def terminal_session_stream(websocket: WebSocket, session_id: str) -> None:
156
- session = terminal_manager.get_session(session_id)
157
- if session is None:
158
- await websocket.close(code=4404)
159
- return
160
-
161
- await websocket.accept()
162
- queue = await session.subscribe()
163
- try:
164
- await websocket.send_json(session.snapshot())
165
- while True:
166
- event = await queue.get()
167
- await websocket.send_json(event)
168
- except WebSocketDisconnect:
169
- pass
170
- finally:
171
- session.unsubscribe(queue)
172
-
173
-
174
- if __name__ == "__main__":
175
- import uvicorn
176
-
177
- uvicorn.run("main:app", host="0.0.0.0", port=8000, reload=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
backend/requirements.txt DELETED
@@ -1,3 +0,0 @@
1
- fastapi
2
- uvicorn[standard]
3
- pydantic
 
 
 
 
backend/terminal_manager.py DELETED
@@ -1,342 +0,0 @@
1
- from __future__ import annotations
2
-
3
- import asyncio
4
- import importlib.util
5
- import json
6
- import os
7
- import pty
8
- import signal
9
- import struct
10
- import subprocess
11
- import termios
12
- import threading
13
- import time
14
- import uuid
15
- from dataclasses import dataclass
16
- from pathlib import Path
17
- from typing import Any
18
-
19
-
20
- REPO_ROOT = Path(__file__).resolve().parents[1]
21
- BUFFER_LIMIT = 160_000
22
- DEFAULT_COLS = 120
23
- DEFAULT_ROWS = 36
24
- PYTHON_CANDIDATES = (
25
- "/usr/local/bin/python3",
26
- "/opt/homebrew/bin/python3",
27
- "/Users/amannindra/miniconda3/bin/python3",
28
- )
29
-
30
-
31
- @dataclass(frozen=True)
32
- class AllowedJob:
33
- job_id: str
34
- label: str
35
- description: str
36
- command: tuple[str, ...]
37
- cwd: Path
38
-
39
- def as_dict(self) -> dict[str, Any]:
40
- return {
41
- "job_id": self.job_id,
42
- "label": self.label,
43
- "description": self.description,
44
- "command": list(self.command),
45
- "cwd": str(self.cwd),
46
- }
47
-
48
-
49
- ALLOWED_JOBS: dict[str, AllowedJob] = {
50
- "qwen": AllowedJob(
51
- job_id="qwen",
52
- label="Qwen Baseline",
53
- description="Runs the exact-kernel Qwen2.5-0.5B benchmark pipeline.",
54
- command=("bash", "scripts/run_qwen_05b_pipeline.sh"),
55
- cwd=REPO_ROOT,
56
- ),
57
- "rl-agent": AllowedJob(
58
- job_id="rl-agent",
59
- label="RL Agent",
60
- description="Runs the multi-family surrogate and runtime benchmark pipeline.",
61
- command=("bash", "scripts/run_full_pipeline.sh"),
62
- cwd=REPO_ROOT,
63
- ),
64
- }
65
-
66
-
67
- def _probe_python(path: str) -> dict[str, Any] | None:
68
- if not Path(path).exists():
69
- return None
70
-
71
- script = (
72
- "import importlib.util, json, sys; "
73
- "print(json.dumps({"
74
- "'executable': sys.executable, "
75
- "'torch': bool(importlib.util.find_spec('torch')), "
76
- "'triton': bool(importlib.util.find_spec('triton'))"
77
- "}))"
78
- )
79
- try:
80
- result = subprocess.run(
81
- [path, "-c", script],
82
- check=True,
83
- capture_output=True,
84
- text=True,
85
- )
86
- except (OSError, subprocess.CalledProcessError):
87
- return None
88
-
89
- try:
90
- payload = json.loads(result.stdout.strip())
91
- except json.JSONDecodeError:
92
- return None
93
- payload["path"] = path
94
- return payload
95
-
96
-
97
- def _best_python_runtime() -> dict[str, Any] | None:
98
- explicit = os.environ.get("TERMINAL_PYTHON_BIN")
99
- if explicit:
100
- probe = _probe_python(explicit)
101
- if probe is not None:
102
- probe["score"] = int(probe["torch"]) + int(probe["triton"])
103
- probe["explicit"] = True
104
- return probe
105
-
106
- best: dict[str, Any] | None = None
107
- for candidate in PYTHON_CANDIDATES:
108
- probe = _probe_python(candidate)
109
- if probe is None:
110
- continue
111
- score = int(probe["torch"]) + int(probe["triton"])
112
- probe["score"] = score
113
- if best is None or score > best["score"]:
114
- best = probe
115
- return best
116
-
117
-
118
- class TerminalSession:
119
- def __init__(self, job: AllowedJob, loop: asyncio.AbstractEventLoop) -> None:
120
- self.id = uuid.uuid4().hex
121
- self.job = job
122
- self.loop = loop
123
- self.created_at = time.time()
124
- self.started_at: float | None = None
125
- self.finished_at: float | None = None
126
- self.exit_code: int | None = None
127
- self.status = "starting"
128
- self.cols = DEFAULT_COLS
129
- self.rows = DEFAULT_ROWS
130
- self.python_runtime = _best_python_runtime()
131
-
132
- self._buffer = ""
133
- self._buffer_lock = threading.Lock()
134
- self._subscribers: set[asyncio.Queue[dict[str, Any]]] = set()
135
- self._subscriber_lock = threading.Lock()
136
-
137
- self._master_fd, slave_fd = pty.openpty()
138
- self._resize_fd(self.cols, self.rows)
139
-
140
- env = os.environ.copy()
141
- env.setdefault("TERM", "xterm-256color")
142
- env.setdefault("PYTHONUNBUFFERED", "1")
143
- env.setdefault("FORCE_COLOR", "1")
144
- if self.python_runtime is not None:
145
- python_dir = str(Path(self.python_runtime["path"]).parent)
146
- env["PATH"] = f"{python_dir}:{env.get('PATH', '')}"
147
- env["PYTHON_BIN"] = self.python_runtime["path"]
148
-
149
- self._append_buffer(self._launcher_banner())
150
-
151
- self.process = subprocess.Popen(
152
- self.job.command,
153
- cwd=str(self.job.cwd),
154
- stdin=slave_fd,
155
- stdout=slave_fd,
156
- stderr=slave_fd,
157
- env=env,
158
- preexec_fn=os.setsid,
159
- close_fds=True,
160
- )
161
- os.close(slave_fd)
162
-
163
- self.started_at = time.time()
164
- self.status = "running"
165
-
166
- self._reader_thread = threading.Thread(target=self._reader_loop, daemon=True)
167
- self._waiter_thread = threading.Thread(target=self._wait_loop, daemon=True)
168
- self._reader_thread.start()
169
- self._waiter_thread.start()
170
-
171
- @property
172
- def command_display(self) -> str:
173
- return " ".join(self.job.command)
174
-
175
- @property
176
- def is_active(self) -> bool:
177
- return self.process.poll() is None
178
-
179
- def snapshot(self) -> dict[str, Any]:
180
- with self._buffer_lock:
181
- buffer = self._buffer
182
- return {
183
- "type": "snapshot",
184
- "session": {
185
- "id": self.id,
186
- "job_id": self.job.job_id,
187
- "label": self.job.label,
188
- "description": self.job.description,
189
- "cwd": str(self.job.cwd),
190
- "command": self.command_display,
191
- "status": self.status,
192
- "created_at": self.created_at,
193
- "started_at": self.started_at,
194
- "finished_at": self.finished_at,
195
- "exit_code": self.exit_code,
196
- "cols": self.cols,
197
- "rows": self.rows,
198
- },
199
- "buffer": buffer,
200
- }
201
-
202
- async def subscribe(self) -> asyncio.Queue[dict[str, Any]]:
203
- queue: asyncio.Queue[dict[str, Any]] = asyncio.Queue()
204
- with self._subscriber_lock:
205
- self._subscribers.add(queue)
206
- return queue
207
-
208
- def unsubscribe(self, queue: asyncio.Queue[dict[str, Any]]) -> None:
209
- with self._subscriber_lock:
210
- self._subscribers.discard(queue)
211
-
212
- def write(self, data: str, append_newline: bool = True) -> None:
213
- if not data:
214
- return
215
- payload = data + ("\n" if append_newline else "")
216
- os.write(self._master_fd, payload.encode("utf-8", errors="replace"))
217
-
218
- def resize(self, cols: int, rows: int) -> None:
219
- self.cols = max(20, cols)
220
- self.rows = max(8, rows)
221
- self._resize_fd(self.cols, self.rows)
222
-
223
- def interrupt(self) -> None:
224
- if self.process.poll() is None:
225
- os.killpg(os.getpgid(self.process.pid), signal.SIGINT)
226
-
227
- def terminate(self) -> None:
228
- if self.process.poll() is None:
229
- os.killpg(os.getpgid(self.process.pid), signal.SIGTERM)
230
-
231
- def _resize_fd(self, cols: int, rows: int) -> None:
232
- winsize = struct.pack("HHHH", rows, cols, 0, 0)
233
- try:
234
- termios.tcsetwinsize(self._master_fd, (rows, cols))
235
- except AttributeError:
236
- pass
237
- try:
238
- import fcntl
239
-
240
- fcntl.ioctl(self._master_fd, termios.TIOCSWINSZ, winsize)
241
- except OSError:
242
- pass
243
-
244
- def _append_buffer(self, chunk: str) -> None:
245
- with self._buffer_lock:
246
- self._buffer = (self._buffer + chunk)[-BUFFER_LIMIT:]
247
-
248
- def _launcher_banner(self) -> str:
249
- lines = [
250
- f"[launcher] job: {self.job.label}",
251
- f"[launcher] cwd: {self.job.cwd}",
252
- f"[launcher] command: {self.command_display}",
253
- ]
254
- if self.python_runtime is not None:
255
- modules = []
256
- modules.append(f"torch={'yes' if self.python_runtime['torch'] else 'no'}")
257
- modules.append(f"triton={'yes' if self.python_runtime['triton'] else 'no'}")
258
- lines.append(f"[launcher] python3: {self.python_runtime['path']} ({', '.join(modules)})")
259
- if self.python_runtime.get("explicit"):
260
- lines.append("[launcher] python3 source: TERMINAL_PYTHON_BIN")
261
- if not self.python_runtime["triton"]:
262
- lines.append("[launcher] warning: Triton is not installed in the selected Python runtime.")
263
- else:
264
- lines.append("[launcher] warning: no preferred Python runtime detected; falling back to PATH lookup.")
265
- return "\n".join(lines) + "\n\n"
266
-
267
- def _publish(self, event: dict[str, Any]) -> None:
268
- with self._subscriber_lock:
269
- subscribers = tuple(self._subscribers)
270
- for queue in subscribers:
271
- self.loop.call_soon_threadsafe(self._safe_put, queue, event)
272
-
273
- @staticmethod
274
- def _safe_put(queue: asyncio.Queue[dict[str, Any]], event: dict[str, Any]) -> None:
275
- try:
276
- queue.put_nowait(event)
277
- except asyncio.QueueFull:
278
- pass
279
-
280
- def _reader_loop(self) -> None:
281
- while True:
282
- try:
283
- data = os.read(self._master_fd, 4096)
284
- except OSError:
285
- break
286
- if not data:
287
- break
288
- text = data.decode("utf-8", errors="replace")
289
- self._append_buffer(text)
290
- self._publish({"type": "output", "data": text})
291
-
292
- def _wait_loop(self) -> None:
293
- exit_code = self.process.wait()
294
- self.exit_code = exit_code
295
- self.finished_at = time.time()
296
- self.status = "exited" if exit_code == 0 else "failed"
297
- self._publish(
298
- {
299
- "type": "exit",
300
- "exit_code": exit_code,
301
- "status": self.status,
302
- "finished_at": self.finished_at,
303
- }
304
- )
305
- try:
306
- os.close(self._master_fd)
307
- except OSError:
308
- pass
309
-
310
-
311
- class TerminalManager:
312
- def __init__(self) -> None:
313
- self._sessions: dict[str, TerminalSession] = {}
314
- self._latest_by_job: dict[str, str] = {}
315
- self._lock = threading.Lock()
316
-
317
- def list_jobs(self) -> list[dict[str, Any]]:
318
- return [job.as_dict() for job in ALLOWED_JOBS.values()]
319
-
320
- def get_session(self, session_id: str) -> TerminalSession | None:
321
- with self._lock:
322
- return self._sessions.get(session_id)
323
-
324
- async def ensure_session(self, job_id: str, restart: bool = False) -> TerminalSession:
325
- if job_id not in ALLOWED_JOBS:
326
- raise KeyError(job_id)
327
-
328
- with self._lock:
329
- existing_id = self._latest_by_job.get(job_id)
330
- existing = self._sessions.get(existing_id) if existing_id else None
331
-
332
- if existing and existing.is_active and not restart:
333
- return existing
334
-
335
- if existing and restart:
336
- existing.interrupt()
337
-
338
- session = TerminalSession(ALLOWED_JOBS[job_id], asyncio.get_running_loop())
339
- with self._lock:
340
- self._sessions[session.id] = session
341
- self._latest_by_job[job_id] = session.id
342
- return session
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
client.py DELETED
@@ -1,77 +0,0 @@
1
- from __future__ import annotations
2
-
3
- import argparse
4
- import sys
5
- from pathlib import Path
6
- from typing import Any, Optional
7
-
8
- import requests
9
-
10
- ROOT = Path(__file__).resolve().parents[1]
11
- if str(ROOT) not in sys.path:
12
- sys.path.append(str(ROOT))
13
-
14
- from server.softmax_surrogate_environment import DEFAULT_BUDGET, SoftmaxSurrogateEnvironment
15
-
16
-
17
- class SoftmaxSurrogateEnvClient:
18
- def __init__(
19
- self,
20
- base_url: Optional[str] = None,
21
- measurement_path: str = "data/autotune_measurements.csv",
22
- budget: int = DEFAULT_BUDGET,
23
- seed: int = 0,
24
- ) -> None:
25
- self.base_url = base_url
26
- self._local_env = None
27
- if base_url is None:
28
- self._local_env = SoftmaxSurrogateEnvironment(
29
- measurement_path=measurement_path,
30
- budget=budget,
31
- seed=seed,
32
- )
33
-
34
- def reset(self, task: Optional[str] = None, seed: Optional[int] = None) -> dict:
35
- if self._local_env is not None:
36
- return self._local_env.reset(task=task, seed=seed)
37
- payload = {}
38
- if task is not None:
39
- payload["task"] = task
40
- if seed is not None:
41
- payload["seed"] = seed
42
- resp = requests.post(f"{self.base_url}/reset", json=payload, timeout=60)
43
- resp.raise_for_status()
44
- return resp.json()
45
-
46
- def step(self, action: Any) -> dict:
47
- if self._local_env is not None:
48
- return self._local_env.step(action)
49
- payload = action if isinstance(action, dict) else {"x": action}
50
- resp = requests.post(f"{self.base_url}/step", json=payload, timeout=120)
51
- resp.raise_for_status()
52
- return resp.json()
53
-
54
- def state(self) -> dict:
55
- if self._local_env is not None:
56
- return self._local_env.state()
57
- resp = requests.get(f"{self.base_url}/state", timeout=60)
58
- resp.raise_for_status()
59
- return resp.json()
60
-
61
-
62
- def parse_args() -> argparse.Namespace:
63
- parser = argparse.ArgumentParser()
64
- parser.add_argument("--remote", default=None, help="Optional base URL (e.g. http://127.0.0.1:8000)")
65
- parser.add_argument("--task", default=None)
66
- parser.add_argument("--seed", type=int, default=0)
67
- return parser.parse_args()
68
-
69
-
70
- def main() -> None:
71
- args = parse_args()
72
- client = SoftmaxSurrogateEnvClient(base_url=args.remote, seed=args.seed)
73
- print(client.reset(task=args.task))
74
-
75
-
76
- if __name__ == "__main__":
77
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
frontend/.gitignore DELETED
@@ -1,24 +0,0 @@
1
- # Logs
2
- logs
3
- *.log
4
- npm-debug.log*
5
- yarn-debug.log*
6
- yarn-error.log*
7
- pnpm-debug.log*
8
- lerna-debug.log*
9
-
10
- node_modules
11
- dist
12
- dist-ssr
13
- *.local
14
-
15
- # Editor directories and files
16
- .vscode/*
17
- !.vscode/extensions.json
18
- .idea
19
- .DS_Store
20
- *.suo
21
- *.ntvs*
22
- *.njsproj
23
- *.sln
24
- *.sw?
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
frontend/README.md DELETED
@@ -1,16 +0,0 @@
1
- # React + Vite
2
-
3
- This template provides a minimal setup to get React working in Vite with HMR and some ESLint rules.
4
-
5
- Currently, two official plugins are available:
6
-
7
- - [@vitejs/plugin-react](https://github.com/vitejs/vite-plugin-react/blob/main/packages/plugin-react) uses [Babel](https://babeljs.io/) (or [oxc](https://oxc.rs) when used in [rolldown-vite](https://vite.dev/guide/rolldown)) for Fast Refresh
8
- - [@vitejs/plugin-react-swc](https://github.com/vitejs/vite-plugin-react/blob/main/packages/plugin-react-swc) uses [SWC](https://swc.rs/) for Fast Refresh
9
-
10
- ## React Compiler
11
-
12
- The React Compiler is not enabled on this template because of its impact on dev & build performances. To add it, see [this documentation](https://react.dev/learn/react-compiler/installation).
13
-
14
- ## Expanding the ESLint configuration
15
-
16
- If you are developing a production application, we recommend using TypeScript with type-aware lint rules enabled. Check out the [TS template](https://github.com/vitejs/vite/tree/main/packages/create-vite/template-react-ts) for information on how to integrate TypeScript and [`typescript-eslint`](https://typescript-eslint.io) in your project.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
frontend/eslint.config.js DELETED
@@ -1,29 +0,0 @@
1
- import js from '@eslint/js'
2
- import globals from 'globals'
3
- import reactHooks from 'eslint-plugin-react-hooks'
4
- import reactRefresh from 'eslint-plugin-react-refresh'
5
- import { defineConfig, globalIgnores } from 'eslint/config'
6
-
7
- export default defineConfig([
8
- globalIgnores(['dist']),
9
- {
10
- files: ['**/*.{js,jsx}'],
11
- extends: [
12
- js.configs.recommended,
13
- reactHooks.configs.flat.recommended,
14
- reactRefresh.configs.vite,
15
- ],
16
- languageOptions: {
17
- ecmaVersion: 2020,
18
- globals: globals.browser,
19
- parserOptions: {
20
- ecmaVersion: 'latest',
21
- ecmaFeatures: { jsx: true },
22
- sourceType: 'module',
23
- },
24
- },
25
- rules: {
26
- 'no-unused-vars': ['error', { varsIgnorePattern: '^[A-Z_]' }],
27
- },
28
- },
29
- ])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
frontend/index.html DELETED
@@ -1,13 +0,0 @@
1
- <!doctype html>
2
- <html lang="en">
3
- <head>
4
- <meta charset="UTF-8" />
5
- <link rel="icon" type="image/svg+xml" href="/vite.svg" />
6
- <meta name="viewport" content="width=device-width, initial-scale=1.0" />
7
- <title>frontend</title>
8
- </head>
9
- <body>
10
- <div id="root"></div>
11
- <script type="module" src="/src/main.jsx"></script>
12
- </body>
13
- </html>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
frontend/package-lock.json DELETED
The diff for this file is too large to render. See raw diff
 
frontend/package.json DELETED
@@ -1,27 +0,0 @@
1
- {
2
- "name": "frontend",
3
- "private": true,
4
- "version": "0.0.0",
5
- "type": "module",
6
- "scripts": {
7
- "dev": "vite",
8
- "build": "vite build",
9
- "lint": "eslint .",
10
- "preview": "vite preview"
11
- },
12
- "dependencies": {
13
- "react": "^19.2.0",
14
- "react-dom": "^19.2.0"
15
- },
16
- "devDependencies": {
17
- "@eslint/js": "^9.39.1",
18
- "@types/react": "^19.2.7",
19
- "@types/react-dom": "^19.2.3",
20
- "@vitejs/plugin-react": "^5.1.1",
21
- "eslint": "^9.39.1",
22
- "eslint-plugin-react-hooks": "^7.0.1",
23
- "eslint-plugin-react-refresh": "^0.4.24",
24
- "globals": "^16.5.0",
25
- "vite": "^7.3.1"
26
- }
27
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
frontend/public/vite.svg DELETED
frontend/src/App.jsx DELETED
@@ -1,195 +0,0 @@
1
- import { useEffect, useRef, useState } from 'react'
2
- import TerminalPane from './components/TerminalPane'
3
-
4
- const panes = [
5
- { jobId: 'qwen', title: 'Qwen Baseline', tone: 'cyan' },
6
- { jobId: 'rl-agent', title: 'RL Agent', tone: 'green' },
7
- ]
8
-
9
- function formatMs(value) {
10
- if (value == null) {
11
- return '--'
12
- }
13
- if (value < 1000) {
14
- return `${Math.round(value)} ms`
15
- }
16
- return `${(value / 1000).toFixed(2)} s`
17
- }
18
-
19
- function buildRunStats(telemetry, issuedAt) {
20
- if (!telemetry || !issuedAt) {
21
- return {
22
- responseMs: null,
23
- completionMs: null,
24
- waiting: true,
25
- }
26
- }
27
-
28
- const responseMs = telemetry.lastOutputAt && telemetry.lastOutputAt >= issuedAt ? telemetry.lastOutputAt - issuedAt : null
29
- const finishedAt = telemetry.session?.finished_at ? telemetry.session.finished_at * 1000 : null
30
- const completionMs = finishedAt && finishedAt >= issuedAt ? finishedAt - issuedAt : null
31
-
32
- return {
33
- responseMs,
34
- completionMs,
35
- waiting: responseMs == null && completionMs == null,
36
- }
37
- }
38
-
39
- function App() {
40
- const [split, setSplit] = useState(50)
41
- const [dragging, setDragging] = useState(false)
42
- const [command, setCommand] = useState('')
43
- const [comparisonRun, setComparisonRun] = useState(null)
44
- const [telemetry, setTelemetry] = useState({
45
- qwen: null,
46
- 'rl-agent': null,
47
- })
48
- const workspaceRef = useRef(null)
49
- const leftPaneRef = useRef(null)
50
- const rightPaneRef = useRef(null)
51
-
52
- useEffect(() => {
53
- if (!dragging) {
54
- return undefined
55
- }
56
-
57
- const handlePointerMove = (event) => {
58
- const bounds = workspaceRef.current?.getBoundingClientRect()
59
- if (!bounds) {
60
- return
61
- }
62
-
63
- const next = ((event.clientX - bounds.left) / bounds.width) * 100
64
- const clamped = Math.min(75, Math.max(25, next))
65
- setSplit(clamped)
66
- }
67
-
68
- const handlePointerUp = () => {
69
- setDragging(false)
70
- }
71
-
72
- window.addEventListener('pointermove', handlePointerMove)
73
- window.addEventListener('pointerup', handlePointerUp)
74
-
75
- return () => {
76
- window.removeEventListener('pointermove', handlePointerMove)
77
- window.removeEventListener('pointerup', handlePointerUp)
78
- }
79
- }, [dragging])
80
-
81
- const handleBroadcast = async (event) => {
82
- event.preventDefault()
83
- const value = command.trim()
84
- if (!value) {
85
- return
86
- }
87
-
88
- const issuedAt = Date.now()
89
- setComparisonRun({
90
- command: value,
91
- issuedAt,
92
- })
93
-
94
- await Promise.allSettled([
95
- leftPaneRef.current?.submit(value),
96
- rightPaneRef.current?.submit(value),
97
- ])
98
- setCommand('')
99
- }
100
-
101
- const handleTelemetryChange = (payload) => {
102
- setTelemetry((previous) => ({
103
- ...previous,
104
- [payload.jobId]: payload,
105
- }))
106
- }
107
-
108
- const qwenStats = buildRunStats(telemetry.qwen, comparisonRun?.issuedAt)
109
- const agentStats = buildRunStats(telemetry['rl-agent'], comparisonRun?.issuedAt)
110
-
111
- let comparisonHeadline = 'Send a shared command to compare runtime.'
112
- if (comparisonRun) {
113
- if (qwenStats.completionMs != null && agentStats.completionMs != null) {
114
- const fasterJob = qwenStats.completionMs <= agentStats.completionMs ? panes[0].title : panes[1].title
115
- const delta = Math.abs(qwenStats.completionMs - agentStats.completionMs)
116
- comparisonHeadline = `${fasterJob} finished ${formatMs(delta)} faster.`
117
- } else if (qwenStats.responseMs != null && agentStats.responseMs != null) {
118
- const fasterJob = qwenStats.responseMs <= agentStats.responseMs ? panes[0].title : panes[1].title
119
- const delta = Math.abs(qwenStats.responseMs - agentStats.responseMs)
120
- comparisonHeadline = `${fasterJob} responded ${formatMs(delta)} faster.`
121
- } else {
122
- comparisonHeadline = `Running shared command: ${comparisonRun.command}`
123
- }
124
- }
125
-
126
- return (
127
- <main className="desktop">
128
- <div className="desktop__glow" />
129
-
130
- <section className="comparison-bar">
131
- <div className="comparison-bar__copy">
132
- <span className="comparison-bar__eyebrow">Runtime compare</span>
133
- <strong>{comparisonHeadline}</strong>
134
- <small>{comparisonRun ? `Command: ${comparisonRun.command}` : 'Broadcast one command to both panes.'}</small>
135
- </div>
136
-
137
- <div className="comparison-bar__stats">
138
- <article className="comparison-card comparison-card--cyan">
139
- <span>{panes[0].title}</span>
140
- <strong>{formatMs(qwenStats.completionMs ?? qwenStats.responseMs)}</strong>
141
- <small>{qwenStats.completionMs != null ? 'completion time' : 'first output latency'}</small>
142
- </article>
143
-
144
- <article className="comparison-card comparison-card--green">
145
- <span>{panes[1].title}</span>
146
- <strong>{formatMs(agentStats.completionMs ?? agentStats.responseMs)}</strong>
147
- <small>{agentStats.completionMs != null ? 'completion time' : 'first output latency'}</small>
148
- </article>
149
- </div>
150
- </section>
151
-
152
- <section ref={workspaceRef} className="workspace">
153
- <div className="workspace__pane" style={{ width: `${split}%` }}>
154
- <TerminalPane ref={leftPaneRef} {...panes[0]} onTelemetryChange={handleTelemetryChange} />
155
- </div>
156
-
157
- <button
158
- type="button"
159
- className={dragging ? 'workspace__divider is-dragging' : 'workspace__divider'}
160
- onPointerDown={() => setDragging(true)}
161
- aria-label="Resize terminal panes"
162
- aria-valuemin={25}
163
- aria-valuemax={75}
164
- aria-valuenow={Math.round(split)}
165
- aria-orientation="vertical"
166
- >
167
- <span />
168
- </button>
169
-
170
- <div className="workspace__pane" style={{ width: `${100 - split}%` }}>
171
- <TerminalPane ref={rightPaneRef} {...panes[1]} onTelemetryChange={handleTelemetryChange} />
172
- </div>
173
- </section>
174
-
175
- <form className="broadcast-bar" onSubmit={handleBroadcast}>
176
- <label className="broadcast-bar__label" htmlFor="broadcast-input">
177
- Shared input
178
- </label>
179
- <div className="broadcast-bar__field">
180
- <span className="broadcast-bar__prompt">$</span>
181
- <input
182
- id="broadcast-input"
183
- value={command}
184
- onChange={(event) => setCommand(event.target.value)}
185
- placeholder="Send the same command to both terminals"
186
- spellCheck="false"
187
- />
188
- <button type="submit">Send to both</button>
189
- </div>
190
- </form>
191
- </main>
192
- )
193
- }
194
-
195
- export default App
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
frontend/src/api/terminal.js DELETED
@@ -1,57 +0,0 @@
1
- const RAW_API_BASE = import.meta.env.VITE_API_BASE_URL || 'http://127.0.0.1:8000'
2
- const API_BASE = RAW_API_BASE.replace(/\/+$/, '')
3
- const WS_BASE = API_BASE.replace(/^http/, 'ws')
4
-
5
- async function request(path, options = {}) {
6
- const response = await fetch(`${API_BASE}${path}`, {
7
- headers: {
8
- 'Content-Type': 'application/json',
9
- ...(options.headers || {}),
10
- },
11
- ...options,
12
- })
13
-
14
- if (!response.ok) {
15
- let message = `Request failed with status ${response.status}`
16
- try {
17
- const payload = await response.json()
18
- message = payload.detail || message
19
- } catch {
20
- // Keep default error message when the payload is not JSON.
21
- }
22
- throw new Error(message)
23
- }
24
-
25
- return response.json()
26
- }
27
-
28
- export function createOrAttachSession(jobId, { restart = false } = {}) {
29
- return request('/terminal/sessions', {
30
- method: 'POST',
31
- body: JSON.stringify({ job_id: jobId, restart }),
32
- })
33
- }
34
-
35
- export function stopTerminalSession(sessionId) {
36
- return request(`/terminal/sessions/${sessionId}/stop`, {
37
- method: 'POST',
38
- })
39
- }
40
-
41
- export function sendTerminalInput(sessionId, data, appendNewline = true) {
42
- return request(`/terminal/sessions/${sessionId}/input`, {
43
- method: 'POST',
44
- body: JSON.stringify({ data, append_newline: appendNewline }),
45
- })
46
- }
47
-
48
- export function resizeTerminalSession(sessionId, cols, rows) {
49
- return request(`/terminal/sessions/${sessionId}/resize`, {
50
- method: 'POST',
51
- body: JSON.stringify({ cols, rows }),
52
- })
53
- }
54
-
55
- export function openTerminalSocket(sessionId) {
56
- return new WebSocket(`${WS_BASE}/terminal/sessions/${sessionId}/stream`)
57
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
frontend/src/assets/react.svg DELETED
frontend/src/components/TerminalPane.jsx DELETED
@@ -1,147 +0,0 @@
1
- import { forwardRef, useEffect, useImperativeHandle, useMemo, useRef } from 'react'
2
- import { useTerminalSession } from '../hooks/useTerminalSession'
3
-
4
- function formatTime(timestamp) {
5
- if (!timestamp) {
6
- return 'Idle'
7
- }
8
- return new Date(timestamp * 1000).toLocaleTimeString([], {
9
- hour: '2-digit',
10
- minute: '2-digit',
11
- second: '2-digit',
12
- })
13
- }
14
-
15
- function statusLabel(status) {
16
- if (status === 'running') {
17
- return 'Running'
18
- }
19
- if (status === 'failed') {
20
- return 'Failed'
21
- }
22
- if (status === 'exited') {
23
- return 'Completed'
24
- }
25
- return 'Starting'
26
- }
27
-
28
- const TerminalPane = forwardRef(function TerminalPane({ jobId, title, tone, onTelemetryChange }, ref) {
29
- const { session, buffer, connectionState, error, lastOutputAt, restart, resize, sendInput, start, stop } =
30
- useTerminalSession(jobId)
31
- const viewportRef = useRef(null)
32
- const scrollRef = useRef(null)
33
-
34
- useImperativeHandle(
35
- ref,
36
- () => ({
37
- submit: async (value) => {
38
- await sendInput(value, true)
39
- },
40
- }),
41
- [sendInput],
42
- )
43
-
44
- useEffect(() => {
45
- onTelemetryChange?.({
46
- jobId,
47
- session,
48
- connectionState,
49
- error,
50
- lastOutputAt,
51
- })
52
- }, [connectionState, error, jobId, lastOutputAt, onTelemetryChange, session])
53
-
54
- useEffect(() => {
55
- const container = scrollRef.current
56
- if (container) {
57
- container.scrollTop = container.scrollHeight
58
- }
59
- }, [buffer])
60
-
61
- useEffect(() => {
62
- const element = viewportRef.current
63
- if (!element) {
64
- return undefined
65
- }
66
-
67
- let frameId = 0
68
- const measure = () => {
69
- cancelAnimationFrame(frameId)
70
- frameId = requestAnimationFrame(() => {
71
- const style = getComputedStyle(element)
72
- const fontSize = parseFloat(style.fontSize) || 15
73
- const lineHeight = parseFloat(style.lineHeight) || 24
74
- const cols = Math.max(48, Math.floor(element.clientWidth / (fontSize * 0.61)))
75
- const rows = Math.max(14, Math.floor(element.clientHeight / lineHeight))
76
- resize(cols, rows)
77
- })
78
- }
79
-
80
- measure()
81
- const observer = new ResizeObserver(measure)
82
- observer.observe(element)
83
-
84
- return () => {
85
- cancelAnimationFrame(frameId)
86
- observer.disconnect()
87
- }
88
- }, [resize])
89
-
90
- const footerMeta = useMemo(
91
- () => [
92
- session?.status ? statusLabel(session.status) : 'Connecting',
93
- session?.started_at ? `Started ${formatTime(session.started_at)}` : null,
94
- session?.exit_code != null ? `Exit ${session.exit_code}` : null,
95
- connectionState === 'connected' ? 'WS live' : connectionState,
96
- ].filter(Boolean),
97
- [connectionState, session],
98
- )
99
-
100
- return (
101
- <article className={`terminal-pane terminal-pane--${tone}`}>
102
- <header className="terminal-pane__header">
103
- <div className="terminal-pane__heading">
104
- <div className="terminal-pane__title-row">
105
- <span className="terminal-pane__dot" />
106
- <h2>{title}</h2>
107
- <span className={`status-chip status-chip--${session?.status || 'starting'}`}>
108
- {statusLabel(session?.status)}
109
- </span>
110
- </div>
111
- <p>{session?.command || 'Waiting for backend session...'}</p>
112
- <small>{session?.cwd || 'No working directory available yet.'}</small>
113
- </div>
114
-
115
- <div className="terminal-pane__actions">
116
- <button type="button" onClick={start}>
117
- Attach
118
- </button>
119
- <button type="button" onClick={restart}>
120
- Restart
121
- </button>
122
- <button type="button" onClick={stop}>
123
- Stop
124
- </button>
125
- </div>
126
- </header>
127
-
128
- <div ref={viewportRef} className="terminal-pane__viewport">
129
- <div ref={scrollRef} className="terminal-pane__scroll">
130
- <pre className="terminal-pane__buffer">{buffer || 'Starting session...\n'}</pre>
131
- {session?.status === 'running' ? <span className="terminal-pane__cursor" aria-hidden="true" /> : null}
132
- </div>
133
- </div>
134
-
135
- <footer className="terminal-pane__footer">
136
- <div className="terminal-pane__meta">
137
- {footerMeta.map((item) => (
138
- <span key={item}>{item}</span>
139
- ))}
140
- {error ? <span className="terminal-pane__error">{error}</span> : null}
141
- </div>
142
- </footer>
143
- </article>
144
- )
145
- })
146
-
147
- export default TerminalPane
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
frontend/src/hooks/useTerminalSession.js DELETED
@@ -1,166 +0,0 @@
1
- import { useCallback, useEffect, useRef, useState } from 'react'
2
- import {
3
- createOrAttachSession,
4
- openTerminalSocket,
5
- resizeTerminalSession,
6
- sendTerminalInput,
7
- stopTerminalSession,
8
- } from '../api/terminal'
9
-
10
- const BUFFER_LIMIT = 160000
11
-
12
- function trimBuffer(text) {
13
- return text.length > BUFFER_LIMIT ? text.slice(-BUFFER_LIMIT) : text
14
- }
15
-
16
- export function useTerminalSession(jobId) {
17
- const [session, setSession] = useState(null)
18
- const [buffer, setBuffer] = useState('')
19
- const [connectionState, setConnectionState] = useState('connecting')
20
- const [error, setError] = useState('')
21
- const [lastOutputAt, setLastOutputAt] = useState(null)
22
-
23
- const socketRef = useRef(null)
24
- const resizeRef = useRef({ cols: null, rows: null })
25
-
26
- const attachSocket = useCallback((sessionId) => {
27
- if (socketRef.current) {
28
- socketRef.current.close()
29
- }
30
-
31
- const socket = openTerminalSocket(sessionId)
32
- socketRef.current = socket
33
- setConnectionState('connecting')
34
-
35
- socket.addEventListener('open', () => {
36
- setConnectionState('connected')
37
- })
38
-
39
- socket.addEventListener('message', (event) => {
40
- const payload = JSON.parse(event.data)
41
-
42
- if (payload.type === 'snapshot') {
43
- setSession(payload.session)
44
- setBuffer(payload.buffer || '')
45
- return
46
- }
47
-
48
- if (payload.type === 'output') {
49
- setLastOutputAt(Date.now())
50
- setBuffer((previous) => trimBuffer(previous + payload.data))
51
- return
52
- }
53
-
54
- if (payload.type === 'exit') {
55
- setSession((previous) =>
56
- previous
57
- ? {
58
- ...previous,
59
- status: payload.status,
60
- exit_code: payload.exit_code,
61
- finished_at: payload.finished_at,
62
- }
63
- : previous,
64
- )
65
- }
66
- })
67
-
68
- socket.addEventListener('close', () => {
69
- setConnectionState('disconnected')
70
- })
71
-
72
- socket.addEventListener('error', () => {
73
- setConnectionState('error')
74
- })
75
- }, [])
76
-
77
- const bootSession = useCallback(
78
- async (restart = false) => {
79
- try {
80
- setError('')
81
- const payload = await createOrAttachSession(jobId, { restart })
82
- setSession(payload.session)
83
- setBuffer(payload.buffer || '')
84
- attachSocket(payload.session.id)
85
- } catch (caughtError) {
86
- setError(caughtError.message)
87
- setConnectionState('error')
88
- }
89
- },
90
- [attachSocket, jobId],
91
- )
92
-
93
- useEffect(() => {
94
- const timeoutId = window.setTimeout(() => {
95
- void bootSession(false)
96
- }, 0)
97
-
98
- return () => {
99
- window.clearTimeout(timeoutId)
100
- if (socketRef.current) {
101
- socketRef.current.close()
102
- }
103
- }
104
- }, [bootSession])
105
-
106
- const restart = useCallback(() => bootSession(true), [bootSession])
107
-
108
- const stop = useCallback(async () => {
109
- if (!session?.id) {
110
- return
111
- }
112
- try {
113
- await stopTerminalSession(session.id)
114
- } catch (caughtError) {
115
- setError(caughtError.message)
116
- }
117
- }, [session])
118
-
119
- const sendInput = useCallback(
120
- async (value, appendNewline = true) => {
121
- if (!session?.id || !value.trim()) {
122
- return
123
- }
124
- try {
125
- await sendTerminalInput(session.id, value, appendNewline)
126
- } catch (caughtError) {
127
- setError(caughtError.message)
128
- }
129
- },
130
- [session],
131
- )
132
-
133
- const resize = useCallback(
134
- async (cols, rows) => {
135
- if (!session?.id) {
136
- return
137
- }
138
-
139
- const previous = resizeRef.current
140
- if (previous.cols === cols && previous.rows === rows) {
141
- return
142
- }
143
- resizeRef.current = { cols, rows }
144
-
145
- try {
146
- await resizeTerminalSession(session.id, cols, rows)
147
- } catch {
148
- // Ignore resize errors so rendering stays responsive.
149
- }
150
- },
151
- [session],
152
- )
153
-
154
- return {
155
- buffer,
156
- connectionState,
157
- error,
158
- lastOutputAt,
159
- restart,
160
- resize,
161
- sendInput,
162
- session,
163
- start: () => bootSession(false),
164
- stop,
165
- }
166
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
frontend/src/index.css DELETED
@@ -1,486 +0,0 @@
1
- :root {
2
- color: #ebeff5;
3
- background:
4
- radial-gradient(circle at top, rgba(88, 108, 146, 0.16), transparent 24%),
5
- linear-gradient(180deg, #0d1015 0%, #090b0f 100%);
6
- font-synthesis: none;
7
- text-rendering: optimizeLegibility;
8
- -webkit-font-smoothing: antialiased;
9
- -moz-osx-font-smoothing: grayscale;
10
- --page: #090b0f;
11
- --pane: rgba(14, 17, 22, 0.98);
12
- --pane-border: rgba(255, 255, 255, 0.07);
13
- --pane-soft: rgba(255, 255, 255, 0.03);
14
- --text: #ebeff5;
15
- --muted: #8f97a6;
16
- --cyan: #8fd1ff;
17
- --green: #93e09f;
18
- --red: #ff7f8c;
19
- --amber: #e6c171;
20
- }
21
-
22
- * {
23
- box-sizing: border-box;
24
- }
25
-
26
- html,
27
- body,
28
- #root {
29
- min-height: 100%;
30
- }
31
-
32
- body {
33
- margin: 0;
34
- min-width: 320px;
35
- background: var(--page);
36
- color: var(--text);
37
- font-family: "IBM Plex Sans", "SF Pro Display", "Segoe UI", sans-serif;
38
- }
39
-
40
- button,
41
- input {
42
- font: inherit;
43
- }
44
-
45
- button {
46
- cursor: pointer;
47
- }
48
-
49
- .desktop {
50
- position: relative;
51
- min-height: 100vh;
52
- overflow: hidden;
53
- }
54
-
55
- .desktop__glow {
56
- position: absolute;
57
- inset: 0;
58
- background:
59
- radial-gradient(circle at 30% 0%, rgba(109, 132, 180, 0.14), transparent 20%),
60
- radial-gradient(circle at 100% 0%, rgba(70, 110, 168, 0.12), transparent 16%);
61
- pointer-events: none;
62
- }
63
-
64
- .workspace {
65
- position: relative;
66
- z-index: 1;
67
- display: flex;
68
- width: 100vw;
69
- height: calc(100vh - 194px);
70
- background: rgba(7, 10, 13, 0.98);
71
- }
72
-
73
- .comparison-bar {
74
- position: relative;
75
- z-index: 1;
76
- display: flex;
77
- align-items: stretch;
78
- justify-content: space-between;
79
- gap: 18px;
80
- padding: 14px 18px;
81
- border-bottom: 1px solid var(--pane-border);
82
- background:
83
- linear-gradient(180deg, rgba(16, 19, 25, 0.98), rgba(11, 13, 18, 0.98));
84
- }
85
-
86
- .comparison-bar__copy {
87
- display: grid;
88
- gap: 5px;
89
- min-width: 0;
90
- }
91
-
92
- .comparison-bar__copy strong,
93
- .comparison-bar__copy small {
94
- overflow: hidden;
95
- text-overflow: ellipsis;
96
- white-space: nowrap;
97
- }
98
-
99
- .comparison-bar__eyebrow {
100
- color: var(--muted);
101
- font:
102
- 500 0.74rem/1 "JetBrains Mono", "SFMono-Regular", Consolas, monospace;
103
- text-transform: uppercase;
104
- letter-spacing: 0.14em;
105
- }
106
-
107
- .comparison-bar__copy strong {
108
- font-size: 1rem;
109
- font-weight: 600;
110
- }
111
-
112
- .comparison-bar__copy small {
113
- color: var(--muted);
114
- }
115
-
116
- .comparison-bar__stats {
117
- display: flex;
118
- gap: 12px;
119
- }
120
-
121
- .comparison-card {
122
- display: grid;
123
- gap: 4px;
124
- min-width: 170px;
125
- padding: 12px 14px;
126
- border: 1px solid var(--pane-border);
127
- border-radius: 14px;
128
- background: rgba(255, 255, 255, 0.03);
129
- }
130
-
131
- .comparison-card span,
132
- .comparison-card small {
133
- color: var(--muted);
134
- }
135
-
136
- .comparison-card strong {
137
- font:
138
- 600 1.2rem/1.1 "JetBrains Mono", "SFMono-Regular", Consolas, monospace;
139
- }
140
-
141
- .comparison-card--cyan strong {
142
- color: var(--cyan);
143
- }
144
-
145
- .comparison-card--green strong {
146
- color: var(--green);
147
- }
148
-
149
- .workspace__pane {
150
- min-width: 0;
151
- height: 100%;
152
- }
153
-
154
- .workspace__divider {
155
- position: relative;
156
- flex: 0 0 12px;
157
- border: 0;
158
- padding: 0;
159
- background: linear-gradient(180deg, rgba(255, 255, 255, 0.04), rgba(255, 255, 255, 0.015));
160
- cursor: col-resize;
161
- }
162
-
163
- .workspace__divider span {
164
- position: absolute;
165
- top: 50%;
166
- left: 50%;
167
- width: 3px;
168
- height: 72px;
169
- border-radius: 999px;
170
- background: rgba(255, 255, 255, 0.22);
171
- transform: translate(-50%, -50%);
172
- }
173
-
174
- .workspace__divider:hover span,
175
- .workspace__divider.is-dragging span {
176
- background: rgba(143, 209, 255, 0.8);
177
- }
178
-
179
- .workspace__divider:focus-visible {
180
- outline: none;
181
- box-shadow: inset 0 0 0 1px rgba(143, 209, 255, 0.5);
182
- }
183
-
184
- .terminal-pane {
185
- display: grid;
186
- grid-template-rows: auto minmax(0, 1fr) auto;
187
- width: 100%;
188
- height: 100%;
189
- background:
190
- linear-gradient(180deg, rgba(255, 255, 255, 0.015), transparent 10%),
191
- var(--pane);
192
- }
193
-
194
- .terminal-pane + .terminal-pane {
195
- border-left: 1px solid var(--pane-border);
196
- }
197
-
198
- .terminal-pane__header {
199
- display: flex;
200
- align-items: flex-start;
201
- justify-content: space-between;
202
- gap: 18px;
203
- padding: 18px 18px 14px;
204
- border-bottom: 1px solid var(--pane-border);
205
- background: rgba(255, 255, 255, 0.02);
206
- }
207
-
208
- .terminal-pane__heading {
209
- min-width: 0;
210
- }
211
-
212
- .terminal-pane__title-row {
213
- display: flex;
214
- align-items: center;
215
- gap: 10px;
216
- margin-bottom: 6px;
217
- }
218
-
219
- .terminal-pane__title-row h2 {
220
- margin: 0;
221
- font-size: 1rem;
222
- font-weight: 600;
223
- letter-spacing: 0.01em;
224
- }
225
-
226
- .terminal-pane__heading p,
227
- .terminal-pane__heading small {
228
- display: block;
229
- margin: 0;
230
- overflow: hidden;
231
- text-overflow: ellipsis;
232
- white-space: nowrap;
233
- }
234
-
235
- .terminal-pane__heading p {
236
- color: var(--text);
237
- font:
238
- 500 0.84rem/1.4 "JetBrains Mono", "SFMono-Regular", Consolas, monospace;
239
- }
240
-
241
- .terminal-pane__heading small {
242
- margin-top: 4px;
243
- color: var(--muted);
244
- }
245
-
246
- .terminal-pane__dot {
247
- width: 9px;
248
- height: 9px;
249
- border-radius: 999px;
250
- background: var(--cyan);
251
- box-shadow: 0 0 24px rgba(143, 209, 255, 0.35);
252
- }
253
-
254
- .terminal-pane--green .terminal-pane__dot {
255
- background: var(--green);
256
- box-shadow: 0 0 24px rgba(147, 224, 159, 0.35);
257
- }
258
-
259
- .terminal-pane__actions {
260
- display: flex;
261
- gap: 8px;
262
- }
263
-
264
- .terminal-pane__actions button {
265
- padding: 8px 10px;
266
- border: 1px solid var(--pane-border);
267
- border-radius: 10px;
268
- background: rgba(255, 255, 255, 0.03);
269
- color: var(--text);
270
- }
271
-
272
- .terminal-pane__actions button:hover {
273
- background: rgba(255, 255, 255, 0.06);
274
- }
275
-
276
- .status-chip {
277
- padding: 4px 8px;
278
- border-radius: 999px;
279
- font-size: 0.7rem;
280
- text-transform: uppercase;
281
- letter-spacing: 0.12em;
282
- border: 1px solid transparent;
283
- }
284
-
285
- .status-chip--running {
286
- color: var(--green);
287
- border-color: rgba(147, 224, 159, 0.24);
288
- background: rgba(147, 224, 159, 0.08);
289
- }
290
-
291
- .status-chip--failed {
292
- color: var(--red);
293
- border-color: rgba(255, 127, 140, 0.25);
294
- background: rgba(255, 127, 140, 0.08);
295
- }
296
-
297
- .status-chip--exited {
298
- color: var(--amber);
299
- border-color: rgba(230, 193, 113, 0.24);
300
- background: rgba(230, 193, 113, 0.08);
301
- }
302
-
303
- .status-chip--starting {
304
- color: var(--cyan);
305
- border-color: rgba(143, 209, 255, 0.24);
306
- background: rgba(143, 209, 255, 0.08);
307
- }
308
-
309
- .terminal-pane__viewport {
310
- min-height: 0;
311
- overflow: hidden;
312
- font:
313
- 500 0.95rem/1.72 "JetBrains Mono", "SFMono-Regular", Consolas, monospace;
314
- }
315
-
316
- .terminal-pane__scroll {
317
- height: 100%;
318
- overflow: auto;
319
- padding: 18px 18px 20px;
320
- }
321
-
322
- .terminal-pane__buffer {
323
- margin: 0;
324
- white-space: pre-wrap;
325
- word-break: break-word;
326
- color: #edf2fa;
327
- }
328
-
329
- .terminal-pane__cursor {
330
- display: inline-block;
331
- width: 0.62em;
332
- height: 1.05em;
333
- margin-top: 4px;
334
- background: rgba(237, 242, 250, 0.9);
335
- animation: blink 1s steps(1) infinite;
336
- }
337
-
338
- .terminal-pane__footer {
339
- display: grid;
340
- padding: 14px 18px 16px;
341
- border-top: 1px solid var(--pane-border);
342
- background: rgba(255, 255, 255, 0.02);
343
- }
344
-
345
- .terminal-pane__meta {
346
- display: flex;
347
- flex-wrap: wrap;
348
- gap: 10px 14px;
349
- color: var(--muted);
350
- font:
351
- 500 0.76rem/1.3 "JetBrains Mono", "SFMono-Regular", Consolas, monospace;
352
- }
353
-
354
- .terminal-pane__error {
355
- color: var(--red);
356
- }
357
-
358
- .broadcast-bar {
359
- position: relative;
360
- z-index: 1;
361
- display: grid;
362
- gap: 8px;
363
- padding: 14px 18px 18px;
364
- border-top: 1px solid var(--pane-border);
365
- background:
366
- linear-gradient(180deg, rgba(18, 21, 27, 0.98), rgba(10, 12, 16, 0.98));
367
- }
368
-
369
- .broadcast-bar__label {
370
- color: var(--muted);
371
- font:
372
- 500 0.78rem/1 "JetBrains Mono", "SFMono-Regular", Consolas, monospace;
373
- text-transform: uppercase;
374
- letter-spacing: 0.12em;
375
- }
376
-
377
- .broadcast-bar__field {
378
- display: flex;
379
- align-items: center;
380
- gap: 12px;
381
- padding: 14px 16px;
382
- border: 1px solid var(--pane-border);
383
- border-radius: 14px;
384
- background: rgba(255, 255, 255, 0.03);
385
- box-shadow: inset 0 1px 0 rgba(255, 255, 255, 0.02);
386
- }
387
-
388
- .broadcast-bar__prompt {
389
- color: var(--cyan);
390
- font:
391
- 600 0.92rem/1 "JetBrains Mono", "SFMono-Regular", Consolas, monospace;
392
- }
393
-
394
- .broadcast-bar__field input {
395
- flex: 1;
396
- min-width: 0;
397
- border: 0;
398
- background: transparent;
399
- color: var(--text);
400
- outline: none;
401
- font:
402
- 500 0.95rem/1.2 "JetBrains Mono", "SFMono-Regular", Consolas, monospace;
403
- }
404
-
405
- .broadcast-bar__field input::placeholder {
406
- color: var(--muted);
407
- }
408
-
409
- .broadcast-bar__field button {
410
- padding: 10px 14px;
411
- border: 1px solid rgba(143, 209, 255, 0.18);
412
- border-radius: 10px;
413
- background: rgba(143, 209, 255, 0.08);
414
- color: var(--text);
415
- }
416
-
417
- .broadcast-bar__field button:hover {
418
- background: rgba(143, 209, 255, 0.14);
419
- }
420
-
421
- @keyframes blink {
422
- 50% {
423
- opacity: 0;
424
- }
425
- }
426
-
427
- @media (max-width: 980px) {
428
- .comparison-bar {
429
- flex-direction: column;
430
- }
431
-
432
- .comparison-bar__stats {
433
- width: 100%;
434
- }
435
-
436
- .comparison-card {
437
- flex: 1;
438
- min-width: 0;
439
- }
440
-
441
- .terminal-pane__header {
442
- flex-direction: column;
443
- }
444
-
445
- .terminal-pane__actions {
446
- width: 100%;
447
- }
448
-
449
- .terminal-pane__actions button {
450
- flex: 1;
451
- }
452
- }
453
-
454
- @media (max-width: 780px) {
455
- .workspace {
456
- height: calc(100vh - 244px);
457
- }
458
-
459
- .workspace__divider {
460
- flex-basis: 10px;
461
- }
462
-
463
- .terminal-pane__scroll {
464
- padding: 16px;
465
- }
466
-
467
- .terminal-pane__viewport {
468
- font-size: 0.88rem;
469
- }
470
-
471
- .terminal-pane__footer {
472
- padding: 12px 14px 14px;
473
- }
474
-
475
- .broadcast-bar {
476
- padding: 12px 14px 14px;
477
- }
478
-
479
- .broadcast-bar__field {
480
- padding: 12px 14px;
481
- }
482
-
483
- .comparison-bar {
484
- padding: 12px 14px;
485
- }
486
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
frontend/src/main.jsx DELETED
@@ -1,10 +0,0 @@
1
- import { StrictMode } from 'react'
2
- import { createRoot } from 'react-dom/client'
3
- import './index.css'
4
- import App from './App.jsx'
5
-
6
- createRoot(document.getElementById('root')).render(
7
- <StrictMode>
8
- <App />
9
- </StrictMode>,
10
- )
 
 
 
 
 
 
 
 
 
 
 
frontend/vite.config.js DELETED
@@ -1,7 +0,0 @@
1
- import { defineConfig } from 'vite'
2
- import react from '@vitejs/plugin-react'
3
-
4
- // https://vite.dev/config/
5
- export default defineConfig({
6
- plugins: [react()],
7
- })
 
 
 
 
 
 
 
 
index.html ADDED
@@ -0,0 +1,317 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!DOCTYPE html>
2
+ <html lang="en">
3
+ <head>
4
+ <meta charset="UTF-8">
5
+ <meta name="viewport" content="width=device-width, initial-scale=1.0">
6
+ <title>RL Surrogate Kernel Autotuning</title>
7
+ <link rel="stylesheet" href="https://fonts.googleapis.com/css2?family=Inter:wght@300;400;500;600;700&family=JetBrains+Mono:wght@400;500&display=swap">
8
+ <style>
9
+ :root {
10
+ --primary: #2563eb;
11
+ --primary-dark: #1d4ed8;
12
+ --secondary: #10b981;
13
+ --accent: #f59e0b;
14
+ --bg: #0f172a;
15
+ --bg-card: #1e293b;
16
+ --text: #f1f5f9;
17
+ --text-muted: #94a3b8;
18
+ --border: #334155;
19
+ }
20
+ * {
21
+ margin: 0;
22
+ padding: 0;
23
+ box-sizing: border-box;
24
+ }
25
+ body {
26
+ font-family: 'Inter', -apple-system, BlinkMacSystemFont, sans-serif;
27
+ background: var(--bg);
28
+ color: var(--text);
29
+ line-height: 1.6;
30
+ min-height: 100vh;
31
+ }
32
+ .container {
33
+ max-width: 1200px;
34
+ margin: 0 auto;
35
+ padding: 2rem;
36
+ }
37
+ header {
38
+ text-align: center;
39
+ padding: 4rem 0;
40
+ position: relative;
41
+ }
42
+ header::before {
43
+ content: '';
44
+ position: absolute;
45
+ top: 0;
46
+ left: 50%;
47
+ transform: translateX(-50%);
48
+ width: 600px;
49
+ height: 600px;
50
+ background: radial-gradient(circle, rgba(37, 99, 235, 0.15) 0%, transparent 70%);
51
+ pointer-events: none;
52
+ }
53
+ h1 {
54
+ font-size: 3rem;
55
+ font-weight: 700;
56
+ margin-bottom: 1rem;
57
+ background: linear-gradient(135deg, #60a5fa 0%, #a78bfa 100%);
58
+ -webkit-background-clip: text;
59
+ -webkit-text-fill-color: transparent;
60
+ background-clip: text;
61
+ }
62
+ .subtitle {
63
+ font-size: 1.25rem;
64
+ color: var(--text-muted);
65
+ max-width: 600px;
66
+ margin: 0 auto;
67
+ }
68
+ .badge {
69
+ display: inline-block;
70
+ padding: 0.5rem 1rem;
71
+ background: var(--primary);
72
+ border-radius: 9999px;
73
+ font-size: 0.875rem;
74
+ font-weight: 500;
75
+ margin-top: 1.5rem;
76
+ }
77
+ .grid {
78
+ display: grid;
79
+ grid-template-columns: repeat(auto-fit, minmax(300px, 1fr));
80
+ gap: 1.5rem;
81
+ margin: 3rem 0;
82
+ }
83
+ .card {
84
+ background: var(--bg-card);
85
+ border: 1px solid var(--border);
86
+ border-radius: 1rem;
87
+ padding: 1.5rem;
88
+ transition: transform 0.2s, box-shadow 0.2s;
89
+ }
90
+ .card:hover {
91
+ transform: translateY(-2px);
92
+ box-shadow: 0 10px 40px rgba(0, 0, 0, 0.3);
93
+ }
94
+ .card-icon {
95
+ width: 48px;
96
+ height: 48px;
97
+ border-radius: 0.75rem;
98
+ display: flex;
99
+ align-items: center;
100
+ justify-content: center;
101
+ font-size: 1.5rem;
102
+ margin-bottom: 1rem;
103
+ }
104
+ .card-icon.blue { background: rgba(37, 99, 235, 0.2); }
105
+ .card-icon.green { background: rgba(16, 185, 129, 0.2); }
106
+ .card-icon.amber { background: rgba(245, 158, 11, 0.2); }
107
+ .card-icon.purple { background: rgba(139, 92, 246, 0.2); }
108
+ .card h3 {
109
+ font-size: 1.125rem;
110
+ font-weight: 600;
111
+ margin-bottom: 0.5rem;
112
+ }
113
+ .card p {
114
+ color: var(--text-muted);
115
+ font-size: 0.875rem;
116
+ }
117
+ .features {
118
+ margin: 3rem 0;
119
+ }
120
+ .features h2 {
121
+ font-size: 1.75rem;
122
+ margin-bottom: 1.5rem;
123
+ text-align: center;
124
+ }
125
+ .feature-list {
126
+ display: grid;
127
+ grid-template-columns: repeat(auto-fit, minmax(280px, 1fr));
128
+ gap: 1rem;
129
+ }
130
+ .feature-item {
131
+ display: flex;
132
+ align-items: center;
133
+ gap: 0.75rem;
134
+ padding: 1rem;
135
+ background: var(--bg-card);
136
+ border-radius: 0.75rem;
137
+ border: 1px solid var(--border);
138
+ }
139
+ .feature-item span {
140
+ color: var(--secondary);
141
+ font-size: 1.25rem;
142
+ }
143
+ .code-block {
144
+ background: #0d1117;
145
+ border: 1px solid var(--border);
146
+ border-radius: 0.75rem;
147
+ padding: 1.5rem;
148
+ margin: 2rem 0;
149
+ overflow-x: auto;
150
+ }
151
+ .code-block pre {
152
+ font-family: 'JetBrains Mono', monospace;
153
+ font-size: 0.875rem;
154
+ color: #e2e8f0;
155
+ white-space: pre;
156
+ }
157
+ .code-block .comment { color: #6b7280; }
158
+ .code-block .keyword { color: #c084fc; }
159
+ .code-block .string { color: #34d399; }
160
+ .code-block .number { color: #f472b6; }
161
+ .cta {
162
+ text-align: center;
163
+ padding: 3rem;
164
+ background: var(--bg-card);
165
+ border-radius: 1rem;
166
+ border: 1px solid var(--border);
167
+ margin-top: 3rem;
168
+ }
169
+ .cta h2 {
170
+ font-size: 1.5rem;
171
+ margin-bottom: 1rem;
172
+ }
173
+ .cta p {
174
+ color: var(--text-muted);
175
+ margin-bottom: 1.5rem;
176
+ }
177
+ .btn {
178
+ display: inline-block;
179
+ padding: 0.75rem 1.5rem;
180
+ background: var(--primary);
181
+ color: white;
182
+ text-decoration: none;
183
+ border-radius: 0.5rem;
184
+ font-weight: 500;
185
+ transition: background 0.2s;
186
+ }
187
+ .btn:hover {
188
+ background: var(--primary-dark);
189
+ }
190
+ footer {
191
+ text-align: center;
192
+ padding: 2rem;
193
+ color: var(--text-muted);
194
+ font-size: 0.875rem;
195
+ border-top: 1px solid var(--border);
196
+ margin-top: 4rem;
197
+ }
198
+ .kernel-families {
199
+ display: flex;
200
+ justify-content: center;
201
+ gap: 1rem;
202
+ margin-top: 2rem;
203
+ flex-wrap: wrap;
204
+ }
205
+ .kernel-tag {
206
+ padding: 0.5rem 1rem;
207
+ background: rgba(139, 92, 246, 0.2);
208
+ border: 1px solid rgba(139, 92, 246, 0.3);
209
+ border-radius: 0.5rem;
210
+ font-family: 'JetBrains Mono', monospace;
211
+ font-size: 0.875rem;
212
+ }
213
+ </style>
214
+ </head>
215
+ <body>
216
+ <div class="container">
217
+ <header>
218
+ <h1>RL Surrogate Kernel Autotuning</h1>
219
+ <p class="subtitle">
220
+ A reinforcement learning environment for optimizing GPU kernels using surrogate models and self-improving AI agents
221
+ </p>
222
+ <span class="badge">🚀 HuggingFace Space</span>
223
+ </header>
224
+ <div class="grid">
225
+ <div class="card">
226
+ <div class="card-icon blue">🎯</div>
227
+ <h3>Multi-Family Support</h3>
228
+ <p>Optimizes three kernel families: Softmax, LayerNorm, and Grouped GEMM for NVIDIA GPUs</p>
229
+ </div>
230
+ <div class="card">
231
+ <div class="card-icon green">📊</div>
232
+ <h3>Surrogate Modeling</h3>
233
+ <p>Uses Gaussian Process surrogates to predict kernel performance and guide exploration</p>
234
+ </div>
235
+ <div class="card">
236
+ <div class="card-icon amber">🤖</div>
237
+ <h3>Self-Improving AI</h3>
238
+ <p>Employs student-teacher LLM architecture for autonomous kernel optimization</p>
239
+ </div>
240
+ <div class="card">
241
+ <div class="card-icon purple">⚡</div>
242
+ <h3>Triton Backend</h3>
243
+ <p>Generates optimized Triton kernels with configurable block sizes, warps, and stages</p>
244
+ </div>
245
+ </div>
246
+ <div class="features">
247
+ <h2>Supported Kernel Families</h2>
248
+ <div class="kernel-families">
249
+ <span class="kernel-tag">softmax</span>
250
+ <span class="kernel-tag">layernorm</span>
251
+ <span class="kernel-tag">grouped_gemm</span>
252
+ </div>
253
+ </div>
254
+ <div class="features">
255
+ <h2>How It Works</h2>
256
+ <div class="feature-list">
257
+ <div class="feature-item">
258
+ <span>1.</span>
259
+ <div><strong>Reset</strong> - Initialize environment with a kernel task (e.g., softmax_m4096_n256)</div>
260
+ </div>
261
+ <div class="feature-item">
262
+ <span>2.</span>
263
+ <div><strong>Configure</strong> - Select Triton config (block_size, num_warps, num_stages)</div>
264
+ </div>
265
+ <div class="feature-item">
266
+ <span>3.</span>
267
+ <div><strong>Step</strong> - Benchmark configuration and receive reward based on latency</div>
268
+ </div>
269
+ <div class="feature-item">
270
+ <span>4.</span>
271
+ <div><strong>Optimize</strong> - Use LLM to generate better kernel variants automatically</div>
272
+ </div>
273
+ <div class="feature-item">
274
+ <span>5.</span>
275
+ <div><strong>Learn</strong> - Surrogate model improves with more observations</div>
276
+ </div>
277
+ <div class="feature-item">
278
+ <span>6.</span>
279
+ <div><strong>Compare</strong> - Student vs Teacher LLM competition for best optimization</div>
280
+ </div>
281
+ </div>
282
+ </div>
283
+ <div class="code-block">
284
+ <pre><span class="comment"># Example: Using the RL Environment</span>
285
+ <span class="comment"># 1. Reset with a task</span>
286
+ POST /reset
287
+ {
288
+ <span class="string">"task"</span>: <span class="string">"softmax_m4096_n256"</span>,
289
+ <span class="string">"seed"</span>: <span class="number">0</span>
290
+ }
291
+ <span class="comment"># 2. Take a step with a config</span>
292
+ POST /step
293
+ {
294
+ <span class="string">"config_id"</span>: <span class="number">0</span>,
295
+ <span class="string">"source"</span>: <span class="string">"def benchmark_generated_kernel..."</span> <span class="comment"># optional kernel code</span>
296
+ }
297
+ <span class="comment"># Response includes reward based on latency improvement</span>
298
+ {
299
+ <span class="string">"reward"</span>: <span class="number">0.093</span>,
300
+ <span class="string">"observation"</span>: {
301
+ <span class="string">"best_so_far_ms"</span>: <span class="number">0.41</span>,
302
+ <span class="string">"oracle_best_ms"</span>: <span class="number">0.38</span>
303
+ }
304
+ }</pre>
305
+ </div>
306
+ <div class="cta">
307
+ <h2>Try It Now</h2>
308
+ <p>Use the interactive Gradio interface above to benchmark kernel configurations and explore the RL environment.</p>
309
+ <a href="/" class="btn">Launch Interface →</a>
310
+ </div>
311
+ <footer>
312
+ <p>RL Surrogate Kernel Autotuning Environment</p>
313
+ <p>Built with FastAPI, Gradio, Triton, and Transformers</p>
314
+ </footer>
315
+ </div>
316
+ </body>
317
+ </html>
models.py DELETED
@@ -1,27 +0,0 @@
1
- from __future__ import annotations
2
-
3
- from dataclasses import dataclass
4
- from typing import Any, Dict, List, Sequence
5
-
6
-
7
- @dataclass
8
- class Action:
9
- config_id: int
10
-
11
-
12
- @dataclass
13
- class StepResult:
14
- observation: Dict[str, Any]
15
- reward: float
16
- done: bool
17
- state: Dict[str, Any]
18
- info: Dict[str, Any]
19
-
20
-
21
- @dataclass
22
- class ResetResult:
23
- observation: Dict[str, Any]
24
- reward: float
25
- done: bool
26
- state: Dict[str, Any]
27
- info: Dict[str, Any]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
openenv.yaml DELETED
@@ -1,15 +0,0 @@
1
- name: softmax_surrogate_env
2
- version: 0.1.0
3
- description: >
4
- OpenEnv scaffold for discrete-action multi-family GPU autotuning with real
5
- H100 measurements and held-out-shape / held-out-family evaluation.
6
- python:
7
- package: .
8
- server:
9
- entrypoint: server.app:app
10
- environment:
11
- class: server.softmax_surrogate_environment.SoftmaxSurrogateEnvironment
12
- init_kwargs:
13
- measurement_path: data/autotune_measurements.csv
14
- budget: 6
15
- seed: 0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
requirements.txt DELETED
@@ -1,8 +0,0 @@
1
- torch>=2.4
2
- triton>=2.1
3
- numpy>=1.23
4
- fastapi>=0.111
5
- uvicorn>=0.30
6
- pydantic>=2.8
7
- requests>=2.32
8
-
 
 
 
 
 
 
 
 
 
scripts/__init__.py DELETED
@@ -1 +0,0 @@
1
- # Package marker for direct imports from the scripts folder.
 
 
scripts/benchmark_qwen_05b_runtime.py DELETED
@@ -1,158 +0,0 @@
1
- #!/usr/bin/env python3
2
- from __future__ import annotations
3
-
4
- import argparse
5
- import json
6
- import sys
7
- import time
8
- from pathlib import Path
9
- from typing import Any, Dict, Tuple
10
-
11
- import numpy as np
12
- import torch
13
-
14
- ROOT = Path(__file__).resolve().parents[1]
15
- if str(ROOT) not in sys.path:
16
- sys.path.append(str(ROOT))
17
-
18
- from scripts.collect_qwen_05b_measurements import EPS, benchmark_qwen_task
19
- from scripts.qwen_05b_spec import QwenKernelTask, qwen_05b_tasks
20
-
21
-
22
- TASK_BY_ID = {task.task_id: task for task in qwen_05b_tasks()}
23
-
24
-
25
- def _bench_callable(fn, args: Tuple[Any, ...], repeats: int, warmup: int) -> float:
26
- for _ in range(max(1, warmup)):
27
- fn(*args)
28
- torch.cuda.synchronize()
29
- start = torch.cuda.Event(enable_timing=True)
30
- end = torch.cuda.Event(enable_timing=True)
31
- durations = []
32
- for _ in range(max(1, repeats)):
33
- torch.cuda.synchronize()
34
- start.record()
35
- fn(*args)
36
- end.record()
37
- end.synchronize()
38
- durations.append(start.elapsed_time(end))
39
- return float(np.median(np.asarray(durations, dtype=np.float32)))
40
-
41
-
42
- def _build_qwen_callable(task: QwenKernelTask, seed: int):
43
- torch.manual_seed(seed)
44
- if task.family == "softmax":
45
- x = torch.randn((task.m, task.n), device="cuda", dtype=torch.float16)
46
-
47
- def fn(inp: torch.Tensor):
48
- return torch.softmax(inp, dim=-1)
49
-
50
- return fn, (x,)
51
-
52
- if task.family == "rmsnorm":
53
- x = torch.randn((task.m, task.n), device="cuda", dtype=torch.float16)
54
-
55
- def fn(inp: torch.Tensor):
56
- return inp.float() * torch.rsqrt(inp.float().pow(2).mean(dim=-1, keepdim=True) + EPS)
57
-
58
- return fn, (x,)
59
-
60
- if task.family == "gemm":
61
- a = torch.randn((task.m, task.k), device="cuda", dtype=torch.float16)
62
- b = torch.randn((task.k, task.n), device="cuda", dtype=torch.float16)
63
-
64
- def fn(lhs: torch.Tensor, rhs: torch.Tensor):
65
- return torch.matmul(lhs, rhs)
66
-
67
- return fn, (a, b)
68
-
69
- raise ValueError(f"Unsupported family: {task.family}")
70
-
71
-
72
- def _benchmark_torch(task: QwenKernelTask, seed: int, repeats: int, warmup: int) -> Dict[str, float]:
73
- eager_fn, args = _build_qwen_callable(task, seed)
74
- eager_latency_ms = _bench_callable(eager_fn, args, repeats=repeats, warmup=warmup)
75
-
76
- compiled_fn = torch.compile(eager_fn)
77
- torch.cuda.synchronize()
78
- start = time.perf_counter()
79
- compiled_fn(*args)
80
- torch.cuda.synchronize()
81
- compile_plus_first_call_ms = float((time.perf_counter() - start) * 1000.0)
82
- compiled_latency_ms = _bench_callable(compiled_fn, args, repeats=repeats, warmup=warmup)
83
- return {
84
- "eager_latency_ms": eager_latency_ms,
85
- "compile_plus_first_call_ms": compile_plus_first_call_ms,
86
- "compiled_latency_ms": compiled_latency_ms,
87
- }
88
-
89
-
90
- def _task_best_configs(eval_results: Dict[str, Any]) -> Dict[str, Dict[str, Dict[str, Any]]]:
91
- task_map: Dict[str, Dict[str, Dict[str, Any]]] = {}
92
- for section in eval_results["results"].values():
93
- for method in ("random", "surrogate"):
94
- for run in section["task_runs"][method]:
95
- task_map.setdefault(run["task"], {})[method] = run["best_overall"]["config"]
96
- return task_map
97
-
98
-
99
- def main() -> None:
100
- parser = argparse.ArgumentParser(description="Benchmark eager/torch.compile and best Triton configs for Qwen2.5-0.5B exact kernels.")
101
- parser.add_argument("--generalization-results", type=Path, default=Path("outputs/qwen_05b_generalization_eval.json"))
102
- parser.add_argument("--repeats", type=int, default=100)
103
- parser.add_argument("--warmup", type=int, default=10)
104
- parser.add_argument("--seed", type=int, default=123)
105
- parser.add_argument("--output", type=Path, default=Path("outputs/qwen_05b_runtime_references.json"))
106
- args = parser.parse_args()
107
-
108
- generalization_results = json.loads(args.generalization_results.read_text(encoding="utf-8"))
109
- task_configs = _task_best_configs(generalization_results)
110
-
111
- results = {}
112
- for idx, task_id in enumerate(sorted(task_configs.keys())):
113
- task = TASK_BY_ID[task_id]
114
- seed = args.seed + idx
115
- torch_metrics = _benchmark_torch(task, seed=seed, repeats=args.repeats, warmup=args.warmup)
116
- triton_results = {
117
- method: benchmark_qwen_task(
118
- task=task,
119
- block_size=int(config["block_size"]),
120
- num_warps=int(config["num_warps"]),
121
- num_stages=int(config["num_stages"]),
122
- repeats=args.repeats,
123
- warmup=args.warmup,
124
- seed=seed,
125
- ).__dict__
126
- for method, config in task_configs[task_id].items()
127
- }
128
- results[task_id] = {
129
- "family": task.family,
130
- "role": task.role,
131
- "mode": task.mode,
132
- "torch": torch_metrics,
133
- "triton": triton_results,
134
- "speedups": {
135
- method: {
136
- "vs_eager": float(torch_metrics["eager_latency_ms"] / row["median_ms"]),
137
- "vs_compiled": float(torch_metrics["compiled_latency_ms"] / row["median_ms"]),
138
- }
139
- for method, row in triton_results.items()
140
- },
141
- }
142
-
143
- summary = {
144
- "generalization_results": str(args.generalization_results),
145
- "repeats": args.repeats,
146
- "warmup": args.warmup,
147
- "seed": args.seed,
148
- "task_count": len(results),
149
- "results": results,
150
- }
151
- args.output.parent.mkdir(parents=True, exist_ok=True)
152
- with args.output.open("w", encoding="utf-8") as handle:
153
- json.dump(summary, handle, indent=2)
154
- print(json.dumps(summary, indent=2))
155
-
156
-
157
- if __name__ == "__main__":
158
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
scripts/benchmark_runtime_references.py DELETED
@@ -1,222 +0,0 @@
1
- #!/usr/bin/env python3
2
- from __future__ import annotations
3
-
4
- import argparse
5
- import json
6
- import re
7
- import sys
8
- import time
9
- from pathlib import Path
10
- from typing import Any, Dict, Iterable, List, Tuple
11
-
12
- import numpy as np
13
- import torch
14
- import torch.nn.functional as F
15
-
16
- ROOT = Path(__file__).resolve().parents[1]
17
- if str(ROOT) not in sys.path:
18
- sys.path.append(str(ROOT))
19
-
20
- from scripts.collect_measurements import benchmark_single_config as benchmark_softmax_config
21
- from scripts.collect_multifamily_measurements import (
22
- EPS,
23
- GROUPED_GEMM_GROUP_COUNT,
24
- GROUPED_GEMM_K,
25
- benchmark_grouped_gemm_config,
26
- benchmark_layernorm_config,
27
- )
28
-
29
-
30
- TASK_RE = re.compile(
31
- r"^(?P<family>[a-z_]+?)(?:_g(?P<g>\d+)_k(?P<k>\d+))?_m(?P<m>\d+)_n(?P<n>\d+)$"
32
- )
33
-
34
-
35
- def _parse_task(task_id: str) -> Dict[str, int | str | None]:
36
- match = TASK_RE.match(task_id)
37
- if not match:
38
- raise ValueError(f"Cannot parse task id: {task_id}")
39
- data = match.groupdict()
40
- return {
41
- "family": data["family"],
42
- "m": int(data["m"]),
43
- "n": int(data["n"]),
44
- "g": int(data["g"]) if data["g"] is not None else None,
45
- "k": int(data["k"]) if data["k"] is not None else None,
46
- }
47
-
48
-
49
- def _bench_callable(fn, args: Tuple[Any, ...], repeats: int, warmup: int) -> float:
50
- for _ in range(max(1, warmup)):
51
- fn(*args)
52
- torch.cuda.synchronize()
53
-
54
- start = torch.cuda.Event(enable_timing=True)
55
- end = torch.cuda.Event(enable_timing=True)
56
- latencies_ms: List[float] = []
57
- for _ in range(max(1, repeats)):
58
- torch.cuda.synchronize()
59
- start.record()
60
- fn(*args)
61
- end.record()
62
- end.synchronize()
63
- latencies_ms.append(start.elapsed_time(end))
64
- return float(np.median(np.asarray(latencies_ms, dtype=np.float32)))
65
-
66
-
67
- def _build_family_callable(task_meta: Dict[str, Any], seed: int) -> Tuple[Any, Tuple[Any, ...]]:
68
- family = str(task_meta["family"])
69
- m = int(task_meta["m"])
70
- n = int(task_meta["n"])
71
- torch.manual_seed(seed)
72
-
73
- if family == "softmax":
74
- x = torch.randn((m, n), device="cuda", dtype=torch.float16)
75
-
76
- def fn(inp: torch.Tensor) -> torch.Tensor:
77
- return torch.softmax(inp, dim=-1)
78
-
79
- return fn, (x,)
80
-
81
- if family == "layernorm":
82
- x = torch.randn((m, n), device="cuda", dtype=torch.float16)
83
-
84
- def fn(inp: torch.Tensor) -> torch.Tensor:
85
- return F.layer_norm(inp, (inp.shape[-1],), eps=EPS)
86
-
87
- return fn, (x,)
88
-
89
- if family == "grouped_gemm":
90
- group_count = int(task_meta.get("g") or GROUPED_GEMM_GROUP_COUNT)
91
- k_dim = int(task_meta.get("k") or GROUPED_GEMM_K)
92
- group_m = max(64, m // group_count)
93
- a_groups = [torch.randn((group_m, k_dim), device="cuda", dtype=torch.float16) for _ in range(group_count)]
94
- b_groups = [torch.randn((k_dim, n), device="cuda", dtype=torch.float16) for _ in range(group_count)]
95
-
96
- def fn(*inputs: torch.Tensor) -> Tuple[torch.Tensor, ...]:
97
- mid = len(inputs) // 2
98
- a_list = inputs[:mid]
99
- b_list = inputs[mid:]
100
- return tuple(torch.matmul(a, b) for a, b in zip(a_list, b_list))
101
-
102
- return fn, tuple(a_groups + b_groups)
103
-
104
- raise ValueError(f"Unsupported family: {family}")
105
-
106
-
107
- def _benchmark_torch_compile(task_meta: Dict[str, Any], seed: int, repeats: int, warmup: int) -> Dict[str, float]:
108
- eager_fn, args = _build_family_callable(task_meta, seed=seed)
109
- eager_latency_ms = _bench_callable(eager_fn, args, repeats=repeats, warmup=warmup)
110
-
111
- compiled_fn = torch.compile(eager_fn)
112
- torch.cuda.synchronize()
113
- start = time.perf_counter()
114
- compiled_fn(*args)
115
- torch.cuda.synchronize()
116
- compile_plus_first_call_ms = float((time.perf_counter() - start) * 1000.0)
117
- compiled_latency_ms = _bench_callable(compiled_fn, args, repeats=repeats, warmup=warmup)
118
- return {
119
- "eager_latency_ms": eager_latency_ms,
120
- "compile_plus_first_call_ms": compile_plus_first_call_ms,
121
- "compiled_latency_ms": compiled_latency_ms,
122
- }
123
-
124
-
125
- def _benchmark_triton_best(task_id: str, config: Dict[str, Any], repeats: int, warmup: int, seed: int) -> Dict[str, Any]:
126
- task = _parse_task(task_id)
127
- family = str(task["family"])
128
- common = {
129
- "n": int(task["n"]),
130
- "block_size": int(config["block_size"]),
131
- "num_warps": int(config["num_warps"]),
132
- "num_stages": int(config["num_stages"]),
133
- "m": int(task["m"]),
134
- "repeats": repeats,
135
- "warmup": warmup,
136
- "seed": seed,
137
- }
138
- if family == "softmax":
139
- row = benchmark_softmax_config(**common)
140
- return row.__dict__
141
- if family == "layernorm":
142
- row = benchmark_layernorm_config(**common)
143
- return row.__dict__
144
- if family == "grouped_gemm":
145
- row = benchmark_grouped_gemm_config(**common)
146
- return row.__dict__
147
- raise ValueError(f"Unsupported family: {family}")
148
-
149
-
150
- def _collect_task_best_configs(generalization_results: Dict[str, Any]) -> Dict[str, Dict[str, Dict[str, Any]]]:
151
- task_map: Dict[str, Dict[str, Dict[str, Any]]] = {}
152
- for section in generalization_results["results"].values():
153
- for method in ("random", "surrogate"):
154
- for run in section["task_runs"][method]:
155
- task_map.setdefault(run["task"], {})[method] = run["best_overall"]["config"]
156
- return task_map
157
-
158
-
159
- def main() -> None:
160
- parser = argparse.ArgumentParser(description="Benchmark eager/torch.compile and live Triton configs on held-out tasks.")
161
- parser.add_argument(
162
- "--generalization-results",
163
- type=Path,
164
- default=Path("outputs/generalization_eval.json"),
165
- )
166
- parser.add_argument("--repeats", type=int, default=100)
167
- parser.add_argument("--warmup", type=int, default=10)
168
- parser.add_argument("--seed", type=int, default=123)
169
- parser.add_argument(
170
- "--output",
171
- type=Path,
172
- default=Path("outputs/runtime_references.json"),
173
- )
174
- args = parser.parse_args()
175
-
176
- generalization_results = json.loads(args.generalization_results.read_text(encoding="utf-8"))
177
- task_best_configs = _collect_task_best_configs(generalization_results)
178
-
179
- results: Dict[str, Any] = {}
180
- for idx, task_id in enumerate(sorted(task_best_configs.keys())):
181
- task_seed = args.seed + idx
182
- task_meta = _parse_task(task_id)
183
- torch_metrics = _benchmark_torch_compile(task_meta, seed=task_seed, repeats=args.repeats, warmup=args.warmup)
184
- method_results = {
185
- method: _benchmark_triton_best(
186
- task_id=task_id,
187
- config=config,
188
- repeats=args.repeats,
189
- warmup=args.warmup,
190
- seed=task_seed,
191
- )
192
- for method, config in task_best_configs[task_id].items()
193
- }
194
- results[task_id] = {
195
- "family": task_meta["family"],
196
- "torch": torch_metrics,
197
- "triton": method_results,
198
- "speedups": {
199
- method: {
200
- "vs_eager": float(torch_metrics["eager_latency_ms"] / row["median_ms"]),
201
- "vs_compiled": float(torch_metrics["compiled_latency_ms"] / row["median_ms"]),
202
- }
203
- for method, row in method_results.items()
204
- },
205
- }
206
-
207
- summary = {
208
- "generalization_results": str(args.generalization_results),
209
- "repeats": args.repeats,
210
- "warmup": args.warmup,
211
- "seed": args.seed,
212
- "task_count": len(results),
213
- "results": results,
214
- }
215
- args.output.parent.mkdir(parents=True, exist_ok=True)
216
- with args.output.open("w", encoding="utf-8") as handle:
217
- json.dump(summary, handle, indent=2)
218
- print(json.dumps(summary, indent=2))
219
-
220
-
221
- if __name__ == "__main__":
222
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
scripts/benchmark_surrogate_vs_triton_autotune.py DELETED
@@ -1,361 +0,0 @@
1
- #!/usr/bin/env python3
2
- from __future__ import annotations
3
-
4
- import argparse
5
- import json
6
- import sys
7
- import time
8
- from pathlib import Path
9
- from typing import Any, Dict, List
10
-
11
- import numpy as np
12
- import torch
13
-
14
- ROOT = Path(__file__).resolve().parents[1]
15
- if str(ROOT) not in sys.path:
16
- sys.path.append(str(ROOT))
17
-
18
- import triton
19
- import triton.language as tl
20
- import triton.testing as ttesting
21
-
22
- from scripts.collect_measurements import BLOCK_SIZES, NUM_STAGES, NUM_WARPS, fused_rowwise_softmax_kernel
23
- from server.softmax_surrogate_environment import SoftmaxSurrogateEnvironment
24
-
25
-
26
- DEFAULT_TASKS = [
27
- "softmax_m4096_n4096",
28
- "softmax_m4096_n6144",
29
- "softmax_m4096_n8192",
30
- ]
31
-
32
-
33
- def _choose_surrogate_config_id(env: SoftmaxSurrogateEnvironment, acquisition: str, beta: float, xi: float) -> int:
34
- seen = set(env.seen_config_ids())
35
- best_config_id = -1
36
- best_score = float("-inf")
37
- for config_id in env.available_config_ids():
38
- if config_id in seen and len(seen) < len(env.available_config_ids()):
39
- continue
40
- score = env.acquisition_score(config_id, strategy=acquisition, beta=beta, xi=xi)
41
- if score > best_score:
42
- best_score = score
43
- best_config_id = config_id
44
- if best_config_id < 0:
45
- raise RuntimeError("Failed to select surrogate config.")
46
- return best_config_id
47
-
48
-
49
- def _valid_autotune_configs(n_cols: int) -> List[triton.Config]:
50
- configs = []
51
- for block_size in BLOCK_SIZES:
52
- if block_size < n_cols:
53
- continue
54
- for num_warps in NUM_WARPS:
55
- for num_stages in NUM_STAGES:
56
- configs.append(
57
- triton.Config(
58
- {"BLOCK_SIZE": block_size},
59
- num_warps=num_warps,
60
- num_stages=num_stages,
61
- )
62
- )
63
- return configs
64
-
65
-
66
- def _compile_plus_first_call_fixed(sample: torch.Tensor, config: Dict[str, int]) -> float:
67
- output = torch.empty_like(sample)
68
- grid = (sample.shape[0],)
69
- torch.cuda.synchronize()
70
- start = time.perf_counter()
71
- fused_rowwise_softmax_kernel[grid](
72
- sample,
73
- output,
74
- sample.stride(0),
75
- sample.stride(1),
76
- output.stride(0),
77
- output.stride(1),
78
- sample.shape[1],
79
- BLOCK_SIZE=int(config["block_size"]),
80
- num_warps=int(config["num_warps"]),
81
- num_stages=int(config["num_stages"]),
82
- )
83
- torch.cuda.synchronize()
84
- return float((time.perf_counter() - start) * 1000.0)
85
-
86
-
87
- def _steady_state_fixed(sample: torch.Tensor, config: Dict[str, int], repeats: int, warmup: int) -> float:
88
- output = torch.empty_like(sample)
89
- grid = (sample.shape[0],)
90
-
91
- def launch() -> None:
92
- fused_rowwise_softmax_kernel[grid](
93
- sample,
94
- output,
95
- sample.stride(0),
96
- sample.stride(1),
97
- output.stride(0),
98
- output.stride(1),
99
- sample.shape[1],
100
- BLOCK_SIZE=int(config["block_size"]),
101
- num_warps=int(config["num_warps"]),
102
- num_stages=int(config["num_stages"]),
103
- )
104
-
105
- return float(
106
- ttesting.do_bench(
107
- launch,
108
- warmup=warmup,
109
- rep=repeats,
110
- quantiles=[0.5],
111
- return_mode="median",
112
- )
113
- )
114
-
115
-
116
- def _benchmark_triton_autotune(sample: torch.Tensor, repeats: int, warmup: int) -> Dict[str, float]:
117
- output = torch.empty_like(sample)
118
- n_cols = sample.shape[1]
119
- configs = _valid_autotune_configs(n_cols)
120
-
121
- @triton.autotune(configs=configs, key=["n_cols"])
122
- @triton.jit
123
- def autotuned_softmax_kernel(
124
- X_ptr,
125
- Y_ptr,
126
- stride_xm,
127
- stride_xn,
128
- stride_ym,
129
- stride_yn,
130
- n_cols,
131
- BLOCK_SIZE: tl.constexpr,
132
- ):
133
- row_idx = tl.program_id(0)
134
- col_offsets = tl.arange(0, BLOCK_SIZE)
135
- x_ptr = X_ptr + row_idx * stride_xm + col_offsets
136
- y_ptr = Y_ptr + row_idx * stride_ym + col_offsets
137
- mask = col_offsets < n_cols
138
-
139
- x = tl.load(x_ptr, mask=mask, other=-float("inf"))
140
- x = x - tl.max(x, axis=0)
141
- numerator = tl.exp(x)
142
- denominator = tl.sum(numerator, axis=0)
143
- y = numerator / denominator
144
- tl.store(y_ptr, y, mask=mask)
145
-
146
- grid = (sample.shape[0],)
147
- torch.cuda.synchronize()
148
- start = time.perf_counter()
149
- autotuned_softmax_kernel[grid](
150
- sample,
151
- output,
152
- sample.stride(0),
153
- sample.stride(1),
154
- output.stride(0),
155
- output.stride(1),
156
- n_cols,
157
- )
158
- torch.cuda.synchronize()
159
- first_call_ms = float((time.perf_counter() - start) * 1000.0)
160
-
161
- def launch() -> None:
162
- autotuned_softmax_kernel[grid](
163
- sample,
164
- output,
165
- sample.stride(0),
166
- sample.stride(1),
167
- output.stride(0),
168
- output.stride(1),
169
- n_cols,
170
- )
171
-
172
- steady_ms = float(
173
- ttesting.do_bench(
174
- launch,
175
- warmup=warmup,
176
- rep=repeats,
177
- quantiles=[0.5],
178
- return_mode="median",
179
- )
180
- )
181
- return {
182
- "autotune_first_call_ms": first_call_ms,
183
- "autotune_steady_ms": steady_ms,
184
- }
185
-
186
-
187
- def _build_summary(
188
- measurement_path: str,
189
- tasks: List[str],
190
- acquisition: str,
191
- beta: float,
192
- xi: float,
193
- results: Dict[str, Any],
194
- ) -> Dict[str, Any]:
195
- if results:
196
- surrogate_first = [task["surrogate"]["compile_plus_first_call_ms"] for task in results.values()]
197
- surrogate_steady = [task["surrogate"]["steady_ms"] for task in results.values()]
198
- autotune_first = [task["triton_autotune"]["autotune_first_call_ms"] for task in results.values()]
199
- autotune_steady = [task["triton_autotune"]["autotune_steady_ms"] for task in results.values()]
200
- surrogate_speedup = [
201
- task["triton_autotune"]["autotune_first_call_ms"] / max(task["surrogate"]["compile_plus_first_call_ms"], 1e-9)
202
- for task in results.values()
203
- ]
204
- summary = {
205
- "mean_surrogate_compile_plus_first_call_ms": float(np.mean(surrogate_first)),
206
- "mean_surrogate_steady_ms": float(np.mean(surrogate_steady)),
207
- "mean_autotune_first_call_ms": float(np.mean(autotune_first)),
208
- "mean_autotune_steady_ms": float(np.mean(autotune_steady)),
209
- "mean_search_time_speedup_surrogate_vs_autotune": float(np.mean(surrogate_speedup)),
210
- "completed_task_count": len(results),
211
- }
212
- else:
213
- summary = {
214
- "mean_surrogate_compile_plus_first_call_ms": None,
215
- "mean_surrogate_steady_ms": None,
216
- "mean_autotune_first_call_ms": None,
217
- "mean_autotune_steady_ms": None,
218
- "mean_search_time_speedup_surrogate_vs_autotune": None,
219
- "completed_task_count": 0,
220
- }
221
- return {
222
- "measurement_path": measurement_path,
223
- "tasks": tasks,
224
- "acquisition": acquisition,
225
- "beta": beta,
226
- "xi": xi,
227
- "results": results,
228
- "summary": summary,
229
- }
230
-
231
-
232
- def run_benchmark(
233
- measurement_path: str,
234
- tasks: List[str],
235
- repeats: int,
236
- warmup: int,
237
- seed: int,
238
- acquisition: str,
239
- beta: float,
240
- xi: float,
241
- output_path: Path,
242
- resume: bool,
243
- ) -> Dict[str, Any]:
244
- env_probe = SoftmaxSurrogateEnvironment(measurement_path=measurement_path, budget=1, seed=seed)
245
- available_tasks = [task for task in env_probe.available_tasks() if task.startswith("softmax_m4096_n")]
246
- train_task_ids = [task for task in available_tasks if task not in tasks]
247
-
248
- results: Dict[str, Any] = {}
249
- if resume and output_path.exists():
250
- try:
251
- existing = json.loads(output_path.read_text(encoding="utf-8"))
252
- results = dict(existing.get("results", {}))
253
- except Exception:
254
- results = {}
255
-
256
- for index, task in enumerate(tasks):
257
- if task in results:
258
- print(f"[progress] skipping completed task {index + 1}/{len(tasks)}: {task}")
259
- continue
260
- task_seed = seed + index
261
- print(f"[progress] starting task {index + 1}/{len(tasks)}: {task}")
262
- env = SoftmaxSurrogateEnvironment(
263
- measurement_path=measurement_path,
264
- budget=6,
265
- seed=task_seed,
266
- train_task_ids=train_task_ids,
267
- )
268
- reset_out = env.reset(task=task, seed=task_seed)
269
- decision_start = time.perf_counter()
270
- surrogate_config_id = _choose_surrogate_config_id(env, acquisition=acquisition, beta=beta, xi=xi)
271
- decision_ms = float((time.perf_counter() - decision_start) * 1000.0)
272
- surrogate_config = env.config_info(surrogate_config_id)
273
- sample = torch.randn((env._task_rows[0].m, env._task_rows[0].n), device="cuda", dtype=torch.float16)
274
-
275
- surrogate_first_call_ms = _compile_plus_first_call_fixed(sample, surrogate_config)
276
- surrogate_steady_ms = _steady_state_fixed(sample, surrogate_config, repeats=repeats, warmup=warmup)
277
- autotune_metrics = _benchmark_triton_autotune(sample, repeats=repeats, warmup=warmup)
278
- oracle_best = env.oracle_best()
279
-
280
- results[task] = {
281
- "seeded_config_ids": reset_out["observation"]["tried_config_ids"],
282
- "train_task_count": len(train_task_ids),
283
- "oracle_best_ms": oracle_best["median_ms"],
284
- "surrogate": {
285
- "config": surrogate_config,
286
- "decision_ms": decision_ms,
287
- "compile_plus_first_call_ms": surrogate_first_call_ms,
288
- "steady_ms": surrogate_steady_ms,
289
- "regret_vs_oracle": float(surrogate_steady_ms / oracle_best["median_ms"] - 1.0),
290
- },
291
- "triton_autotune": {
292
- **autotune_metrics,
293
- "regret_vs_oracle": float(autotune_metrics["autotune_steady_ms"] / oracle_best["median_ms"] - 1.0),
294
- },
295
- }
296
- print(
297
- "[progress] finished"
298
- f" task={task}"
299
- f" surrogate_first_ms={results[task]['surrogate']['compile_plus_first_call_ms']:.3f}"
300
- f" surrogate_steady_ms={results[task]['surrogate']['steady_ms']:.3f}"
301
- f" autotune_first_ms={results[task]['triton_autotune']['autotune_first_call_ms']:.3f}"
302
- f" autotune_steady_ms={results[task]['triton_autotune']['autotune_steady_ms']:.3f}"
303
- )
304
- snapshot = _build_summary(
305
- measurement_path=measurement_path,
306
- tasks=tasks,
307
- acquisition=acquisition,
308
- beta=beta,
309
- xi=xi,
310
- results=results,
311
- )
312
- output_path.parent.mkdir(parents=True, exist_ok=True)
313
- output_path.write_text(json.dumps(snapshot, indent=2), encoding="utf-8")
314
-
315
- return _build_summary(
316
- measurement_path=measurement_path,
317
- tasks=tasks,
318
- acquisition=acquisition,
319
- beta=beta,
320
- xi=xi,
321
- results=results,
322
- )
323
-
324
-
325
- def parse_args() -> argparse.Namespace:
326
- parser = argparse.ArgumentParser(description="Quick benchmark: surrogate-guided softmax config choice vs Triton autotune.")
327
- parser.add_argument("--measurement-path", default="data/autotune_measurements.csv")
328
- parser.add_argument("--tasks", nargs="+", default=DEFAULT_TASKS)
329
- parser.add_argument("--repeats", type=int, default=50)
330
- parser.add_argument("--warmup", type=int, default=10)
331
- parser.add_argument("--seed", type=int, default=7)
332
- parser.add_argument("--acquisition", choices=("mean", "ucb", "ei"), default="ucb")
333
- parser.add_argument("--beta", type=float, default=2.0)
334
- parser.add_argument("--xi", type=float, default=0.0)
335
- parser.add_argument("--output", type=Path, default=Path("outputs/surrogate_vs_triton_autotune.json"))
336
- parser.add_argument("--no-resume", action="store_true")
337
- return parser.parse_args()
338
-
339
-
340
- def main() -> None:
341
- args = parse_args()
342
- summary = run_benchmark(
343
- measurement_path=args.measurement_path,
344
- tasks=args.tasks,
345
- repeats=args.repeats,
346
- warmup=args.warmup,
347
- seed=args.seed,
348
- acquisition=args.acquisition,
349
- beta=args.beta,
350
- xi=args.xi,
351
- output_path=args.output,
352
- resume=not args.no_resume,
353
- )
354
- args.output.parent.mkdir(parents=True, exist_ok=True)
355
- with args.output.open("w", encoding="utf-8") as handle:
356
- json.dump(summary, handle, indent=2)
357
- print(json.dumps(summary, indent=2))
358
-
359
-
360
- if __name__ == "__main__":
361
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
scripts/build_benchmark_splits.py DELETED
@@ -1,81 +0,0 @@
1
- #!/usr/bin/env python3
2
- from __future__ import annotations
3
-
4
- import argparse
5
- import csv
6
- import json
7
- from collections import defaultdict
8
- from pathlib import Path
9
- from typing import Dict, List
10
-
11
-
12
- def _load_tasks(path: Path) -> Dict[str, List[dict]]:
13
- grouped: Dict[str, List[dict]] = defaultdict(list)
14
- with path.open("r", newline="", encoding="utf-8") as handle:
15
- for row in csv.DictReader(handle):
16
- grouped[row["task_id"]].append(row)
17
- return grouped
18
-
19
-
20
- def _task_n(task_id: str) -> int:
21
- return int(task_id.split("_n")[-1])
22
-
23
-
24
- def build_splits(grouped: Dict[str, List[dict]], heldout_family: str | None) -> Dict[str, object]:
25
- tasks_by_family: Dict[str, List[str]] = defaultdict(list)
26
- for task_id, rows in grouped.items():
27
- tasks_by_family[rows[0]["family"]].append(task_id)
28
-
29
- for family in tasks_by_family:
30
- tasks_by_family[family].sort(key=_task_n)
31
-
32
- families = sorted(tasks_by_family.keys())
33
- if not families:
34
- raise RuntimeError("No tasks found in measurement file.")
35
-
36
- shape_train: List[str] = []
37
- shape_test: List[str] = []
38
- for family, tasks in tasks_by_family.items():
39
- holdout_count = 2 if len(tasks) >= 4 else 1
40
- split_idx = max(1, len(tasks) - holdout_count)
41
- shape_train.extend(tasks[:split_idx])
42
- shape_test.extend(tasks[split_idx:])
43
-
44
- if heldout_family is None:
45
- heldout_family = families[-1]
46
- if heldout_family not in tasks_by_family:
47
- raise ValueError(f"Held-out family {heldout_family} is not present.")
48
-
49
- family_train = [task_id for family, tasks in tasks_by_family.items() if family != heldout_family for task_id in tasks]
50
- family_test = list(tasks_by_family[heldout_family])
51
-
52
- return {
53
- "families_present": families,
54
- "shape_generalization": {
55
- "train_tasks": sorted(shape_train),
56
- "test_tasks": sorted(shape_test),
57
- },
58
- "family_holdout": {
59
- "heldout_family": heldout_family,
60
- "train_tasks": sorted(family_train),
61
- "test_tasks": sorted(family_test),
62
- },
63
- }
64
-
65
-
66
- def main() -> None:
67
- parser = argparse.ArgumentParser(description="Build train/test split manifests for the multi-family benchmark.")
68
- parser.add_argument("--measurement-path", type=Path, default=Path("data/autotune_measurements.csv"))
69
- parser.add_argument("--output", type=Path, default=Path("data/benchmark_splits.json"))
70
- parser.add_argument("--heldout-family", type=str, default=None)
71
- args = parser.parse_args()
72
-
73
- splits = build_splits(_load_tasks(args.measurement_path), args.heldout_family)
74
- args.output.parent.mkdir(parents=True, exist_ok=True)
75
- with args.output.open("w", encoding="utf-8") as handle:
76
- json.dump(splits, handle, indent=2)
77
- print(json.dumps(splits, indent=2))
78
-
79
-
80
- if __name__ == "__main__":
81
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
scripts/build_qwen_05b_splits.py DELETED
@@ -1,56 +0,0 @@
1
- #!/usr/bin/env python3
2
- from __future__ import annotations
3
-
4
- import argparse
5
- import json
6
- from pathlib import Path
7
-
8
- from scripts.qwen_05b_spec import DECODE_CTX_LENS, PREFILL_SEQ_LENS, qwen_05b_tasks
9
-
10
-
11
- def build_splits() -> dict:
12
- tasks = qwen_05b_tasks()
13
- long_prefill = max(PREFILL_SEQ_LENS)
14
- long_decode = max(DECODE_CTX_LENS)
15
-
16
- shape_train = []
17
- shape_test = []
18
- for task in tasks:
19
- if task.mode == "prefill" and task.seq_len == long_prefill:
20
- shape_test.append(task.task_id)
21
- elif task.mode == "decode" and task.ctx_len == long_decode:
22
- shape_test.append(task.task_id)
23
- else:
24
- shape_train.append(task.task_id)
25
-
26
- family_holdout_train = [task.task_id for task in tasks if task.family != "gemm"]
27
- family_holdout_test = [task.task_id for task in tasks if task.family == "gemm"]
28
-
29
- return {
30
- "model_id": "Qwen/Qwen2.5-0.5B",
31
- "shape_generalization": {
32
- "train_tasks": sorted(shape_train),
33
- "test_tasks": sorted(shape_test),
34
- },
35
- "family_holdout": {
36
- "heldout_family": "gemm",
37
- "train_tasks": sorted(family_holdout_train),
38
- "test_tasks": sorted(family_holdout_test),
39
- },
40
- }
41
-
42
-
43
- def main() -> None:
44
- parser = argparse.ArgumentParser(description="Build train/test split manifests for Qwen2.5-0.5B kernel tuning.")
45
- parser.add_argument("--output", type=Path, default=Path("data/qwen_05b_splits.json"))
46
- args = parser.parse_args()
47
-
48
- splits = build_splits()
49
- args.output.parent.mkdir(parents=True, exist_ok=True)
50
- with args.output.open("w", encoding="utf-8") as handle:
51
- json.dump(splits, handle, indent=2)
52
- print(json.dumps(splits, indent=2))
53
-
54
-
55
- if __name__ == "__main__":
56
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
scripts/check_torch_cuda.py DELETED
@@ -1,18 +0,0 @@
1
- #!/usr/bin/env python3
2
- from __future__ import annotations
3
-
4
- import torch
5
-
6
-
7
- def main() -> None:
8
- print(f"python: {__import__('sys').executable}")
9
- print(f"torch: {torch.__version__}")
10
- print(f"cuda_available: {torch.cuda.is_available()}")
11
- if torch.cuda.is_available():
12
- print(f"cuda_device_name: {torch.cuda.get_device_name(0)}")
13
- print(f"cuda_capability: {torch.cuda.get_device_capability(0)}")
14
-
15
-
16
- if __name__ == "__main__":
17
- main()
18
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
scripts/collect_measurements.py DELETED
@@ -1,398 +0,0 @@
1
- #!/usr/bin/env python3
2
- from __future__ import annotations
3
-
4
- import argparse
5
- import csv
6
- import math
7
- import statistics
8
- from dataclasses import dataclass
9
- from pathlib import Path
10
- from typing import Iterable, List, Dict, Optional
11
-
12
- import numpy as np
13
- import torch
14
-
15
- try:
16
- import triton
17
- import triton.language as tl
18
- import triton.testing as ttesting
19
- except Exception as err: # pragma: no cover
20
- raise RuntimeError(
21
- "Triton is required for this measurement collector. Install Triton and rerun."
22
- ) from err
23
-
24
-
25
- BLOCK_SIZES = (256, 512, 1024, 2048, 4096, 8192)
26
- NUM_WARPS = (1, 2, 4, 8)
27
- NUM_STAGES = (1, 2, 3, 4)
28
- N_VALUES = (256, 512, 1024, 1536, 2048, 3072, 4096, 6144, 8192)
29
-
30
-
31
- @dataclass(frozen=True)
32
- class MeasurementRow:
33
- task_id: str
34
- m: int
35
- n: int
36
- block_size: int
37
- num_warps: int
38
- num_stages: int
39
- median_ms: float
40
- effective_gbps: float
41
- score: float
42
- validation_error: float
43
-
44
-
45
- @triton.jit
46
- def fused_rowwise_softmax_kernel(
47
- X_ptr,
48
- Y_ptr,
49
- stride_xm,
50
- stride_xn,
51
- stride_ym,
52
- stride_yn,
53
- n_cols,
54
- BLOCK_SIZE: tl.constexpr,
55
- ):
56
- row_idx = tl.program_id(0)
57
- col_offsets = tl.arange(0, BLOCK_SIZE)
58
- x_ptr = X_ptr + row_idx * stride_xm + col_offsets
59
- y_ptr = Y_ptr + row_idx * stride_ym + col_offsets
60
- mask = col_offsets < n_cols
61
-
62
- x = tl.load(x_ptr, mask=mask, other=-float("inf"))
63
- x = x - tl.max(x, axis=0)
64
- numerator = tl.exp(x)
65
- denominator = tl.sum(numerator, axis=0)
66
- y = numerator / denominator
67
- tl.store(y_ptr, y, mask=mask)
68
-
69
-
70
- def _task_id(m: int, n: int) -> str:
71
- return f"softmax_m{m}_n{n}"
72
-
73
-
74
- def _effective_gbps(sample: torch.Tensor, median_ms: float) -> float:
75
- bytes_processed = float(sample.numel() * sample.element_size() * 2)
76
- if median_ms <= 0:
77
- return 0.0
78
- return bytes_processed / (median_ms / 1000.0) / 1e9
79
-
80
-
81
- def _benchmark_config(
82
- sample: torch.Tensor,
83
- block_size: int,
84
- num_warps: int,
85
- num_stages: int,
86
- repeats: int,
87
- warmup: int,
88
- ) -> float:
89
- output = torch.empty_like(sample)
90
- m, n = sample.shape
91
- grid = (m,)
92
-
93
- def launch() -> None:
94
- fused_rowwise_softmax_kernel[grid](
95
- sample,
96
- output,
97
- sample.stride(0),
98
- sample.stride(1),
99
- output.stride(0),
100
- output.stride(1),
101
- n,
102
- BLOCK_SIZE=block_size,
103
- num_warps=num_warps,
104
- num_stages=num_stages,
105
- )
106
-
107
- # warmup to compile and stabilize caches / clocks.
108
- for _ in range(max(1, warmup)):
109
- launch()
110
- torch.cuda.synchronize()
111
-
112
- if ttesting is not None:
113
- result = ttesting.do_bench(
114
- launch,
115
- warmup=0,
116
- rep=repeats,
117
- quantiles=[0.5],
118
- return_mode="median",
119
- )
120
- if isinstance(result, dict):
121
- return float(result.get("median", 0.0))
122
- return float(result)
123
-
124
- start = torch.cuda.Event(enable_timing=True)
125
- end = torch.cuda.Event(enable_timing=True)
126
- durations_ms: List[float] = []
127
- for _ in range(max(1, repeats)):
128
- torch.cuda.synchronize()
129
- start.record()
130
- launch()
131
- end.record()
132
- end.synchronize()
133
- durations_ms.append(start.elapsed_time(end))
134
- return float(statistics.median(durations_ms))
135
-
136
-
137
- def _validate_correctness(sample: torch.Tensor, block_size: int, num_warps: int, num_stages: int) -> float:
138
- ref = torch.softmax(sample.float(), dim=-1).to(sample.dtype)
139
- out = torch.empty_like(sample)
140
-
141
- fused_rowwise_softmax_kernel[(sample.shape[0],)](
142
- sample,
143
- out,
144
- sample.stride(0),
145
- sample.stride(1),
146
- out.stride(0),
147
- out.stride(1),
148
- sample.shape[1],
149
- BLOCK_SIZE=block_size,
150
- num_warps=num_warps,
151
- num_stages=num_stages,
152
- )
153
- torch.cuda.synchronize()
154
-
155
- err = (out.to(torch.float32) - ref).abs().max().item()
156
- return float(err)
157
-
158
-
159
- def benchmark_single_config(
160
- n: int,
161
- block_size: int,
162
- num_warps: int,
163
- num_stages: int,
164
- m: int = 4096,
165
- repeats: int = 200,
166
- warmup: int = 25,
167
- seed: int = 0,
168
- ) -> MeasurementRow:
169
- if block_size < n:
170
- raise ValueError(f"Invalid config: BLOCK_SIZE {block_size} < N {n}")
171
-
172
- torch.manual_seed(seed)
173
- sample = torch.randn((m, n), device="cuda", dtype=torch.float16)
174
-
175
- if num_warps not in NUM_WARPS:
176
- raise ValueError(f"Unsupported num_warps={num_warps}")
177
- if num_stages not in NUM_STAGES:
178
- raise ValueError(f"Unsupported num_stages={num_stages}")
179
- if block_size not in BLOCK_SIZES:
180
- raise ValueError(f"Unsupported BLOCK_SIZE={block_size}")
181
-
182
- median_ms = _benchmark_config(
183
- sample=sample,
184
- block_size=block_size,
185
- num_warps=num_warps,
186
- num_stages=num_stages,
187
- repeats=repeats,
188
- warmup=warmup,
189
- )
190
-
191
- val_err = _validate_correctness(
192
- sample=sample,
193
- block_size=block_size,
194
- num_warps=num_warps,
195
- num_stages=num_stages,
196
- )
197
-
198
- gbps = _effective_gbps(sample, median_ms)
199
- score = -math.log(max(median_ms, np.finfo(float).tiny))
200
- return MeasurementRow(
201
- task_id=_task_id(m, n),
202
- m=m,
203
- n=n,
204
- block_size=block_size,
205
- num_warps=num_warps,
206
- num_stages=num_stages,
207
- median_ms=float(median_ms),
208
- effective_gbps=float(gbps),
209
- score=float(score),
210
- validation_error=float(val_err),
211
- )
212
-
213
-
214
- def collect_measurements(
215
- output_path: Path,
216
- n_values: Iterable[int],
217
- repeats: int,
218
- warmup: int,
219
- seed: int,
220
- m: int = 4096,
221
- append: bool = True,
222
- ) -> List[MeasurementRow]:
223
- if not torch.cuda.is_available():
224
- raise RuntimeError("CUDA is required to run measurements on H100.")
225
- if not torch.cuda.get_device_name(0):
226
- raise RuntimeError("No CUDA device found.")
227
-
228
- output_path.parent.mkdir(parents=True, exist_ok=True)
229
-
230
- existing: Dict[str, MeasurementRow] = {}
231
- if output_path.exists():
232
- with output_path.open("r", newline="", encoding="utf-8") as f:
233
- for row in csv.DictReader(f):
234
- key = _row_key(
235
- int(row["m"]),
236
- int(row["n"]),
237
- int(row["block_size"]),
238
- int(row["num_warps"]),
239
- int(row["num_stages"]),
240
- )
241
- existing[key] = MeasurementRow(
242
- task_id=row["task_id"],
243
- m=int(row["m"]),
244
- n=int(row["n"]),
245
- block_size=int(row["block_size"]),
246
- num_warps=int(row["num_warps"]),
247
- num_stages=int(row["num_stages"]),
248
- median_ms=float(row["median_ms"]),
249
- effective_gbps=float(row["effective_gbps"]),
250
- score=float(row["score"]),
251
- validation_error=float(row["validation_error"]),
252
- )
253
-
254
- results: List[MeasurementRow] = []
255
- for n in n_values:
256
- if n < 0:
257
- raise ValueError(f"Invalid n value: {n}")
258
- for block in BLOCK_SIZES:
259
- if block < n:
260
- continue
261
- for num_warps in NUM_WARPS:
262
- for num_stages in NUM_STAGES:
263
- key = _row_key(m, n, block, num_warps, num_stages)
264
- if append and key in existing:
265
- continue
266
-
267
- torch.cuda.synchronize()
268
- row = benchmark_single_config(
269
- n=n,
270
- block_size=block,
271
- num_warps=num_warps,
272
- num_stages=num_stages,
273
- m=m,
274
- repeats=repeats,
275
- warmup=warmup,
276
- seed=seed,
277
- )
278
- results.append(row)
279
-
280
- if results:
281
- write_mode = "a" if output_path.exists() and append else "w"
282
- with output_path.open(write_mode, newline="", encoding="utf-8") as f:
283
- writer = csv.writer(f)
284
- if write_mode == "w" or (output_path.stat().st_size == 0):
285
- writer.writerow(
286
- [
287
- "task_id",
288
- "m",
289
- "n",
290
- "block_size",
291
- "num_warps",
292
- "num_stages",
293
- "median_ms",
294
- "effective_gbps",
295
- "score",
296
- "validation_error",
297
- ]
298
- )
299
- for r in results:
300
- writer.writerow(
301
- [
302
- r.task_id,
303
- r.m,
304
- r.n,
305
- r.block_size,
306
- r.num_warps,
307
- r.num_stages,
308
- r.median_ms,
309
- r.effective_gbps,
310
- r.score,
311
- r.validation_error,
312
- ]
313
- )
314
- return results
315
-
316
-
317
- def _row_key(m: int, n: int, block_size: int, num_warps: int, num_stages: int) -> str:
318
- return f"{m}|{n}|{block_size}|{num_warps}|{num_stages}"
319
-
320
-
321
- def parse_args() -> argparse.Namespace:
322
- parser = argparse.ArgumentParser(
323
- description="Collect real H100 measurements for Triton row-wise fused softmax."
324
- )
325
- parser.add_argument(
326
- "--output",
327
- type=Path,
328
- default=Path("data/softmax_measurements.csv"),
329
- help="CSV output path.",
330
- )
331
- parser.add_argument(
332
- "--n-cols",
333
- type=int,
334
- nargs="+",
335
- default=N_VALUES,
336
- help="Softmax inner dimension N values to benchmark.",
337
- )
338
- parser.add_argument("--m", type=int, default=4096, help="Outer dimension M.")
339
- parser.add_argument("--repeats", type=int, default=200, help="Benchmark repeats.")
340
- parser.add_argument("--warmup", type=int, default=25, help="Benchmark warmup runs.")
341
- parser.add_argument("--seed", type=int, default=0, help="Torch/random seed.")
342
- parser.add_argument(
343
- "--single-run",
344
- action="store_true",
345
- help="Run one specific config and print JSON-like output.",
346
- )
347
- parser.add_argument("--block-size", type=int, default=1024)
348
- parser.add_argument("--num-warps", type=int, default=4)
349
- parser.add_argument("--num-stages", type=int, default=2)
350
- parser.add_argument(
351
- "--append",
352
- action="store_true",
353
- help="Append to existing CSV file (default). If false, overwrite.",
354
- )
355
- parser.add_argument(
356
- "--overwrite",
357
- action="store_true",
358
- help="Overwrite existing CSV data instead of appending.",
359
- )
360
- return parser.parse_args()
361
-
362
-
363
- def main() -> None:
364
- args = parse_args()
365
-
366
- if not torch.cuda.is_available():
367
- raise RuntimeError("CUDA is not available.")
368
- if not torch.cuda.get_device_name(0):
369
- raise RuntimeError("No CUDA device was found.")
370
-
371
- append = args.append and not args.overwrite
372
- if args.single_run:
373
- row = benchmark_single_config(
374
- n=args.n_cols[0],
375
- block_size=args.block_size,
376
- num_warps=args.num_warps,
377
- num_stages=args.num_stages,
378
- m=args.m,
379
- repeats=args.repeats,
380
- warmup=args.warmup,
381
- seed=args.seed,
382
- )
383
- print(row.__dict__)
384
- return
385
-
386
- collect_measurements(
387
- output_path=args.output,
388
- n_values=args.n_cols,
389
- repeats=args.repeats,
390
- warmup=args.warmup,
391
- seed=args.seed,
392
- m=args.m,
393
- append=append,
394
- )
395
-
396
-
397
- if __name__ == "__main__":
398
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
scripts/collect_multifamily_measurements.py DELETED
@@ -1,659 +0,0 @@
1
- #!/usr/bin/env python3
2
- from __future__ import annotations
3
-
4
- import argparse
5
- import csv
6
- import json
7
- import math
8
- import sys
9
- from dataclasses import dataclass
10
- from pathlib import Path
11
- from typing import Dict, Iterable, List, Sequence, Tuple
12
-
13
- import numpy as np
14
- import torch
15
-
16
- ROOT = Path(__file__).resolve().parents[1]
17
- if str(ROOT) not in sys.path:
18
- sys.path.append(str(ROOT))
19
-
20
- try:
21
- import triton
22
- import triton.language as tl
23
- import triton.testing as ttesting
24
- except Exception as err: # pragma: no cover
25
- raise RuntimeError(
26
- "Triton is required for the multi-family measurement collector."
27
- ) from err
28
-
29
- from scripts.collect_measurements import (
30
- BLOCK_SIZES,
31
- NUM_STAGES,
32
- NUM_WARPS,
33
- N_VALUES,
34
- benchmark_single_config as benchmark_softmax_config,
35
- )
36
-
37
-
38
- EPS = 1e-5
39
- GROUPED_GEMM_GROUP_COUNT = 4
40
- GROUPED_GEMM_K = 512
41
-
42
-
43
- @dataclass(frozen=True)
44
- class MultiFamilyMeasurementRow:
45
- family_group: str
46
- family: str
47
- task_id: str
48
- m: int
49
- n: int
50
- config_id: int
51
- block_size: int
52
- num_warps: int
53
- num_stages: int
54
- shape_json: str
55
- config_json: str
56
- median_ms: float
57
- effective_gbps: float
58
- score: float
59
- validation_error: float
60
-
61
-
62
- @triton.jit
63
- def fused_rowwise_layernorm_kernel(
64
- X_ptr,
65
- Y_ptr,
66
- stride_xm,
67
- stride_xn,
68
- stride_ym,
69
- stride_yn,
70
- n_cols,
71
- eps,
72
- BLOCK_SIZE: tl.constexpr,
73
- ):
74
- row_idx = tl.program_id(0)
75
- col_offsets = tl.arange(0, BLOCK_SIZE)
76
- x_ptr = X_ptr + row_idx * stride_xm + col_offsets
77
- y_ptr = Y_ptr + row_idx * stride_ym + col_offsets
78
- mask = col_offsets < n_cols
79
-
80
- x = tl.load(x_ptr, mask=mask, other=0.0).to(tl.float32)
81
- mean = tl.sum(x, axis=0) / n_cols
82
- x_centered = x - mean
83
- var = tl.sum(x_centered * x_centered, axis=0) / n_cols
84
- inv_std = tl.rsqrt(var + eps)
85
- y = x_centered * inv_std
86
- tl.store(y_ptr, y.to(tl.float16), mask=mask)
87
-
88
-
89
- @triton.jit
90
- def matmul_kernel(
91
- A_ptr,
92
- B_ptr,
93
- C_ptr,
94
- M,
95
- N,
96
- K,
97
- stride_am,
98
- stride_ak,
99
- stride_bk,
100
- stride_bn,
101
- stride_cm,
102
- stride_cn,
103
- BLOCK_M: tl.constexpr,
104
- BLOCK_N: tl.constexpr,
105
- BLOCK_K: tl.constexpr,
106
- ):
107
- pid = tl.program_id(axis=0)
108
- num_pid_m = tl.cdiv(M, BLOCK_M)
109
- num_pid_n = tl.cdiv(N, BLOCK_N)
110
- pid_m = pid // num_pid_n
111
- pid_n = pid % num_pid_n
112
-
113
- offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
114
- offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
115
- offs_k = tl.arange(0, BLOCK_K)
116
-
117
- a_ptrs = A_ptr + offs_m[:, None] * stride_am + offs_k[None, :] * stride_ak
118
- b_ptrs = B_ptr + offs_k[:, None] * stride_bk + offs_n[None, :] * stride_bn
119
- acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
120
- for _ in range(0, tl.cdiv(K, BLOCK_K)):
121
- a = tl.load(a_ptrs, mask=(offs_m[:, None] < M) & (offs_k[None, :] < K), other=0.0)
122
- b = tl.load(b_ptrs, mask=(offs_k[:, None] < K) & (offs_n[None, :] < N), other=0.0)
123
- acc += tl.dot(a, b)
124
- a_ptrs += BLOCK_K * stride_ak
125
- b_ptrs += BLOCK_K * stride_bk
126
- offs_k += BLOCK_K
127
-
128
- c = acc.to(tl.float16)
129
- c_ptrs = C_ptr + offs_m[:, None] * stride_cm + offs_n[None, :] * stride_cn
130
- c_mask = (offs_m[:, None] < M) & (offs_n[None, :] < N)
131
- tl.store(c_ptrs, c, mask=c_mask)
132
-
133
-
134
- def _task_id(family: str, m: int, n: int) -> str:
135
- if family == "grouped_gemm":
136
- return f"{family}_g{GROUPED_GEMM_GROUP_COUNT}_k{GROUPED_GEMM_K}_m{m}_n{n}"
137
- return f"{family}_m{m}_n{n}"
138
-
139
-
140
- def _valid_configs(family: str, n: int) -> List[Tuple[int, int, int]]:
141
- configs: List[Tuple[int, int, int]] = []
142
- if family == "grouped_gemm":
143
- candidate_blocks = tuple(block for block in BLOCK_SIZES if block <= 1024)
144
- else:
145
- candidate_blocks = BLOCK_SIZES
146
- for block_size in candidate_blocks:
147
- if family != "grouped_gemm" and block_size < n:
148
- continue
149
- for num_warps in NUM_WARPS:
150
- for num_stages in NUM_STAGES:
151
- configs.append((block_size, num_warps, num_stages))
152
- return configs
153
-
154
-
155
- def _effective_gbps(sample: torch.Tensor, median_ms: float) -> float:
156
- bytes_processed = float(sample.numel() * sample.element_size() * 2)
157
- if median_ms <= 0:
158
- return 0.0
159
- return bytes_processed / (median_ms / 1000.0) / 1e9
160
-
161
-
162
- def _benchmark_layernorm_config(
163
- sample: torch.Tensor,
164
- block_size: int,
165
- num_warps: int,
166
- num_stages: int,
167
- repeats: int,
168
- warmup: int,
169
- ) -> float:
170
- output = torch.empty_like(sample)
171
- grid = (sample.shape[0],)
172
-
173
- def launch() -> None:
174
- fused_rowwise_layernorm_kernel[grid](
175
- sample,
176
- output,
177
- sample.stride(0),
178
- sample.stride(1),
179
- output.stride(0),
180
- output.stride(1),
181
- sample.shape[1],
182
- EPS,
183
- BLOCK_SIZE=block_size,
184
- num_warps=num_warps,
185
- num_stages=num_stages,
186
- )
187
-
188
- for _ in range(max(1, warmup)):
189
- launch()
190
- torch.cuda.synchronize()
191
-
192
- if ttesting is not None:
193
- result = ttesting.do_bench(
194
- launch,
195
- warmup=0,
196
- rep=repeats,
197
- quantiles=[0.5],
198
- return_mode="median",
199
- )
200
- if isinstance(result, dict):
201
- return float(result.get("median", 0.0))
202
- return float(result)
203
-
204
- events: List[float] = []
205
- start = torch.cuda.Event(enable_timing=True)
206
- end = torch.cuda.Event(enable_timing=True)
207
- for _ in range(max(1, repeats)):
208
- torch.cuda.synchronize()
209
- start.record()
210
- launch()
211
- end.record()
212
- end.synchronize()
213
- events.append(start.elapsed_time(end))
214
- return float(np.median(np.asarray(events, dtype=np.float32)))
215
-
216
-
217
- def _validate_layernorm(sample: torch.Tensor, block_size: int, num_warps: int, num_stages: int) -> float:
218
- ref = torch.nn.functional.layer_norm(sample.float(), (sample.shape[1],), eps=EPS).to(sample.dtype)
219
- out = torch.empty_like(sample)
220
- fused_rowwise_layernorm_kernel[(sample.shape[0],)](
221
- sample,
222
- out,
223
- sample.stride(0),
224
- sample.stride(1),
225
- out.stride(0),
226
- out.stride(1),
227
- sample.shape[1],
228
- EPS,
229
- BLOCK_SIZE=block_size,
230
- num_warps=num_warps,
231
- num_stages=num_stages,
232
- )
233
- torch.cuda.synchronize()
234
- return float((out.to(torch.float32) - ref).abs().max().item())
235
-
236
-
237
- def benchmark_layernorm_config(
238
- n: int,
239
- block_size: int,
240
- num_warps: int,
241
- num_stages: int,
242
- m: int = 4096,
243
- repeats: int = 200,
244
- warmup: int = 25,
245
- seed: int = 0,
246
- ) -> MultiFamilyMeasurementRow:
247
- if block_size < n:
248
- raise ValueError(f"Invalid config: BLOCK_SIZE {block_size} < N {n}")
249
-
250
- torch.manual_seed(seed)
251
- sample = torch.randn((m, n), device="cuda", dtype=torch.float16)
252
- median_ms = _benchmark_layernorm_config(
253
- sample=sample,
254
- block_size=block_size,
255
- num_warps=num_warps,
256
- num_stages=num_stages,
257
- repeats=repeats,
258
- warmup=warmup,
259
- )
260
- val_err = _validate_layernorm(
261
- sample=sample,
262
- block_size=block_size,
263
- num_warps=num_warps,
264
- num_stages=num_stages,
265
- )
266
- gbps = _effective_gbps(sample, median_ms)
267
- return _make_row(
268
- family="layernorm",
269
- m=m,
270
- n=n,
271
- block_size=block_size,
272
- num_warps=num_warps,
273
- num_stages=num_stages,
274
- median_ms=median_ms,
275
- effective_gbps=gbps,
276
- validation_error=val_err,
277
- )
278
-
279
-
280
- def _grouped_gemm_shapes(m: int, n: int) -> Tuple[int, int, int]:
281
- group_m = max(64, m // GROUPED_GEMM_GROUP_COUNT)
282
- return GROUPED_GEMM_GROUP_COUNT, group_m, GROUPED_GEMM_K
283
-
284
-
285
- def _matmul_meta_from_block(block_size: int) -> Tuple[int, int, int]:
286
- block_m = max(32, min(block_size // 4, 256))
287
- block_n = max(32, min(block_size // 4, 256))
288
- block_k = 32
289
- return block_m, block_n, block_k
290
-
291
-
292
- def _benchmark_grouped_gemm_config(
293
- a_groups: Sequence[torch.Tensor],
294
- b_groups: Sequence[torch.Tensor],
295
- block_size: int,
296
- num_warps: int,
297
- num_stages: int,
298
- repeats: int,
299
- warmup: int,
300
- ) -> float:
301
- c_groups = [torch.empty((a.shape[0], b.shape[1]), device=a.device, dtype=a.dtype) for a, b in zip(a_groups, b_groups)]
302
- block_m, block_n, block_k = _matmul_meta_from_block(block_size)
303
-
304
- def launch() -> None:
305
- for a, b, c in zip(a_groups, b_groups, c_groups):
306
- grid = (triton.cdiv(a.shape[0], block_m) * triton.cdiv(b.shape[1], block_n),)
307
- matmul_kernel[grid](
308
- a,
309
- b,
310
- c,
311
- a.shape[0],
312
- b.shape[1],
313
- a.shape[1],
314
- a.stride(0),
315
- a.stride(1),
316
- b.stride(0),
317
- b.stride(1),
318
- c.stride(0),
319
- c.stride(1),
320
- BLOCK_M=block_m,
321
- BLOCK_N=block_n,
322
- BLOCK_K=block_k,
323
- num_warps=num_warps,
324
- num_stages=num_stages,
325
- )
326
-
327
- for _ in range(max(1, warmup)):
328
- launch()
329
- torch.cuda.synchronize()
330
-
331
- if ttesting is not None:
332
- result = ttesting.do_bench(
333
- launch,
334
- warmup=0,
335
- rep=repeats,
336
- quantiles=[0.5],
337
- return_mode="median",
338
- )
339
- if isinstance(result, dict):
340
- return float(result.get("median", 0.0))
341
- return float(result)
342
-
343
- durations_ms: List[float] = []
344
- start = torch.cuda.Event(enable_timing=True)
345
- end = torch.cuda.Event(enable_timing=True)
346
- for _ in range(max(1, repeats)):
347
- torch.cuda.synchronize()
348
- start.record()
349
- launch()
350
- end.record()
351
- end.synchronize()
352
- durations_ms.append(start.elapsed_time(end))
353
- return float(np.median(np.asarray(durations_ms, dtype=np.float32)))
354
-
355
-
356
- def _validate_grouped_gemm(
357
- a_groups: Sequence[torch.Tensor],
358
- b_groups: Sequence[torch.Tensor],
359
- block_size: int,
360
- num_warps: int,
361
- num_stages: int,
362
- ) -> float:
363
- c_groups = [torch.empty((a.shape[0], b.shape[1]), device=a.device, dtype=a.dtype) for a, b in zip(a_groups, b_groups)]
364
- block_m, block_n, block_k = _matmul_meta_from_block(block_size)
365
- for a, b, c in zip(a_groups, b_groups, c_groups):
366
- grid = (triton.cdiv(a.shape[0], block_m) * triton.cdiv(b.shape[1], block_n),)
367
- matmul_kernel[grid](
368
- a,
369
- b,
370
- c,
371
- a.shape[0],
372
- b.shape[1],
373
- a.shape[1],
374
- a.stride(0),
375
- a.stride(1),
376
- b.stride(0),
377
- b.stride(1),
378
- c.stride(0),
379
- c.stride(1),
380
- BLOCK_M=block_m,
381
- BLOCK_N=block_n,
382
- BLOCK_K=block_k,
383
- num_warps=num_warps,
384
- num_stages=num_stages,
385
- )
386
- torch.cuda.synchronize()
387
-
388
- max_err = 0.0
389
- for a, b, c in zip(a_groups, b_groups, c_groups):
390
- ref = torch.matmul(a.float(), b.float()).to(torch.float16)
391
- max_err = max(max_err, float((c.float() - ref.float()).abs().max().item()))
392
- return max_err
393
-
394
-
395
- def benchmark_grouped_gemm_config(
396
- n: int,
397
- block_size: int,
398
- num_warps: int,
399
- num_stages: int,
400
- m: int = 4096,
401
- repeats: int = 200,
402
- warmup: int = 25,
403
- seed: int = 0,
404
- ) -> MultiFamilyMeasurementRow:
405
- group_count, group_m, k_dim = _grouped_gemm_shapes(m, n)
406
- torch.manual_seed(seed)
407
- a_groups = [torch.randn((group_m, k_dim), device="cuda", dtype=torch.float16) for _ in range(group_count)]
408
- b_groups = [torch.randn((k_dim, n), device="cuda", dtype=torch.float16) for _ in range(group_count)]
409
-
410
- median_ms = _benchmark_grouped_gemm_config(
411
- a_groups=a_groups,
412
- b_groups=b_groups,
413
- block_size=block_size,
414
- num_warps=num_warps,
415
- num_stages=num_stages,
416
- repeats=repeats,
417
- warmup=warmup,
418
- )
419
- val_err = _validate_grouped_gemm(
420
- a_groups=a_groups,
421
- b_groups=b_groups,
422
- block_size=block_size,
423
- num_warps=num_warps,
424
- num_stages=num_stages,
425
- )
426
- bytes_processed = group_count * (
427
- a_groups[0].numel() * a_groups[0].element_size()
428
- + b_groups[0].numel() * b_groups[0].element_size()
429
- + group_m * n * a_groups[0].element_size()
430
- )
431
- effective_gbps = bytes_processed / (median_ms / 1000.0) / 1e9 if median_ms > 0 else 0.0
432
- return _make_row(
433
- family="grouped_gemm",
434
- m=m,
435
- n=n,
436
- block_size=block_size,
437
- num_warps=num_warps,
438
- num_stages=num_stages,
439
- median_ms=median_ms,
440
- effective_gbps=effective_gbps,
441
- validation_error=val_err,
442
- )
443
-
444
-
445
- def _make_row(
446
- family: str,
447
- m: int,
448
- n: int,
449
- block_size: int,
450
- num_warps: int,
451
- num_stages: int,
452
- median_ms: float,
453
- effective_gbps: float,
454
- validation_error: float,
455
- ) -> MultiFamilyMeasurementRow:
456
- configs = _valid_configs(family, n)
457
- config_id = configs.index((block_size, num_warps, num_stages))
458
- return MultiFamilyMeasurementRow(
459
- family_group="A" if family in {"softmax", "layernorm"} else "B",
460
- family=family,
461
- task_id=_task_id(family, m, n),
462
- m=m,
463
- n=n,
464
- config_id=config_id,
465
- block_size=block_size,
466
- num_warps=num_warps,
467
- num_stages=num_stages,
468
- shape_json=json.dumps(
469
- {
470
- "family": family,
471
- "m": m,
472
- "n": n,
473
- "group_count": GROUPED_GEMM_GROUP_COUNT if family == "grouped_gemm" else None,
474
- "k": GROUPED_GEMM_K if family == "grouped_gemm" else None,
475
- },
476
- sort_keys=True,
477
- ),
478
- config_json=json.dumps(
479
- {
480
- "block_size": block_size,
481
- "num_warps": num_warps,
482
- "num_stages": num_stages,
483
- },
484
- sort_keys=True,
485
- ),
486
- median_ms=float(median_ms),
487
- effective_gbps=float(effective_gbps),
488
- score=float(-math.log(max(median_ms, np.finfo(float).tiny))),
489
- validation_error=float(validation_error),
490
- )
491
-
492
-
493
- def _softmax_row_to_multi(row: object) -> MultiFamilyMeasurementRow:
494
- return _make_row(
495
- family="softmax",
496
- m=int(row.m),
497
- n=int(row.n),
498
- block_size=int(row.block_size),
499
- num_warps=int(row.num_warps),
500
- num_stages=int(row.num_stages),
501
- median_ms=float(row.median_ms),
502
- effective_gbps=float(row.effective_gbps),
503
- validation_error=float(row.validation_error),
504
- )
505
-
506
-
507
- def collect_multifamily_measurements(
508
- output_path: Path,
509
- families: Sequence[str],
510
- n_values: Iterable[int],
511
- repeats: int,
512
- warmup: int,
513
- seed: int,
514
- m: int = 4096,
515
- append: bool = True,
516
- ) -> List[MultiFamilyMeasurementRow]:
517
- if not torch.cuda.is_available():
518
- raise RuntimeError("CUDA is required to run measurements on GPU.")
519
-
520
- output_path.parent.mkdir(parents=True, exist_ok=True)
521
- existing = set()
522
- if output_path.exists():
523
- with output_path.open("r", newline="", encoding="utf-8") as handle:
524
- for row in csv.DictReader(handle):
525
- existing.add(_row_key(row["family"], int(row["m"]), int(row["n"]), int(row["config_id"])))
526
-
527
- results: List[MultiFamilyMeasurementRow] = []
528
- for family in families:
529
- if family not in {"softmax", "layernorm", "grouped_gemm"}:
530
- raise ValueError(f"Unsupported family: {family}")
531
- for n in n_values:
532
- for config_id, (block_size, num_warps, num_stages) in enumerate(_valid_configs(family, n)):
533
- key = _row_key(family, m, n, config_id)
534
- if append and key in existing:
535
- continue
536
- if family == "softmax":
537
- row = _softmax_row_to_multi(
538
- benchmark_softmax_config(
539
- n=n,
540
- block_size=block_size,
541
- num_warps=num_warps,
542
- num_stages=num_stages,
543
- m=m,
544
- repeats=repeats,
545
- warmup=warmup,
546
- seed=seed,
547
- )
548
- )
549
- elif family == "layernorm":
550
- row = benchmark_layernorm_config(
551
- n=n,
552
- block_size=block_size,
553
- num_warps=num_warps,
554
- num_stages=num_stages,
555
- m=m,
556
- repeats=repeats,
557
- warmup=warmup,
558
- seed=seed,
559
- )
560
- else:
561
- row = benchmark_grouped_gemm_config(
562
- n=n,
563
- block_size=block_size,
564
- num_warps=num_warps,
565
- num_stages=num_stages,
566
- m=m,
567
- repeats=repeats,
568
- warmup=warmup,
569
- seed=seed,
570
- )
571
- results.append(row)
572
-
573
- if results:
574
- write_mode = "a" if output_path.exists() and append else "w"
575
- with output_path.open(write_mode, newline="", encoding="utf-8") as handle:
576
- writer = csv.writer(handle)
577
- if write_mode == "w" or output_path.stat().st_size == 0:
578
- writer.writerow(
579
- [
580
- "family_group",
581
- "family",
582
- "task_id",
583
- "m",
584
- "n",
585
- "config_id",
586
- "block_size",
587
- "num_warps",
588
- "num_stages",
589
- "shape_json",
590
- "config_json",
591
- "median_ms",
592
- "effective_gbps",
593
- "score",
594
- "validation_error",
595
- ]
596
- )
597
- for row in results:
598
- writer.writerow(
599
- [
600
- row.family_group,
601
- row.family,
602
- row.task_id,
603
- row.m,
604
- row.n,
605
- row.config_id,
606
- row.block_size,
607
- row.num_warps,
608
- row.num_stages,
609
- row.shape_json,
610
- row.config_json,
611
- row.median_ms,
612
- row.effective_gbps,
613
- row.score,
614
- row.validation_error,
615
- ]
616
- )
617
- return results
618
-
619
-
620
- def _row_key(family: str, m: int, n: int, config_id: int) -> str:
621
- return f"{family}|{m}|{n}|{config_id}"
622
-
623
-
624
- def parse_args() -> argparse.Namespace:
625
- parser = argparse.ArgumentParser(description="Collect multi-family GPU measurements.")
626
- parser.add_argument("--output", type=Path, default=Path("data/autotune_measurements.csv"))
627
- parser.add_argument(
628
- "--families",
629
- nargs="+",
630
- default=("softmax", "layernorm", "grouped_gemm"),
631
- choices=("softmax", "layernorm", "grouped_gemm"),
632
- )
633
- parser.add_argument("--n-cols", type=int, nargs="+", default=N_VALUES)
634
- parser.add_argument("--m", type=int, default=4096)
635
- parser.add_argument("--repeats", type=int, default=200)
636
- parser.add_argument("--warmup", type=int, default=25)
637
- parser.add_argument("--seed", type=int, default=0)
638
- parser.add_argument("--append", action="store_true")
639
- parser.add_argument("--overwrite", action="store_true")
640
- return parser.parse_args()
641
-
642
-
643
- def main() -> None:
644
- args = parse_args()
645
- append = args.append and not args.overwrite
646
- collect_multifamily_measurements(
647
- output_path=args.output,
648
- families=args.families,
649
- n_values=args.n_cols,
650
- repeats=args.repeats,
651
- warmup=args.warmup,
652
- seed=args.seed,
653
- m=args.m,
654
- append=append,
655
- )
656
-
657
-
658
- if __name__ == "__main__":
659
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
scripts/collect_qwen_05b_measurements.py DELETED
@@ -1,449 +0,0 @@
1
- #!/usr/bin/env python3
2
- from __future__ import annotations
3
-
4
- import argparse
5
- import csv
6
- import json
7
- import math
8
- import sys
9
- from dataclasses import dataclass
10
- from pathlib import Path
11
- from typing import Dict, List, Sequence
12
-
13
- import numpy as np
14
- import torch
15
-
16
- ROOT = Path(__file__).resolve().parents[1]
17
- if str(ROOT) not in sys.path:
18
- sys.path.append(str(ROOT))
19
-
20
- try:
21
- import triton
22
- import triton.language as tl
23
- import triton.testing as ttesting
24
- except Exception as err: # pragma: no cover
25
- raise RuntimeError("Triton is required for Qwen kernel measurement.") from err
26
-
27
- from scripts.collect_measurements import BLOCK_SIZES, NUM_STAGES, NUM_WARPS, benchmark_single_config
28
- from scripts.qwen_05b_spec import QwenKernelTask, qwen_05b_tasks
29
-
30
-
31
- EPS = 1e-5
32
-
33
-
34
- @dataclass(frozen=True)
35
- class QwenMeasurementRow:
36
- family_group: str
37
- family: str
38
- task_id: str
39
- m: int
40
- n: int
41
- k: int
42
- config_id: int
43
- block_size: int
44
- num_warps: int
45
- num_stages: int
46
- shape_json: str
47
- config_json: str
48
- median_ms: float
49
- effective_gbps: float
50
- score: float
51
- validation_error: float
52
-
53
-
54
- @triton.jit
55
- def fused_rowwise_rmsnorm_kernel(
56
- X_ptr,
57
- Y_ptr,
58
- stride_xm,
59
- stride_xn,
60
- stride_ym,
61
- stride_yn,
62
- n_cols,
63
- eps,
64
- BLOCK_SIZE: tl.constexpr,
65
- ):
66
- row_idx = tl.program_id(0)
67
- col_offsets = tl.arange(0, BLOCK_SIZE)
68
- mask = col_offsets < n_cols
69
- x_ptr = X_ptr + row_idx * stride_xm + col_offsets
70
- y_ptr = Y_ptr + row_idx * stride_ym + col_offsets
71
-
72
- x = tl.load(x_ptr, mask=mask, other=0.0).to(tl.float32)
73
- mean_sq = tl.sum(x * x, axis=0) / n_cols
74
- inv_rms = tl.rsqrt(mean_sq + eps)
75
- y = x * inv_rms
76
- tl.store(y_ptr, y.to(tl.float16), mask=mask)
77
-
78
-
79
- @triton.jit
80
- def matmul_kernel(
81
- A_ptr,
82
- B_ptr,
83
- C_ptr,
84
- M,
85
- N,
86
- K,
87
- stride_am,
88
- stride_ak,
89
- stride_bk,
90
- stride_bn,
91
- stride_cm,
92
- stride_cn,
93
- BLOCK_M: tl.constexpr,
94
- BLOCK_N: tl.constexpr,
95
- BLOCK_K: tl.constexpr,
96
- ):
97
- pid = tl.program_id(axis=0)
98
- num_pid_m = tl.cdiv(M, BLOCK_M)
99
- num_pid_n = tl.cdiv(N, BLOCK_N)
100
- pid_m = pid // num_pid_n
101
- pid_n = pid % num_pid_n
102
-
103
- offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
104
- offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
105
- offs_k = tl.arange(0, BLOCK_K)
106
-
107
- a_ptrs = A_ptr + offs_m[:, None] * stride_am + offs_k[None, :] * stride_ak
108
- b_ptrs = B_ptr + offs_k[:, None] * stride_bk + offs_n[None, :] * stride_bn
109
- acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
110
- for _ in range(0, tl.cdiv(K, BLOCK_K)):
111
- a = tl.load(a_ptrs, mask=(offs_m[:, None] < M) & (offs_k[None, :] < K), other=0.0)
112
- b = tl.load(b_ptrs, mask=(offs_k[:, None] < K) & (offs_n[None, :] < N), other=0.0)
113
- acc += tl.dot(a, b)
114
- a_ptrs += BLOCK_K * stride_ak
115
- b_ptrs += BLOCK_K * stride_bk
116
- offs_k += BLOCK_K
117
-
118
- c = acc.to(tl.float16)
119
- c_ptrs = C_ptr + offs_m[:, None] * stride_cm + offs_n[None, :] * stride_cn
120
- c_mask = (offs_m[:, None] < M) & (offs_n[None, :] < N)
121
- tl.store(c_ptrs, c, mask=c_mask)
122
-
123
-
124
- def _gemm_blocks(block_size: int) -> tuple[int, int, int]:
125
- block_m = max(16, min(block_size // 4, 128))
126
- block_n = max(16, min(block_size // 4, 128))
127
- block_k = 32
128
- return block_m, block_n, block_k
129
-
130
-
131
- def _effective_gbps(bytes_processed: int, median_ms: float) -> float:
132
- if median_ms <= 0:
133
- return 0.0
134
- return float(bytes_processed) / (median_ms / 1000.0) / 1e9
135
-
136
-
137
- def _score(ms: float) -> float:
138
- return float(-math.log(max(ms, np.finfo(float).tiny)))
139
-
140
-
141
- def _config_json(block_size: int, num_warps: int, num_stages: int) -> str:
142
- return json.dumps(
143
- {"block_size": block_size, "num_warps": num_warps, "num_stages": num_stages},
144
- sort_keys=True,
145
- )
146
-
147
-
148
- def _valid_configs(task: QwenKernelTask) -> List[tuple[int, int, int]]:
149
- configs: List[tuple[int, int, int]] = []
150
- for block_size in BLOCK_SIZES:
151
- if task.family in {"softmax", "rmsnorm"} and block_size < task.n:
152
- continue
153
- if task.family == "gemm" and block_size > 1024:
154
- continue
155
- for num_warps in NUM_WARPS:
156
- for num_stages in NUM_STAGES:
157
- configs.append((block_size, num_warps, num_stages))
158
- return configs
159
-
160
-
161
- def _benchmark_rmsnorm(task: QwenKernelTask, block_size: int, num_warps: int, num_stages: int, repeats: int, warmup: int, seed: int) -> QwenMeasurementRow:
162
- torch.manual_seed(seed)
163
- sample = torch.randn((task.m, task.n), device="cuda", dtype=torch.float16)
164
- output = torch.empty_like(sample)
165
- grid = (sample.shape[0],)
166
-
167
- def launch() -> None:
168
- fused_rowwise_rmsnorm_kernel[grid](
169
- sample,
170
- output,
171
- sample.stride(0),
172
- sample.stride(1),
173
- output.stride(0),
174
- output.stride(1),
175
- sample.shape[1],
176
- EPS,
177
- BLOCK_SIZE=block_size,
178
- num_warps=num_warps,
179
- num_stages=num_stages,
180
- )
181
-
182
- for _ in range(max(1, warmup)):
183
- launch()
184
- torch.cuda.synchronize()
185
- if ttesting is not None:
186
- result = ttesting.do_bench(launch, warmup=0, rep=repeats, quantiles=[0.5], return_mode="median")
187
- median_ms = float(result.get("median", 0.0) if isinstance(result, dict) else result)
188
- else:
189
- start = torch.cuda.Event(enable_timing=True)
190
- end = torch.cuda.Event(enable_timing=True)
191
- durations: List[float] = []
192
- for _ in range(max(1, repeats)):
193
- torch.cuda.synchronize()
194
- start.record()
195
- launch()
196
- end.record()
197
- end.synchronize()
198
- durations.append(start.elapsed_time(end))
199
- median_ms = float(np.median(np.asarray(durations, dtype=np.float32)))
200
-
201
- ref = sample.float() * torch.rsqrt(sample.float().pow(2).mean(dim=-1, keepdim=True) + EPS)
202
- fused_rowwise_rmsnorm_kernel[grid](
203
- sample,
204
- output,
205
- sample.stride(0),
206
- sample.stride(1),
207
- output.stride(0),
208
- output.stride(1),
209
- sample.shape[1],
210
- EPS,
211
- BLOCK_SIZE=block_size,
212
- num_warps=num_warps,
213
- num_stages=num_stages,
214
- )
215
- torch.cuda.synchronize()
216
- max_err = float((output.float() - ref.float()).abs().max().item())
217
- config_id = _valid_configs(task).index((block_size, num_warps, num_stages))
218
- return QwenMeasurementRow(
219
- family_group=task.family_group,
220
- family=task.family,
221
- task_id=task.task_id,
222
- m=task.m,
223
- n=task.n,
224
- k=task.k,
225
- config_id=config_id,
226
- block_size=block_size,
227
- num_warps=num_warps,
228
- num_stages=num_stages,
229
- shape_json=json.dumps(task.shape_fields(), sort_keys=True),
230
- config_json=_config_json(block_size, num_warps, num_stages),
231
- median_ms=median_ms,
232
- effective_gbps=_effective_gbps(sample.numel() * sample.element_size() * 2, median_ms),
233
- score=_score(median_ms),
234
- validation_error=max_err,
235
- )
236
-
237
-
238
- def _benchmark_gemm(task: QwenKernelTask, block_size: int, num_warps: int, num_stages: int, repeats: int, warmup: int, seed: int) -> QwenMeasurementRow:
239
- torch.manual_seed(seed)
240
- a = torch.randn((task.m, task.k), device="cuda", dtype=torch.float16)
241
- b = torch.randn((task.k, task.n), device="cuda", dtype=torch.float16)
242
- c = torch.empty((task.m, task.n), device="cuda", dtype=torch.float16)
243
- block_m, block_n, block_k = _gemm_blocks(block_size)
244
-
245
- def launch() -> None:
246
- grid = (triton.cdiv(task.m, block_m) * triton.cdiv(task.n, block_n),)
247
- matmul_kernel[grid](
248
- a,
249
- b,
250
- c,
251
- task.m,
252
- task.n,
253
- task.k,
254
- a.stride(0),
255
- a.stride(1),
256
- b.stride(0),
257
- b.stride(1),
258
- c.stride(0),
259
- c.stride(1),
260
- BLOCK_M=block_m,
261
- BLOCK_N=block_n,
262
- BLOCK_K=block_k,
263
- num_warps=num_warps,
264
- num_stages=num_stages,
265
- )
266
-
267
- for _ in range(max(1, warmup)):
268
- launch()
269
- torch.cuda.synchronize()
270
- if ttesting is not None:
271
- result = ttesting.do_bench(launch, warmup=0, rep=repeats, quantiles=[0.5], return_mode="median")
272
- median_ms = float(result.get("median", 0.0) if isinstance(result, dict) else result)
273
- else:
274
- start = torch.cuda.Event(enable_timing=True)
275
- end = torch.cuda.Event(enable_timing=True)
276
- durations: List[float] = []
277
- for _ in range(max(1, repeats)):
278
- torch.cuda.synchronize()
279
- start.record()
280
- launch()
281
- end.record()
282
- end.synchronize()
283
- durations.append(start.elapsed_time(end))
284
- median_ms = float(np.median(np.asarray(durations, dtype=np.float32)))
285
-
286
- ref = torch.matmul(a.float(), b.float()).to(torch.float16)
287
- launch()
288
- torch.cuda.synchronize()
289
- max_err = float((c.float() - ref.float()).abs().max().item())
290
- config_id = _valid_configs(task).index((block_size, num_warps, num_stages))
291
- bytes_processed = a.numel() * a.element_size() + b.numel() * b.element_size() + c.numel() * c.element_size()
292
- return QwenMeasurementRow(
293
- family_group=task.family_group,
294
- family=task.family,
295
- task_id=task.task_id,
296
- m=task.m,
297
- n=task.n,
298
- k=task.k,
299
- config_id=config_id,
300
- block_size=block_size,
301
- num_warps=num_warps,
302
- num_stages=num_stages,
303
- shape_json=json.dumps(task.shape_fields(), sort_keys=True),
304
- config_json=_config_json(block_size, num_warps, num_stages),
305
- median_ms=median_ms,
306
- effective_gbps=_effective_gbps(bytes_processed, median_ms),
307
- score=_score(median_ms),
308
- validation_error=max_err,
309
- )
310
-
311
-
312
- def benchmark_qwen_task(task: QwenKernelTask, block_size: int, num_warps: int, num_stages: int, repeats: int, warmup: int, seed: int) -> QwenMeasurementRow:
313
- if task.family == "softmax":
314
- row = benchmark_single_config(
315
- n=task.n,
316
- block_size=block_size,
317
- num_warps=num_warps,
318
- num_stages=num_stages,
319
- m=task.m,
320
- repeats=repeats,
321
- warmup=warmup,
322
- seed=seed,
323
- )
324
- config_id = _valid_configs(task).index((block_size, num_warps, num_stages))
325
- return QwenMeasurementRow(
326
- family_group=task.family_group,
327
- family=task.family,
328
- task_id=task.task_id,
329
- m=task.m,
330
- n=task.n,
331
- k=0,
332
- config_id=config_id,
333
- block_size=block_size,
334
- num_warps=num_warps,
335
- num_stages=num_stages,
336
- shape_json=json.dumps(task.shape_fields(), sort_keys=True),
337
- config_json=_config_json(block_size, num_warps, num_stages),
338
- median_ms=float(row.median_ms),
339
- effective_gbps=float(row.effective_gbps),
340
- score=float(row.score),
341
- validation_error=float(row.validation_error),
342
- )
343
- if task.family == "rmsnorm":
344
- return _benchmark_rmsnorm(task, block_size, num_warps, num_stages, repeats, warmup, seed)
345
- if task.family == "gemm":
346
- return _benchmark_gemm(task, block_size, num_warps, num_stages, repeats, warmup, seed)
347
- raise ValueError(f"Unsupported family: {task.family}")
348
-
349
-
350
- def collect_qwen_measurements(output_path: Path, repeats: int, warmup: int, seed: int, append: bool = True) -> List[QwenMeasurementRow]:
351
- output_path.parent.mkdir(parents=True, exist_ok=True)
352
- existing = set()
353
- if output_path.exists():
354
- with output_path.open("r", newline="", encoding="utf-8") as handle:
355
- for row in csv.DictReader(handle):
356
- existing.add((row["task_id"], int(row["config_id"])))
357
-
358
- results: List[QwenMeasurementRow] = []
359
- for task in qwen_05b_tasks():
360
- for config_id, (block_size, num_warps, num_stages) in enumerate(_valid_configs(task)):
361
- key = (task.task_id, config_id)
362
- if append and key in existing:
363
- continue
364
- results.append(
365
- benchmark_qwen_task(
366
- task=task,
367
- block_size=block_size,
368
- num_warps=num_warps,
369
- num_stages=num_stages,
370
- repeats=repeats,
371
- warmup=warmup,
372
- seed=seed,
373
- )
374
- )
375
-
376
- if results:
377
- write_mode = "a" if output_path.exists() and append else "w"
378
- with output_path.open(write_mode, newline="", encoding="utf-8") as handle:
379
- writer = csv.writer(handle)
380
- if write_mode == "w" or output_path.stat().st_size == 0:
381
- writer.writerow(
382
- [
383
- "family_group",
384
- "family",
385
- "task_id",
386
- "m",
387
- "n",
388
- "k",
389
- "config_id",
390
- "block_size",
391
- "num_warps",
392
- "num_stages",
393
- "shape_json",
394
- "config_json",
395
- "median_ms",
396
- "effective_gbps",
397
- "score",
398
- "validation_error",
399
- ]
400
- )
401
- for row in results:
402
- writer.writerow(
403
- [
404
- row.family_group,
405
- row.family,
406
- row.task_id,
407
- row.m,
408
- row.n,
409
- row.k,
410
- row.config_id,
411
- row.block_size,
412
- row.num_warps,
413
- row.num_stages,
414
- row.shape_json,
415
- row.config_json,
416
- row.median_ms,
417
- row.effective_gbps,
418
- row.score,
419
- row.validation_error,
420
- ]
421
- )
422
- return results
423
-
424
-
425
- def parse_args() -> argparse.Namespace:
426
- parser = argparse.ArgumentParser(description="Collect exact-kernel measurements for Qwen2.5-0.5B.")
427
- parser.add_argument("--output", type=Path, default=Path("data/qwen_05b_measurements.csv"))
428
- parser.add_argument("--repeats", type=int, default=200)
429
- parser.add_argument("--warmup", type=int, default=25)
430
- parser.add_argument("--seed", type=int, default=0)
431
- parser.add_argument("--append", action="store_true")
432
- parser.add_argument("--overwrite", action="store_true")
433
- return parser.parse_args()
434
-
435
-
436
- def main() -> None:
437
- args = parse_args()
438
- append = args.append and not args.overwrite
439
- collect_qwen_measurements(
440
- output_path=args.output,
441
- repeats=args.repeats,
442
- warmup=args.warmup,
443
- seed=args.seed,
444
- append=append,
445
- )
446
-
447
-
448
- if __name__ == "__main__":
449
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
scripts/demo_compare.py DELETED
@@ -1,200 +0,0 @@
1
- #!/usr/bin/env python3
2
- from __future__ import annotations
3
-
4
- import argparse
5
- import json
6
- from dataclasses import dataclass
7
- import sys
8
- from pathlib import Path
9
- from typing import Any, Dict, List
10
-
11
- ROOT = Path(__file__).resolve().parents[1]
12
- if str(ROOT) not in sys.path:
13
- sys.path.append(str(ROOT))
14
-
15
- from server.softmax_surrogate_environment import SoftmaxSurrogateEnvironment
16
- from scripts.collect_measurements import benchmark_single_config
17
- from scripts.run_random_baseline import run_random_baseline
18
- from scripts.run_surrogate_baseline import run_surrogate_baseline
19
-
20
-
21
- @dataclass
22
- class BaselineResult:
23
- method: str
24
- latency_ms: float
25
- config: Dict[str, int]
26
- regret: float
27
-
28
-
29
- def _search_metric_key(summary: Dict[str, Any], oracle_best_ms: float) -> tuple[float, float, float]:
30
- metrics = summary.get("aggregate_metrics", {})
31
- mean_auc_regret = float(metrics.get("mean_auc_regret", float("inf")))
32
- oracle_hit_rate_final = float(metrics.get("oracle_hit_rate_final", 0.0))
33
- best_latency_ms = float(summary["best_overall"]["latency_ms"])
34
- latency_regret = best_latency_ms / oracle_best_ms - 1.0
35
- return (mean_auc_regret, -oracle_hit_rate_final, latency_regret)
36
-
37
-
38
- def _heuristic_for_task(task_id: str, task_rows: List[Dict[str, Any]], env: SoftmaxSurrogateEnvironment) -> BaselineResult:
39
- n = int(task_id.split("_n")[-1])
40
- block = min(row["block_size"] for row in task_rows if row["block_size"] >= n)
41
- warp = 4 if 4 in {row["num_warps"] for row in task_rows} else 2
42
- stage = 2 if 2 in {row["num_stages"] for row in task_rows} else 1
43
-
44
- candidate = None
45
- for row in task_rows:
46
- if row["block_size"] == block and row["num_warps"] == warp and row["num_stages"] == stage:
47
- candidate = row
48
- break
49
- if candidate is None:
50
- candidate = min(
51
- task_rows,
52
- key=lambda row: abs(row["block_size"] - block) + 10 * abs(row["num_warps"] - warp),
53
- )
54
-
55
- latency_ms = env.measured_latency_ms(candidate["config_id"])
56
- oracle_best_ms = env.oracle_best()["median_ms"]
57
- return BaselineResult(
58
- method="heuristic",
59
- latency_ms=float(latency_ms),
60
- config=candidate,
61
- regret=float(latency_ms / oracle_best_ms - 1.0),
62
- )
63
-
64
-
65
- def _pick_task(task_arg: str | None, measurement_path: str, budget: int) -> str:
66
- env = SoftmaxSurrogateEnvironment(measurement_path=measurement_path, budget=budget, seed=0)
67
- if task_arg:
68
- env.reset(task=task_arg)
69
- else:
70
- env.reset()
71
- return env.state()["task_id"]
72
-
73
-
74
- def _run_all(
75
- task: str,
76
- budget: int,
77
- episodes: int,
78
- seed: int,
79
- measurement_path: str,
80
- acquisition: str,
81
- beta: float,
82
- xi: float,
83
- ) -> Dict[str, Any]:
84
- env = SoftmaxSurrogateEnvironment(measurement_path=measurement_path, budget=budget, seed=seed)
85
- env.reset(task=task)
86
- task_rows = env.available_configs()
87
- oracle_best = env.oracle_best()
88
-
89
- heuristic = _heuristic_for_task(task, task_rows, env)
90
- random_summary = run_random_baseline(task, episodes=episodes, budget=budget, seed=seed, measurement_path=measurement_path)
91
- surrogate_summary = run_surrogate_baseline(
92
- task,
93
- episodes=episodes,
94
- budget=budget,
95
- seed=seed,
96
- measurement_path=measurement_path,
97
- acquisition=acquisition,
98
- beta=beta,
99
- xi=xi,
100
- )
101
-
102
- search_summaries = {
103
- "random": random_summary,
104
- "surrogate": surrogate_summary,
105
- }
106
- winner_method, winner_summary = min(
107
- search_summaries.items(),
108
- key=lambda item: _search_metric_key(item[1], oracle_best["median_ms"]),
109
- )
110
- winner_cfg = winner_summary["best_overall"]["config"]
111
- winner_regret = float(winner_summary["best_overall"]["latency_ms"] / oracle_best["median_ms"] - 1.0)
112
- n = int(task.split("_n")[-1])
113
-
114
- live = benchmark_single_config(
115
- n=n,
116
- block_size=winner_cfg["block_size"],
117
- num_warps=winner_cfg["num_warps"],
118
- num_stages=winner_cfg["num_stages"],
119
- repeats=max(200, budget * 20),
120
- warmup=25,
121
- seed=seed + 999,
122
- )
123
-
124
- return {
125
- "task": task,
126
- "seed": seed,
127
- "budget": budget,
128
- "episodes": episodes,
129
- "acquisition": acquisition,
130
- "beta": beta,
131
- "xi": xi,
132
- "oracle_best": oracle_best,
133
- "heuristic": heuristic.__dict__,
134
- "random": random_summary["best_overall"],
135
- "random_aggregate_metrics": random_summary.get("aggregate_metrics", {}),
136
- "surrogate": surrogate_summary["best_overall"],
137
- "surrogate_aggregate_metrics": surrogate_summary.get("aggregate_metrics", {}),
138
- "winner": {
139
- "method": winner_method,
140
- "selection_metric": "min(mean_auc_regret), tie-break max(oracle_hit_rate_final), then best latency",
141
- "latency_ms": winner_summary["best_overall"]["latency_ms"],
142
- "config": winner_cfg,
143
- "regret": winner_regret,
144
- "live_rerun": live.__dict__,
145
- },
146
- }
147
-
148
-
149
- def parse_args() -> argparse.Namespace:
150
- parser = argparse.ArgumentParser(description="Compare heuristic/random/surrogate baselines.")
151
- parser.add_argument(
152
- "--task",
153
- default="softmax_m4096_n2048",
154
- help="Task ID (e.g., softmax_m4096_n2048)",
155
- )
156
- parser.add_argument("--budget", type=int, default=6)
157
- parser.add_argument("--episodes", type=int, default=20)
158
- parser.add_argument("--seed", type=int, default=2)
159
- parser.add_argument(
160
- "--acquisition",
161
- type=str,
162
- choices=("mean", "ucb", "ei"),
163
- default="ucb",
164
- )
165
- parser.add_argument("--beta", type=float, default=2.0)
166
- parser.add_argument("--xi", type=float, default=0.0)
167
- parser.add_argument(
168
- "--measurement-path",
169
- type=str,
170
- default="data/autotune_measurements.csv",
171
- )
172
- parser.add_argument(
173
- "--output",
174
- type=Path,
175
- default=Path("outputs/demo_compare.json"),
176
- )
177
- return parser.parse_args()
178
-
179
-
180
- def main() -> None:
181
- args = parse_args()
182
- task = _pick_task(args.task, args.measurement_path, args.budget)
183
- summary = _run_all(
184
- task=task,
185
- budget=args.budget,
186
- episodes=args.episodes,
187
- seed=args.seed,
188
- measurement_path=args.measurement_path,
189
- acquisition=args.acquisition,
190
- beta=args.beta,
191
- xi=args.xi,
192
- )
193
- args.output.parent.mkdir(parents=True, exist_ok=True)
194
- with args.output.open("w", encoding="utf-8") as f:
195
- json.dump(summary, f, indent=2)
196
- print(json.dumps(summary, indent=2))
197
-
198
-
199
- if __name__ == "__main__":
200
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
scripts/eval_generalization.py DELETED
@@ -1,147 +0,0 @@
1
- #!/usr/bin/env python3
2
- from __future__ import annotations
3
-
4
- import argparse
5
- import json
6
- import sys
7
- from pathlib import Path
8
- from typing import Any, Dict, List
9
-
10
- import numpy as np
11
-
12
- ROOT = Path(__file__).resolve().parents[1]
13
- if str(ROOT) not in sys.path:
14
- sys.path.append(str(ROOT))
15
-
16
- from scripts.run_random_baseline import run_random_baseline
17
- from scripts.run_surrogate_baseline import run_surrogate_baseline
18
-
19
-
20
- def _average_metric_dict(records: List[Dict[str, float]]) -> Dict[str, float]:
21
- if not records:
22
- return {}
23
- keys = sorted({key for record in records for key in record.keys()}, key=lambda value: int(value))
24
- return {
25
- key: float(np.mean(np.asarray([record[key] for record in records if key in record], dtype=np.float32)))
26
- for key in keys
27
- }
28
-
29
-
30
- def _summarize_runs(runs: List[Dict[str, Any]]) -> Dict[str, Any]:
31
- mean_regret_records = [run["aggregate_metrics"].get("mean_regret_at", {}) for run in runs]
32
- median_regret_records = [run["aggregate_metrics"].get("median_regret_at", {}) for run in runs]
33
- auc_values = [run["aggregate_metrics"].get("mean_auc_regret") for run in runs]
34
- oracle_hit_values = [run["aggregate_metrics"].get("oracle_hit_rate_final") for run in runs]
35
- return {
36
- "mean_regret_at": _average_metric_dict(mean_regret_records),
37
- "median_regret_at": _average_metric_dict(median_regret_records),
38
- "mean_best_so_far_auc": float(np.mean(np.asarray(auc_values, dtype=np.float32))) if auc_values else None,
39
- "mean_oracle_hit_rate_final": float(np.mean(np.asarray(oracle_hit_values, dtype=np.float32))) if oracle_hit_values else None,
40
- }
41
-
42
-
43
- def _evaluate_section(
44
- section_name: str,
45
- split: Dict[str, Any],
46
- measurement_path: str,
47
- episodes: int,
48
- budget: int,
49
- seed: int,
50
- acquisition: str,
51
- beta: float,
52
- xi: float,
53
- ) -> Dict[str, Any]:
54
- train_tasks = split["train_tasks"]
55
- test_tasks = split["test_tasks"]
56
- random_runs: List[Dict[str, Any]] = []
57
- surrogate_runs: List[Dict[str, Any]] = []
58
-
59
- for idx, task in enumerate(test_tasks):
60
- task_seed = seed + idx * 1000
61
- random_runs.append(
62
- run_random_baseline(
63
- task=task,
64
- episodes=episodes,
65
- budget=budget,
66
- seed=task_seed,
67
- measurement_path=measurement_path,
68
- )
69
- )
70
- surrogate_runs.append(
71
- run_surrogate_baseline(
72
- task=task,
73
- episodes=episodes,
74
- budget=budget,
75
- seed=task_seed,
76
- measurement_path=measurement_path,
77
- train_task_ids=train_tasks,
78
- acquisition=acquisition,
79
- beta=beta,
80
- xi=xi,
81
- )
82
- )
83
-
84
- return {
85
- "section": section_name,
86
- "train_tasks": train_tasks,
87
- "test_tasks": test_tasks,
88
- "random_summary": _summarize_runs(random_runs),
89
- "surrogate_summary": _summarize_runs(surrogate_runs),
90
- "task_runs": {
91
- "random": random_runs,
92
- "surrogate": surrogate_runs,
93
- },
94
- }
95
-
96
-
97
- def main() -> None:
98
- parser = argparse.ArgumentParser(description="Evaluate random vs surrogate on shape and family holdout splits.")
99
- parser.add_argument("--measurement-path", type=str, default="data/autotune_measurements.csv")
100
- parser.add_argument("--splits", type=Path, default=Path("data/benchmark_splits.json"))
101
- parser.add_argument("--episodes", type=int, default=20)
102
- parser.add_argument("--budget", type=int, default=6)
103
- parser.add_argument("--seed", type=int, default=2)
104
- parser.add_argument("--acquisition", choices=("mean", "ucb", "ei"), default="ucb")
105
- parser.add_argument("--beta", type=float, default=2.0)
106
- parser.add_argument("--xi", type=float, default=0.0)
107
- parser.add_argument("--output", type=Path, default=Path("outputs/generalization_eval.json"))
108
- args = parser.parse_args()
109
-
110
- splits = json.loads(args.splits.read_text(encoding="utf-8"))
111
- sections = {
112
- "shape_generalization": splits["shape_generalization"],
113
- "family_holdout": splits["family_holdout"],
114
- }
115
- results = {
116
- name: _evaluate_section(
117
- section_name=name,
118
- split=section,
119
- measurement_path=args.measurement_path,
120
- episodes=args.episodes,
121
- budget=args.budget,
122
- seed=args.seed,
123
- acquisition=args.acquisition,
124
- beta=args.beta,
125
- xi=args.xi,
126
- )
127
- for name, section in sections.items()
128
- }
129
-
130
- summary = {
131
- "measurement_path": args.measurement_path,
132
- "splits_path": str(args.splits),
133
- "episodes": args.episodes,
134
- "budget": args.budget,
135
- "acquisition": args.acquisition,
136
- "beta": args.beta,
137
- "xi": args.xi,
138
- "results": results,
139
- }
140
- args.output.parent.mkdir(parents=True, exist_ok=True)
141
- with args.output.open("w", encoding="utf-8") as handle:
142
- json.dump(summary, handle, indent=2)
143
- print(json.dumps(summary, indent=2))
144
-
145
-
146
- if __name__ == "__main__":
147
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
scripts/qwen_05b_spec.py DELETED
@@ -1,115 +0,0 @@
1
- #!/usr/bin/env python3
2
- from __future__ import annotations
3
-
4
- from dataclasses import dataclass
5
- from typing import Dict, List
6
-
7
-
8
- MODEL_ID = "Qwen/Qwen2.5-0.5B"
9
- HIDDEN_SIZE = 896
10
- INTERMEDIATE_SIZE = 4864
11
- NUM_HIDDEN_LAYERS = 24
12
- NUM_ATTENTION_HEADS = 14
13
- NUM_KEY_VALUE_HEADS = 2
14
- HEAD_DIM = HIDDEN_SIZE // NUM_ATTENTION_HEADS
15
- KV_PROJ_SIZE = NUM_KEY_VALUE_HEADS * HEAD_DIM
16
- DTYPE = "bfloat16"
17
- BATCH_SIZE = 1
18
-
19
- PREFILL_SEQ_LENS = (128, 512, 2048)
20
- DECODE_CTX_LENS = (128, 512, 2048, 8192)
21
-
22
- MODE_IDS = {
23
- "prefill": 1,
24
- "decode": 2,
25
- }
26
-
27
- ROLE_IDS = {
28
- "rmsnorm": 1,
29
- "attn_softmax": 2,
30
- "q_proj": 3,
31
- "k_proj": 4,
32
- "v_proj": 5,
33
- "o_proj": 6,
34
- "gate_proj": 7,
35
- "up_proj": 8,
36
- "down_proj": 9,
37
- }
38
-
39
-
40
- @dataclass(frozen=True)
41
- class QwenKernelTask:
42
- family_group: str
43
- family: str
44
- task_id: str
45
- role: str
46
- mode: str
47
- m: int
48
- n: int
49
- k: int
50
- seq_len: int
51
- ctx_len: int
52
-
53
- def shape_fields(self) -> Dict[str, int | str]:
54
- return {
55
- "family_group": self.family_group,
56
- "family": self.family,
57
- "role": self.role,
58
- "mode": self.mode,
59
- "role_id": ROLE_IDS[self.role],
60
- "mode_id": MODE_IDS[self.mode],
61
- "m": self.m,
62
- "n": self.n,
63
- "k": self.k,
64
- "seq_len": self.seq_len,
65
- "ctx_len": self.ctx_len,
66
- "hidden_size": HIDDEN_SIZE,
67
- "intermediate_size": INTERMEDIATE_SIZE,
68
- "num_attention_heads": NUM_ATTENTION_HEADS,
69
- "num_key_value_heads": NUM_KEY_VALUE_HEADS,
70
- "head_dim": HEAD_DIM,
71
- "dtype": DTYPE,
72
- "model_id": MODEL_ID,
73
- }
74
-
75
-
76
- def _task_id(role: str, mode: str, m: int, n: int, k: int, seq_len: int, ctx_len: int) -> str:
77
- extra = f"_k{k}" if k > 0 else ""
78
- ctx = f"_ctx{ctx_len}" if ctx_len > 0 else ""
79
- return f"qwen05b_{role}_{mode}_m{m}_n{n}{extra}_seq{seq_len}{ctx}"
80
-
81
-
82
- def qwen_05b_tasks() -> List[QwenKernelTask]:
83
- tasks: List[QwenKernelTask] = []
84
-
85
- for seq_len in PREFILL_SEQ_LENS:
86
- tasks.extend(
87
- [
88
- QwenKernelTask("QWEN", "rmsnorm", _task_id("rmsnorm", "prefill", seq_len, HIDDEN_SIZE, 0, seq_len, seq_len), "rmsnorm", "prefill", seq_len, HIDDEN_SIZE, 0, seq_len, seq_len),
89
- QwenKernelTask("QWEN", "softmax", _task_id("attn_softmax", "prefill", NUM_ATTENTION_HEADS * seq_len, seq_len, 0, seq_len, seq_len), "attn_softmax", "prefill", NUM_ATTENTION_HEADS * seq_len, seq_len, 0, seq_len, seq_len),
90
- QwenKernelTask("QWEN", "gemm", _task_id("q_proj", "prefill", seq_len, HIDDEN_SIZE, HIDDEN_SIZE, seq_len, seq_len), "q_proj", "prefill", seq_len, HIDDEN_SIZE, HIDDEN_SIZE, seq_len, seq_len),
91
- QwenKernelTask("QWEN", "gemm", _task_id("k_proj", "prefill", seq_len, KV_PROJ_SIZE, HIDDEN_SIZE, seq_len, seq_len), "k_proj", "prefill", seq_len, KV_PROJ_SIZE, HIDDEN_SIZE, seq_len, seq_len),
92
- QwenKernelTask("QWEN", "gemm", _task_id("v_proj", "prefill", seq_len, KV_PROJ_SIZE, HIDDEN_SIZE, seq_len, seq_len), "v_proj", "prefill", seq_len, KV_PROJ_SIZE, HIDDEN_SIZE, seq_len, seq_len),
93
- QwenKernelTask("QWEN", "gemm", _task_id("o_proj", "prefill", seq_len, HIDDEN_SIZE, HIDDEN_SIZE, seq_len, seq_len), "o_proj", "prefill", seq_len, HIDDEN_SIZE, HIDDEN_SIZE, seq_len, seq_len),
94
- QwenKernelTask("QWEN", "gemm", _task_id("gate_proj", "prefill", seq_len, INTERMEDIATE_SIZE, HIDDEN_SIZE, seq_len, seq_len), "gate_proj", "prefill", seq_len, INTERMEDIATE_SIZE, HIDDEN_SIZE, seq_len, seq_len),
95
- QwenKernelTask("QWEN", "gemm", _task_id("up_proj", "prefill", seq_len, INTERMEDIATE_SIZE, HIDDEN_SIZE, seq_len, seq_len), "up_proj", "prefill", seq_len, INTERMEDIATE_SIZE, HIDDEN_SIZE, seq_len, seq_len),
96
- QwenKernelTask("QWEN", "gemm", _task_id("down_proj", "prefill", seq_len, HIDDEN_SIZE, INTERMEDIATE_SIZE, seq_len, seq_len), "down_proj", "prefill", seq_len, HIDDEN_SIZE, INTERMEDIATE_SIZE, seq_len, seq_len),
97
- ]
98
- )
99
-
100
- for ctx_len in DECODE_CTX_LENS:
101
- tasks.extend(
102
- [
103
- QwenKernelTask("QWEN", "rmsnorm", _task_id("rmsnorm", "decode", 1, HIDDEN_SIZE, 0, 1, ctx_len), "rmsnorm", "decode", 1, HIDDEN_SIZE, 0, 1, ctx_len),
104
- QwenKernelTask("QWEN", "softmax", _task_id("attn_softmax", "decode", NUM_ATTENTION_HEADS, ctx_len, 0, 1, ctx_len), "attn_softmax", "decode", NUM_ATTENTION_HEADS, ctx_len, 0, 1, ctx_len),
105
- QwenKernelTask("QWEN", "gemm", _task_id("q_proj", "decode", 1, HIDDEN_SIZE, HIDDEN_SIZE, 1, ctx_len), "q_proj", "decode", 1, HIDDEN_SIZE, HIDDEN_SIZE, 1, ctx_len),
106
- QwenKernelTask("QWEN", "gemm", _task_id("k_proj", "decode", 1, KV_PROJ_SIZE, HIDDEN_SIZE, 1, ctx_len), "k_proj", "decode", 1, KV_PROJ_SIZE, HIDDEN_SIZE, 1, ctx_len),
107
- QwenKernelTask("QWEN", "gemm", _task_id("v_proj", "decode", 1, KV_PROJ_SIZE, HIDDEN_SIZE, 1, ctx_len), "v_proj", "decode", 1, KV_PROJ_SIZE, HIDDEN_SIZE, 1, ctx_len),
108
- QwenKernelTask("QWEN", "gemm", _task_id("o_proj", "decode", 1, HIDDEN_SIZE, HIDDEN_SIZE, 1, ctx_len), "o_proj", "decode", 1, HIDDEN_SIZE, HIDDEN_SIZE, 1, ctx_len),
109
- QwenKernelTask("QWEN", "gemm", _task_id("gate_proj", "decode", 1, INTERMEDIATE_SIZE, HIDDEN_SIZE, 1, ctx_len), "gate_proj", "decode", 1, INTERMEDIATE_SIZE, HIDDEN_SIZE, 1, ctx_len),
110
- QwenKernelTask("QWEN", "gemm", _task_id("up_proj", "decode", 1, INTERMEDIATE_SIZE, HIDDEN_SIZE, 1, ctx_len), "up_proj", "decode", 1, INTERMEDIATE_SIZE, HIDDEN_SIZE, 1, ctx_len),
111
- QwenKernelTask("QWEN", "gemm", _task_id("down_proj", "decode", 1, HIDDEN_SIZE, INTERMEDIATE_SIZE, 1, ctx_len), "down_proj", "decode", 1, HIDDEN_SIZE, INTERMEDIATE_SIZE, 1, ctx_len),
112
- ]
113
- )
114
-
115
- return tasks
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
scripts/report_task_hardness.py DELETED
@@ -1,42 +0,0 @@
1
- #!/usr/bin/env python3
2
- from __future__ import annotations
3
-
4
- import argparse
5
- import csv
6
- from collections import defaultdict
7
- from pathlib import Path
8
- from typing import Dict, List
9
-
10
-
11
- def _load_rows(path: Path) -> Dict[str, List[float]]:
12
- grouped: Dict[str, List[float]] = defaultdict(list)
13
- with path.open("r", newline="", encoding="utf-8") as f:
14
- for row in csv.DictReader(f):
15
- grouped[row["task_id"]].append(float(row["median_ms"]))
16
- return grouped
17
-
18
-
19
- def main() -> None:
20
- parser = argparse.ArgumentParser(description="Report task hardness from measured latency table.")
21
- parser.add_argument("--measurement-path", type=Path, default=Path("data/autotune_measurements.csv"))
22
- parser.add_argument("--budget", type=int, default=6)
23
- args = parser.parse_args()
24
-
25
- grouped = _load_rows(args.measurement_path)
26
- for task_id, vals in sorted(grouped.items()):
27
- vals = sorted(vals)
28
- best = vals[0]
29
- ncfg = len(vals)
30
- within1 = sum(v <= best * 1.01 for v in vals)
31
- within2 = sum(v <= best * 1.02 for v in vals)
32
- within5 = sum(v <= best * 1.05 for v in vals)
33
- hit_best = 1.0 - (1.0 - 1.0 / ncfg) ** args.budget
34
- print(
35
- f"{task_id} ncfg={ncfg} best_ms={best:.9f} "
36
- f"within1={within1} within2={within2} within5={within5} "
37
- f"random_hit_best@{args.budget}={hit_best:.4f}"
38
- )
39
-
40
-
41
- if __name__ == "__main__":
42
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
scripts/run_full_pipeline.sh DELETED
@@ -1,84 +0,0 @@
1
- #!/usr/bin/env bash
2
- set -euo pipefail
3
-
4
- SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
5
- REPO_ROOT="$(cd "${SCRIPT_DIR}/.." && pwd)"
6
- cd "${REPO_ROOT}"
7
-
8
- MEASUREMENTS="${REPO_ROOT}/data/autotune_measurements.csv"
9
- SPLITS="${REPO_ROOT}/data/benchmark_splits.json"
10
- OUTPUT_DIR="${REPO_ROOT}/outputs"
11
- mkdir -p "${OUTPUT_DIR}"
12
- EPISODES=20
13
- BUDGET=6
14
- BETA=2.0
15
-
16
- echo "[step] 0) checking workspace"
17
- python3 - <<'PY'
18
- import torch
19
-
20
- print(f"python={__import__('sys').executable}")
21
- print(f"torch={getattr(__import__('torch'), '__version__', 'missing')}")
22
- print(f"torch.cuda.is_available={torch.cuda.is_available()}")
23
- if not torch.cuda.is_available():
24
- raise SystemExit("CUDA is not available. Run this pipeline on a GPU machine.")
25
- print("ready")
26
- PY
27
-
28
- echo "[step] 1) collect multi-family measurements"
29
- python3 scripts/collect_multifamily_measurements.py \
30
- --output "${MEASUREMENTS}" \
31
- --families softmax layernorm grouped_gemm \
32
- --n-cols 256 512 1024 1536 2048 3072 4096 6144 8192 \
33
- --m 4096 \
34
- --repeats 200 \
35
- --warmup 25 \
36
- --seed 0 \
37
- --append
38
-
39
- echo "[step] 2) build train/test splits"
40
- python3 scripts/build_benchmark_splits.py \
41
- --measurement-path "${MEASUREMENTS}" \
42
- --output "${SPLITS}" \
43
- --heldout-family grouped_gemm
44
-
45
- echo "[step] 3) local smoke test"
46
- python3 scripts/smoke_test_client.py | tee "${OUTPUT_DIR}/smoke_test_client.json"
47
-
48
- echo "[step] 4) evaluate held-out shapes and held-out family"
49
- python3 scripts/eval_generalization.py \
50
- --measurement-path "${MEASUREMENTS}" \
51
- --splits "${SPLITS}" \
52
- --episodes "${EPISODES}" \
53
- --budget "${BUDGET}" \
54
- --seed 2 \
55
- --acquisition ucb \
56
- --beta "${BETA}" \
57
- --output "${OUTPUT_DIR}/generalization_eval.json"
58
-
59
- echo "[step] 5) benchmark eager vs torch.compile vs best Triton configs"
60
- python3 scripts/benchmark_runtime_references.py \
61
- --generalization-results "${OUTPUT_DIR}/generalization_eval.json" \
62
- --repeats 100 \
63
- --warmup 10 \
64
- --seed 123 \
65
- --output "${OUTPUT_DIR}/runtime_references.json"
66
-
67
- python3 - <<'PY'
68
- import json
69
- from pathlib import Path
70
-
71
- summary = json.loads(Path("outputs/generalization_eval.json").read_text(encoding="utf-8"))
72
- for section_name, section in summary["results"].items():
73
- print(section_name)
74
- print(" random:", section["random_summary"])
75
- print(" surrogate:", section["surrogate_summary"])
76
-
77
- runtime = json.loads(Path("outputs/runtime_references.json").read_text(encoding="utf-8"))
78
- for task_id, task in runtime["results"].items():
79
- print(task_id)
80
- print(" torch:", task["torch"])
81
- print(" speedups:", task["speedups"])
82
- PY
83
-
84
- echo "[done] outputs in outputs/"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
scripts/run_qwen_05b_pipeline.sh DELETED
@@ -1,90 +0,0 @@
1
- #!/usr/bin/env bash
2
- set -euo pipefail
3
-
4
- SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
5
- REPO_ROOT="$(cd "${SCRIPT_DIR}/.." && pwd)"
6
- cd "${REPO_ROOT}"
7
-
8
- MEASUREMENTS="${REPO_ROOT}/data/qwen_05b_measurements.csv"
9
- SPLITS="${REPO_ROOT}/data/qwen_05b_splits.json"
10
- OUTPUT_DIR="${REPO_ROOT}/outputs"
11
- mkdir -p "${OUTPUT_DIR}"
12
- EPISODES=20
13
- BUDGET=6
14
- BETA=2.0
15
-
16
- echo "[step] 0) checking workspace"
17
- python3 - <<'PY'
18
- import torch
19
-
20
- print(f"python={__import__('sys').executable}")
21
- print(f"torch={getattr(__import__('torch'), '__version__', 'missing')}")
22
- print(f"torch.cuda.is_available={torch.cuda.is_available()}")
23
- if not torch.cuda.is_available():
24
- raise SystemExit("CUDA is not available. Run this pipeline on a GPU machine.")
25
- print("ready")
26
- PY
27
-
28
- echo "[step] 1) collect exact Qwen2.5-0.5B kernel measurements"
29
- python3 scripts/collect_qwen_05b_measurements.py \
30
- --output "${MEASUREMENTS}" \
31
- --repeats 200 \
32
- --warmup 25 \
33
- --seed 0 \
34
- --append
35
-
36
- echo "[step] 2) build Qwen splits"
37
- python3 scripts/build_qwen_05b_splits.py --output "${SPLITS}"
38
-
39
- echo "[step] 3) local smoke test against Qwen measurement cache"
40
- python3 - <<'PY'
41
- import json
42
- from pathlib import Path
43
- from client import SoftmaxSurrogateEnvClient
44
-
45
- client = SoftmaxSurrogateEnvClient(measurement_path="data/qwen_05b_measurements.csv")
46
- reset_out = client.reset()
47
- first_config = reset_out["observation"]["tried_config_ids"][0]
48
- step_out = client.step({"config_id": first_config})
49
- summary = {"reset": reset_out, "step": step_out}
50
- Path("outputs/qwen_05b_smoke_test.json").write_text(json.dumps(summary, indent=2), encoding="utf-8")
51
- print(json.dumps(summary, indent=2))
52
- PY
53
-
54
- echo "[step] 4) evaluate held-out shapes and held-out family"
55
- python3 scripts/eval_generalization.py \
56
- --measurement-path "${MEASUREMENTS}" \
57
- --splits "${SPLITS}" \
58
- --episodes "${EPISODES}" \
59
- --budget "${BUDGET}" \
60
- --seed 2 \
61
- --acquisition ucb \
62
- --beta "${BETA}" \
63
- --output "${OUTPUT_DIR}/qwen_05b_generalization_eval.json"
64
-
65
- echo "[step] 5) benchmark eager vs torch.compile vs best Triton configs"
66
- python3 scripts/benchmark_qwen_05b_runtime.py \
67
- --generalization-results "${OUTPUT_DIR}/qwen_05b_generalization_eval.json" \
68
- --repeats 100 \
69
- --warmup 10 \
70
- --seed 123 \
71
- --output "${OUTPUT_DIR}/qwen_05b_runtime_references.json"
72
-
73
- python3 - <<'PY'
74
- import json
75
- from pathlib import Path
76
-
77
- eval_summary = json.loads(Path("outputs/qwen_05b_generalization_eval.json").read_text(encoding="utf-8"))
78
- for section_name, section in eval_summary["results"].items():
79
- print(section_name)
80
- print(" random:", section["random_summary"])
81
- print(" surrogate:", section["surrogate_summary"])
82
-
83
- runtime = json.loads(Path("outputs/qwen_05b_runtime_references.json").read_text(encoding="utf-8"))
84
- for task_id, task in runtime["results"].items():
85
- print(task_id)
86
- print(" torch:", task["torch"])
87
- print(" speedups:", task["speedups"])
88
- PY
89
-
90
- echo "[done] Qwen outputs in outputs/"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
scripts/run_random_baseline.py DELETED
@@ -1,187 +0,0 @@
1
- #!/usr/bin/env python3
2
- from __future__ import annotations
3
-
4
- import argparse
5
- import json
6
- from dataclasses import dataclass
7
- import sys
8
- from pathlib import Path
9
- from typing import Any, Dict, List
10
-
11
- import numpy as np
12
-
13
- ROOT = Path(__file__).resolve().parents[1]
14
- if str(ROOT) not in sys.path:
15
- sys.path.append(str(ROOT))
16
-
17
- from server.softmax_surrogate_environment import SoftmaxSurrogateEnvironment
18
-
19
-
20
- @dataclass
21
- class RunRecord:
22
- task_id: str
23
- episode: int
24
- best_latency_ms: float
25
- best_config: Dict[str, int]
26
- final_validation_mse: float
27
- final_state: Dict[str, Any]
28
- final_regret: float
29
- history: List[Dict[str, Any]]
30
-
31
-
32
- def _aggregate_metrics(episode_records: List[Dict[str, Any]], budget: int) -> Dict[str, Any]:
33
- ks = sorted(set(k for k in (1, 3, 5, budget) if k <= budget))
34
- regrets_by_k: Dict[int, List[float]] = {k: [] for k in ks}
35
- auc_regrets: List[float] = []
36
-
37
- for episode in episode_records:
38
- regrets = [float(step["regret"]) for step in episode["history"]]
39
- if regrets:
40
- auc_regrets.append(float(sum(regrets) / len(regrets)))
41
- for k in ks:
42
- if len(regrets) >= k:
43
- regrets_by_k[k].append(regrets[k - 1])
44
-
45
- return {
46
- "mean_regret_at": {
47
- str(k): float(sum(vals) / len(vals)) for k, vals in regrets_by_k.items() if vals
48
- },
49
- "median_regret_at": {
50
- str(k): float(np.median(np.asarray(vals, dtype=np.float32))) for k, vals in regrets_by_k.items() if vals
51
- },
52
- "mean_auc_regret": float(sum(auc_regrets) / len(auc_regrets)) if auc_regrets else None,
53
- "oracle_hit_rate_final": float(
54
- sum(1 for episode in episode_records if float(episode["final_regret"]) == 0.0) / len(episode_records)
55
- ) if episode_records else None,
56
- }
57
-
58
-
59
- def _pick_task_from_input(args: argparse.Namespace) -> str:
60
- if args.task:
61
- return args.task
62
- env = SoftmaxSurrogateEnvironment(
63
- measurement_path=args.measurement_path,
64
- budget=args.budget,
65
- seed=args.seed,
66
- )
67
- return env.reset()["observation"]["task_id"]
68
-
69
-
70
- def run_random_baseline(
71
- task: str,
72
- episodes: int,
73
- budget: int,
74
- seed: int,
75
- measurement_path: str,
76
- ) -> Dict[str, Any]:
77
- rng = np.random.default_rng(seed)
78
- best_overall: Dict[str, Any] = {"latency_ms": float("inf"), "config": None, "task_id": task}
79
- episode_records: List[Dict[str, Any]] = []
80
-
81
- env = SoftmaxSurrogateEnvironment(
82
- measurement_path=measurement_path,
83
- budget=budget,
84
- seed=seed,
85
- )
86
-
87
- for episode in range(episodes):
88
- env.reset(task=task, seed=seed + episode)
89
- done = False
90
- episode_best = float("inf")
91
- episode_best_cfg: Dict[str, int] | None = None
92
- history: List[Dict[str, Any]] = []
93
-
94
- while not done:
95
- unseen = [config_id for config_id in env.available_config_ids() if config_id not in env.seen_config_ids()]
96
- choice_pool = unseen if unseen else env.available_config_ids()
97
- config_id = int(rng.choice(choice_pool))
98
- step_out = env.step({"config_id": config_id})
99
- obs = step_out["observation"]
100
- trial = obs["last_trial"]
101
- history.append(
102
- {
103
- "config_id": config_id,
104
- "latency_ms": trial["latency_ms"],
105
- "config": trial["config"],
106
- "reward": step_out["reward"],
107
- "regret": step_out["info"]["current_regret"],
108
- "validation_mse": step_out["info"]["validation_mse"],
109
- }
110
- )
111
- if obs["best_so_far_ms"] < episode_best:
112
- episode_best = obs["best_so_far_ms"]
113
- best_id = env.seen_config_ids()[int(np.argmin([env.measured_latency_ms(cid) for cid in env.seen_config_ids()]))]
114
- episode_best_cfg = env.config_info(best_id)
115
- done = bool(step_out["done"])
116
-
117
- if episode_best < best_overall["latency_ms"]:
118
- best_overall = {
119
- "latency_ms": float(episode_best),
120
- "config": episode_best_cfg,
121
- "task_id": task,
122
- }
123
-
124
- diagnostics = env.diagnostics()
125
- episode_records.append(
126
- RunRecord(
127
- task_id=task,
128
- episode=episode,
129
- best_latency_ms=float(episode_best),
130
- best_config=episode_best_cfg or {},
131
- final_validation_mse=float(diagnostics["validation_mse"]),
132
- final_state=env.state(),
133
- final_regret=float(diagnostics["current_regret"]),
134
- history=history,
135
- ).__dict__
136
- )
137
-
138
- return {
139
- "task": task,
140
- "method": "random",
141
- "episodes": episodes,
142
- "budget": budget,
143
- "seed": seed,
144
- "oracle_best_ms": env.oracle_best()["median_ms"],
145
- "best_overall": best_overall,
146
- "aggregate_metrics": _aggregate_metrics(episode_records, budget),
147
- "episodes_summary": episode_records,
148
- }
149
-
150
-
151
- def parse_args() -> argparse.Namespace:
152
- parser = argparse.ArgumentParser(description="Random baseline for surrogate environment.")
153
- parser.add_argument("--task", default=None, help="Task ID (e.g., softmax_m4096_n2048)")
154
- parser.add_argument("--episodes", type=int, default=20)
155
- parser.add_argument("--budget", type=int, default=6)
156
- parser.add_argument("--seed", type=int, default=0)
157
- parser.add_argument(
158
- "--measurement-path",
159
- type=str,
160
- default="data/autotune_measurements.csv",
161
- )
162
- parser.add_argument(
163
- "--output",
164
- type=Path,
165
- default=Path("outputs/random_baseline.json"),
166
- )
167
- return parser.parse_args()
168
-
169
-
170
- def main() -> None:
171
- args = parse_args()
172
- task = _pick_task_from_input(args)
173
- summary = run_random_baseline(
174
- task=task,
175
- episodes=args.episodes,
176
- budget=args.budget,
177
- seed=args.seed,
178
- measurement_path=args.measurement_path,
179
- )
180
- args.output.parent.mkdir(parents=True, exist_ok=True)
181
- with args.output.open("w", encoding="utf-8") as f:
182
- json.dump(summary, f, indent=2)
183
- print(json.dumps(summary, indent=2))
184
-
185
-
186
- if __name__ == "__main__":
187
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
scripts/run_surrogate_baseline.py DELETED
@@ -1,217 +0,0 @@
1
- #!/usr/bin/env python3
2
- from __future__ import annotations
3
-
4
- import argparse
5
- import json
6
- import sys
7
- from pathlib import Path
8
- from typing import Any, Dict, List
9
-
10
- import numpy as np
11
-
12
- ROOT = Path(__file__).resolve().parents[1]
13
- if str(ROOT) not in sys.path:
14
- sys.path.append(str(ROOT))
15
-
16
- from server.softmax_surrogate_environment import SoftmaxSurrogateEnvironment
17
-
18
-
19
- def _choose_surrogate_action(
20
- env: SoftmaxSurrogateEnvironment,
21
- acquisition: str,
22
- beta: float,
23
- xi: float,
24
- ) -> int:
25
- seen = set(env.seen_config_ids())
26
- best_config_id = -1
27
- best_score = float("-inf")
28
-
29
- for config_id in env.available_config_ids():
30
- if config_id in seen and len(seen) < len(env.available_config_ids()):
31
- continue
32
- score = env.acquisition_score(config_id, strategy=acquisition, beta=beta, xi=xi)
33
- if score > best_score:
34
- best_score = score
35
- best_config_id = config_id
36
-
37
- if best_config_id < 0:
38
- raise RuntimeError("Failed to choose a surrogate action.")
39
- return best_config_id
40
-
41
-
42
- def _aggregate_metrics(episode_records: List[Dict[str, Any]], budget: int) -> Dict[str, Any]:
43
- ks = sorted(set(k for k in (1, 3, 5, budget) if k <= budget))
44
- regrets_by_k: Dict[int, List[float]] = {k: [] for k in ks}
45
- auc_regrets: List[float] = []
46
-
47
- for episode in episode_records:
48
- regrets = [float(step["regret"]) for step in episode["history"]]
49
- if regrets:
50
- auc_regrets.append(float(sum(regrets) / len(regrets)))
51
- for k in ks:
52
- if len(regrets) >= k:
53
- regrets_by_k[k].append(regrets[k - 1])
54
-
55
- return {
56
- "mean_regret_at": {
57
- str(k): float(sum(vals) / len(vals)) for k, vals in regrets_by_k.items() if vals
58
- },
59
- "median_regret_at": {
60
- str(k): float(np.median(np.asarray(vals, dtype=np.float32))) for k, vals in regrets_by_k.items() if vals
61
- },
62
- "mean_auc_regret": float(sum(auc_regrets) / len(auc_regrets)) if auc_regrets else None,
63
- "oracle_hit_rate_final": float(
64
- sum(1 for episode in episode_records if float(episode["final_regret"]) == 0.0) / len(episode_records)
65
- ) if episode_records else None,
66
- }
67
-
68
-
69
- def run_surrogate_baseline(
70
- task: str,
71
- episodes: int,
72
- budget: int,
73
- seed: int,
74
- measurement_path: str,
75
- train_task_ids: List[str] | None = None,
76
- acquisition: str = "ucb",
77
- beta: float = 1.5,
78
- xi: float = 0.0,
79
- ) -> Dict[str, Any]:
80
- env = SoftmaxSurrogateEnvironment(
81
- measurement_path=measurement_path,
82
- budget=budget,
83
- seed=seed,
84
- train_task_ids=train_task_ids,
85
- )
86
-
87
- best_overall = {"latency_ms": float("inf"), "config": None, "task_id": task}
88
- episode_records: List[Dict[str, Any]] = []
89
-
90
- for episode in range(episodes):
91
- env.reset(task=task, seed=seed + episode)
92
- done = False
93
- episode_best = float("inf")
94
- episode_best_cfg: Dict[str, int] | None = None
95
- history: List[Dict[str, Any]] = []
96
-
97
- while not done:
98
- config_id = _choose_surrogate_action(env, acquisition=acquisition, beta=beta, xi=xi)
99
- out = env.step({"config_id": config_id})
100
- obs = out["observation"]
101
- trial = obs["last_trial"]
102
- history.append(
103
- {
104
- "config_id": config_id,
105
- "latency_ms": trial["latency_ms"],
106
- "config": trial["config"],
107
- "reward": out["reward"],
108
- "regret": out["info"]["current_regret"],
109
- "validation_mse": out["info"]["validation_mse"],
110
- }
111
- )
112
- if obs["best_so_far_ms"] < episode_best:
113
- episode_best = obs["best_so_far_ms"]
114
- best_seen = min(env.seen_config_ids(), key=env.measured_latency_ms)
115
- episode_best_cfg = env.config_info(best_seen)
116
- done = bool(out["done"])
117
-
118
- if episode_best < best_overall["latency_ms"]:
119
- best_overall = {
120
- "latency_ms": float(episode_best),
121
- "config": episode_best_cfg,
122
- "task_id": task,
123
- }
124
-
125
- diagnostics = env.diagnostics()
126
- episode_records.append(
127
- {
128
- "task_id": task,
129
- "episode": episode,
130
- "best_latency_ms": episode_best,
131
- "best_config": episode_best_cfg or {},
132
- "final_validation_mse": diagnostics["validation_mse"],
133
- "final_regret": diagnostics["current_regret"],
134
- "history": history,
135
- }
136
- )
137
-
138
- return {
139
- "task": task,
140
- "method": "surrogate",
141
- "episodes": episodes,
142
- "budget": budget,
143
- "seed": seed,
144
- "train_task_ids": list(train_task_ids or []),
145
- "acquisition": acquisition,
146
- "beta": beta,
147
- "xi": xi,
148
- "oracle_best_ms": env.oracle_best()["median_ms"],
149
- "best_overall": best_overall,
150
- "aggregate_metrics": _aggregate_metrics(episode_records, budget),
151
- "episodes_summary": episode_records,
152
- }
153
-
154
-
155
- def parse_args() -> argparse.Namespace:
156
- parser = argparse.ArgumentParser(description="Surrogate-guided baseline.")
157
- parser.add_argument("--task", default=None, help="Task ID (e.g., softmax_m4096_n2048)")
158
- parser.add_argument("--episodes", type=int, default=20)
159
- parser.add_argument("--budget", type=int, default=6)
160
- parser.add_argument("--seed", type=int, default=0)
161
- parser.add_argument(
162
- "--acquisition",
163
- type=str,
164
- choices=("mean", "ucb", "ei"),
165
- default="ucb",
166
- help="Candidate selection mode: mean, ucb, or ei.",
167
- )
168
- parser.add_argument("--beta", type=float, default=1.5, help="UCB exploration strength.")
169
- parser.add_argument("--xi", type=float, default=0.0, help="Expected-improvement margin.")
170
- parser.add_argument(
171
- "--train-tasks-file",
172
- type=Path,
173
- default=None,
174
- help="Optional JSON file containing a list of train task ids.",
175
- )
176
- parser.add_argument(
177
- "--measurement-path",
178
- type=str,
179
- default="data/autotune_measurements.csv",
180
- )
181
- parser.add_argument(
182
- "--output",
183
- type=Path,
184
- default=Path("outputs/surrogate_baseline.json"),
185
- )
186
- return parser.parse_args()
187
-
188
-
189
- def main() -> None:
190
- args = parse_args()
191
- if not args.task:
192
- env = SoftmaxSurrogateEnvironment(measurement_path=args.measurement_path, budget=args.budget, seed=args.seed)
193
- args.task = env.reset()["observation"]["task_id"]
194
-
195
- train_task_ids = None
196
- if args.train_tasks_file is not None:
197
- train_task_ids = json.loads(args.train_tasks_file.read_text(encoding="utf-8"))
198
-
199
- summary = run_surrogate_baseline(
200
- task=args.task,
201
- episodes=args.episodes,
202
- budget=args.budget,
203
- seed=args.seed,
204
- measurement_path=args.measurement_path,
205
- train_task_ids=train_task_ids,
206
- acquisition=args.acquisition,
207
- beta=args.beta,
208
- xi=args.xi,
209
- )
210
- args.output.parent.mkdir(parents=True, exist_ok=True)
211
- with args.output.open("w", encoding="utf-8") as f:
212
- json.dump(summary, f, indent=2)
213
- print(json.dumps(summary, indent=2))
214
-
215
-
216
- if __name__ == "__main__":
217
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
scripts/smoke_test_client.py DELETED
@@ -1,28 +0,0 @@
1
- #!/usr/bin/env python3
2
- from __future__ import annotations
3
-
4
- import json
5
- import sys
6
- from pathlib import Path
7
-
8
- ROOT = Path(__file__).resolve().parents[1]
9
- if str(ROOT) not in sys.path:
10
- sys.path.append(str(ROOT))
11
-
12
- from client import SoftmaxSurrogateEnvClient
13
-
14
-
15
- def main() -> None:
16
- client = SoftmaxSurrogateEnvClient()
17
- reset_out = client.reset()
18
- step_out = client.step({"config_id": 0})
19
- summary = {"reset": reset_out, "step": step_out}
20
- out = Path("outputs/smoke_test_client.json")
21
- out.parent.mkdir(parents=True, exist_ok=True)
22
- with out.open("w", encoding="utf-8") as f:
23
- json.dump(summary, f, indent=2)
24
- print(json.dumps(summary, indent=2))
25
-
26
-
27
- if __name__ == "__main__":
28
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
server/__init__.py DELETED
@@ -1 +0,0 @@
1
- # Package marker for importing server modules directly from workspace scripts.
 
 
server/app.py DELETED
@@ -1,72 +0,0 @@
1
- from __future__ import annotations
2
-
3
- import argparse
4
- import sys
5
- from pathlib import Path
6
- from typing import Any, Dict, Optional
7
-
8
- ROOT = Path(__file__).resolve().parents[1]
9
- if str(ROOT) not in sys.path:
10
- sys.path.append(str(ROOT))
11
-
12
- from fastapi import FastAPI
13
- from fastapi import HTTPException
14
- from pydantic import BaseModel
15
-
16
- from models import ResetResult, StepResult
17
- from server.softmax_surrogate_environment import SoftmaxSurrogateEnvironment
18
-
19
-
20
- app = FastAPI(title="Autotune Benchmark OpenEnv Server")
21
- env = SoftmaxSurrogateEnvironment()
22
-
23
-
24
- class ResetRequest(BaseModel):
25
- task: Optional[str] = None
26
- seed: Optional[int] = None
27
-
28
-
29
- class StepRequest(BaseModel):
30
- config_id: Optional[int] = None
31
- x: Optional[list[float]] = None
32
-
33
-
34
- @app.get("/health")
35
- def health() -> Dict[str, str]:
36
- return {"ok": "true"}
37
-
38
-
39
- @app.post("/reset")
40
- def reset(payload: ResetRequest) -> Dict[str, Any]:
41
- result = env.reset(task=payload.task, seed=payload.seed)
42
- return result
43
-
44
-
45
- @app.post("/step")
46
- def step(payload: StepRequest) -> Dict[str, Any]:
47
- if payload.config_id is not None:
48
- result = env.step({"config_id": payload.config_id})
49
- return result
50
- if payload.x is not None:
51
- result = env.step({"x": payload.x})
52
- return result
53
- raise HTTPException(status_code=400, detail="Missing config_id.")
54
- return result
55
-
56
-
57
- @app.get("/state")
58
- def state() -> Dict[str, Any]:
59
- return env.state()
60
-
61
-
62
- if __name__ == "__main__":
63
- parser = argparse.ArgumentParser(description="Run softmax surrogate environment server.")
64
- parser.add_argument("--host", default="0.0.0.0")
65
- parser.add_argument("--port", type=int, default=8000)
66
- args = parser.parse_args()
67
- try:
68
- import uvicorn
69
-
70
- uvicorn.run("app:app", host=args.host, port=args.port, reload=False)
71
- except Exception as err: # pragma: no cover
72
- raise RuntimeError("uvicorn not available") from err
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
server/softmax_surrogate_environment.py DELETED
@@ -1,570 +0,0 @@
1
- from __future__ import annotations
2
-
3
- import csv
4
- import json
5
- import math
6
- import random
7
- import sys
8
- from dataclasses import dataclass
9
- from pathlib import Path
10
- from typing import Any, Dict, List, Optional, Sequence, Tuple
11
-
12
- import numpy as np
13
-
14
- ROOT = Path(__file__).resolve().parents[1]
15
- if str(ROOT) not in sys.path:
16
- sys.path.append(str(ROOT))
17
-
18
- from scripts.collect_measurements import BLOCK_SIZES, NUM_STAGES, NUM_WARPS
19
-
20
-
21
- DEFAULT_MEASUREMENT_PATH = "data/autotune_measurements.csv"
22
- DEFAULT_BUDGET = 6
23
- INITIAL_DATASET_SIZE = 2
24
- DUPLICATE_PENALTY = -1e-4
25
- FAMILY_INDEX = {
26
- "softmax": 0,
27
- "layernorm": 1,
28
- "grouped_gemm": 2,
29
- "rmsnorm": 3,
30
- "gemm": 4,
31
- }
32
-
33
-
34
- @dataclass(frozen=True)
35
- class Measurement:
36
- family_group: str
37
- family: str
38
- task_id: str
39
- m: int
40
- n: int
41
- k: int
42
- config_id: int
43
- block_size: int
44
- num_warps: int
45
- num_stages: int
46
- shape_json: str
47
- config_json: str
48
- median_ms: float
49
- effective_gbps: float
50
- score: float
51
- validation_error: float
52
-
53
-
54
- def _normalize_discrete(values: Sequence[int], value: int) -> float:
55
- idx = list(values).index(int(value))
56
- if len(values) == 1:
57
- return 0.0
58
- return 2.0 * (idx / (len(values) - 1)) - 1.0
59
-
60
-
61
- class SoftmaxSurrogateEnvironment:
62
- """
63
- Generic discrete-action autotuning environment backed by measured GPU data.
64
-
65
- The class name is kept for compatibility with the existing local server and
66
- baseline scripts, but the task space is now multi-family.
67
- """
68
-
69
- def __init__(
70
- self,
71
- measurement_path: str = DEFAULT_MEASUREMENT_PATH,
72
- budget: int = DEFAULT_BUDGET,
73
- seed: int = 0,
74
- initial_samples: int = INITIAL_DATASET_SIZE,
75
- train_task_ids: Optional[Sequence[str]] = None,
76
- ) -> None:
77
- self.measurement_path = Path(measurement_path)
78
- self.budget = int(budget)
79
- self.seed = int(seed)
80
- self.initial_samples = max(1, int(initial_samples))
81
- self.train_task_ids = set(train_task_ids or [])
82
-
83
- self._measurements = self._load_measurements()
84
- self._task_ids = sorted(self._measurements.keys())
85
- if not self._task_ids:
86
- raise RuntimeError(
87
- "No measurement data found. Run the measurement collectors first."
88
- )
89
-
90
- self._rng = random.Random(self.seed)
91
- self._episode_counter = 0
92
-
93
- self._task_id: Optional[str] = None
94
- self._family: Optional[str] = None
95
- self._episode_id: Optional[str] = None
96
- self._task_rows: List[Measurement] = []
97
- self._prior_rows: List[Measurement] = []
98
- self._config_by_id: Dict[int, Measurement] = {}
99
- self._observed_ids: List[int] = []
100
- self._observed_id_set = set()
101
- self._observed_rows: List[Measurement] = []
102
- self._observed_latencies: List[float] = []
103
- self._steps_taken = 0
104
- self._steps_remaining = 0
105
- self._best_latency_ms = float("inf")
106
- self._best_config_id: Optional[int] = None
107
- self._validation_mse = float("inf")
108
-
109
- self._surrogate_version = 0
110
- self._surrogate_fitted_version = -1
111
- self._surrogate_x: Optional[np.ndarray] = None
112
- self._surrogate_y: Optional[np.ndarray] = None
113
- self._surrogate_alpha: Optional[np.ndarray] = None
114
- self._surrogate_k: Optional[np.ndarray] = None
115
- self._surrogate_length_scale: float = 0.5
116
-
117
- def reset(self, task: Optional[str] = None, seed: Optional[int] = None) -> Dict[str, Any]:
118
- if seed is not None:
119
- self._rng = random.Random(int(seed))
120
-
121
- if task is None:
122
- task = self._rng.choice(self._task_ids)
123
- if task not in self._measurements:
124
- raise ValueError(f"Unknown task: {task}")
125
-
126
- rows = self._measurements[task]
127
- self._task_id = task
128
- self._family = rows[0].family
129
- self._task_rows = rows
130
- self._config_by_id = {row.config_id: row for row in rows}
131
- self._prior_rows = self._build_prior_rows(task)
132
- self._observed_ids = []
133
- self._observed_id_set = set()
134
- self._observed_rows = []
135
- self._observed_latencies = []
136
- self._steps_taken = 0
137
- self._steps_remaining = self.budget
138
- self._best_latency_ms = float("inf")
139
- self._best_config_id = None
140
- self._episode_counter += 1
141
- self._episode_id = f"{task}:{self.seed}:{self._episode_counter}"
142
-
143
- sample_count = min(self.initial_samples, len(rows))
144
- for config_id in self._rng.sample(list(self._config_by_id.keys()), k=sample_count):
145
- self._observe_config(config_id)
146
-
147
- self._invalidate_surrogate()
148
- self._validation_mse = self._compute_validation_mse()
149
-
150
- return self._format_step_output(
151
- observation=self._observation_payload(kind="reset"),
152
- reward=0.0,
153
- done=False,
154
- info=self.diagnostics(),
155
- )
156
-
157
- def step(self, action: Any) -> Dict[str, Any]:
158
- if self._task_id is None:
159
- raise RuntimeError("Call reset() before step().")
160
- if self._steps_remaining <= 0:
161
- return self._format_step_output(
162
- observation=self._observation_payload(kind="done"),
163
- reward=0.0,
164
- done=True,
165
- info=self.diagnostics(),
166
- )
167
-
168
- config_id = self._extract_config_id(action)
169
- row = self._row_for_id(config_id)
170
- prev_best = self._best_latency_ms
171
- duplicate = config_id in self._observed_id_set
172
-
173
- if not duplicate:
174
- self._observe_config(config_id)
175
- self._surrogate_version += 1
176
-
177
- self._steps_taken += 1
178
- self._steps_remaining -= 1
179
- self._validation_mse = self._compute_validation_mse()
180
-
181
- reward = DUPLICATE_PENALTY if duplicate else max(0.0, math.log(prev_best) - math.log(self._best_latency_ms))
182
- observation = self._observation_payload(
183
- kind="step",
184
- last_trial={
185
- "config_id": config_id,
186
- "config": self.config_info(config_id),
187
- "latency_ms": row.median_ms,
188
- "score": row.score,
189
- "duplicate": duplicate,
190
- },
191
- )
192
- return self._format_step_output(
193
- observation=observation,
194
- reward=reward,
195
- done=self._steps_remaining <= 0,
196
- info=self.diagnostics(),
197
- )
198
-
199
- def state(self) -> Dict[str, Any]:
200
- if self._task_id is None:
201
- return {"status": "uninitialized"}
202
- return {
203
- "episode_id": self._episode_id,
204
- "step_count": self._steps_taken,
205
- "task_id": self._task_id,
206
- "family": self._family,
207
- "tried_config_ids": list(self._observed_ids),
208
- }
209
-
210
- def diagnostics(self) -> Dict[str, Any]:
211
- if self._task_id is None:
212
- return {"status": "uninitialized"}
213
- oracle_best_ms = self.oracle_best()["median_ms"]
214
- regret = self._best_latency_ms / oracle_best_ms - 1.0
215
- return {
216
- "validation_mse": self._validation_mse,
217
- "best_so_far_ms": self._best_latency_ms,
218
- "oracle_best_ms": oracle_best_ms,
219
- "current_regret": regret,
220
- "observed_count": len(self._observed_ids),
221
- "prior_count": len(self._prior_rows),
222
- }
223
-
224
- def available_tasks(self) -> List[str]:
225
- return list(self._task_ids)
226
-
227
- def available_config_ids(self) -> List[int]:
228
- if self._task_id is None:
229
- raise RuntimeError("Call reset() before accessing config ids.")
230
- return sorted(self._config_by_id.keys())
231
-
232
- def available_configs(self) -> List[Dict[str, Any]]:
233
- return [self.config_info(config_id) for config_id in self.available_config_ids()]
234
-
235
- def config_info(self, config_id: int) -> Dict[str, Any]:
236
- row = self._row_for_id(config_id)
237
- return {
238
- "config_id": int(config_id),
239
- "family": row.family,
240
- "task_id": row.task_id,
241
- "block_size": row.block_size,
242
- "num_warps": row.num_warps,
243
- "num_stages": row.num_stages,
244
- }
245
-
246
- def measured_latency_ms(self, config_id: int) -> float:
247
- return self._row_for_id(config_id).median_ms
248
-
249
- def oracle_best(self) -> Dict[str, Any]:
250
- if self._task_id is None:
251
- raise RuntimeError("Call reset() before querying oracle_best().")
252
- best = min(self._task_rows, key=lambda row: row.median_ms)
253
- return {
254
- "config_id": best.config_id,
255
- "family": best.family,
256
- "task_id": best.task_id,
257
- "block_size": best.block_size,
258
- "num_warps": best.num_warps,
259
- "num_stages": best.num_stages,
260
- "median_ms": best.median_ms,
261
- "score": best.score,
262
- }
263
-
264
- def predict_score(self, config_id: int) -> float:
265
- return float(self._predict_with_uncertainty(config_id)[0])
266
-
267
- def acquisition_score(
268
- self,
269
- config_id: int,
270
- strategy: str = "ucb",
271
- beta: float = 1.0,
272
- xi: float = 0.0,
273
- ) -> float:
274
- mean, sigma = self._predict_with_uncertainty(config_id)
275
- if strategy == "mean":
276
- return float(mean)
277
- if strategy == "ucb":
278
- return float(mean + float(beta) * sigma)
279
- if strategy == "ei":
280
- best_observed = max(row.score for row in self._observed_rows) if self._observed_rows else mean
281
- delta = mean - best_observed - float(xi)
282
- if sigma <= 0.0:
283
- return float(max(delta, 0.0))
284
- z = delta / sigma
285
- return float(max(delta * _normal_cdf(z) + sigma * _normal_pdf(z), 0.0))
286
- raise ValueError(f"Unknown acquisition strategy: {strategy}")
287
-
288
- def seen_config_ids(self) -> List[int]:
289
- return list(self._observed_ids)
290
-
291
- def _build_prior_rows(self, current_task: str) -> List[Measurement]:
292
- if not self.train_task_ids:
293
- return []
294
- prior_rows: List[Measurement] = []
295
- for task_id in sorted(self.train_task_ids):
296
- if task_id == current_task or task_id not in self._measurements:
297
- continue
298
- prior_rows.extend(self._measurements[task_id])
299
- return prior_rows
300
-
301
- def _predict_with_uncertainty(self, config_id: int) -> Tuple[float, float]:
302
- if not self._observed_rows and not self._prior_rows:
303
- raise RuntimeError("No surrogate data available.")
304
- self._fit_surrogate()
305
- if self._surrogate_x is None or self._surrogate_y is None:
306
- raise RuntimeError("Surrogate model unavailable.")
307
- if self._surrogate_x.shape[0] == 1:
308
- return float(self._surrogate_y[0]), 0.0
309
-
310
- cfg = _config_to_vector(self._row_for_id(config_id)).reshape(1, -1)
311
- if self._surrogate_k is None or self._surrogate_alpha is None:
312
- raise RuntimeError("Surrogate model unavailable.")
313
- k = _rbf_kernel(self._surrogate_x, cfg, self._surrogate_length_scale).reshape(-1)
314
- pred = float(k @ self._surrogate_alpha)
315
- solve = np.linalg.solve(self._surrogate_k, k)
316
- var = max(0.0, float(1.0 - k @ solve))
317
- return pred, float(math.sqrt(max(var, 1e-12)))
318
-
319
- def _fit_surrogate(self) -> None:
320
- if self._surrogate_fitted_version == self._surrogate_version:
321
- return
322
- rows = self._prior_rows + self._observed_rows
323
- if not rows:
324
- self._surrogate_x = None
325
- self._surrogate_y = None
326
- self._surrogate_alpha = None
327
- self._surrogate_k = None
328
- self._surrogate_fitted_version = self._surrogate_version
329
- return
330
-
331
- self._surrogate_x = np.array([_config_to_vector(row) for row in rows], dtype=np.float32)
332
- self._surrogate_y = np.array([row.score for row in rows], dtype=np.float32)
333
- if self._surrogate_x.shape[0] == 1:
334
- self._surrogate_alpha = self._surrogate_y.copy()
335
- self._surrogate_k = None
336
- self._surrogate_fitted_version = self._surrogate_version
337
- return
338
-
339
- pairwise = _pairwise_sq_dists(self._surrogate_x)
340
- triu = pairwise[np.triu_indices(self._surrogate_x.shape[0], k=1)]
341
- med_dist = float(np.median(np.sqrt(triu))) if triu.size else 0.5
342
- self._surrogate_length_scale = max(0.15, med_dist)
343
- k = _rbf_kernel(self._surrogate_x, self._surrogate_x, self._surrogate_length_scale)
344
- k[np.diag_indices_from(k)] += 1e-3
345
- self._surrogate_k = k
346
- self._surrogate_alpha = np.linalg.solve(k, self._surrogate_y)
347
- self._surrogate_fitted_version = self._surrogate_version
348
-
349
- def _compute_validation_mse(self) -> float:
350
- if not self._task_rows:
351
- return float("inf")
352
- preds = np.array(
353
- [self._predict_with_uncertainty(config_id)[0] for config_id in self.available_config_ids()],
354
- dtype=np.float32,
355
- )
356
- target = np.array([self._row_for_id(config_id).score for config_id in self.available_config_ids()], dtype=np.float32)
357
- return float(np.mean((preds - target) ** 2))
358
-
359
- def _observe_config(self, config_id: int) -> None:
360
- row = self._row_for_id(config_id)
361
- self._observed_ids.append(config_id)
362
- self._observed_id_set.add(config_id)
363
- self._observed_rows.append(row)
364
- self._observed_latencies.append(row.median_ms)
365
- if row.median_ms < self._best_latency_ms:
366
- self._best_latency_ms = row.median_ms
367
- self._best_config_id = config_id
368
-
369
- def _observation_payload(
370
- self,
371
- kind: str,
372
- last_trial: Optional[Dict[str, Any]] = None,
373
- ) -> Dict[str, Any]:
374
- payload = {
375
- "type": kind,
376
- "task_id": self._task_id,
377
- "family": self._family,
378
- "M": self._task_rows[0].m if self._task_rows else None,
379
- "N": self._task_rows[0].n if self._task_rows else None,
380
- "dtype": "fp16",
381
- "tried_config_ids": list(self._observed_ids),
382
- "tried_latencies_ms": list(self._observed_latencies),
383
- "best_so_far_ms": self._best_latency_ms,
384
- "steps_remaining": self._steps_remaining,
385
- }
386
- if last_trial is not None:
387
- payload["last_trial"] = last_trial
388
- return payload
389
-
390
- def _extract_config_id(self, action: Any) -> int:
391
- if isinstance(action, (str, bytes)):
392
- action = json.loads(action)
393
- if isinstance(action, dict):
394
- if "config_id" in action:
395
- return int(action["config_id"])
396
- if "x" in action:
397
- normalized = self._extract_legacy_action(action["x"])
398
- config = self._map_legacy_action_to_config(normalized)
399
- return config
400
- if isinstance(action, (int, np.integer)):
401
- return int(action)
402
- raise TypeError("Action must be an int config_id or dict with config_id.")
403
-
404
- def _extract_legacy_action(self, action: Any) -> List[float]:
405
- arr = np.clip(np.asarray(action, dtype=float), -1.0, 1.0)
406
- if arr.shape != (3,):
407
- raise ValueError("Legacy action vector must have 3 values.")
408
- return arr.tolist()
409
-
410
- def _map_legacy_action_to_config(self, action: Sequence[float]) -> int:
411
- base = (
412
- _de_norm(float(action[0]), BLOCK_SIZES),
413
- _de_norm(float(action[1]), NUM_WARPS),
414
- _de_norm(float(action[2]), NUM_STAGES),
415
- )
416
- best_id = min(
417
- self.available_config_ids(),
418
- key=lambda config_id: (
419
- self._row_for_id(config_id).block_size - base[0]
420
- ) ** 2
421
- + (self._row_for_id(config_id).num_warps - base[1]) ** 2
422
- + (self._row_for_id(config_id).num_stages - base[2]) ** 2,
423
- )
424
- return int(best_id)
425
-
426
- def _row_for_id(self, config_id: int) -> Measurement:
427
- if config_id not in self._config_by_id:
428
- raise ValueError(f"Unknown config_id={config_id}")
429
- return self._config_by_id[int(config_id)]
430
-
431
- def _invalidate_surrogate(self) -> None:
432
- self._surrogate_version += 1
433
- self._surrogate_fitted_version = -1
434
- self._surrogate_x = None
435
- self._surrogate_y = None
436
- self._surrogate_alpha = None
437
- self._surrogate_k = None
438
-
439
- def _format_step_output(
440
- self,
441
- observation: Dict[str, Any],
442
- reward: float,
443
- done: bool,
444
- info: Optional[Dict[str, Any]] = None,
445
- ) -> Dict[str, Any]:
446
- return {
447
- "observation": observation,
448
- "reward": float(reward),
449
- "done": bool(done),
450
- "state": self.state(),
451
- "info": info or {},
452
- }
453
-
454
- def _load_measurements(self) -> Dict[str, List[Measurement]]:
455
- if not self.measurement_path.exists():
456
- raise FileNotFoundError(
457
- f"Missing measurement file at {self.measurement_path}. "
458
- "Run the measurement collectors first."
459
- )
460
-
461
- grouped: Dict[str, List[Measurement]] = {}
462
- with self.measurement_path.open("r", newline="", encoding="utf-8") as handle:
463
- reader = csv.DictReader(handle)
464
- fallback_config_ids: Dict[str, int] = {}
465
- for row in reader:
466
- family = row.get("family", "softmax")
467
- family_group = row.get("family_group", "A" if family in {"softmax", "layernorm"} else "B")
468
- task_id = row["task_id"]
469
- block_size = int(row["block_size"])
470
- num_warps = int(row["num_warps"])
471
- num_stages = int(row["num_stages"])
472
- config_id_raw = row.get("config_id")
473
- if config_id_raw in (None, ""):
474
- key = f"{task_id}|{block_size}|{num_warps}|{num_stages}"
475
- if key not in fallback_config_ids:
476
- fallback_config_ids[key] = len([k for k in fallback_config_ids if k.startswith(f"{task_id}|")])
477
- config_id = fallback_config_ids[key]
478
- else:
479
- config_id = int(config_id_raw)
480
-
481
- measurement = Measurement(
482
- family_group=family_group,
483
- family=family,
484
- task_id=task_id,
485
- m=int(row["m"]),
486
- n=int(row["n"]),
487
- k=int(row.get("k", 0) or 0),
488
- config_id=config_id,
489
- block_size=block_size,
490
- num_warps=num_warps,
491
- num_stages=num_stages,
492
- shape_json=row.get("shape_json", "{}"),
493
- config_json=row.get("config_json", "{}"),
494
- median_ms=float(row["median_ms"]),
495
- effective_gbps=float(row["effective_gbps"]),
496
- score=float(row["score"]),
497
- validation_error=float(row["validation_error"]),
498
- )
499
- grouped.setdefault(task_id, []).append(measurement)
500
-
501
- for task_id in grouped:
502
- grouped[task_id].sort(key=lambda row: row.config_id)
503
- return grouped
504
-
505
-
506
- def _config_to_vector(row: Measurement) -> np.ndarray:
507
- family_vec = np.zeros(len(FAMILY_INDEX), dtype=np.float32)
508
- if row.family in FAMILY_INDEX:
509
- family_vec[FAMILY_INDEX[row.family]] = 1.0
510
- shape_fields = {}
511
- try:
512
- shape_fields = json.loads(row.shape_json) if row.shape_json else {}
513
- except Exception:
514
- shape_fields = {}
515
- mode_val = _shape_scalar(shape_fields.get("mode_id"))
516
- role_val = _shape_scalar(shape_fields.get("role_id"))
517
- seq_val = _shape_scalar(shape_fields.get("seq_len"))
518
- ctx_val = _shape_scalar(shape_fields.get("ctx_len"))
519
- group_val = _shape_scalar(shape_fields.get("group_count"))
520
- numeric = np.array(
521
- [
522
- math.log2(max(row.m, 1)) / 16.0,
523
- math.log2(max(row.n, 1)) / 16.0,
524
- math.log2(max(row.k, 1)) / 16.0 if row.k > 0 else 0.0,
525
- math.log2(max(seq_val, 1.0)) / 16.0 if seq_val > 0 else 0.0,
526
- math.log2(max(ctx_val, 1.0)) / 16.0 if ctx_val > 0 else 0.0,
527
- math.log2(max(group_val, 1.0)) / 8.0 if group_val > 0 else 0.0,
528
- mode_val / 8.0,
529
- role_val / 16.0,
530
- _normalize_discrete(BLOCK_SIZES, row.block_size),
531
- _normalize_discrete(NUM_WARPS, row.num_warps),
532
- _normalize_discrete(NUM_STAGES, row.num_stages),
533
- ],
534
- dtype=np.float32,
535
- )
536
- return np.concatenate([family_vec, numeric], axis=0)
537
-
538
-
539
- def _pairwise_sq_dists(X: np.ndarray) -> np.ndarray:
540
- diff = X[:, None, :] - X[None, :, :]
541
- return np.sum(diff * diff, axis=2)
542
-
543
-
544
- def _rbf_kernel(X: np.ndarray, Y: np.ndarray, length_scale: float) -> np.ndarray:
545
- sigma2 = float(length_scale * length_scale)
546
- if sigma2 <= 0:
547
- sigma2 = 1e-6
548
- xy = X @ Y.T
549
- x2 = np.sum(X * X, axis=1)[:, None]
550
- y2 = np.sum(Y * Y, axis=1)[None, :]
551
- d2 = np.maximum(x2 - 2.0 * xy + y2, 0.0)
552
- return np.exp(-0.5 * d2 / sigma2).astype(np.float32)
553
-
554
-
555
- def _normal_pdf(z: float) -> float:
556
- inv_sqrt_2pi = 1.0 / math.sqrt(2.0 * math.pi)
557
- return float(inv_sqrt_2pi * math.exp(-0.5 * z * z))
558
-
559
-
560
- def _normal_cdf(z: float) -> float:
561
- return 0.5 * (1.0 + math.erf(z / math.sqrt(2.0)))
562
-
563
-
564
- def _shape_scalar(value: Any) -> float:
565
- if value is None:
566
- return 0.0
567
- try:
568
- return float(value)
569
- except (TypeError, ValueError):
570
- return 0.0