Spaces:
Configuration error
Configuration error
Update index.html
#1
by
SaiPranavSripathi - opened
This view is limited to 50 files because it contains too many changes.
See the raw diff here.
- .gitattributes +35 -0
- .gitignore +0 -21
- BENCHMARK_REPORT.md +0 -439
- HACKATHON_SUMMARY.md +0 -179
- README.md +12 -129
- backend/main.py +0 -177
- backend/requirements.txt +0 -3
- backend/terminal_manager.py +0 -342
- client.py +0 -77
- frontend/.gitignore +0 -24
- frontend/README.md +0 -16
- frontend/eslint.config.js +0 -29
- frontend/index.html +0 -13
- frontend/package-lock.json +0 -0
- frontend/package.json +0 -27
- frontend/public/vite.svg +0 -1
- frontend/src/App.jsx +0 -195
- frontend/src/api/terminal.js +0 -57
- frontend/src/assets/react.svg +0 -1
- frontend/src/components/TerminalPane.jsx +0 -147
- frontend/src/hooks/useTerminalSession.js +0 -166
- frontend/src/index.css +0 -486
- frontend/src/main.jsx +0 -10
- frontend/vite.config.js +0 -7
- index.html +317 -0
- models.py +0 -27
- openenv.yaml +0 -15
- requirements.txt +0 -8
- scripts/__init__.py +0 -1
- scripts/benchmark_qwen_05b_runtime.py +0 -158
- scripts/benchmark_runtime_references.py +0 -222
- scripts/benchmark_surrogate_vs_triton_autotune.py +0 -361
- scripts/build_benchmark_splits.py +0 -81
- scripts/build_qwen_05b_splits.py +0 -56
- scripts/check_torch_cuda.py +0 -18
- scripts/collect_measurements.py +0 -398
- scripts/collect_multifamily_measurements.py +0 -659
- scripts/collect_qwen_05b_measurements.py +0 -449
- scripts/demo_compare.py +0 -200
- scripts/eval_generalization.py +0 -147
- scripts/qwen_05b_spec.py +0 -115
- scripts/report_task_hardness.py +0 -42
- scripts/run_full_pipeline.sh +0 -84
- scripts/run_qwen_05b_pipeline.sh +0 -90
- scripts/run_random_baseline.py +0 -187
- scripts/run_surrogate_baseline.py +0 -217
- scripts/smoke_test_client.py +0 -28
- server/__init__.py +0 -1
- server/app.py +0 -72
- server/softmax_surrogate_environment.py +0 -570
.gitattributes
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
| 2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
| 3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
| 5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
| 6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
| 7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
| 8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
| 9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
| 10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
| 11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
| 12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
| 13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
| 14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
| 15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
| 16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
| 17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
| 18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
| 19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
| 20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
| 21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
| 22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
| 25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
| 27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
| 28 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
| 29 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
| 30 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
| 31 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
| 32 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
| 33 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
.gitignore
DELETED
|
@@ -1,21 +0,0 @@
|
|
| 1 |
-
__pycache__/
|
| 2 |
-
*.pyc
|
| 3 |
-
*.pyo
|
| 4 |
-
*.log
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
outputs/
|
| 8 |
-
data/softmax_measurements.csv
|
| 9 |
-
backend/build/
|
| 10 |
-
backend/dist/
|
| 11 |
-
backend/venv/
|
| 12 |
-
backend/__pycache__/
|
| 13 |
-
backend/*.pyc
|
| 14 |
-
backend/*.pyo
|
| 15 |
-
backend/*.log
|
| 16 |
-
backend/outputs/
|
| 17 |
-
backend/data/
|
| 18 |
-
backend/data/softmax_measurements.csv
|
| 19 |
-
backend/backend/build/
|
| 20 |
-
backend/backend/dist/
|
| 21 |
-
backend/backend/__pycache__/
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
BENCHMARK_REPORT.md
DELETED
|
@@ -1,439 +0,0 @@
|
|
| 1 |
-
# GPU Autotuning Benchmark Report
|
| 2 |
-
|
| 3 |
-
## Project Summary
|
| 4 |
-
|
| 5 |
-
This project implements a discrete-action, measured-oracle autotuning benchmark for GPU kernels on an H200-class machine. The optimizer is a surrogate-guided search policy that selects kernel configurations under a short evaluation budget. The benchmark is intentionally structured to answer two different questions:
|
| 6 |
-
|
| 7 |
-
1. How much more search-efficient is a surrogate-guided policy than a random baseline?
|
| 8 |
-
2. Do the kernel configurations found by the search process translate into materially better runtime than strong PyTorch baselines such as eager execution and `torch.compile`?
|
| 9 |
-
|
| 10 |
-
The current benchmark answers both questions with real measurements.
|
| 11 |
-
|
| 12 |
-
## What Was Built
|
| 13 |
-
|
| 14 |
-
The current system includes:
|
| 15 |
-
|
| 16 |
-
- A shared multi-family measurement cache:
|
| 17 |
-
- `softmax`
|
| 18 |
-
- `layernorm`
|
| 19 |
-
- `grouped_gemm`
|
| 20 |
-
- A discrete-action environment with:
|
| 21 |
-
- hidden oracle table
|
| 22 |
-
- short-budget search episodes
|
| 23 |
-
- regret-based metrics
|
| 24 |
-
- train-task priors for cross-task generalization
|
| 25 |
-
- Split-based evaluation:
|
| 26 |
-
- held-out shapes
|
| 27 |
-
- held-out family
|
| 28 |
-
- Runtime reference benchmarking against:
|
| 29 |
-
- eager PyTorch
|
| 30 |
-
- `torch.compile`
|
| 31 |
-
- live reruns of the best Triton configs found by `random`
|
| 32 |
-
- live reruns of the best Triton configs found by `surrogate`
|
| 33 |
-
|
| 34 |
-
The key design choice is that search evaluation and runtime evaluation are separated:
|
| 35 |
-
|
| 36 |
-
- Search benchmark: measures how quickly each method finds good configs
|
| 37 |
-
- Runtime benchmark: measures how fast the final chosen kernels actually run
|
| 38 |
-
|
| 39 |
-
## Benchmark Structure
|
| 40 |
-
|
| 41 |
-
### Families
|
| 42 |
-
|
| 43 |
-
- Family A:
|
| 44 |
-
- `softmax`
|
| 45 |
-
- `layernorm`
|
| 46 |
-
- Family B:
|
| 47 |
-
- `grouped_gemm`
|
| 48 |
-
|
| 49 |
-
### Shapes
|
| 50 |
-
|
| 51 |
-
Current task shapes use:
|
| 52 |
-
|
| 53 |
-
- fixed `M = 4096`
|
| 54 |
-
- `N ∈ {256, 512, 1024, 1536, 2048, 3072, 4096, 6144, 8192}`
|
| 55 |
-
|
| 56 |
-
This yields:
|
| 57 |
-
|
| 58 |
-
- `9` softmax tasks
|
| 59 |
-
- `9` layernorm tasks
|
| 60 |
-
- `9` grouped GEMM tasks
|
| 61 |
-
- `27` tasks total
|
| 62 |
-
|
| 63 |
-
### Search Space
|
| 64 |
-
|
| 65 |
-
Current kernel config axes:
|
| 66 |
-
|
| 67 |
-
- `block_size`
|
| 68 |
-
- `num_warps`
|
| 69 |
-
- `num_stages`
|
| 70 |
-
|
| 71 |
-
Search is over discrete `config_id`s, not projected continuous actions. That makes the random baseline a true uniform baseline over legal configs.
|
| 72 |
-
|
| 73 |
-
### Evaluation Splits
|
| 74 |
-
|
| 75 |
-
The benchmark currently uses:
|
| 76 |
-
|
| 77 |
-
- `shape_generalization`
|
| 78 |
-
- train on lower/intermediate shapes from each family
|
| 79 |
-
- test on held-out larger shapes within the same families
|
| 80 |
-
- `family_holdout`
|
| 81 |
-
- train on `softmax` + `layernorm`
|
| 82 |
-
- test on all `grouped_gemm` tasks
|
| 83 |
-
|
| 84 |
-
At the time of the main run:
|
| 85 |
-
|
| 86 |
-
- shape-generalization test tasks: `6`
|
| 87 |
-
- family-holdout test tasks: `9`
|
| 88 |
-
- unique held-out runtime-reference tasks: `13`
|
| 89 |
-
|
| 90 |
-
## Metrics
|
| 91 |
-
|
| 92 |
-
### Search Metrics
|
| 93 |
-
|
| 94 |
-
The main optimization metrics are:
|
| 95 |
-
|
| 96 |
-
- `regret@k`
|
| 97 |
-
- best-so-far latency relative to oracle best after `k` search steps
|
| 98 |
-
- best-so-far AUC
|
| 99 |
-
- average regret across the whole search trajectory
|
| 100 |
-
- lower is better
|
| 101 |
-
- final oracle hit rate
|
| 102 |
-
- how often the optimizer reaches the oracle-best config by the end of the episode
|
| 103 |
-
|
| 104 |
-
These metrics intentionally emphasize search efficiency, not just the final endpoint.
|
| 105 |
-
|
| 106 |
-
### Runtime Metrics
|
| 107 |
-
|
| 108 |
-
The runtime-reference benchmark records:
|
| 109 |
-
|
| 110 |
-
- eager PyTorch steady-state latency
|
| 111 |
-
- `torch.compile` compile-plus-first-call time
|
| 112 |
-
- `torch.compile` steady-state latency
|
| 113 |
-
- live rerun latency of best Triton config found by `random`
|
| 114 |
-
- live rerun latency of best Triton config found by `surrogate`
|
| 115 |
-
- Triton speedups vs eager and vs compiled
|
| 116 |
-
|
| 117 |
-
## Main Search Results
|
| 118 |
-
|
| 119 |
-
The main search results came from the held-out-shape / held-out-family evaluation in `run13.log`.
|
| 120 |
-
|
| 121 |
-
### Held-out Shapes
|
| 122 |
-
|
| 123 |
-
Random:
|
| 124 |
-
|
| 125 |
-
- `mean_regret_at`:
|
| 126 |
-
- `@1 = 0.31341859698295593`
|
| 127 |
-
- `@3 = 0.13305269181728363`
|
| 128 |
-
- `@5 = 0.1070360466837883`
|
| 129 |
-
- `@6 = 0.06344400346279144`
|
| 130 |
-
- `mean_best_so_far_auc = 0.1483089178800583`
|
| 131 |
-
- `mean_oracle_hit_rate_final = 0.6749999523162842`
|
| 132 |
-
|
| 133 |
-
Surrogate:
|
| 134 |
-
|
| 135 |
-
- `mean_regret_at`:
|
| 136 |
-
- `@1 = 0.002321675419807434`
|
| 137 |
-
- `@3 = 0.002293013734742999`
|
| 138 |
-
- `@5 = 0.002293013734742999`
|
| 139 |
-
- `@6 = 0.002293013734742999`
|
| 140 |
-
- `mean_best_so_far_auc = 0.0023013732861727476`
|
| 141 |
-
- `mean_oracle_hit_rate_final = 0.8333333134651184`
|
| 142 |
-
|
| 143 |
-
Interpretation:
|
| 144 |
-
|
| 145 |
-
- Surrogate reduced best-so-far AUC by about `98.4%` versus random.
|
| 146 |
-
- Surrogate reduced final `regret@6` by about `96.4%`.
|
| 147 |
-
- Surrogate improved final oracle-hit rate by about `15.8` percentage points.
|
| 148 |
-
|
| 149 |
-
This is a very strong within-family / held-out-shape result.
|
| 150 |
-
|
| 151 |
-
### Held-out Family (`grouped_gemm`)
|
| 152 |
-
|
| 153 |
-
Random:
|
| 154 |
-
|
| 155 |
-
- `mean_regret_at`:
|
| 156 |
-
- `@1 = 2.341181755065918`
|
| 157 |
-
- `@3 = 0.8532703518867493`
|
| 158 |
-
- `@5 = 0.3116174638271332`
|
| 159 |
-
- `@6 = 0.21012252569198608`
|
| 160 |
-
- `mean_best_so_far_auc = 0.9102223515510559`
|
| 161 |
-
- `mean_oracle_hit_rate_final = 0.17777778208255768`
|
| 162 |
-
|
| 163 |
-
Surrogate:
|
| 164 |
-
|
| 165 |
-
- `mean_regret_at`:
|
| 166 |
-
- `@1 = 0.4722703695297241`
|
| 167 |
-
- `@3 = 0.29785311222076416`
|
| 168 |
-
- `@5 = 0.20862582325935364`
|
| 169 |
-
- `@6 = 0.17804712057113647`
|
| 170 |
-
- `mean_best_so_far_auc = 0.3014116585254669`
|
| 171 |
-
- `mean_oracle_hit_rate_final = 0.3444444239139557`
|
| 172 |
-
|
| 173 |
-
Interpretation:
|
| 174 |
-
|
| 175 |
-
- Surrogate reduced best-so-far AUC by about `66.9%` versus random.
|
| 176 |
-
- Surrogate reduced final `regret@6` by about `15.3%`.
|
| 177 |
-
- Surrogate improved final oracle-hit rate by about `16.7` percentage points.
|
| 178 |
-
|
| 179 |
-
This is a good cross-family transfer result. The gap is smaller than in the held-out-shape setting, which is expected.
|
| 180 |
-
|
| 181 |
-
## Main Runtime Results
|
| 182 |
-
|
| 183 |
-
The runtime-reference benchmark compares eager PyTorch, `torch.compile`, and the live reruns of the best Triton configs found by `random` and `surrogate`.
|
| 184 |
-
|
| 185 |
-
### Summary by Family
|
| 186 |
-
|
| 187 |
-
#### Softmax
|
| 188 |
-
|
| 189 |
-
Average eager latency:
|
| 190 |
-
|
| 191 |
-
- `0.1103919968008995 ms`
|
| 192 |
-
|
| 193 |
-
Average `torch.compile` steady-state latency:
|
| 194 |
-
|
| 195 |
-
- `0.11152799427509308 ms`
|
| 196 |
-
|
| 197 |
-
Average compile-plus-first-call time:
|
| 198 |
-
|
| 199 |
-
- `529.9687260048813 ms`
|
| 200 |
-
|
| 201 |
-
Average Triton speedup vs eager:
|
| 202 |
-
|
| 203 |
-
- random-best: `3.362561387683493x`
|
| 204 |
-
- surrogate-best: `3.286588301595338x`
|
| 205 |
-
|
| 206 |
-
Average Triton speedup vs compiled:
|
| 207 |
-
|
| 208 |
-
- random-best: `3.3985671575178635x`
|
| 209 |
-
- surrogate-best: `3.321742054891467x`
|
| 210 |
-
|
| 211 |
-
Interpretation:
|
| 212 |
-
|
| 213 |
-
- `torch.compile` is effectively flat vs eager on this softmax set.
|
| 214 |
-
- Tuned Triton is substantially faster than both.
|
| 215 |
-
- Surrogate-best and random-best final kernels are very close in absolute runtime.
|
| 216 |
-
|
| 217 |
-
#### LayerNorm
|
| 218 |
-
|
| 219 |
-
Average eager latency:
|
| 220 |
-
|
| 221 |
-
- `0.05939200147986412 ms`
|
| 222 |
-
|
| 223 |
-
Average `torch.compile` steady-state latency:
|
| 224 |
-
|
| 225 |
-
- `0.09503999352455139 ms`
|
| 226 |
-
|
| 227 |
-
Average compile-plus-first-call time:
|
| 228 |
-
|
| 229 |
-
- `440.1235789991915 ms`
|
| 230 |
-
|
| 231 |
-
Average Triton speedup vs eager:
|
| 232 |
-
|
| 233 |
-
- random-best: `1.8776593781360051x`
|
| 234 |
-
- surrogate-best: `1.8364378273209185x`
|
| 235 |
-
|
| 236 |
-
Average Triton speedup vs compiled:
|
| 237 |
-
|
| 238 |
-
- random-best: `2.927484944635789x`
|
| 239 |
-
- surrogate-best: `2.862647103483093x`
|
| 240 |
-
|
| 241 |
-
Interpretation:
|
| 242 |
-
|
| 243 |
-
- `torch.compile` is slower than eager on this LayerNorm set.
|
| 244 |
-
- Tuned Triton is materially faster than both eager and compiled.
|
| 245 |
-
- Again, surrogate-best and random-best final kernels are close in endpoint runtime.
|
| 246 |
-
|
| 247 |
-
#### Grouped GEMM
|
| 248 |
-
|
| 249 |
-
Average eager latency:
|
| 250 |
-
|
| 251 |
-
- `0.05589688859052128 ms`
|
| 252 |
-
|
| 253 |
-
Average `torch.compile` steady-state latency:
|
| 254 |
-
|
| 255 |
-
- `0.101806221736802 ms`
|
| 256 |
-
|
| 257 |
-
Average compile-plus-first-call time:
|
| 258 |
-
|
| 259 |
-
- `102.45987688863858 ms`
|
| 260 |
-
|
| 261 |
-
Average Triton speedup vs eager:
|
| 262 |
-
|
| 263 |
-
- random-best: `1.2771213149737215x`
|
| 264 |
-
- surrogate-best: `1.2644549628354071x`
|
| 265 |
-
|
| 266 |
-
Average Triton speedup vs compiled:
|
| 267 |
-
|
| 268 |
-
- random-best: `2.4414293463407355x`
|
| 269 |
-
- surrogate-best: `2.4156697207038382x`
|
| 270 |
-
|
| 271 |
-
Interpretation:
|
| 272 |
-
|
| 273 |
-
- `torch.compile` is materially slower than eager on this grouped-GEMM set.
|
| 274 |
-
- Tuned Triton is faster than both eager and compiled.
|
| 275 |
-
- Endpoint difference between surrogate-best and random-best remains small.
|
| 276 |
-
|
| 277 |
-
## Triton Autotune Comparison
|
| 278 |
-
|
| 279 |
-
We also compared the surrogate search workflow directly against `triton.autotune` on three large softmax tasks:
|
| 280 |
-
|
| 281 |
-
- `softmax_m4096_n4096`
|
| 282 |
-
- `softmax_m4096_n6144`
|
| 283 |
-
- `softmax_m4096_n8192`
|
| 284 |
-
|
| 285 |
-
This comparison measures two things:
|
| 286 |
-
|
| 287 |
-
- search plus compile cost to the first usable tuned kernel
|
| 288 |
-
- steady-state runtime of the resulting tuned kernel
|
| 289 |
-
|
| 290 |
-
### Per-task Results
|
| 291 |
-
|
| 292 |
-
#### `softmax_m4096_n4096`
|
| 293 |
-
|
| 294 |
-
- oracle best: `0.02127999998629093 ms`
|
| 295 |
-
- surrogate:
|
| 296 |
-
- decision time: `33.06370500649791 ms`
|
| 297 |
-
- compile plus first call: `294.734695009538 ms`
|
| 298 |
-
- steady-state: `0.02127999998629093 ms`
|
| 299 |
-
- regret vs oracle: `0.0`
|
| 300 |
-
- `triton.autotune`:
|
| 301 |
-
- first call: `8970.702438004082 ms`
|
| 302 |
-
- steady-state: `0.021856000646948814 ms`
|
| 303 |
-
- regret vs oracle: `0.0270677002363231`
|
| 304 |
-
|
| 305 |
-
#### `softmax_m4096_n6144`
|
| 306 |
-
|
| 307 |
-
- oracle best: `0.030719999223947525 ms`
|
| 308 |
-
- surrogate:
|
| 309 |
-
- decision time: `15.47088599181734 ms`
|
| 310 |
-
- compile plus first call: `0.9627069957787171 ms`
|
| 311 |
-
- steady-state: `0.031007999554276466 ms`
|
| 312 |
-
- regret vs oracle: `0.009375010989727928`
|
| 313 |
-
- `triton.autotune`:
|
| 314 |
-
- first call: `1990.3547260037158 ms`
|
| 315 |
-
- steady-state: `0.031039999797940254 ms`
|
| 316 |
-
- regret vs oracle: `0.010416685614473398`
|
| 317 |
-
|
| 318 |
-
#### `softmax_m4096_n8192`
|
| 319 |
-
|
| 320 |
-
- oracle best: `0.03747199848294258 ms`
|
| 321 |
-
- surrogate:
|
| 322 |
-
- decision time: `15.47144899086561 ms`
|
| 323 |
-
- compile plus first call: `0.09679699724074453 ms`
|
| 324 |
-
- steady-state: `0.03798399865627289 ms`
|
| 325 |
-
- regret vs oracle: `0.013663540618560122`
|
| 326 |
-
- `triton.autotune`:
|
| 327 |
-
- first call: `1988.913670007605 ms`
|
| 328 |
-
- steady-state: `0.03747199848294258 ms`
|
| 329 |
-
- regret vs oracle: `0.0`
|
| 330 |
-
|
| 331 |
-
### Summary
|
| 332 |
-
|
| 333 |
-
- mean surrogate compile plus first call: `98.59806633418582 ms`
|
| 334 |
-
- mean surrogate steady-state latency: `0.03009066606561343 ms`
|
| 335 |
-
- mean `triton.autotune` first call: `4316.656944671801 ms`
|
| 336 |
-
- mean `triton.autotune` steady-state latency: `0.030122666309277218 ms`
|
| 337 |
-
|
| 338 |
-
Interpretation:
|
| 339 |
-
|
| 340 |
-
- The surrogate reaches a competitive tuned kernel far faster than `triton.autotune` on these tasks.
|
| 341 |
-
- Steady-state performance is effectively the same on average:
|
| 342 |
-
- surrogate mean steady-state: `0.0300907 ms`
|
| 343 |
-
- `triton.autotune` mean steady-state: `0.0301227 ms`
|
| 344 |
-
- On one task the surrogate exactly matched the oracle best.
|
| 345 |
-
- On the other two tasks the surrogate was slightly off the oracle, but still close.
|
| 346 |
-
- `triton.autotune` won one task in steady-state quality, but paid a much larger first-call search cost.
|
| 347 |
-
|
| 348 |
-
This is an important result because it shows the surrogate is not only better than a random search baseline. It is also competitive with Triton's built-in autotuning in final kernel quality while being dramatically cheaper in tuning-time-to-first-good-kernel on these tested shapes.
|
| 349 |
-
|
| 350 |
-
## What The Results Mean
|
| 351 |
-
|
| 352 |
-
The results support the following conclusions:
|
| 353 |
-
|
| 354 |
-
1. The surrogate optimizer is genuinely useful as a search policy.
|
| 355 |
-
- It is substantially more sample-efficient than random.
|
| 356 |
-
- It reaches good kernels much earlier in the budget.
|
| 357 |
-
- This effect is very strong on held-out shapes and still meaningful on held-out family transfer.
|
| 358 |
-
|
| 359 |
-
2. The resulting Triton kernels are genuinely useful as runtime implementations.
|
| 360 |
-
- They are faster than eager PyTorch.
|
| 361 |
-
- They are faster than `torch.compile`.
|
| 362 |
-
- The advantage is strongest on `softmax`, then `layernorm`, then `grouped_gemm`.
|
| 363 |
-
|
| 364 |
-
3. The surrogate is also competitive with `triton.autotune` on final steady-state runtime while being much cheaper in first-call tuning cost on the tested softmax shapes.
|
| 365 |
-
|
| 366 |
-
4. The main value of the surrogate is search efficiency, not necessarily a dramatically better final endpoint than a lucky random search.
|
| 367 |
-
- By the end of the short search budget, random and surrogate can still land on very similar endpoint kernels.
|
| 368 |
-
- This is visible in the runtime benchmark, where random-best and surrogate-best final kernels are often close in ms.
|
| 369 |
-
- The surrogate still wins decisively on regret and AUC.
|
| 370 |
-
|
| 371 |
-
This is a coherent and valuable result. The optimizer is improving how quickly good kernels are found, and the kernels it finds are fast in absolute runtime.
|
| 372 |
-
|
| 373 |
-
## Caveats and Professional Notes
|
| 374 |
-
|
| 375 |
-
This benchmark is strong for a hackathon project, but it should not be oversold.
|
| 376 |
-
|
| 377 |
-
Important caveats:
|
| 378 |
-
|
| 379 |
-
- The benchmark is not publication-grade yet.
|
| 380 |
-
- Only one `M` value is used.
|
| 381 |
-
- Family/task distributions are still narrow.
|
| 382 |
-
- Grouped GEMM was added recently and should be validated more deeply.
|
| 383 |
-
- `torch.compile` compile time measurement should be interpreted carefully.
|
| 384 |
-
- The measured compile-plus-first-call times vary across tasks.
|
| 385 |
-
- A more publication-grade measurement would reset compiler state more aggressively and isolate cold-start behavior more carefully.
|
| 386 |
-
- The runtime benchmark uses live reruns.
|
| 387 |
-
- That means small run-to-run variation is expected.
|
| 388 |
-
- The endpoint runtime gap between surrogate-best and random-best is small.
|
| 389 |
-
- This is not a failure of the benchmark.
|
| 390 |
-
- It means the benchmark currently demonstrates search-efficiency gains more strongly than final-endpoint gains.
|
| 391 |
-
|
| 392 |
-
These caveats do not invalidate the results. They define the proper scope of the claims.
|
| 393 |
-
|
| 394 |
-
## Recommended Project Framing
|
| 395 |
-
|
| 396 |
-
The most honest and compelling project framing is:
|
| 397 |
-
|
| 398 |
-
> We built a measured-oracle GPU autotuning benchmark with held-out-shape and held-out-family evaluation. Our surrogate-guided optimizer substantially outperforms a random baseline on short-budget search efficiency, and the Triton kernels it finds are materially faster than both eager PyTorch and `torch.compile`.
|
| 399 |
-
|
| 400 |
-
That statement matches the data.
|
| 401 |
-
|
| 402 |
-
## Recommended Next Steps
|
| 403 |
-
|
| 404 |
-
Highest-value next steps:
|
| 405 |
-
|
| 406 |
-
1. Expand the runtime benchmark:
|
| 407 |
-
- more held-out tasks
|
| 408 |
-
- more end-to-end summaries
|
| 409 |
-
|
| 410 |
-
2. Improve search without changing the measurement cache:
|
| 411 |
-
- stronger acquisition strategies
|
| 412 |
-
- family-aware priors
|
| 413 |
-
- feature engineering before simply increasing parameter count
|
| 414 |
-
|
| 415 |
-
3. Make the collector more production-friendly:
|
| 416 |
-
- incremental writes
|
| 417 |
-
- progress logging
|
| 418 |
-
- resumable measurement collection
|
| 419 |
-
|
| 420 |
-
4. If needed, increase benchmark rigor:
|
| 421 |
-
- broader shape sets
|
| 422 |
-
- more families
|
| 423 |
-
- more careful cold-start `torch.compile` accounting
|
| 424 |
-
|
| 425 |
-
## Deliverable Status
|
| 426 |
-
|
| 427 |
-
At the current stage, this project is:
|
| 428 |
-
|
| 429 |
-
- hackathon-ready
|
| 430 |
-
- technically credible
|
| 431 |
-
- professionally explainable
|
| 432 |
-
- useful as a platform for next-stage kernel autotuning work
|
| 433 |
-
|
| 434 |
-
It is not yet:
|
| 435 |
-
|
| 436 |
-
- a finished research benchmark
|
| 437 |
-
- a final systems paper artifact
|
| 438 |
-
|
| 439 |
-
That is the correct level of rigor for the current results.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
HACKATHON_SUMMARY.md
DELETED
|
@@ -1,179 +0,0 @@
|
|
| 1 |
-
# GPU Kernel Autotuning: Hackathon Summary
|
| 2 |
-
|
| 3 |
-
## What We Built
|
| 4 |
-
|
| 5 |
-
We built a measured-oracle GPU autotuning benchmark for Triton kernels on an H200-class machine.
|
| 6 |
-
|
| 7 |
-
The system has two layers:
|
| 8 |
-
|
| 9 |
-
- Search benchmark
|
| 10 |
-
- compares `surrogate` vs `random`
|
| 11 |
-
- measures how quickly each method finds good kernel configs
|
| 12 |
-
- Runtime benchmark
|
| 13 |
-
- compares the selected Triton kernels against eager PyTorch and `torch.compile`
|
| 14 |
-
- measures actual execution latency in milliseconds
|
| 15 |
-
|
| 16 |
-
The benchmark is discrete-action, uses real measured kernel runtimes, and supports held-out-shape and held-out-family evaluation.
|
| 17 |
-
|
| 18 |
-
## Kernel Families
|
| 19 |
-
|
| 20 |
-
Current benchmark families:
|
| 21 |
-
|
| 22 |
-
- Family A:
|
| 23 |
-
- `softmax`
|
| 24 |
-
- `layernorm`
|
| 25 |
-
- Family B:
|
| 26 |
-
- `grouped_gemm`
|
| 27 |
-
|
| 28 |
-
Current shape set:
|
| 29 |
-
|
| 30 |
-
- `M = 4096`
|
| 31 |
-
- `N ∈ {256, 512, 1024, 1536, 2048, 3072, 4096, 6144, 8192}`
|
| 32 |
-
|
| 33 |
-
Total tasks:
|
| 34 |
-
|
| 35 |
-
- `27` tasks
|
| 36 |
-
|
| 37 |
-
## Why This Matters
|
| 38 |
-
|
| 39 |
-
There are two distinct questions in GPU autotuning:
|
| 40 |
-
|
| 41 |
-
1. Can an optimizer find good kernels quickly under a short tuning budget?
|
| 42 |
-
2. Are the kernels it finds actually fast in real runtime?
|
| 43 |
-
|
| 44 |
-
This project answers both.
|
| 45 |
-
|
| 46 |
-
## Search Results
|
| 47 |
-
|
| 48 |
-
### Held-out Shapes
|
| 49 |
-
|
| 50 |
-
Against the `random` baseline, the `surrogate` optimizer achieved:
|
| 51 |
-
|
| 52 |
-
- `98.4%` lower best-so-far regret AUC
|
| 53 |
-
- `96.4%` lower `regret@6`
|
| 54 |
-
- `+15.8` percentage points higher final oracle-hit rate
|
| 55 |
-
|
| 56 |
-
Interpretation:
|
| 57 |
-
|
| 58 |
-
- On held-out shapes, the surrogate is dramatically more sample-efficient than random.
|
| 59 |
-
- It finds near-optimal kernels almost immediately.
|
| 60 |
-
|
| 61 |
-
### Held-out Family (`grouped_gemm`)
|
| 62 |
-
|
| 63 |
-
Against the `random` baseline, the `surrogate` optimizer achieved:
|
| 64 |
-
|
| 65 |
-
- `66.9%` lower best-so-far regret AUC
|
| 66 |
-
- `15.3%` lower `regret@6`
|
| 67 |
-
- `+16.7` percentage points higher final oracle-hit rate
|
| 68 |
-
|
| 69 |
-
Interpretation:
|
| 70 |
-
|
| 71 |
-
- The surrogate also transfers across kernel families.
|
| 72 |
-
- The cross-family problem is harder, but the optimizer still wins clearly.
|
| 73 |
-
|
| 74 |
-
## Runtime Results
|
| 75 |
-
|
| 76 |
-
We then reran the selected Triton kernels live and compared them to:
|
| 77 |
-
|
| 78 |
-
- eager PyTorch
|
| 79 |
-
- `torch.compile`
|
| 80 |
-
- `triton.autotune` on selected large softmax tasks
|
| 81 |
-
|
| 82 |
-
### Softmax
|
| 83 |
-
|
| 84 |
-
Average surrogate-best Triton runtime:
|
| 85 |
-
|
| 86 |
-
- about `0.0336 ms`
|
| 87 |
-
|
| 88 |
-
Average `torch.compile` runtime:
|
| 89 |
-
|
| 90 |
-
- about `0.1115 ms`
|
| 91 |
-
|
| 92 |
-
Result:
|
| 93 |
-
|
| 94 |
-
- surrogate-selected Triton is about `3.32x` faster than `torch.compile`
|
| 95 |
-
|
| 96 |
-
### LayerNorm
|
| 97 |
-
|
| 98 |
-
Average surrogate-best Triton runtime:
|
| 99 |
-
|
| 100 |
-
- about `0.0323 ms`
|
| 101 |
-
|
| 102 |
-
Average `torch.compile` runtime:
|
| 103 |
-
|
| 104 |
-
- about `0.0950 ms`
|
| 105 |
-
|
| 106 |
-
Result:
|
| 107 |
-
|
| 108 |
-
- surrogate-selected Triton is about `2.86x` faster than `torch.compile`
|
| 109 |
-
|
| 110 |
-
### Grouped GEMM
|
| 111 |
-
|
| 112 |
-
Average surrogate-best Triton runtime:
|
| 113 |
-
|
| 114 |
-
- about `0.0442 ms`
|
| 115 |
-
|
| 116 |
-
Average `torch.compile` runtime:
|
| 117 |
-
|
| 118 |
-
- about `0.1018 ms`
|
| 119 |
-
|
| 120 |
-
Result:
|
| 121 |
-
|
| 122 |
-
- surrogate-selected Triton is about `2.42x` faster than `torch.compile`
|
| 123 |
-
|
| 124 |
-
### Triton Autotune on Large Softmax
|
| 125 |
-
|
| 126 |
-
We also compared the surrogate directly against `triton.autotune` on:
|
| 127 |
-
|
| 128 |
-
- `softmax_m4096_n4096`
|
| 129 |
-
- `softmax_m4096_n6144`
|
| 130 |
-
- `softmax_m4096_n8192`
|
| 131 |
-
|
| 132 |
-
Result:
|
| 133 |
-
|
| 134 |
-
- mean surrogate compile plus first call: `98.6 ms`
|
| 135 |
-
- mean `triton.autotune` first call: `4316.7 ms`
|
| 136 |
-
- mean steady-state latency was effectively the same:
|
| 137 |
-
- surrogate: `0.03009 ms`
|
| 138 |
-
- `triton.autotune`: `0.03012 ms`
|
| 139 |
-
|
| 140 |
-
Interpretation:
|
| 141 |
-
|
| 142 |
-
- The surrogate reaches a strong tuned kernel far faster than `triton.autotune` on these tested softmax shapes.
|
| 143 |
-
- Final steady-state kernel quality is essentially matched on average.
|
| 144 |
-
|
| 145 |
-
## What The Results Mean
|
| 146 |
-
|
| 147 |
-
The current system shows two real advantages:
|
| 148 |
-
|
| 149 |
-
- The surrogate is much better than random at short-budget tuning.
|
| 150 |
-
- The resulting Triton kernels are materially faster than `torch.compile`.
|
| 151 |
-
- On selected large softmax tasks, the surrogate also reaches competitive tuned kernels far faster than `triton.autotune`.
|
| 152 |
-
|
| 153 |
-
The subtle point is that the surrogate's biggest gain is in search efficiency, not necessarily in a huge final-endpoint gap over a lucky random search. That is why regret/AUC is the right optimization metric here.
|
| 154 |
-
|
| 155 |
-
## What Is Strong
|
| 156 |
-
|
| 157 |
-
- Real GPU measurements, not synthetic rewards
|
| 158 |
-
- Hidden-oracle benchmark protocol
|
| 159 |
-
- Discrete action space
|
| 160 |
-
- Held-out-shape evaluation
|
| 161 |
-
- Held-out-family evaluation
|
| 162 |
-
- Runtime comparison against eager PyTorch and `torch.compile`
|
| 163 |
-
|
| 164 |
-
## What Is Not Final Yet
|
| 165 |
-
|
| 166 |
-
- This is not publication-grade yet
|
| 167 |
-
- Only one `M` dimension is used
|
| 168 |
-
- Grouped GEMM was added recently and should be stress-validated further
|
| 169 |
-
- `torch.compile` cold-start accounting could be measured even more rigorously
|
| 170 |
-
|
| 171 |
-
## Bottom Line
|
| 172 |
-
|
| 173 |
-
This project demonstrates a credible autotuning workflow:
|
| 174 |
-
|
| 175 |
-
- surrogate-guided search beats random on search quality
|
| 176 |
-
- tuned Triton kernels beat `torch.compile` on runtime
|
| 177 |
-
- the benchmark already shows cross-family transfer
|
| 178 |
-
|
| 179 |
-
That is a strong hackathon result with professional-grade measurement discipline.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
README.md
CHANGED
|
@@ -1,129 +1,12 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
- smoke-tests the local OpenEnv-style environment
|
| 14 |
-
- evaluates `random` vs `surrogate` using `regret@k` and best-so-far AUC
|
| 15 |
-
- benchmarks eager PyTorch and `torch.compile` against best Triton configs
|
| 16 |
-
|
| 17 |
-
## Measurement cache
|
| 18 |
-
|
| 19 |
-
Main collector:
|
| 20 |
-
|
| 21 |
-
```bash
|
| 22 |
-
python3 scripts/collect_multifamily_measurements.py \
|
| 23 |
-
--output data/autotune_measurements.csv \
|
| 24 |
-
--families softmax layernorm grouped_gemm \
|
| 25 |
-
--n-cols 256 512 1024 1536 2048 3072 4096 6144 8192 \
|
| 26 |
-
--m 4096 \
|
| 27 |
-
--repeats 200 \
|
| 28 |
-
--warmup 25 \
|
| 29 |
-
--seed 0 \
|
| 30 |
-
--append
|
| 31 |
-
```
|
| 32 |
-
|
| 33 |
-
Current implemented families:
|
| 34 |
-
- `softmax`
|
| 35 |
-
- `layernorm`
|
| 36 |
-
- `grouped_gemm`
|
| 37 |
-
|
| 38 |
-
The shared CSV schema includes:
|
| 39 |
-
- `family_group`
|
| 40 |
-
- `family`
|
| 41 |
-
- `task_id`
|
| 42 |
-
- `m`, `n`
|
| 43 |
-
- `config_id`
|
| 44 |
-
- `block_size`, `num_warps`, `num_stages`
|
| 45 |
-
- `shape_json`, `config_json`
|
| 46 |
-
- `median_ms`, `effective_gbps`, `score`, `validation_error`
|
| 47 |
-
|
| 48 |
-
## Splits and eval
|
| 49 |
-
|
| 50 |
-
Build split manifests:
|
| 51 |
-
|
| 52 |
-
```bash
|
| 53 |
-
python3 scripts/build_benchmark_splits.py \
|
| 54 |
-
--measurement-path data/autotune_measurements.csv \
|
| 55 |
-
--output data/benchmark_splits.json \
|
| 56 |
-
--heldout-family grouped_gemm
|
| 57 |
-
```
|
| 58 |
-
|
| 59 |
-
Run split-based evaluation:
|
| 60 |
-
|
| 61 |
-
```bash
|
| 62 |
-
python3 scripts/eval_generalization.py \
|
| 63 |
-
--measurement-path data/autotune_measurements.csv \
|
| 64 |
-
--splits data/benchmark_splits.json \
|
| 65 |
-
--episodes 20 \
|
| 66 |
-
--budget 6 \
|
| 67 |
-
--seed 2 \
|
| 68 |
-
--acquisition ucb \
|
| 69 |
-
--beta 2.0
|
| 70 |
-
```
|
| 71 |
-
|
| 72 |
-
Benchmark absolute runtime references:
|
| 73 |
-
|
| 74 |
-
```bash
|
| 75 |
-
python3 scripts/benchmark_runtime_references.py \
|
| 76 |
-
--generalization-results outputs/generalization_eval.json \
|
| 77 |
-
--repeats 100 \
|
| 78 |
-
--warmup 10 \
|
| 79 |
-
--seed 123
|
| 80 |
-
```
|
| 81 |
-
|
| 82 |
-
Metrics:
|
| 83 |
-
- `mean_regret_at`
|
| 84 |
-
- `median_regret_at`
|
| 85 |
-
- `mean_best_so_far_auc`
|
| 86 |
-
- `mean_oracle_hit_rate_final`
|
| 87 |
-
- `eager_latency_ms`
|
| 88 |
-
- `compile_plus_first_call_ms`
|
| 89 |
-
- `compiled_latency_ms`
|
| 90 |
-
- Triton speedups vs eager / compiled
|
| 91 |
-
|
| 92 |
-
## Environment
|
| 93 |
-
|
| 94 |
-
OpenEnv metadata is in:
|
| 95 |
-
- `openenv.yaml`
|
| 96 |
-
|
| 97 |
-
The environment server still uses:
|
| 98 |
-
- `server/app.py`
|
| 99 |
-
- `server/softmax_surrogate_environment.py`
|
| 100 |
-
|
| 101 |
-
Despite the filename, the env is now multi-family and supports a train-task prior for held-out-shape / held-out-family evaluation.
|
| 102 |
-
|
| 103 |
-
## Qwen2.5-0.5B exact-kernel pipeline
|
| 104 |
-
|
| 105 |
-
This repo now also includes a model-specific benchmark pipeline for the exact inference kernel roles needed by `Qwen/Qwen2.5-0.5B`.
|
| 106 |
-
|
| 107 |
-
Kernel roles covered:
|
| 108 |
-
- `rmsnorm`
|
| 109 |
-
- attention `softmax`
|
| 110 |
-
- `q_proj`, `k_proj`, `v_proj`, `o_proj`
|
| 111 |
-
- `gate_proj`, `up_proj`, `down_proj`
|
| 112 |
-
|
| 113 |
-
Run it with:
|
| 114 |
-
|
| 115 |
-
```bash
|
| 116 |
-
./scripts/run_qwen_05b_pipeline.sh
|
| 117 |
-
```
|
| 118 |
-
|
| 119 |
-
Key files:
|
| 120 |
-
- `scripts/qwen_05b_spec.py`
|
| 121 |
-
- `scripts/collect_qwen_05b_measurements.py`
|
| 122 |
-
- `scripts/build_qwen_05b_splits.py`
|
| 123 |
-
- `scripts/benchmark_qwen_05b_runtime.py`
|
| 124 |
-
|
| 125 |
-
Outputs:
|
| 126 |
-
- `data/qwen_05b_measurements.csv`
|
| 127 |
-
- `data/qwen_05b_splits.json`
|
| 128 |
-
- `outputs/qwen_05b_generalization_eval.json`
|
| 129 |
-
- `outputs/qwen_05b_runtime_references.json`
|
|
|
|
| 1 |
+
---
|
| 2 |
+
title: RL Surrogate ENV
|
| 3 |
+
emoji: 👀
|
| 4 |
+
colorFrom: gray
|
| 5 |
+
colorTo: blue
|
| 6 |
+
sdk: static
|
| 7 |
+
pinned: false
|
| 8 |
+
license: apache-2.0
|
| 9 |
+
short_description: This is a surrogate discovery environment to improve perform
|
| 10 |
+
---
|
| 11 |
+
|
| 12 |
+
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
backend/main.py
DELETED
|
@@ -1,177 +0,0 @@
|
|
| 1 |
-
from __future__ import annotations
|
| 2 |
-
|
| 3 |
-
import sys
|
| 4 |
-
from pathlib import Path
|
| 5 |
-
from typing import Any, Dict, List, Optional
|
| 6 |
-
|
| 7 |
-
ROOT = Path(__file__).resolve().parents[1]
|
| 8 |
-
if str(ROOT) not in sys.path:
|
| 9 |
-
sys.path.insert(0, str(ROOT))
|
| 10 |
-
|
| 11 |
-
from fastapi import FastAPI, HTTPException, WebSocket, WebSocketDisconnect
|
| 12 |
-
from fastapi.middleware.cors import CORSMiddleware
|
| 13 |
-
from pydantic import BaseModel
|
| 14 |
-
|
| 15 |
-
from backend.terminal_manager import ALLOWED_JOBS, TerminalManager
|
| 16 |
-
|
| 17 |
-
app = FastAPI(
|
| 18 |
-
title="RL Autotuning Backend",
|
| 19 |
-
description="Backend API for the multi-family GPU autotuning benchmark",
|
| 20 |
-
version="0.1.0",
|
| 21 |
-
)
|
| 22 |
-
|
| 23 |
-
app.add_middleware(
|
| 24 |
-
CORSMiddleware,
|
| 25 |
-
allow_origins=[
|
| 26 |
-
"http://localhost:5173",
|
| 27 |
-
"http://127.0.0.1:5173",
|
| 28 |
-
"http://localhost:4173",
|
| 29 |
-
"http://127.0.0.1:4173",
|
| 30 |
-
],
|
| 31 |
-
allow_credentials=False,
|
| 32 |
-
allow_methods=["*"],
|
| 33 |
-
allow_headers=["*"],
|
| 34 |
-
)
|
| 35 |
-
|
| 36 |
-
env = None
|
| 37 |
-
terminal_manager = TerminalManager()
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
def _get_env():
|
| 41 |
-
global env
|
| 42 |
-
if env is None:
|
| 43 |
-
try:
|
| 44 |
-
from server.softmax_surrogate_environment import SoftmaxSurrogateEnvironment
|
| 45 |
-
env = SoftmaxSurrogateEnvironment()
|
| 46 |
-
except ImportError as exc:
|
| 47 |
-
raise HTTPException(
|
| 48 |
-
status_code=503,
|
| 49 |
-
detail=f"Environment unavailable – missing dependency: {exc.name}",
|
| 50 |
-
)
|
| 51 |
-
return env
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
class ResetRequest(BaseModel):
|
| 55 |
-
task: Optional[str] = None
|
| 56 |
-
seed: Optional[int] = None
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
class StepRequest(BaseModel):
|
| 60 |
-
config_id: Optional[int] = None
|
| 61 |
-
x: Optional[List[float]] = None
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
class SessionRequest(BaseModel):
|
| 65 |
-
job_id: str
|
| 66 |
-
restart: bool = False
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
class SessionInputRequest(BaseModel):
|
| 70 |
-
data: str
|
| 71 |
-
append_newline: bool = True
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
class SessionResizeRequest(BaseModel):
|
| 75 |
-
cols: int
|
| 76 |
-
rows: int
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
@app.get("/health")
|
| 80 |
-
def health() -> Dict[str, str]:
|
| 81 |
-
return {"status": "ok"}
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
@app.post("/reset")
|
| 85 |
-
def reset(payload: ResetRequest) -> Dict[str, Any]:
|
| 86 |
-
return _get_env().reset(task=payload.task, seed=payload.seed)
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
@app.post("/step")
|
| 90 |
-
def step(payload: StepRequest) -> Dict[str, Any]:
|
| 91 |
-
e = _get_env()
|
| 92 |
-
if payload.config_id is not None:
|
| 93 |
-
return e.step({"config_id": payload.config_id})
|
| 94 |
-
if payload.x is not None:
|
| 95 |
-
return e.step({"x": payload.x})
|
| 96 |
-
raise HTTPException(status_code=400, detail="Provide config_id or x.")
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
@app.get("/state")
|
| 100 |
-
def state() -> Dict[str, Any]:
|
| 101 |
-
return _get_env().state()
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
@app.get("/terminal/jobs")
|
| 105 |
-
def terminal_jobs() -> Dict[str, Any]:
|
| 106 |
-
return {"jobs": terminal_manager.list_jobs()}
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
@app.post("/terminal/sessions")
|
| 110 |
-
async def create_terminal_session(payload: SessionRequest) -> Dict[str, Any]:
|
| 111 |
-
if payload.job_id not in ALLOWED_JOBS:
|
| 112 |
-
raise HTTPException(status_code=404, detail=f"Unknown job_id: {payload.job_id}")
|
| 113 |
-
session = await terminal_manager.ensure_session(payload.job_id, restart=payload.restart)
|
| 114 |
-
return session.snapshot()
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
@app.get("/terminal/sessions/{session_id}")
|
| 118 |
-
def terminal_session_snapshot(session_id: str) -> Dict[str, Any]:
|
| 119 |
-
session = terminal_manager.get_session(session_id)
|
| 120 |
-
if session is None:
|
| 121 |
-
raise HTTPException(status_code=404, detail="Session not found")
|
| 122 |
-
return session.snapshot()
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
@app.post("/terminal/sessions/{session_id}/input")
|
| 126 |
-
def terminal_session_input(session_id: str, payload: SessionInputRequest) -> Dict[str, Any]:
|
| 127 |
-
session = terminal_manager.get_session(session_id)
|
| 128 |
-
if session is None:
|
| 129 |
-
raise HTTPException(status_code=404, detail="Session not found")
|
| 130 |
-
if not session.is_active:
|
| 131 |
-
raise HTTPException(status_code=409, detail="Session is not running")
|
| 132 |
-
session.write(payload.data, append_newline=payload.append_newline)
|
| 133 |
-
return {"ok": True}
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
@app.post("/terminal/sessions/{session_id}/resize")
|
| 137 |
-
def terminal_session_resize(session_id: str, payload: SessionResizeRequest) -> Dict[str, Any]:
|
| 138 |
-
session = terminal_manager.get_session(session_id)
|
| 139 |
-
if session is None:
|
| 140 |
-
raise HTTPException(status_code=404, detail="Session not found")
|
| 141 |
-
session.resize(payload.cols, payload.rows)
|
| 142 |
-
return {"ok": True}
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
@app.post("/terminal/sessions/{session_id}/stop")
|
| 146 |
-
def terminal_session_stop(session_id: str) -> Dict[str, Any]:
|
| 147 |
-
session = terminal_manager.get_session(session_id)
|
| 148 |
-
if session is None:
|
| 149 |
-
raise HTTPException(status_code=404, detail="Session not found")
|
| 150 |
-
session.interrupt()
|
| 151 |
-
return {"ok": True}
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
@app.websocket("/terminal/sessions/{session_id}/stream")
|
| 155 |
-
async def terminal_session_stream(websocket: WebSocket, session_id: str) -> None:
|
| 156 |
-
session = terminal_manager.get_session(session_id)
|
| 157 |
-
if session is None:
|
| 158 |
-
await websocket.close(code=4404)
|
| 159 |
-
return
|
| 160 |
-
|
| 161 |
-
await websocket.accept()
|
| 162 |
-
queue = await session.subscribe()
|
| 163 |
-
try:
|
| 164 |
-
await websocket.send_json(session.snapshot())
|
| 165 |
-
while True:
|
| 166 |
-
event = await queue.get()
|
| 167 |
-
await websocket.send_json(event)
|
| 168 |
-
except WebSocketDisconnect:
|
| 169 |
-
pass
|
| 170 |
-
finally:
|
| 171 |
-
session.unsubscribe(queue)
|
| 172 |
-
|
| 173 |
-
|
| 174 |
-
if __name__ == "__main__":
|
| 175 |
-
import uvicorn
|
| 176 |
-
|
| 177 |
-
uvicorn.run("main:app", host="0.0.0.0", port=8000, reload=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
backend/requirements.txt
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
fastapi
|
| 2 |
-
uvicorn[standard]
|
| 3 |
-
pydantic
|
|
|
|
|
|
|
|
|
|
|
|
backend/terminal_manager.py
DELETED
|
@@ -1,342 +0,0 @@
|
|
| 1 |
-
from __future__ import annotations
|
| 2 |
-
|
| 3 |
-
import asyncio
|
| 4 |
-
import importlib.util
|
| 5 |
-
import json
|
| 6 |
-
import os
|
| 7 |
-
import pty
|
| 8 |
-
import signal
|
| 9 |
-
import struct
|
| 10 |
-
import subprocess
|
| 11 |
-
import termios
|
| 12 |
-
import threading
|
| 13 |
-
import time
|
| 14 |
-
import uuid
|
| 15 |
-
from dataclasses import dataclass
|
| 16 |
-
from pathlib import Path
|
| 17 |
-
from typing import Any
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
REPO_ROOT = Path(__file__).resolve().parents[1]
|
| 21 |
-
BUFFER_LIMIT = 160_000
|
| 22 |
-
DEFAULT_COLS = 120
|
| 23 |
-
DEFAULT_ROWS = 36
|
| 24 |
-
PYTHON_CANDIDATES = (
|
| 25 |
-
"/usr/local/bin/python3",
|
| 26 |
-
"/opt/homebrew/bin/python3",
|
| 27 |
-
"/Users/amannindra/miniconda3/bin/python3",
|
| 28 |
-
)
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
@dataclass(frozen=True)
|
| 32 |
-
class AllowedJob:
|
| 33 |
-
job_id: str
|
| 34 |
-
label: str
|
| 35 |
-
description: str
|
| 36 |
-
command: tuple[str, ...]
|
| 37 |
-
cwd: Path
|
| 38 |
-
|
| 39 |
-
def as_dict(self) -> dict[str, Any]:
|
| 40 |
-
return {
|
| 41 |
-
"job_id": self.job_id,
|
| 42 |
-
"label": self.label,
|
| 43 |
-
"description": self.description,
|
| 44 |
-
"command": list(self.command),
|
| 45 |
-
"cwd": str(self.cwd),
|
| 46 |
-
}
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
ALLOWED_JOBS: dict[str, AllowedJob] = {
|
| 50 |
-
"qwen": AllowedJob(
|
| 51 |
-
job_id="qwen",
|
| 52 |
-
label="Qwen Baseline",
|
| 53 |
-
description="Runs the exact-kernel Qwen2.5-0.5B benchmark pipeline.",
|
| 54 |
-
command=("bash", "scripts/run_qwen_05b_pipeline.sh"),
|
| 55 |
-
cwd=REPO_ROOT,
|
| 56 |
-
),
|
| 57 |
-
"rl-agent": AllowedJob(
|
| 58 |
-
job_id="rl-agent",
|
| 59 |
-
label="RL Agent",
|
| 60 |
-
description="Runs the multi-family surrogate and runtime benchmark pipeline.",
|
| 61 |
-
command=("bash", "scripts/run_full_pipeline.sh"),
|
| 62 |
-
cwd=REPO_ROOT,
|
| 63 |
-
),
|
| 64 |
-
}
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
def _probe_python(path: str) -> dict[str, Any] | None:
|
| 68 |
-
if not Path(path).exists():
|
| 69 |
-
return None
|
| 70 |
-
|
| 71 |
-
script = (
|
| 72 |
-
"import importlib.util, json, sys; "
|
| 73 |
-
"print(json.dumps({"
|
| 74 |
-
"'executable': sys.executable, "
|
| 75 |
-
"'torch': bool(importlib.util.find_spec('torch')), "
|
| 76 |
-
"'triton': bool(importlib.util.find_spec('triton'))"
|
| 77 |
-
"}))"
|
| 78 |
-
)
|
| 79 |
-
try:
|
| 80 |
-
result = subprocess.run(
|
| 81 |
-
[path, "-c", script],
|
| 82 |
-
check=True,
|
| 83 |
-
capture_output=True,
|
| 84 |
-
text=True,
|
| 85 |
-
)
|
| 86 |
-
except (OSError, subprocess.CalledProcessError):
|
| 87 |
-
return None
|
| 88 |
-
|
| 89 |
-
try:
|
| 90 |
-
payload = json.loads(result.stdout.strip())
|
| 91 |
-
except json.JSONDecodeError:
|
| 92 |
-
return None
|
| 93 |
-
payload["path"] = path
|
| 94 |
-
return payload
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
def _best_python_runtime() -> dict[str, Any] | None:
|
| 98 |
-
explicit = os.environ.get("TERMINAL_PYTHON_BIN")
|
| 99 |
-
if explicit:
|
| 100 |
-
probe = _probe_python(explicit)
|
| 101 |
-
if probe is not None:
|
| 102 |
-
probe["score"] = int(probe["torch"]) + int(probe["triton"])
|
| 103 |
-
probe["explicit"] = True
|
| 104 |
-
return probe
|
| 105 |
-
|
| 106 |
-
best: dict[str, Any] | None = None
|
| 107 |
-
for candidate in PYTHON_CANDIDATES:
|
| 108 |
-
probe = _probe_python(candidate)
|
| 109 |
-
if probe is None:
|
| 110 |
-
continue
|
| 111 |
-
score = int(probe["torch"]) + int(probe["triton"])
|
| 112 |
-
probe["score"] = score
|
| 113 |
-
if best is None or score > best["score"]:
|
| 114 |
-
best = probe
|
| 115 |
-
return best
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
class TerminalSession:
|
| 119 |
-
def __init__(self, job: AllowedJob, loop: asyncio.AbstractEventLoop) -> None:
|
| 120 |
-
self.id = uuid.uuid4().hex
|
| 121 |
-
self.job = job
|
| 122 |
-
self.loop = loop
|
| 123 |
-
self.created_at = time.time()
|
| 124 |
-
self.started_at: float | None = None
|
| 125 |
-
self.finished_at: float | None = None
|
| 126 |
-
self.exit_code: int | None = None
|
| 127 |
-
self.status = "starting"
|
| 128 |
-
self.cols = DEFAULT_COLS
|
| 129 |
-
self.rows = DEFAULT_ROWS
|
| 130 |
-
self.python_runtime = _best_python_runtime()
|
| 131 |
-
|
| 132 |
-
self._buffer = ""
|
| 133 |
-
self._buffer_lock = threading.Lock()
|
| 134 |
-
self._subscribers: set[asyncio.Queue[dict[str, Any]]] = set()
|
| 135 |
-
self._subscriber_lock = threading.Lock()
|
| 136 |
-
|
| 137 |
-
self._master_fd, slave_fd = pty.openpty()
|
| 138 |
-
self._resize_fd(self.cols, self.rows)
|
| 139 |
-
|
| 140 |
-
env = os.environ.copy()
|
| 141 |
-
env.setdefault("TERM", "xterm-256color")
|
| 142 |
-
env.setdefault("PYTHONUNBUFFERED", "1")
|
| 143 |
-
env.setdefault("FORCE_COLOR", "1")
|
| 144 |
-
if self.python_runtime is not None:
|
| 145 |
-
python_dir = str(Path(self.python_runtime["path"]).parent)
|
| 146 |
-
env["PATH"] = f"{python_dir}:{env.get('PATH', '')}"
|
| 147 |
-
env["PYTHON_BIN"] = self.python_runtime["path"]
|
| 148 |
-
|
| 149 |
-
self._append_buffer(self._launcher_banner())
|
| 150 |
-
|
| 151 |
-
self.process = subprocess.Popen(
|
| 152 |
-
self.job.command,
|
| 153 |
-
cwd=str(self.job.cwd),
|
| 154 |
-
stdin=slave_fd,
|
| 155 |
-
stdout=slave_fd,
|
| 156 |
-
stderr=slave_fd,
|
| 157 |
-
env=env,
|
| 158 |
-
preexec_fn=os.setsid,
|
| 159 |
-
close_fds=True,
|
| 160 |
-
)
|
| 161 |
-
os.close(slave_fd)
|
| 162 |
-
|
| 163 |
-
self.started_at = time.time()
|
| 164 |
-
self.status = "running"
|
| 165 |
-
|
| 166 |
-
self._reader_thread = threading.Thread(target=self._reader_loop, daemon=True)
|
| 167 |
-
self._waiter_thread = threading.Thread(target=self._wait_loop, daemon=True)
|
| 168 |
-
self._reader_thread.start()
|
| 169 |
-
self._waiter_thread.start()
|
| 170 |
-
|
| 171 |
-
@property
|
| 172 |
-
def command_display(self) -> str:
|
| 173 |
-
return " ".join(self.job.command)
|
| 174 |
-
|
| 175 |
-
@property
|
| 176 |
-
def is_active(self) -> bool:
|
| 177 |
-
return self.process.poll() is None
|
| 178 |
-
|
| 179 |
-
def snapshot(self) -> dict[str, Any]:
|
| 180 |
-
with self._buffer_lock:
|
| 181 |
-
buffer = self._buffer
|
| 182 |
-
return {
|
| 183 |
-
"type": "snapshot",
|
| 184 |
-
"session": {
|
| 185 |
-
"id": self.id,
|
| 186 |
-
"job_id": self.job.job_id,
|
| 187 |
-
"label": self.job.label,
|
| 188 |
-
"description": self.job.description,
|
| 189 |
-
"cwd": str(self.job.cwd),
|
| 190 |
-
"command": self.command_display,
|
| 191 |
-
"status": self.status,
|
| 192 |
-
"created_at": self.created_at,
|
| 193 |
-
"started_at": self.started_at,
|
| 194 |
-
"finished_at": self.finished_at,
|
| 195 |
-
"exit_code": self.exit_code,
|
| 196 |
-
"cols": self.cols,
|
| 197 |
-
"rows": self.rows,
|
| 198 |
-
},
|
| 199 |
-
"buffer": buffer,
|
| 200 |
-
}
|
| 201 |
-
|
| 202 |
-
async def subscribe(self) -> asyncio.Queue[dict[str, Any]]:
|
| 203 |
-
queue: asyncio.Queue[dict[str, Any]] = asyncio.Queue()
|
| 204 |
-
with self._subscriber_lock:
|
| 205 |
-
self._subscribers.add(queue)
|
| 206 |
-
return queue
|
| 207 |
-
|
| 208 |
-
def unsubscribe(self, queue: asyncio.Queue[dict[str, Any]]) -> None:
|
| 209 |
-
with self._subscriber_lock:
|
| 210 |
-
self._subscribers.discard(queue)
|
| 211 |
-
|
| 212 |
-
def write(self, data: str, append_newline: bool = True) -> None:
|
| 213 |
-
if not data:
|
| 214 |
-
return
|
| 215 |
-
payload = data + ("\n" if append_newline else "")
|
| 216 |
-
os.write(self._master_fd, payload.encode("utf-8", errors="replace"))
|
| 217 |
-
|
| 218 |
-
def resize(self, cols: int, rows: int) -> None:
|
| 219 |
-
self.cols = max(20, cols)
|
| 220 |
-
self.rows = max(8, rows)
|
| 221 |
-
self._resize_fd(self.cols, self.rows)
|
| 222 |
-
|
| 223 |
-
def interrupt(self) -> None:
|
| 224 |
-
if self.process.poll() is None:
|
| 225 |
-
os.killpg(os.getpgid(self.process.pid), signal.SIGINT)
|
| 226 |
-
|
| 227 |
-
def terminate(self) -> None:
|
| 228 |
-
if self.process.poll() is None:
|
| 229 |
-
os.killpg(os.getpgid(self.process.pid), signal.SIGTERM)
|
| 230 |
-
|
| 231 |
-
def _resize_fd(self, cols: int, rows: int) -> None:
|
| 232 |
-
winsize = struct.pack("HHHH", rows, cols, 0, 0)
|
| 233 |
-
try:
|
| 234 |
-
termios.tcsetwinsize(self._master_fd, (rows, cols))
|
| 235 |
-
except AttributeError:
|
| 236 |
-
pass
|
| 237 |
-
try:
|
| 238 |
-
import fcntl
|
| 239 |
-
|
| 240 |
-
fcntl.ioctl(self._master_fd, termios.TIOCSWINSZ, winsize)
|
| 241 |
-
except OSError:
|
| 242 |
-
pass
|
| 243 |
-
|
| 244 |
-
def _append_buffer(self, chunk: str) -> None:
|
| 245 |
-
with self._buffer_lock:
|
| 246 |
-
self._buffer = (self._buffer + chunk)[-BUFFER_LIMIT:]
|
| 247 |
-
|
| 248 |
-
def _launcher_banner(self) -> str:
|
| 249 |
-
lines = [
|
| 250 |
-
f"[launcher] job: {self.job.label}",
|
| 251 |
-
f"[launcher] cwd: {self.job.cwd}",
|
| 252 |
-
f"[launcher] command: {self.command_display}",
|
| 253 |
-
]
|
| 254 |
-
if self.python_runtime is not None:
|
| 255 |
-
modules = []
|
| 256 |
-
modules.append(f"torch={'yes' if self.python_runtime['torch'] else 'no'}")
|
| 257 |
-
modules.append(f"triton={'yes' if self.python_runtime['triton'] else 'no'}")
|
| 258 |
-
lines.append(f"[launcher] python3: {self.python_runtime['path']} ({', '.join(modules)})")
|
| 259 |
-
if self.python_runtime.get("explicit"):
|
| 260 |
-
lines.append("[launcher] python3 source: TERMINAL_PYTHON_BIN")
|
| 261 |
-
if not self.python_runtime["triton"]:
|
| 262 |
-
lines.append("[launcher] warning: Triton is not installed in the selected Python runtime.")
|
| 263 |
-
else:
|
| 264 |
-
lines.append("[launcher] warning: no preferred Python runtime detected; falling back to PATH lookup.")
|
| 265 |
-
return "\n".join(lines) + "\n\n"
|
| 266 |
-
|
| 267 |
-
def _publish(self, event: dict[str, Any]) -> None:
|
| 268 |
-
with self._subscriber_lock:
|
| 269 |
-
subscribers = tuple(self._subscribers)
|
| 270 |
-
for queue in subscribers:
|
| 271 |
-
self.loop.call_soon_threadsafe(self._safe_put, queue, event)
|
| 272 |
-
|
| 273 |
-
@staticmethod
|
| 274 |
-
def _safe_put(queue: asyncio.Queue[dict[str, Any]], event: dict[str, Any]) -> None:
|
| 275 |
-
try:
|
| 276 |
-
queue.put_nowait(event)
|
| 277 |
-
except asyncio.QueueFull:
|
| 278 |
-
pass
|
| 279 |
-
|
| 280 |
-
def _reader_loop(self) -> None:
|
| 281 |
-
while True:
|
| 282 |
-
try:
|
| 283 |
-
data = os.read(self._master_fd, 4096)
|
| 284 |
-
except OSError:
|
| 285 |
-
break
|
| 286 |
-
if not data:
|
| 287 |
-
break
|
| 288 |
-
text = data.decode("utf-8", errors="replace")
|
| 289 |
-
self._append_buffer(text)
|
| 290 |
-
self._publish({"type": "output", "data": text})
|
| 291 |
-
|
| 292 |
-
def _wait_loop(self) -> None:
|
| 293 |
-
exit_code = self.process.wait()
|
| 294 |
-
self.exit_code = exit_code
|
| 295 |
-
self.finished_at = time.time()
|
| 296 |
-
self.status = "exited" if exit_code == 0 else "failed"
|
| 297 |
-
self._publish(
|
| 298 |
-
{
|
| 299 |
-
"type": "exit",
|
| 300 |
-
"exit_code": exit_code,
|
| 301 |
-
"status": self.status,
|
| 302 |
-
"finished_at": self.finished_at,
|
| 303 |
-
}
|
| 304 |
-
)
|
| 305 |
-
try:
|
| 306 |
-
os.close(self._master_fd)
|
| 307 |
-
except OSError:
|
| 308 |
-
pass
|
| 309 |
-
|
| 310 |
-
|
| 311 |
-
class TerminalManager:
|
| 312 |
-
def __init__(self) -> None:
|
| 313 |
-
self._sessions: dict[str, TerminalSession] = {}
|
| 314 |
-
self._latest_by_job: dict[str, str] = {}
|
| 315 |
-
self._lock = threading.Lock()
|
| 316 |
-
|
| 317 |
-
def list_jobs(self) -> list[dict[str, Any]]:
|
| 318 |
-
return [job.as_dict() for job in ALLOWED_JOBS.values()]
|
| 319 |
-
|
| 320 |
-
def get_session(self, session_id: str) -> TerminalSession | None:
|
| 321 |
-
with self._lock:
|
| 322 |
-
return self._sessions.get(session_id)
|
| 323 |
-
|
| 324 |
-
async def ensure_session(self, job_id: str, restart: bool = False) -> TerminalSession:
|
| 325 |
-
if job_id not in ALLOWED_JOBS:
|
| 326 |
-
raise KeyError(job_id)
|
| 327 |
-
|
| 328 |
-
with self._lock:
|
| 329 |
-
existing_id = self._latest_by_job.get(job_id)
|
| 330 |
-
existing = self._sessions.get(existing_id) if existing_id else None
|
| 331 |
-
|
| 332 |
-
if existing and existing.is_active and not restart:
|
| 333 |
-
return existing
|
| 334 |
-
|
| 335 |
-
if existing and restart:
|
| 336 |
-
existing.interrupt()
|
| 337 |
-
|
| 338 |
-
session = TerminalSession(ALLOWED_JOBS[job_id], asyncio.get_running_loop())
|
| 339 |
-
with self._lock:
|
| 340 |
-
self._sessions[session.id] = session
|
| 341 |
-
self._latest_by_job[job_id] = session.id
|
| 342 |
-
return session
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
client.py
DELETED
|
@@ -1,77 +0,0 @@
|
|
| 1 |
-
from __future__ import annotations
|
| 2 |
-
|
| 3 |
-
import argparse
|
| 4 |
-
import sys
|
| 5 |
-
from pathlib import Path
|
| 6 |
-
from typing import Any, Optional
|
| 7 |
-
|
| 8 |
-
import requests
|
| 9 |
-
|
| 10 |
-
ROOT = Path(__file__).resolve().parents[1]
|
| 11 |
-
if str(ROOT) not in sys.path:
|
| 12 |
-
sys.path.append(str(ROOT))
|
| 13 |
-
|
| 14 |
-
from server.softmax_surrogate_environment import DEFAULT_BUDGET, SoftmaxSurrogateEnvironment
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
class SoftmaxSurrogateEnvClient:
|
| 18 |
-
def __init__(
|
| 19 |
-
self,
|
| 20 |
-
base_url: Optional[str] = None,
|
| 21 |
-
measurement_path: str = "data/autotune_measurements.csv",
|
| 22 |
-
budget: int = DEFAULT_BUDGET,
|
| 23 |
-
seed: int = 0,
|
| 24 |
-
) -> None:
|
| 25 |
-
self.base_url = base_url
|
| 26 |
-
self._local_env = None
|
| 27 |
-
if base_url is None:
|
| 28 |
-
self._local_env = SoftmaxSurrogateEnvironment(
|
| 29 |
-
measurement_path=measurement_path,
|
| 30 |
-
budget=budget,
|
| 31 |
-
seed=seed,
|
| 32 |
-
)
|
| 33 |
-
|
| 34 |
-
def reset(self, task: Optional[str] = None, seed: Optional[int] = None) -> dict:
|
| 35 |
-
if self._local_env is not None:
|
| 36 |
-
return self._local_env.reset(task=task, seed=seed)
|
| 37 |
-
payload = {}
|
| 38 |
-
if task is not None:
|
| 39 |
-
payload["task"] = task
|
| 40 |
-
if seed is not None:
|
| 41 |
-
payload["seed"] = seed
|
| 42 |
-
resp = requests.post(f"{self.base_url}/reset", json=payload, timeout=60)
|
| 43 |
-
resp.raise_for_status()
|
| 44 |
-
return resp.json()
|
| 45 |
-
|
| 46 |
-
def step(self, action: Any) -> dict:
|
| 47 |
-
if self._local_env is not None:
|
| 48 |
-
return self._local_env.step(action)
|
| 49 |
-
payload = action if isinstance(action, dict) else {"x": action}
|
| 50 |
-
resp = requests.post(f"{self.base_url}/step", json=payload, timeout=120)
|
| 51 |
-
resp.raise_for_status()
|
| 52 |
-
return resp.json()
|
| 53 |
-
|
| 54 |
-
def state(self) -> dict:
|
| 55 |
-
if self._local_env is not None:
|
| 56 |
-
return self._local_env.state()
|
| 57 |
-
resp = requests.get(f"{self.base_url}/state", timeout=60)
|
| 58 |
-
resp.raise_for_status()
|
| 59 |
-
return resp.json()
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
def parse_args() -> argparse.Namespace:
|
| 63 |
-
parser = argparse.ArgumentParser()
|
| 64 |
-
parser.add_argument("--remote", default=None, help="Optional base URL (e.g. http://127.0.0.1:8000)")
|
| 65 |
-
parser.add_argument("--task", default=None)
|
| 66 |
-
parser.add_argument("--seed", type=int, default=0)
|
| 67 |
-
return parser.parse_args()
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
def main() -> None:
|
| 71 |
-
args = parse_args()
|
| 72 |
-
client = SoftmaxSurrogateEnvClient(base_url=args.remote, seed=args.seed)
|
| 73 |
-
print(client.reset(task=args.task))
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
if __name__ == "__main__":
|
| 77 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
frontend/.gitignore
DELETED
|
@@ -1,24 +0,0 @@
|
|
| 1 |
-
# Logs
|
| 2 |
-
logs
|
| 3 |
-
*.log
|
| 4 |
-
npm-debug.log*
|
| 5 |
-
yarn-debug.log*
|
| 6 |
-
yarn-error.log*
|
| 7 |
-
pnpm-debug.log*
|
| 8 |
-
lerna-debug.log*
|
| 9 |
-
|
| 10 |
-
node_modules
|
| 11 |
-
dist
|
| 12 |
-
dist-ssr
|
| 13 |
-
*.local
|
| 14 |
-
|
| 15 |
-
# Editor directories and files
|
| 16 |
-
.vscode/*
|
| 17 |
-
!.vscode/extensions.json
|
| 18 |
-
.idea
|
| 19 |
-
.DS_Store
|
| 20 |
-
*.suo
|
| 21 |
-
*.ntvs*
|
| 22 |
-
*.njsproj
|
| 23 |
-
*.sln
|
| 24 |
-
*.sw?
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
frontend/README.md
DELETED
|
@@ -1,16 +0,0 @@
|
|
| 1 |
-
# React + Vite
|
| 2 |
-
|
| 3 |
-
This template provides a minimal setup to get React working in Vite with HMR and some ESLint rules.
|
| 4 |
-
|
| 5 |
-
Currently, two official plugins are available:
|
| 6 |
-
|
| 7 |
-
- [@vitejs/plugin-react](https://github.com/vitejs/vite-plugin-react/blob/main/packages/plugin-react) uses [Babel](https://babeljs.io/) (or [oxc](https://oxc.rs) when used in [rolldown-vite](https://vite.dev/guide/rolldown)) for Fast Refresh
|
| 8 |
-
- [@vitejs/plugin-react-swc](https://github.com/vitejs/vite-plugin-react/blob/main/packages/plugin-react-swc) uses [SWC](https://swc.rs/) for Fast Refresh
|
| 9 |
-
|
| 10 |
-
## React Compiler
|
| 11 |
-
|
| 12 |
-
The React Compiler is not enabled on this template because of its impact on dev & build performances. To add it, see [this documentation](https://react.dev/learn/react-compiler/installation).
|
| 13 |
-
|
| 14 |
-
## Expanding the ESLint configuration
|
| 15 |
-
|
| 16 |
-
If you are developing a production application, we recommend using TypeScript with type-aware lint rules enabled. Check out the [TS template](https://github.com/vitejs/vite/tree/main/packages/create-vite/template-react-ts) for information on how to integrate TypeScript and [`typescript-eslint`](https://typescript-eslint.io) in your project.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
frontend/eslint.config.js
DELETED
|
@@ -1,29 +0,0 @@
|
|
| 1 |
-
import js from '@eslint/js'
|
| 2 |
-
import globals from 'globals'
|
| 3 |
-
import reactHooks from 'eslint-plugin-react-hooks'
|
| 4 |
-
import reactRefresh from 'eslint-plugin-react-refresh'
|
| 5 |
-
import { defineConfig, globalIgnores } from 'eslint/config'
|
| 6 |
-
|
| 7 |
-
export default defineConfig([
|
| 8 |
-
globalIgnores(['dist']),
|
| 9 |
-
{
|
| 10 |
-
files: ['**/*.{js,jsx}'],
|
| 11 |
-
extends: [
|
| 12 |
-
js.configs.recommended,
|
| 13 |
-
reactHooks.configs.flat.recommended,
|
| 14 |
-
reactRefresh.configs.vite,
|
| 15 |
-
],
|
| 16 |
-
languageOptions: {
|
| 17 |
-
ecmaVersion: 2020,
|
| 18 |
-
globals: globals.browser,
|
| 19 |
-
parserOptions: {
|
| 20 |
-
ecmaVersion: 'latest',
|
| 21 |
-
ecmaFeatures: { jsx: true },
|
| 22 |
-
sourceType: 'module',
|
| 23 |
-
},
|
| 24 |
-
},
|
| 25 |
-
rules: {
|
| 26 |
-
'no-unused-vars': ['error', { varsIgnorePattern: '^[A-Z_]' }],
|
| 27 |
-
},
|
| 28 |
-
},
|
| 29 |
-
])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
frontend/index.html
DELETED
|
@@ -1,13 +0,0 @@
|
|
| 1 |
-
<!doctype html>
|
| 2 |
-
<html lang="en">
|
| 3 |
-
<head>
|
| 4 |
-
<meta charset="UTF-8" />
|
| 5 |
-
<link rel="icon" type="image/svg+xml" href="/vite.svg" />
|
| 6 |
-
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
|
| 7 |
-
<title>frontend</title>
|
| 8 |
-
</head>
|
| 9 |
-
<body>
|
| 10 |
-
<div id="root"></div>
|
| 11 |
-
<script type="module" src="/src/main.jsx"></script>
|
| 12 |
-
</body>
|
| 13 |
-
</html>
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
frontend/package-lock.json
DELETED
|
The diff for this file is too large to render.
See raw diff
|
|
|
frontend/package.json
DELETED
|
@@ -1,27 +0,0 @@
|
|
| 1 |
-
{
|
| 2 |
-
"name": "frontend",
|
| 3 |
-
"private": true,
|
| 4 |
-
"version": "0.0.0",
|
| 5 |
-
"type": "module",
|
| 6 |
-
"scripts": {
|
| 7 |
-
"dev": "vite",
|
| 8 |
-
"build": "vite build",
|
| 9 |
-
"lint": "eslint .",
|
| 10 |
-
"preview": "vite preview"
|
| 11 |
-
},
|
| 12 |
-
"dependencies": {
|
| 13 |
-
"react": "^19.2.0",
|
| 14 |
-
"react-dom": "^19.2.0"
|
| 15 |
-
},
|
| 16 |
-
"devDependencies": {
|
| 17 |
-
"@eslint/js": "^9.39.1",
|
| 18 |
-
"@types/react": "^19.2.7",
|
| 19 |
-
"@types/react-dom": "^19.2.3",
|
| 20 |
-
"@vitejs/plugin-react": "^5.1.1",
|
| 21 |
-
"eslint": "^9.39.1",
|
| 22 |
-
"eslint-plugin-react-hooks": "^7.0.1",
|
| 23 |
-
"eslint-plugin-react-refresh": "^0.4.24",
|
| 24 |
-
"globals": "^16.5.0",
|
| 25 |
-
"vite": "^7.3.1"
|
| 26 |
-
}
|
| 27 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
frontend/public/vite.svg
DELETED
frontend/src/App.jsx
DELETED
|
@@ -1,195 +0,0 @@
|
|
| 1 |
-
import { useEffect, useRef, useState } from 'react'
|
| 2 |
-
import TerminalPane from './components/TerminalPane'
|
| 3 |
-
|
| 4 |
-
const panes = [
|
| 5 |
-
{ jobId: 'qwen', title: 'Qwen Baseline', tone: 'cyan' },
|
| 6 |
-
{ jobId: 'rl-agent', title: 'RL Agent', tone: 'green' },
|
| 7 |
-
]
|
| 8 |
-
|
| 9 |
-
function formatMs(value) {
|
| 10 |
-
if (value == null) {
|
| 11 |
-
return '--'
|
| 12 |
-
}
|
| 13 |
-
if (value < 1000) {
|
| 14 |
-
return `${Math.round(value)} ms`
|
| 15 |
-
}
|
| 16 |
-
return `${(value / 1000).toFixed(2)} s`
|
| 17 |
-
}
|
| 18 |
-
|
| 19 |
-
function buildRunStats(telemetry, issuedAt) {
|
| 20 |
-
if (!telemetry || !issuedAt) {
|
| 21 |
-
return {
|
| 22 |
-
responseMs: null,
|
| 23 |
-
completionMs: null,
|
| 24 |
-
waiting: true,
|
| 25 |
-
}
|
| 26 |
-
}
|
| 27 |
-
|
| 28 |
-
const responseMs = telemetry.lastOutputAt && telemetry.lastOutputAt >= issuedAt ? telemetry.lastOutputAt - issuedAt : null
|
| 29 |
-
const finishedAt = telemetry.session?.finished_at ? telemetry.session.finished_at * 1000 : null
|
| 30 |
-
const completionMs = finishedAt && finishedAt >= issuedAt ? finishedAt - issuedAt : null
|
| 31 |
-
|
| 32 |
-
return {
|
| 33 |
-
responseMs,
|
| 34 |
-
completionMs,
|
| 35 |
-
waiting: responseMs == null && completionMs == null,
|
| 36 |
-
}
|
| 37 |
-
}
|
| 38 |
-
|
| 39 |
-
function App() {
|
| 40 |
-
const [split, setSplit] = useState(50)
|
| 41 |
-
const [dragging, setDragging] = useState(false)
|
| 42 |
-
const [command, setCommand] = useState('')
|
| 43 |
-
const [comparisonRun, setComparisonRun] = useState(null)
|
| 44 |
-
const [telemetry, setTelemetry] = useState({
|
| 45 |
-
qwen: null,
|
| 46 |
-
'rl-agent': null,
|
| 47 |
-
})
|
| 48 |
-
const workspaceRef = useRef(null)
|
| 49 |
-
const leftPaneRef = useRef(null)
|
| 50 |
-
const rightPaneRef = useRef(null)
|
| 51 |
-
|
| 52 |
-
useEffect(() => {
|
| 53 |
-
if (!dragging) {
|
| 54 |
-
return undefined
|
| 55 |
-
}
|
| 56 |
-
|
| 57 |
-
const handlePointerMove = (event) => {
|
| 58 |
-
const bounds = workspaceRef.current?.getBoundingClientRect()
|
| 59 |
-
if (!bounds) {
|
| 60 |
-
return
|
| 61 |
-
}
|
| 62 |
-
|
| 63 |
-
const next = ((event.clientX - bounds.left) / bounds.width) * 100
|
| 64 |
-
const clamped = Math.min(75, Math.max(25, next))
|
| 65 |
-
setSplit(clamped)
|
| 66 |
-
}
|
| 67 |
-
|
| 68 |
-
const handlePointerUp = () => {
|
| 69 |
-
setDragging(false)
|
| 70 |
-
}
|
| 71 |
-
|
| 72 |
-
window.addEventListener('pointermove', handlePointerMove)
|
| 73 |
-
window.addEventListener('pointerup', handlePointerUp)
|
| 74 |
-
|
| 75 |
-
return () => {
|
| 76 |
-
window.removeEventListener('pointermove', handlePointerMove)
|
| 77 |
-
window.removeEventListener('pointerup', handlePointerUp)
|
| 78 |
-
}
|
| 79 |
-
}, [dragging])
|
| 80 |
-
|
| 81 |
-
const handleBroadcast = async (event) => {
|
| 82 |
-
event.preventDefault()
|
| 83 |
-
const value = command.trim()
|
| 84 |
-
if (!value) {
|
| 85 |
-
return
|
| 86 |
-
}
|
| 87 |
-
|
| 88 |
-
const issuedAt = Date.now()
|
| 89 |
-
setComparisonRun({
|
| 90 |
-
command: value,
|
| 91 |
-
issuedAt,
|
| 92 |
-
})
|
| 93 |
-
|
| 94 |
-
await Promise.allSettled([
|
| 95 |
-
leftPaneRef.current?.submit(value),
|
| 96 |
-
rightPaneRef.current?.submit(value),
|
| 97 |
-
])
|
| 98 |
-
setCommand('')
|
| 99 |
-
}
|
| 100 |
-
|
| 101 |
-
const handleTelemetryChange = (payload) => {
|
| 102 |
-
setTelemetry((previous) => ({
|
| 103 |
-
...previous,
|
| 104 |
-
[payload.jobId]: payload,
|
| 105 |
-
}))
|
| 106 |
-
}
|
| 107 |
-
|
| 108 |
-
const qwenStats = buildRunStats(telemetry.qwen, comparisonRun?.issuedAt)
|
| 109 |
-
const agentStats = buildRunStats(telemetry['rl-agent'], comparisonRun?.issuedAt)
|
| 110 |
-
|
| 111 |
-
let comparisonHeadline = 'Send a shared command to compare runtime.'
|
| 112 |
-
if (comparisonRun) {
|
| 113 |
-
if (qwenStats.completionMs != null && agentStats.completionMs != null) {
|
| 114 |
-
const fasterJob = qwenStats.completionMs <= agentStats.completionMs ? panes[0].title : panes[1].title
|
| 115 |
-
const delta = Math.abs(qwenStats.completionMs - agentStats.completionMs)
|
| 116 |
-
comparisonHeadline = `${fasterJob} finished ${formatMs(delta)} faster.`
|
| 117 |
-
} else if (qwenStats.responseMs != null && agentStats.responseMs != null) {
|
| 118 |
-
const fasterJob = qwenStats.responseMs <= agentStats.responseMs ? panes[0].title : panes[1].title
|
| 119 |
-
const delta = Math.abs(qwenStats.responseMs - agentStats.responseMs)
|
| 120 |
-
comparisonHeadline = `${fasterJob} responded ${formatMs(delta)} faster.`
|
| 121 |
-
} else {
|
| 122 |
-
comparisonHeadline = `Running shared command: ${comparisonRun.command}`
|
| 123 |
-
}
|
| 124 |
-
}
|
| 125 |
-
|
| 126 |
-
return (
|
| 127 |
-
<main className="desktop">
|
| 128 |
-
<div className="desktop__glow" />
|
| 129 |
-
|
| 130 |
-
<section className="comparison-bar">
|
| 131 |
-
<div className="comparison-bar__copy">
|
| 132 |
-
<span className="comparison-bar__eyebrow">Runtime compare</span>
|
| 133 |
-
<strong>{comparisonHeadline}</strong>
|
| 134 |
-
<small>{comparisonRun ? `Command: ${comparisonRun.command}` : 'Broadcast one command to both panes.'}</small>
|
| 135 |
-
</div>
|
| 136 |
-
|
| 137 |
-
<div className="comparison-bar__stats">
|
| 138 |
-
<article className="comparison-card comparison-card--cyan">
|
| 139 |
-
<span>{panes[0].title}</span>
|
| 140 |
-
<strong>{formatMs(qwenStats.completionMs ?? qwenStats.responseMs)}</strong>
|
| 141 |
-
<small>{qwenStats.completionMs != null ? 'completion time' : 'first output latency'}</small>
|
| 142 |
-
</article>
|
| 143 |
-
|
| 144 |
-
<article className="comparison-card comparison-card--green">
|
| 145 |
-
<span>{panes[1].title}</span>
|
| 146 |
-
<strong>{formatMs(agentStats.completionMs ?? agentStats.responseMs)}</strong>
|
| 147 |
-
<small>{agentStats.completionMs != null ? 'completion time' : 'first output latency'}</small>
|
| 148 |
-
</article>
|
| 149 |
-
</div>
|
| 150 |
-
</section>
|
| 151 |
-
|
| 152 |
-
<section ref={workspaceRef} className="workspace">
|
| 153 |
-
<div className="workspace__pane" style={{ width: `${split}%` }}>
|
| 154 |
-
<TerminalPane ref={leftPaneRef} {...panes[0]} onTelemetryChange={handleTelemetryChange} />
|
| 155 |
-
</div>
|
| 156 |
-
|
| 157 |
-
<button
|
| 158 |
-
type="button"
|
| 159 |
-
className={dragging ? 'workspace__divider is-dragging' : 'workspace__divider'}
|
| 160 |
-
onPointerDown={() => setDragging(true)}
|
| 161 |
-
aria-label="Resize terminal panes"
|
| 162 |
-
aria-valuemin={25}
|
| 163 |
-
aria-valuemax={75}
|
| 164 |
-
aria-valuenow={Math.round(split)}
|
| 165 |
-
aria-orientation="vertical"
|
| 166 |
-
>
|
| 167 |
-
<span />
|
| 168 |
-
</button>
|
| 169 |
-
|
| 170 |
-
<div className="workspace__pane" style={{ width: `${100 - split}%` }}>
|
| 171 |
-
<TerminalPane ref={rightPaneRef} {...panes[1]} onTelemetryChange={handleTelemetryChange} />
|
| 172 |
-
</div>
|
| 173 |
-
</section>
|
| 174 |
-
|
| 175 |
-
<form className="broadcast-bar" onSubmit={handleBroadcast}>
|
| 176 |
-
<label className="broadcast-bar__label" htmlFor="broadcast-input">
|
| 177 |
-
Shared input
|
| 178 |
-
</label>
|
| 179 |
-
<div className="broadcast-bar__field">
|
| 180 |
-
<span className="broadcast-bar__prompt">$</span>
|
| 181 |
-
<input
|
| 182 |
-
id="broadcast-input"
|
| 183 |
-
value={command}
|
| 184 |
-
onChange={(event) => setCommand(event.target.value)}
|
| 185 |
-
placeholder="Send the same command to both terminals"
|
| 186 |
-
spellCheck="false"
|
| 187 |
-
/>
|
| 188 |
-
<button type="submit">Send to both</button>
|
| 189 |
-
</div>
|
| 190 |
-
</form>
|
| 191 |
-
</main>
|
| 192 |
-
)
|
| 193 |
-
}
|
| 194 |
-
|
| 195 |
-
export default App
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
frontend/src/api/terminal.js
DELETED
|
@@ -1,57 +0,0 @@
|
|
| 1 |
-
const RAW_API_BASE = import.meta.env.VITE_API_BASE_URL || 'http://127.0.0.1:8000'
|
| 2 |
-
const API_BASE = RAW_API_BASE.replace(/\/+$/, '')
|
| 3 |
-
const WS_BASE = API_BASE.replace(/^http/, 'ws')
|
| 4 |
-
|
| 5 |
-
async function request(path, options = {}) {
|
| 6 |
-
const response = await fetch(`${API_BASE}${path}`, {
|
| 7 |
-
headers: {
|
| 8 |
-
'Content-Type': 'application/json',
|
| 9 |
-
...(options.headers || {}),
|
| 10 |
-
},
|
| 11 |
-
...options,
|
| 12 |
-
})
|
| 13 |
-
|
| 14 |
-
if (!response.ok) {
|
| 15 |
-
let message = `Request failed with status ${response.status}`
|
| 16 |
-
try {
|
| 17 |
-
const payload = await response.json()
|
| 18 |
-
message = payload.detail || message
|
| 19 |
-
} catch {
|
| 20 |
-
// Keep default error message when the payload is not JSON.
|
| 21 |
-
}
|
| 22 |
-
throw new Error(message)
|
| 23 |
-
}
|
| 24 |
-
|
| 25 |
-
return response.json()
|
| 26 |
-
}
|
| 27 |
-
|
| 28 |
-
export function createOrAttachSession(jobId, { restart = false } = {}) {
|
| 29 |
-
return request('/terminal/sessions', {
|
| 30 |
-
method: 'POST',
|
| 31 |
-
body: JSON.stringify({ job_id: jobId, restart }),
|
| 32 |
-
})
|
| 33 |
-
}
|
| 34 |
-
|
| 35 |
-
export function stopTerminalSession(sessionId) {
|
| 36 |
-
return request(`/terminal/sessions/${sessionId}/stop`, {
|
| 37 |
-
method: 'POST',
|
| 38 |
-
})
|
| 39 |
-
}
|
| 40 |
-
|
| 41 |
-
export function sendTerminalInput(sessionId, data, appendNewline = true) {
|
| 42 |
-
return request(`/terminal/sessions/${sessionId}/input`, {
|
| 43 |
-
method: 'POST',
|
| 44 |
-
body: JSON.stringify({ data, append_newline: appendNewline }),
|
| 45 |
-
})
|
| 46 |
-
}
|
| 47 |
-
|
| 48 |
-
export function resizeTerminalSession(sessionId, cols, rows) {
|
| 49 |
-
return request(`/terminal/sessions/${sessionId}/resize`, {
|
| 50 |
-
method: 'POST',
|
| 51 |
-
body: JSON.stringify({ cols, rows }),
|
| 52 |
-
})
|
| 53 |
-
}
|
| 54 |
-
|
| 55 |
-
export function openTerminalSocket(sessionId) {
|
| 56 |
-
return new WebSocket(`${WS_BASE}/terminal/sessions/${sessionId}/stream`)
|
| 57 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
frontend/src/assets/react.svg
DELETED
frontend/src/components/TerminalPane.jsx
DELETED
|
@@ -1,147 +0,0 @@
|
|
| 1 |
-
import { forwardRef, useEffect, useImperativeHandle, useMemo, useRef } from 'react'
|
| 2 |
-
import { useTerminalSession } from '../hooks/useTerminalSession'
|
| 3 |
-
|
| 4 |
-
function formatTime(timestamp) {
|
| 5 |
-
if (!timestamp) {
|
| 6 |
-
return 'Idle'
|
| 7 |
-
}
|
| 8 |
-
return new Date(timestamp * 1000).toLocaleTimeString([], {
|
| 9 |
-
hour: '2-digit',
|
| 10 |
-
minute: '2-digit',
|
| 11 |
-
second: '2-digit',
|
| 12 |
-
})
|
| 13 |
-
}
|
| 14 |
-
|
| 15 |
-
function statusLabel(status) {
|
| 16 |
-
if (status === 'running') {
|
| 17 |
-
return 'Running'
|
| 18 |
-
}
|
| 19 |
-
if (status === 'failed') {
|
| 20 |
-
return 'Failed'
|
| 21 |
-
}
|
| 22 |
-
if (status === 'exited') {
|
| 23 |
-
return 'Completed'
|
| 24 |
-
}
|
| 25 |
-
return 'Starting'
|
| 26 |
-
}
|
| 27 |
-
|
| 28 |
-
const TerminalPane = forwardRef(function TerminalPane({ jobId, title, tone, onTelemetryChange }, ref) {
|
| 29 |
-
const { session, buffer, connectionState, error, lastOutputAt, restart, resize, sendInput, start, stop } =
|
| 30 |
-
useTerminalSession(jobId)
|
| 31 |
-
const viewportRef = useRef(null)
|
| 32 |
-
const scrollRef = useRef(null)
|
| 33 |
-
|
| 34 |
-
useImperativeHandle(
|
| 35 |
-
ref,
|
| 36 |
-
() => ({
|
| 37 |
-
submit: async (value) => {
|
| 38 |
-
await sendInput(value, true)
|
| 39 |
-
},
|
| 40 |
-
}),
|
| 41 |
-
[sendInput],
|
| 42 |
-
)
|
| 43 |
-
|
| 44 |
-
useEffect(() => {
|
| 45 |
-
onTelemetryChange?.({
|
| 46 |
-
jobId,
|
| 47 |
-
session,
|
| 48 |
-
connectionState,
|
| 49 |
-
error,
|
| 50 |
-
lastOutputAt,
|
| 51 |
-
})
|
| 52 |
-
}, [connectionState, error, jobId, lastOutputAt, onTelemetryChange, session])
|
| 53 |
-
|
| 54 |
-
useEffect(() => {
|
| 55 |
-
const container = scrollRef.current
|
| 56 |
-
if (container) {
|
| 57 |
-
container.scrollTop = container.scrollHeight
|
| 58 |
-
}
|
| 59 |
-
}, [buffer])
|
| 60 |
-
|
| 61 |
-
useEffect(() => {
|
| 62 |
-
const element = viewportRef.current
|
| 63 |
-
if (!element) {
|
| 64 |
-
return undefined
|
| 65 |
-
}
|
| 66 |
-
|
| 67 |
-
let frameId = 0
|
| 68 |
-
const measure = () => {
|
| 69 |
-
cancelAnimationFrame(frameId)
|
| 70 |
-
frameId = requestAnimationFrame(() => {
|
| 71 |
-
const style = getComputedStyle(element)
|
| 72 |
-
const fontSize = parseFloat(style.fontSize) || 15
|
| 73 |
-
const lineHeight = parseFloat(style.lineHeight) || 24
|
| 74 |
-
const cols = Math.max(48, Math.floor(element.clientWidth / (fontSize * 0.61)))
|
| 75 |
-
const rows = Math.max(14, Math.floor(element.clientHeight / lineHeight))
|
| 76 |
-
resize(cols, rows)
|
| 77 |
-
})
|
| 78 |
-
}
|
| 79 |
-
|
| 80 |
-
measure()
|
| 81 |
-
const observer = new ResizeObserver(measure)
|
| 82 |
-
observer.observe(element)
|
| 83 |
-
|
| 84 |
-
return () => {
|
| 85 |
-
cancelAnimationFrame(frameId)
|
| 86 |
-
observer.disconnect()
|
| 87 |
-
}
|
| 88 |
-
}, [resize])
|
| 89 |
-
|
| 90 |
-
const footerMeta = useMemo(
|
| 91 |
-
() => [
|
| 92 |
-
session?.status ? statusLabel(session.status) : 'Connecting',
|
| 93 |
-
session?.started_at ? `Started ${formatTime(session.started_at)}` : null,
|
| 94 |
-
session?.exit_code != null ? `Exit ${session.exit_code}` : null,
|
| 95 |
-
connectionState === 'connected' ? 'WS live' : connectionState,
|
| 96 |
-
].filter(Boolean),
|
| 97 |
-
[connectionState, session],
|
| 98 |
-
)
|
| 99 |
-
|
| 100 |
-
return (
|
| 101 |
-
<article className={`terminal-pane terminal-pane--${tone}`}>
|
| 102 |
-
<header className="terminal-pane__header">
|
| 103 |
-
<div className="terminal-pane__heading">
|
| 104 |
-
<div className="terminal-pane__title-row">
|
| 105 |
-
<span className="terminal-pane__dot" />
|
| 106 |
-
<h2>{title}</h2>
|
| 107 |
-
<span className={`status-chip status-chip--${session?.status || 'starting'}`}>
|
| 108 |
-
{statusLabel(session?.status)}
|
| 109 |
-
</span>
|
| 110 |
-
</div>
|
| 111 |
-
<p>{session?.command || 'Waiting for backend session...'}</p>
|
| 112 |
-
<small>{session?.cwd || 'No working directory available yet.'}</small>
|
| 113 |
-
</div>
|
| 114 |
-
|
| 115 |
-
<div className="terminal-pane__actions">
|
| 116 |
-
<button type="button" onClick={start}>
|
| 117 |
-
Attach
|
| 118 |
-
</button>
|
| 119 |
-
<button type="button" onClick={restart}>
|
| 120 |
-
Restart
|
| 121 |
-
</button>
|
| 122 |
-
<button type="button" onClick={stop}>
|
| 123 |
-
Stop
|
| 124 |
-
</button>
|
| 125 |
-
</div>
|
| 126 |
-
</header>
|
| 127 |
-
|
| 128 |
-
<div ref={viewportRef} className="terminal-pane__viewport">
|
| 129 |
-
<div ref={scrollRef} className="terminal-pane__scroll">
|
| 130 |
-
<pre className="terminal-pane__buffer">{buffer || 'Starting session...\n'}</pre>
|
| 131 |
-
{session?.status === 'running' ? <span className="terminal-pane__cursor" aria-hidden="true" /> : null}
|
| 132 |
-
</div>
|
| 133 |
-
</div>
|
| 134 |
-
|
| 135 |
-
<footer className="terminal-pane__footer">
|
| 136 |
-
<div className="terminal-pane__meta">
|
| 137 |
-
{footerMeta.map((item) => (
|
| 138 |
-
<span key={item}>{item}</span>
|
| 139 |
-
))}
|
| 140 |
-
{error ? <span className="terminal-pane__error">{error}</span> : null}
|
| 141 |
-
</div>
|
| 142 |
-
</footer>
|
| 143 |
-
</article>
|
| 144 |
-
)
|
| 145 |
-
})
|
| 146 |
-
|
| 147 |
-
export default TerminalPane
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
frontend/src/hooks/useTerminalSession.js
DELETED
|
@@ -1,166 +0,0 @@
|
|
| 1 |
-
import { useCallback, useEffect, useRef, useState } from 'react'
|
| 2 |
-
import {
|
| 3 |
-
createOrAttachSession,
|
| 4 |
-
openTerminalSocket,
|
| 5 |
-
resizeTerminalSession,
|
| 6 |
-
sendTerminalInput,
|
| 7 |
-
stopTerminalSession,
|
| 8 |
-
} from '../api/terminal'
|
| 9 |
-
|
| 10 |
-
const BUFFER_LIMIT = 160000
|
| 11 |
-
|
| 12 |
-
function trimBuffer(text) {
|
| 13 |
-
return text.length > BUFFER_LIMIT ? text.slice(-BUFFER_LIMIT) : text
|
| 14 |
-
}
|
| 15 |
-
|
| 16 |
-
export function useTerminalSession(jobId) {
|
| 17 |
-
const [session, setSession] = useState(null)
|
| 18 |
-
const [buffer, setBuffer] = useState('')
|
| 19 |
-
const [connectionState, setConnectionState] = useState('connecting')
|
| 20 |
-
const [error, setError] = useState('')
|
| 21 |
-
const [lastOutputAt, setLastOutputAt] = useState(null)
|
| 22 |
-
|
| 23 |
-
const socketRef = useRef(null)
|
| 24 |
-
const resizeRef = useRef({ cols: null, rows: null })
|
| 25 |
-
|
| 26 |
-
const attachSocket = useCallback((sessionId) => {
|
| 27 |
-
if (socketRef.current) {
|
| 28 |
-
socketRef.current.close()
|
| 29 |
-
}
|
| 30 |
-
|
| 31 |
-
const socket = openTerminalSocket(sessionId)
|
| 32 |
-
socketRef.current = socket
|
| 33 |
-
setConnectionState('connecting')
|
| 34 |
-
|
| 35 |
-
socket.addEventListener('open', () => {
|
| 36 |
-
setConnectionState('connected')
|
| 37 |
-
})
|
| 38 |
-
|
| 39 |
-
socket.addEventListener('message', (event) => {
|
| 40 |
-
const payload = JSON.parse(event.data)
|
| 41 |
-
|
| 42 |
-
if (payload.type === 'snapshot') {
|
| 43 |
-
setSession(payload.session)
|
| 44 |
-
setBuffer(payload.buffer || '')
|
| 45 |
-
return
|
| 46 |
-
}
|
| 47 |
-
|
| 48 |
-
if (payload.type === 'output') {
|
| 49 |
-
setLastOutputAt(Date.now())
|
| 50 |
-
setBuffer((previous) => trimBuffer(previous + payload.data))
|
| 51 |
-
return
|
| 52 |
-
}
|
| 53 |
-
|
| 54 |
-
if (payload.type === 'exit') {
|
| 55 |
-
setSession((previous) =>
|
| 56 |
-
previous
|
| 57 |
-
? {
|
| 58 |
-
...previous,
|
| 59 |
-
status: payload.status,
|
| 60 |
-
exit_code: payload.exit_code,
|
| 61 |
-
finished_at: payload.finished_at,
|
| 62 |
-
}
|
| 63 |
-
: previous,
|
| 64 |
-
)
|
| 65 |
-
}
|
| 66 |
-
})
|
| 67 |
-
|
| 68 |
-
socket.addEventListener('close', () => {
|
| 69 |
-
setConnectionState('disconnected')
|
| 70 |
-
})
|
| 71 |
-
|
| 72 |
-
socket.addEventListener('error', () => {
|
| 73 |
-
setConnectionState('error')
|
| 74 |
-
})
|
| 75 |
-
}, [])
|
| 76 |
-
|
| 77 |
-
const bootSession = useCallback(
|
| 78 |
-
async (restart = false) => {
|
| 79 |
-
try {
|
| 80 |
-
setError('')
|
| 81 |
-
const payload = await createOrAttachSession(jobId, { restart })
|
| 82 |
-
setSession(payload.session)
|
| 83 |
-
setBuffer(payload.buffer || '')
|
| 84 |
-
attachSocket(payload.session.id)
|
| 85 |
-
} catch (caughtError) {
|
| 86 |
-
setError(caughtError.message)
|
| 87 |
-
setConnectionState('error')
|
| 88 |
-
}
|
| 89 |
-
},
|
| 90 |
-
[attachSocket, jobId],
|
| 91 |
-
)
|
| 92 |
-
|
| 93 |
-
useEffect(() => {
|
| 94 |
-
const timeoutId = window.setTimeout(() => {
|
| 95 |
-
void bootSession(false)
|
| 96 |
-
}, 0)
|
| 97 |
-
|
| 98 |
-
return () => {
|
| 99 |
-
window.clearTimeout(timeoutId)
|
| 100 |
-
if (socketRef.current) {
|
| 101 |
-
socketRef.current.close()
|
| 102 |
-
}
|
| 103 |
-
}
|
| 104 |
-
}, [bootSession])
|
| 105 |
-
|
| 106 |
-
const restart = useCallback(() => bootSession(true), [bootSession])
|
| 107 |
-
|
| 108 |
-
const stop = useCallback(async () => {
|
| 109 |
-
if (!session?.id) {
|
| 110 |
-
return
|
| 111 |
-
}
|
| 112 |
-
try {
|
| 113 |
-
await stopTerminalSession(session.id)
|
| 114 |
-
} catch (caughtError) {
|
| 115 |
-
setError(caughtError.message)
|
| 116 |
-
}
|
| 117 |
-
}, [session])
|
| 118 |
-
|
| 119 |
-
const sendInput = useCallback(
|
| 120 |
-
async (value, appendNewline = true) => {
|
| 121 |
-
if (!session?.id || !value.trim()) {
|
| 122 |
-
return
|
| 123 |
-
}
|
| 124 |
-
try {
|
| 125 |
-
await sendTerminalInput(session.id, value, appendNewline)
|
| 126 |
-
} catch (caughtError) {
|
| 127 |
-
setError(caughtError.message)
|
| 128 |
-
}
|
| 129 |
-
},
|
| 130 |
-
[session],
|
| 131 |
-
)
|
| 132 |
-
|
| 133 |
-
const resize = useCallback(
|
| 134 |
-
async (cols, rows) => {
|
| 135 |
-
if (!session?.id) {
|
| 136 |
-
return
|
| 137 |
-
}
|
| 138 |
-
|
| 139 |
-
const previous = resizeRef.current
|
| 140 |
-
if (previous.cols === cols && previous.rows === rows) {
|
| 141 |
-
return
|
| 142 |
-
}
|
| 143 |
-
resizeRef.current = { cols, rows }
|
| 144 |
-
|
| 145 |
-
try {
|
| 146 |
-
await resizeTerminalSession(session.id, cols, rows)
|
| 147 |
-
} catch {
|
| 148 |
-
// Ignore resize errors so rendering stays responsive.
|
| 149 |
-
}
|
| 150 |
-
},
|
| 151 |
-
[session],
|
| 152 |
-
)
|
| 153 |
-
|
| 154 |
-
return {
|
| 155 |
-
buffer,
|
| 156 |
-
connectionState,
|
| 157 |
-
error,
|
| 158 |
-
lastOutputAt,
|
| 159 |
-
restart,
|
| 160 |
-
resize,
|
| 161 |
-
sendInput,
|
| 162 |
-
session,
|
| 163 |
-
start: () => bootSession(false),
|
| 164 |
-
stop,
|
| 165 |
-
}
|
| 166 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
frontend/src/index.css
DELETED
|
@@ -1,486 +0,0 @@
|
|
| 1 |
-
:root {
|
| 2 |
-
color: #ebeff5;
|
| 3 |
-
background:
|
| 4 |
-
radial-gradient(circle at top, rgba(88, 108, 146, 0.16), transparent 24%),
|
| 5 |
-
linear-gradient(180deg, #0d1015 0%, #090b0f 100%);
|
| 6 |
-
font-synthesis: none;
|
| 7 |
-
text-rendering: optimizeLegibility;
|
| 8 |
-
-webkit-font-smoothing: antialiased;
|
| 9 |
-
-moz-osx-font-smoothing: grayscale;
|
| 10 |
-
--page: #090b0f;
|
| 11 |
-
--pane: rgba(14, 17, 22, 0.98);
|
| 12 |
-
--pane-border: rgba(255, 255, 255, 0.07);
|
| 13 |
-
--pane-soft: rgba(255, 255, 255, 0.03);
|
| 14 |
-
--text: #ebeff5;
|
| 15 |
-
--muted: #8f97a6;
|
| 16 |
-
--cyan: #8fd1ff;
|
| 17 |
-
--green: #93e09f;
|
| 18 |
-
--red: #ff7f8c;
|
| 19 |
-
--amber: #e6c171;
|
| 20 |
-
}
|
| 21 |
-
|
| 22 |
-
* {
|
| 23 |
-
box-sizing: border-box;
|
| 24 |
-
}
|
| 25 |
-
|
| 26 |
-
html,
|
| 27 |
-
body,
|
| 28 |
-
#root {
|
| 29 |
-
min-height: 100%;
|
| 30 |
-
}
|
| 31 |
-
|
| 32 |
-
body {
|
| 33 |
-
margin: 0;
|
| 34 |
-
min-width: 320px;
|
| 35 |
-
background: var(--page);
|
| 36 |
-
color: var(--text);
|
| 37 |
-
font-family: "IBM Plex Sans", "SF Pro Display", "Segoe UI", sans-serif;
|
| 38 |
-
}
|
| 39 |
-
|
| 40 |
-
button,
|
| 41 |
-
input {
|
| 42 |
-
font: inherit;
|
| 43 |
-
}
|
| 44 |
-
|
| 45 |
-
button {
|
| 46 |
-
cursor: pointer;
|
| 47 |
-
}
|
| 48 |
-
|
| 49 |
-
.desktop {
|
| 50 |
-
position: relative;
|
| 51 |
-
min-height: 100vh;
|
| 52 |
-
overflow: hidden;
|
| 53 |
-
}
|
| 54 |
-
|
| 55 |
-
.desktop__glow {
|
| 56 |
-
position: absolute;
|
| 57 |
-
inset: 0;
|
| 58 |
-
background:
|
| 59 |
-
radial-gradient(circle at 30% 0%, rgba(109, 132, 180, 0.14), transparent 20%),
|
| 60 |
-
radial-gradient(circle at 100% 0%, rgba(70, 110, 168, 0.12), transparent 16%);
|
| 61 |
-
pointer-events: none;
|
| 62 |
-
}
|
| 63 |
-
|
| 64 |
-
.workspace {
|
| 65 |
-
position: relative;
|
| 66 |
-
z-index: 1;
|
| 67 |
-
display: flex;
|
| 68 |
-
width: 100vw;
|
| 69 |
-
height: calc(100vh - 194px);
|
| 70 |
-
background: rgba(7, 10, 13, 0.98);
|
| 71 |
-
}
|
| 72 |
-
|
| 73 |
-
.comparison-bar {
|
| 74 |
-
position: relative;
|
| 75 |
-
z-index: 1;
|
| 76 |
-
display: flex;
|
| 77 |
-
align-items: stretch;
|
| 78 |
-
justify-content: space-between;
|
| 79 |
-
gap: 18px;
|
| 80 |
-
padding: 14px 18px;
|
| 81 |
-
border-bottom: 1px solid var(--pane-border);
|
| 82 |
-
background:
|
| 83 |
-
linear-gradient(180deg, rgba(16, 19, 25, 0.98), rgba(11, 13, 18, 0.98));
|
| 84 |
-
}
|
| 85 |
-
|
| 86 |
-
.comparison-bar__copy {
|
| 87 |
-
display: grid;
|
| 88 |
-
gap: 5px;
|
| 89 |
-
min-width: 0;
|
| 90 |
-
}
|
| 91 |
-
|
| 92 |
-
.comparison-bar__copy strong,
|
| 93 |
-
.comparison-bar__copy small {
|
| 94 |
-
overflow: hidden;
|
| 95 |
-
text-overflow: ellipsis;
|
| 96 |
-
white-space: nowrap;
|
| 97 |
-
}
|
| 98 |
-
|
| 99 |
-
.comparison-bar__eyebrow {
|
| 100 |
-
color: var(--muted);
|
| 101 |
-
font:
|
| 102 |
-
500 0.74rem/1 "JetBrains Mono", "SFMono-Regular", Consolas, monospace;
|
| 103 |
-
text-transform: uppercase;
|
| 104 |
-
letter-spacing: 0.14em;
|
| 105 |
-
}
|
| 106 |
-
|
| 107 |
-
.comparison-bar__copy strong {
|
| 108 |
-
font-size: 1rem;
|
| 109 |
-
font-weight: 600;
|
| 110 |
-
}
|
| 111 |
-
|
| 112 |
-
.comparison-bar__copy small {
|
| 113 |
-
color: var(--muted);
|
| 114 |
-
}
|
| 115 |
-
|
| 116 |
-
.comparison-bar__stats {
|
| 117 |
-
display: flex;
|
| 118 |
-
gap: 12px;
|
| 119 |
-
}
|
| 120 |
-
|
| 121 |
-
.comparison-card {
|
| 122 |
-
display: grid;
|
| 123 |
-
gap: 4px;
|
| 124 |
-
min-width: 170px;
|
| 125 |
-
padding: 12px 14px;
|
| 126 |
-
border: 1px solid var(--pane-border);
|
| 127 |
-
border-radius: 14px;
|
| 128 |
-
background: rgba(255, 255, 255, 0.03);
|
| 129 |
-
}
|
| 130 |
-
|
| 131 |
-
.comparison-card span,
|
| 132 |
-
.comparison-card small {
|
| 133 |
-
color: var(--muted);
|
| 134 |
-
}
|
| 135 |
-
|
| 136 |
-
.comparison-card strong {
|
| 137 |
-
font:
|
| 138 |
-
600 1.2rem/1.1 "JetBrains Mono", "SFMono-Regular", Consolas, monospace;
|
| 139 |
-
}
|
| 140 |
-
|
| 141 |
-
.comparison-card--cyan strong {
|
| 142 |
-
color: var(--cyan);
|
| 143 |
-
}
|
| 144 |
-
|
| 145 |
-
.comparison-card--green strong {
|
| 146 |
-
color: var(--green);
|
| 147 |
-
}
|
| 148 |
-
|
| 149 |
-
.workspace__pane {
|
| 150 |
-
min-width: 0;
|
| 151 |
-
height: 100%;
|
| 152 |
-
}
|
| 153 |
-
|
| 154 |
-
.workspace__divider {
|
| 155 |
-
position: relative;
|
| 156 |
-
flex: 0 0 12px;
|
| 157 |
-
border: 0;
|
| 158 |
-
padding: 0;
|
| 159 |
-
background: linear-gradient(180deg, rgba(255, 255, 255, 0.04), rgba(255, 255, 255, 0.015));
|
| 160 |
-
cursor: col-resize;
|
| 161 |
-
}
|
| 162 |
-
|
| 163 |
-
.workspace__divider span {
|
| 164 |
-
position: absolute;
|
| 165 |
-
top: 50%;
|
| 166 |
-
left: 50%;
|
| 167 |
-
width: 3px;
|
| 168 |
-
height: 72px;
|
| 169 |
-
border-radius: 999px;
|
| 170 |
-
background: rgba(255, 255, 255, 0.22);
|
| 171 |
-
transform: translate(-50%, -50%);
|
| 172 |
-
}
|
| 173 |
-
|
| 174 |
-
.workspace__divider:hover span,
|
| 175 |
-
.workspace__divider.is-dragging span {
|
| 176 |
-
background: rgba(143, 209, 255, 0.8);
|
| 177 |
-
}
|
| 178 |
-
|
| 179 |
-
.workspace__divider:focus-visible {
|
| 180 |
-
outline: none;
|
| 181 |
-
box-shadow: inset 0 0 0 1px rgba(143, 209, 255, 0.5);
|
| 182 |
-
}
|
| 183 |
-
|
| 184 |
-
.terminal-pane {
|
| 185 |
-
display: grid;
|
| 186 |
-
grid-template-rows: auto minmax(0, 1fr) auto;
|
| 187 |
-
width: 100%;
|
| 188 |
-
height: 100%;
|
| 189 |
-
background:
|
| 190 |
-
linear-gradient(180deg, rgba(255, 255, 255, 0.015), transparent 10%),
|
| 191 |
-
var(--pane);
|
| 192 |
-
}
|
| 193 |
-
|
| 194 |
-
.terminal-pane + .terminal-pane {
|
| 195 |
-
border-left: 1px solid var(--pane-border);
|
| 196 |
-
}
|
| 197 |
-
|
| 198 |
-
.terminal-pane__header {
|
| 199 |
-
display: flex;
|
| 200 |
-
align-items: flex-start;
|
| 201 |
-
justify-content: space-between;
|
| 202 |
-
gap: 18px;
|
| 203 |
-
padding: 18px 18px 14px;
|
| 204 |
-
border-bottom: 1px solid var(--pane-border);
|
| 205 |
-
background: rgba(255, 255, 255, 0.02);
|
| 206 |
-
}
|
| 207 |
-
|
| 208 |
-
.terminal-pane__heading {
|
| 209 |
-
min-width: 0;
|
| 210 |
-
}
|
| 211 |
-
|
| 212 |
-
.terminal-pane__title-row {
|
| 213 |
-
display: flex;
|
| 214 |
-
align-items: center;
|
| 215 |
-
gap: 10px;
|
| 216 |
-
margin-bottom: 6px;
|
| 217 |
-
}
|
| 218 |
-
|
| 219 |
-
.terminal-pane__title-row h2 {
|
| 220 |
-
margin: 0;
|
| 221 |
-
font-size: 1rem;
|
| 222 |
-
font-weight: 600;
|
| 223 |
-
letter-spacing: 0.01em;
|
| 224 |
-
}
|
| 225 |
-
|
| 226 |
-
.terminal-pane__heading p,
|
| 227 |
-
.terminal-pane__heading small {
|
| 228 |
-
display: block;
|
| 229 |
-
margin: 0;
|
| 230 |
-
overflow: hidden;
|
| 231 |
-
text-overflow: ellipsis;
|
| 232 |
-
white-space: nowrap;
|
| 233 |
-
}
|
| 234 |
-
|
| 235 |
-
.terminal-pane__heading p {
|
| 236 |
-
color: var(--text);
|
| 237 |
-
font:
|
| 238 |
-
500 0.84rem/1.4 "JetBrains Mono", "SFMono-Regular", Consolas, monospace;
|
| 239 |
-
}
|
| 240 |
-
|
| 241 |
-
.terminal-pane__heading small {
|
| 242 |
-
margin-top: 4px;
|
| 243 |
-
color: var(--muted);
|
| 244 |
-
}
|
| 245 |
-
|
| 246 |
-
.terminal-pane__dot {
|
| 247 |
-
width: 9px;
|
| 248 |
-
height: 9px;
|
| 249 |
-
border-radius: 999px;
|
| 250 |
-
background: var(--cyan);
|
| 251 |
-
box-shadow: 0 0 24px rgba(143, 209, 255, 0.35);
|
| 252 |
-
}
|
| 253 |
-
|
| 254 |
-
.terminal-pane--green .terminal-pane__dot {
|
| 255 |
-
background: var(--green);
|
| 256 |
-
box-shadow: 0 0 24px rgba(147, 224, 159, 0.35);
|
| 257 |
-
}
|
| 258 |
-
|
| 259 |
-
.terminal-pane__actions {
|
| 260 |
-
display: flex;
|
| 261 |
-
gap: 8px;
|
| 262 |
-
}
|
| 263 |
-
|
| 264 |
-
.terminal-pane__actions button {
|
| 265 |
-
padding: 8px 10px;
|
| 266 |
-
border: 1px solid var(--pane-border);
|
| 267 |
-
border-radius: 10px;
|
| 268 |
-
background: rgba(255, 255, 255, 0.03);
|
| 269 |
-
color: var(--text);
|
| 270 |
-
}
|
| 271 |
-
|
| 272 |
-
.terminal-pane__actions button:hover {
|
| 273 |
-
background: rgba(255, 255, 255, 0.06);
|
| 274 |
-
}
|
| 275 |
-
|
| 276 |
-
.status-chip {
|
| 277 |
-
padding: 4px 8px;
|
| 278 |
-
border-radius: 999px;
|
| 279 |
-
font-size: 0.7rem;
|
| 280 |
-
text-transform: uppercase;
|
| 281 |
-
letter-spacing: 0.12em;
|
| 282 |
-
border: 1px solid transparent;
|
| 283 |
-
}
|
| 284 |
-
|
| 285 |
-
.status-chip--running {
|
| 286 |
-
color: var(--green);
|
| 287 |
-
border-color: rgba(147, 224, 159, 0.24);
|
| 288 |
-
background: rgba(147, 224, 159, 0.08);
|
| 289 |
-
}
|
| 290 |
-
|
| 291 |
-
.status-chip--failed {
|
| 292 |
-
color: var(--red);
|
| 293 |
-
border-color: rgba(255, 127, 140, 0.25);
|
| 294 |
-
background: rgba(255, 127, 140, 0.08);
|
| 295 |
-
}
|
| 296 |
-
|
| 297 |
-
.status-chip--exited {
|
| 298 |
-
color: var(--amber);
|
| 299 |
-
border-color: rgba(230, 193, 113, 0.24);
|
| 300 |
-
background: rgba(230, 193, 113, 0.08);
|
| 301 |
-
}
|
| 302 |
-
|
| 303 |
-
.status-chip--starting {
|
| 304 |
-
color: var(--cyan);
|
| 305 |
-
border-color: rgba(143, 209, 255, 0.24);
|
| 306 |
-
background: rgba(143, 209, 255, 0.08);
|
| 307 |
-
}
|
| 308 |
-
|
| 309 |
-
.terminal-pane__viewport {
|
| 310 |
-
min-height: 0;
|
| 311 |
-
overflow: hidden;
|
| 312 |
-
font:
|
| 313 |
-
500 0.95rem/1.72 "JetBrains Mono", "SFMono-Regular", Consolas, monospace;
|
| 314 |
-
}
|
| 315 |
-
|
| 316 |
-
.terminal-pane__scroll {
|
| 317 |
-
height: 100%;
|
| 318 |
-
overflow: auto;
|
| 319 |
-
padding: 18px 18px 20px;
|
| 320 |
-
}
|
| 321 |
-
|
| 322 |
-
.terminal-pane__buffer {
|
| 323 |
-
margin: 0;
|
| 324 |
-
white-space: pre-wrap;
|
| 325 |
-
word-break: break-word;
|
| 326 |
-
color: #edf2fa;
|
| 327 |
-
}
|
| 328 |
-
|
| 329 |
-
.terminal-pane__cursor {
|
| 330 |
-
display: inline-block;
|
| 331 |
-
width: 0.62em;
|
| 332 |
-
height: 1.05em;
|
| 333 |
-
margin-top: 4px;
|
| 334 |
-
background: rgba(237, 242, 250, 0.9);
|
| 335 |
-
animation: blink 1s steps(1) infinite;
|
| 336 |
-
}
|
| 337 |
-
|
| 338 |
-
.terminal-pane__footer {
|
| 339 |
-
display: grid;
|
| 340 |
-
padding: 14px 18px 16px;
|
| 341 |
-
border-top: 1px solid var(--pane-border);
|
| 342 |
-
background: rgba(255, 255, 255, 0.02);
|
| 343 |
-
}
|
| 344 |
-
|
| 345 |
-
.terminal-pane__meta {
|
| 346 |
-
display: flex;
|
| 347 |
-
flex-wrap: wrap;
|
| 348 |
-
gap: 10px 14px;
|
| 349 |
-
color: var(--muted);
|
| 350 |
-
font:
|
| 351 |
-
500 0.76rem/1.3 "JetBrains Mono", "SFMono-Regular", Consolas, monospace;
|
| 352 |
-
}
|
| 353 |
-
|
| 354 |
-
.terminal-pane__error {
|
| 355 |
-
color: var(--red);
|
| 356 |
-
}
|
| 357 |
-
|
| 358 |
-
.broadcast-bar {
|
| 359 |
-
position: relative;
|
| 360 |
-
z-index: 1;
|
| 361 |
-
display: grid;
|
| 362 |
-
gap: 8px;
|
| 363 |
-
padding: 14px 18px 18px;
|
| 364 |
-
border-top: 1px solid var(--pane-border);
|
| 365 |
-
background:
|
| 366 |
-
linear-gradient(180deg, rgba(18, 21, 27, 0.98), rgba(10, 12, 16, 0.98));
|
| 367 |
-
}
|
| 368 |
-
|
| 369 |
-
.broadcast-bar__label {
|
| 370 |
-
color: var(--muted);
|
| 371 |
-
font:
|
| 372 |
-
500 0.78rem/1 "JetBrains Mono", "SFMono-Regular", Consolas, monospace;
|
| 373 |
-
text-transform: uppercase;
|
| 374 |
-
letter-spacing: 0.12em;
|
| 375 |
-
}
|
| 376 |
-
|
| 377 |
-
.broadcast-bar__field {
|
| 378 |
-
display: flex;
|
| 379 |
-
align-items: center;
|
| 380 |
-
gap: 12px;
|
| 381 |
-
padding: 14px 16px;
|
| 382 |
-
border: 1px solid var(--pane-border);
|
| 383 |
-
border-radius: 14px;
|
| 384 |
-
background: rgba(255, 255, 255, 0.03);
|
| 385 |
-
box-shadow: inset 0 1px 0 rgba(255, 255, 255, 0.02);
|
| 386 |
-
}
|
| 387 |
-
|
| 388 |
-
.broadcast-bar__prompt {
|
| 389 |
-
color: var(--cyan);
|
| 390 |
-
font:
|
| 391 |
-
600 0.92rem/1 "JetBrains Mono", "SFMono-Regular", Consolas, monospace;
|
| 392 |
-
}
|
| 393 |
-
|
| 394 |
-
.broadcast-bar__field input {
|
| 395 |
-
flex: 1;
|
| 396 |
-
min-width: 0;
|
| 397 |
-
border: 0;
|
| 398 |
-
background: transparent;
|
| 399 |
-
color: var(--text);
|
| 400 |
-
outline: none;
|
| 401 |
-
font:
|
| 402 |
-
500 0.95rem/1.2 "JetBrains Mono", "SFMono-Regular", Consolas, monospace;
|
| 403 |
-
}
|
| 404 |
-
|
| 405 |
-
.broadcast-bar__field input::placeholder {
|
| 406 |
-
color: var(--muted);
|
| 407 |
-
}
|
| 408 |
-
|
| 409 |
-
.broadcast-bar__field button {
|
| 410 |
-
padding: 10px 14px;
|
| 411 |
-
border: 1px solid rgba(143, 209, 255, 0.18);
|
| 412 |
-
border-radius: 10px;
|
| 413 |
-
background: rgba(143, 209, 255, 0.08);
|
| 414 |
-
color: var(--text);
|
| 415 |
-
}
|
| 416 |
-
|
| 417 |
-
.broadcast-bar__field button:hover {
|
| 418 |
-
background: rgba(143, 209, 255, 0.14);
|
| 419 |
-
}
|
| 420 |
-
|
| 421 |
-
@keyframes blink {
|
| 422 |
-
50% {
|
| 423 |
-
opacity: 0;
|
| 424 |
-
}
|
| 425 |
-
}
|
| 426 |
-
|
| 427 |
-
@media (max-width: 980px) {
|
| 428 |
-
.comparison-bar {
|
| 429 |
-
flex-direction: column;
|
| 430 |
-
}
|
| 431 |
-
|
| 432 |
-
.comparison-bar__stats {
|
| 433 |
-
width: 100%;
|
| 434 |
-
}
|
| 435 |
-
|
| 436 |
-
.comparison-card {
|
| 437 |
-
flex: 1;
|
| 438 |
-
min-width: 0;
|
| 439 |
-
}
|
| 440 |
-
|
| 441 |
-
.terminal-pane__header {
|
| 442 |
-
flex-direction: column;
|
| 443 |
-
}
|
| 444 |
-
|
| 445 |
-
.terminal-pane__actions {
|
| 446 |
-
width: 100%;
|
| 447 |
-
}
|
| 448 |
-
|
| 449 |
-
.terminal-pane__actions button {
|
| 450 |
-
flex: 1;
|
| 451 |
-
}
|
| 452 |
-
}
|
| 453 |
-
|
| 454 |
-
@media (max-width: 780px) {
|
| 455 |
-
.workspace {
|
| 456 |
-
height: calc(100vh - 244px);
|
| 457 |
-
}
|
| 458 |
-
|
| 459 |
-
.workspace__divider {
|
| 460 |
-
flex-basis: 10px;
|
| 461 |
-
}
|
| 462 |
-
|
| 463 |
-
.terminal-pane__scroll {
|
| 464 |
-
padding: 16px;
|
| 465 |
-
}
|
| 466 |
-
|
| 467 |
-
.terminal-pane__viewport {
|
| 468 |
-
font-size: 0.88rem;
|
| 469 |
-
}
|
| 470 |
-
|
| 471 |
-
.terminal-pane__footer {
|
| 472 |
-
padding: 12px 14px 14px;
|
| 473 |
-
}
|
| 474 |
-
|
| 475 |
-
.broadcast-bar {
|
| 476 |
-
padding: 12px 14px 14px;
|
| 477 |
-
}
|
| 478 |
-
|
| 479 |
-
.broadcast-bar__field {
|
| 480 |
-
padding: 12px 14px;
|
| 481 |
-
}
|
| 482 |
-
|
| 483 |
-
.comparison-bar {
|
| 484 |
-
padding: 12px 14px;
|
| 485 |
-
}
|
| 486 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
frontend/src/main.jsx
DELETED
|
@@ -1,10 +0,0 @@
|
|
| 1 |
-
import { StrictMode } from 'react'
|
| 2 |
-
import { createRoot } from 'react-dom/client'
|
| 3 |
-
import './index.css'
|
| 4 |
-
import App from './App.jsx'
|
| 5 |
-
|
| 6 |
-
createRoot(document.getElementById('root')).render(
|
| 7 |
-
<StrictMode>
|
| 8 |
-
<App />
|
| 9 |
-
</StrictMode>,
|
| 10 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
frontend/vite.config.js
DELETED
|
@@ -1,7 +0,0 @@
|
|
| 1 |
-
import { defineConfig } from 'vite'
|
| 2 |
-
import react from '@vitejs/plugin-react'
|
| 3 |
-
|
| 4 |
-
// https://vite.dev/config/
|
| 5 |
-
export default defineConfig({
|
| 6 |
-
plugins: [react()],
|
| 7 |
-
})
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
index.html
ADDED
|
@@ -0,0 +1,317 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<!DOCTYPE html>
|
| 2 |
+
<html lang="en">
|
| 3 |
+
<head>
|
| 4 |
+
<meta charset="UTF-8">
|
| 5 |
+
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
| 6 |
+
<title>RL Surrogate Kernel Autotuning</title>
|
| 7 |
+
<link rel="stylesheet" href="https://fonts.googleapis.com/css2?family=Inter:wght@300;400;500;600;700&family=JetBrains+Mono:wght@400;500&display=swap">
|
| 8 |
+
<style>
|
| 9 |
+
:root {
|
| 10 |
+
--primary: #2563eb;
|
| 11 |
+
--primary-dark: #1d4ed8;
|
| 12 |
+
--secondary: #10b981;
|
| 13 |
+
--accent: #f59e0b;
|
| 14 |
+
--bg: #0f172a;
|
| 15 |
+
--bg-card: #1e293b;
|
| 16 |
+
--text: #f1f5f9;
|
| 17 |
+
--text-muted: #94a3b8;
|
| 18 |
+
--border: #334155;
|
| 19 |
+
}
|
| 20 |
+
* {
|
| 21 |
+
margin: 0;
|
| 22 |
+
padding: 0;
|
| 23 |
+
box-sizing: border-box;
|
| 24 |
+
}
|
| 25 |
+
body {
|
| 26 |
+
font-family: 'Inter', -apple-system, BlinkMacSystemFont, sans-serif;
|
| 27 |
+
background: var(--bg);
|
| 28 |
+
color: var(--text);
|
| 29 |
+
line-height: 1.6;
|
| 30 |
+
min-height: 100vh;
|
| 31 |
+
}
|
| 32 |
+
.container {
|
| 33 |
+
max-width: 1200px;
|
| 34 |
+
margin: 0 auto;
|
| 35 |
+
padding: 2rem;
|
| 36 |
+
}
|
| 37 |
+
header {
|
| 38 |
+
text-align: center;
|
| 39 |
+
padding: 4rem 0;
|
| 40 |
+
position: relative;
|
| 41 |
+
}
|
| 42 |
+
header::before {
|
| 43 |
+
content: '';
|
| 44 |
+
position: absolute;
|
| 45 |
+
top: 0;
|
| 46 |
+
left: 50%;
|
| 47 |
+
transform: translateX(-50%);
|
| 48 |
+
width: 600px;
|
| 49 |
+
height: 600px;
|
| 50 |
+
background: radial-gradient(circle, rgba(37, 99, 235, 0.15) 0%, transparent 70%);
|
| 51 |
+
pointer-events: none;
|
| 52 |
+
}
|
| 53 |
+
h1 {
|
| 54 |
+
font-size: 3rem;
|
| 55 |
+
font-weight: 700;
|
| 56 |
+
margin-bottom: 1rem;
|
| 57 |
+
background: linear-gradient(135deg, #60a5fa 0%, #a78bfa 100%);
|
| 58 |
+
-webkit-background-clip: text;
|
| 59 |
+
-webkit-text-fill-color: transparent;
|
| 60 |
+
background-clip: text;
|
| 61 |
+
}
|
| 62 |
+
.subtitle {
|
| 63 |
+
font-size: 1.25rem;
|
| 64 |
+
color: var(--text-muted);
|
| 65 |
+
max-width: 600px;
|
| 66 |
+
margin: 0 auto;
|
| 67 |
+
}
|
| 68 |
+
.badge {
|
| 69 |
+
display: inline-block;
|
| 70 |
+
padding: 0.5rem 1rem;
|
| 71 |
+
background: var(--primary);
|
| 72 |
+
border-radius: 9999px;
|
| 73 |
+
font-size: 0.875rem;
|
| 74 |
+
font-weight: 500;
|
| 75 |
+
margin-top: 1.5rem;
|
| 76 |
+
}
|
| 77 |
+
.grid {
|
| 78 |
+
display: grid;
|
| 79 |
+
grid-template-columns: repeat(auto-fit, minmax(300px, 1fr));
|
| 80 |
+
gap: 1.5rem;
|
| 81 |
+
margin: 3rem 0;
|
| 82 |
+
}
|
| 83 |
+
.card {
|
| 84 |
+
background: var(--bg-card);
|
| 85 |
+
border: 1px solid var(--border);
|
| 86 |
+
border-radius: 1rem;
|
| 87 |
+
padding: 1.5rem;
|
| 88 |
+
transition: transform 0.2s, box-shadow 0.2s;
|
| 89 |
+
}
|
| 90 |
+
.card:hover {
|
| 91 |
+
transform: translateY(-2px);
|
| 92 |
+
box-shadow: 0 10px 40px rgba(0, 0, 0, 0.3);
|
| 93 |
+
}
|
| 94 |
+
.card-icon {
|
| 95 |
+
width: 48px;
|
| 96 |
+
height: 48px;
|
| 97 |
+
border-radius: 0.75rem;
|
| 98 |
+
display: flex;
|
| 99 |
+
align-items: center;
|
| 100 |
+
justify-content: center;
|
| 101 |
+
font-size: 1.5rem;
|
| 102 |
+
margin-bottom: 1rem;
|
| 103 |
+
}
|
| 104 |
+
.card-icon.blue { background: rgba(37, 99, 235, 0.2); }
|
| 105 |
+
.card-icon.green { background: rgba(16, 185, 129, 0.2); }
|
| 106 |
+
.card-icon.amber { background: rgba(245, 158, 11, 0.2); }
|
| 107 |
+
.card-icon.purple { background: rgba(139, 92, 246, 0.2); }
|
| 108 |
+
.card h3 {
|
| 109 |
+
font-size: 1.125rem;
|
| 110 |
+
font-weight: 600;
|
| 111 |
+
margin-bottom: 0.5rem;
|
| 112 |
+
}
|
| 113 |
+
.card p {
|
| 114 |
+
color: var(--text-muted);
|
| 115 |
+
font-size: 0.875rem;
|
| 116 |
+
}
|
| 117 |
+
.features {
|
| 118 |
+
margin: 3rem 0;
|
| 119 |
+
}
|
| 120 |
+
.features h2 {
|
| 121 |
+
font-size: 1.75rem;
|
| 122 |
+
margin-bottom: 1.5rem;
|
| 123 |
+
text-align: center;
|
| 124 |
+
}
|
| 125 |
+
.feature-list {
|
| 126 |
+
display: grid;
|
| 127 |
+
grid-template-columns: repeat(auto-fit, minmax(280px, 1fr));
|
| 128 |
+
gap: 1rem;
|
| 129 |
+
}
|
| 130 |
+
.feature-item {
|
| 131 |
+
display: flex;
|
| 132 |
+
align-items: center;
|
| 133 |
+
gap: 0.75rem;
|
| 134 |
+
padding: 1rem;
|
| 135 |
+
background: var(--bg-card);
|
| 136 |
+
border-radius: 0.75rem;
|
| 137 |
+
border: 1px solid var(--border);
|
| 138 |
+
}
|
| 139 |
+
.feature-item span {
|
| 140 |
+
color: var(--secondary);
|
| 141 |
+
font-size: 1.25rem;
|
| 142 |
+
}
|
| 143 |
+
.code-block {
|
| 144 |
+
background: #0d1117;
|
| 145 |
+
border: 1px solid var(--border);
|
| 146 |
+
border-radius: 0.75rem;
|
| 147 |
+
padding: 1.5rem;
|
| 148 |
+
margin: 2rem 0;
|
| 149 |
+
overflow-x: auto;
|
| 150 |
+
}
|
| 151 |
+
.code-block pre {
|
| 152 |
+
font-family: 'JetBrains Mono', monospace;
|
| 153 |
+
font-size: 0.875rem;
|
| 154 |
+
color: #e2e8f0;
|
| 155 |
+
white-space: pre;
|
| 156 |
+
}
|
| 157 |
+
.code-block .comment { color: #6b7280; }
|
| 158 |
+
.code-block .keyword { color: #c084fc; }
|
| 159 |
+
.code-block .string { color: #34d399; }
|
| 160 |
+
.code-block .number { color: #f472b6; }
|
| 161 |
+
.cta {
|
| 162 |
+
text-align: center;
|
| 163 |
+
padding: 3rem;
|
| 164 |
+
background: var(--bg-card);
|
| 165 |
+
border-radius: 1rem;
|
| 166 |
+
border: 1px solid var(--border);
|
| 167 |
+
margin-top: 3rem;
|
| 168 |
+
}
|
| 169 |
+
.cta h2 {
|
| 170 |
+
font-size: 1.5rem;
|
| 171 |
+
margin-bottom: 1rem;
|
| 172 |
+
}
|
| 173 |
+
.cta p {
|
| 174 |
+
color: var(--text-muted);
|
| 175 |
+
margin-bottom: 1.5rem;
|
| 176 |
+
}
|
| 177 |
+
.btn {
|
| 178 |
+
display: inline-block;
|
| 179 |
+
padding: 0.75rem 1.5rem;
|
| 180 |
+
background: var(--primary);
|
| 181 |
+
color: white;
|
| 182 |
+
text-decoration: none;
|
| 183 |
+
border-radius: 0.5rem;
|
| 184 |
+
font-weight: 500;
|
| 185 |
+
transition: background 0.2s;
|
| 186 |
+
}
|
| 187 |
+
.btn:hover {
|
| 188 |
+
background: var(--primary-dark);
|
| 189 |
+
}
|
| 190 |
+
footer {
|
| 191 |
+
text-align: center;
|
| 192 |
+
padding: 2rem;
|
| 193 |
+
color: var(--text-muted);
|
| 194 |
+
font-size: 0.875rem;
|
| 195 |
+
border-top: 1px solid var(--border);
|
| 196 |
+
margin-top: 4rem;
|
| 197 |
+
}
|
| 198 |
+
.kernel-families {
|
| 199 |
+
display: flex;
|
| 200 |
+
justify-content: center;
|
| 201 |
+
gap: 1rem;
|
| 202 |
+
margin-top: 2rem;
|
| 203 |
+
flex-wrap: wrap;
|
| 204 |
+
}
|
| 205 |
+
.kernel-tag {
|
| 206 |
+
padding: 0.5rem 1rem;
|
| 207 |
+
background: rgba(139, 92, 246, 0.2);
|
| 208 |
+
border: 1px solid rgba(139, 92, 246, 0.3);
|
| 209 |
+
border-radius: 0.5rem;
|
| 210 |
+
font-family: 'JetBrains Mono', monospace;
|
| 211 |
+
font-size: 0.875rem;
|
| 212 |
+
}
|
| 213 |
+
</style>
|
| 214 |
+
</head>
|
| 215 |
+
<body>
|
| 216 |
+
<div class="container">
|
| 217 |
+
<header>
|
| 218 |
+
<h1>RL Surrogate Kernel Autotuning</h1>
|
| 219 |
+
<p class="subtitle">
|
| 220 |
+
A reinforcement learning environment for optimizing GPU kernels using surrogate models and self-improving AI agents
|
| 221 |
+
</p>
|
| 222 |
+
<span class="badge">🚀 HuggingFace Space</span>
|
| 223 |
+
</header>
|
| 224 |
+
<div class="grid">
|
| 225 |
+
<div class="card">
|
| 226 |
+
<div class="card-icon blue">🎯</div>
|
| 227 |
+
<h3>Multi-Family Support</h3>
|
| 228 |
+
<p>Optimizes three kernel families: Softmax, LayerNorm, and Grouped GEMM for NVIDIA GPUs</p>
|
| 229 |
+
</div>
|
| 230 |
+
<div class="card">
|
| 231 |
+
<div class="card-icon green">📊</div>
|
| 232 |
+
<h3>Surrogate Modeling</h3>
|
| 233 |
+
<p>Uses Gaussian Process surrogates to predict kernel performance and guide exploration</p>
|
| 234 |
+
</div>
|
| 235 |
+
<div class="card">
|
| 236 |
+
<div class="card-icon amber">🤖</div>
|
| 237 |
+
<h3>Self-Improving AI</h3>
|
| 238 |
+
<p>Employs student-teacher LLM architecture for autonomous kernel optimization</p>
|
| 239 |
+
</div>
|
| 240 |
+
<div class="card">
|
| 241 |
+
<div class="card-icon purple">⚡</div>
|
| 242 |
+
<h3>Triton Backend</h3>
|
| 243 |
+
<p>Generates optimized Triton kernels with configurable block sizes, warps, and stages</p>
|
| 244 |
+
</div>
|
| 245 |
+
</div>
|
| 246 |
+
<div class="features">
|
| 247 |
+
<h2>Supported Kernel Families</h2>
|
| 248 |
+
<div class="kernel-families">
|
| 249 |
+
<span class="kernel-tag">softmax</span>
|
| 250 |
+
<span class="kernel-tag">layernorm</span>
|
| 251 |
+
<span class="kernel-tag">grouped_gemm</span>
|
| 252 |
+
</div>
|
| 253 |
+
</div>
|
| 254 |
+
<div class="features">
|
| 255 |
+
<h2>How It Works</h2>
|
| 256 |
+
<div class="feature-list">
|
| 257 |
+
<div class="feature-item">
|
| 258 |
+
<span>1.</span>
|
| 259 |
+
<div><strong>Reset</strong> - Initialize environment with a kernel task (e.g., softmax_m4096_n256)</div>
|
| 260 |
+
</div>
|
| 261 |
+
<div class="feature-item">
|
| 262 |
+
<span>2.</span>
|
| 263 |
+
<div><strong>Configure</strong> - Select Triton config (block_size, num_warps, num_stages)</div>
|
| 264 |
+
</div>
|
| 265 |
+
<div class="feature-item">
|
| 266 |
+
<span>3.</span>
|
| 267 |
+
<div><strong>Step</strong> - Benchmark configuration and receive reward based on latency</div>
|
| 268 |
+
</div>
|
| 269 |
+
<div class="feature-item">
|
| 270 |
+
<span>4.</span>
|
| 271 |
+
<div><strong>Optimize</strong> - Use LLM to generate better kernel variants automatically</div>
|
| 272 |
+
</div>
|
| 273 |
+
<div class="feature-item">
|
| 274 |
+
<span>5.</span>
|
| 275 |
+
<div><strong>Learn</strong> - Surrogate model improves with more observations</div>
|
| 276 |
+
</div>
|
| 277 |
+
<div class="feature-item">
|
| 278 |
+
<span>6.</span>
|
| 279 |
+
<div><strong>Compare</strong> - Student vs Teacher LLM competition for best optimization</div>
|
| 280 |
+
</div>
|
| 281 |
+
</div>
|
| 282 |
+
</div>
|
| 283 |
+
<div class="code-block">
|
| 284 |
+
<pre><span class="comment"># Example: Using the RL Environment</span>
|
| 285 |
+
<span class="comment"># 1. Reset with a task</span>
|
| 286 |
+
POST /reset
|
| 287 |
+
{
|
| 288 |
+
<span class="string">"task"</span>: <span class="string">"softmax_m4096_n256"</span>,
|
| 289 |
+
<span class="string">"seed"</span>: <span class="number">0</span>
|
| 290 |
+
}
|
| 291 |
+
<span class="comment"># 2. Take a step with a config</span>
|
| 292 |
+
POST /step
|
| 293 |
+
{
|
| 294 |
+
<span class="string">"config_id"</span>: <span class="number">0</span>,
|
| 295 |
+
<span class="string">"source"</span>: <span class="string">"def benchmark_generated_kernel..."</span> <span class="comment"># optional kernel code</span>
|
| 296 |
+
}
|
| 297 |
+
<span class="comment"># Response includes reward based on latency improvement</span>
|
| 298 |
+
{
|
| 299 |
+
<span class="string">"reward"</span>: <span class="number">0.093</span>,
|
| 300 |
+
<span class="string">"observation"</span>: {
|
| 301 |
+
<span class="string">"best_so_far_ms"</span>: <span class="number">0.41</span>,
|
| 302 |
+
<span class="string">"oracle_best_ms"</span>: <span class="number">0.38</span>
|
| 303 |
+
}
|
| 304 |
+
}</pre>
|
| 305 |
+
</div>
|
| 306 |
+
<div class="cta">
|
| 307 |
+
<h2>Try It Now</h2>
|
| 308 |
+
<p>Use the interactive Gradio interface above to benchmark kernel configurations and explore the RL environment.</p>
|
| 309 |
+
<a href="/" class="btn">Launch Interface →</a>
|
| 310 |
+
</div>
|
| 311 |
+
<footer>
|
| 312 |
+
<p>RL Surrogate Kernel Autotuning Environment</p>
|
| 313 |
+
<p>Built with FastAPI, Gradio, Triton, and Transformers</p>
|
| 314 |
+
</footer>
|
| 315 |
+
</div>
|
| 316 |
+
</body>
|
| 317 |
+
</html>
|
models.py
DELETED
|
@@ -1,27 +0,0 @@
|
|
| 1 |
-
from __future__ import annotations
|
| 2 |
-
|
| 3 |
-
from dataclasses import dataclass
|
| 4 |
-
from typing import Any, Dict, List, Sequence
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
@dataclass
|
| 8 |
-
class Action:
|
| 9 |
-
config_id: int
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
@dataclass
|
| 13 |
-
class StepResult:
|
| 14 |
-
observation: Dict[str, Any]
|
| 15 |
-
reward: float
|
| 16 |
-
done: bool
|
| 17 |
-
state: Dict[str, Any]
|
| 18 |
-
info: Dict[str, Any]
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
@dataclass
|
| 22 |
-
class ResetResult:
|
| 23 |
-
observation: Dict[str, Any]
|
| 24 |
-
reward: float
|
| 25 |
-
done: bool
|
| 26 |
-
state: Dict[str, Any]
|
| 27 |
-
info: Dict[str, Any]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
openenv.yaml
DELETED
|
@@ -1,15 +0,0 @@
|
|
| 1 |
-
name: softmax_surrogate_env
|
| 2 |
-
version: 0.1.0
|
| 3 |
-
description: >
|
| 4 |
-
OpenEnv scaffold for discrete-action multi-family GPU autotuning with real
|
| 5 |
-
H100 measurements and held-out-shape / held-out-family evaluation.
|
| 6 |
-
python:
|
| 7 |
-
package: .
|
| 8 |
-
server:
|
| 9 |
-
entrypoint: server.app:app
|
| 10 |
-
environment:
|
| 11 |
-
class: server.softmax_surrogate_environment.SoftmaxSurrogateEnvironment
|
| 12 |
-
init_kwargs:
|
| 13 |
-
measurement_path: data/autotune_measurements.csv
|
| 14 |
-
budget: 6
|
| 15 |
-
seed: 0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
requirements.txt
DELETED
|
@@ -1,8 +0,0 @@
|
|
| 1 |
-
torch>=2.4
|
| 2 |
-
triton>=2.1
|
| 3 |
-
numpy>=1.23
|
| 4 |
-
fastapi>=0.111
|
| 5 |
-
uvicorn>=0.30
|
| 6 |
-
pydantic>=2.8
|
| 7 |
-
requests>=2.32
|
| 8 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
scripts/__init__.py
DELETED
|
@@ -1 +0,0 @@
|
|
| 1 |
-
# Package marker for direct imports from the scripts folder.
|
|
|
|
|
|
scripts/benchmark_qwen_05b_runtime.py
DELETED
|
@@ -1,158 +0,0 @@
|
|
| 1 |
-
#!/usr/bin/env python3
|
| 2 |
-
from __future__ import annotations
|
| 3 |
-
|
| 4 |
-
import argparse
|
| 5 |
-
import json
|
| 6 |
-
import sys
|
| 7 |
-
import time
|
| 8 |
-
from pathlib import Path
|
| 9 |
-
from typing import Any, Dict, Tuple
|
| 10 |
-
|
| 11 |
-
import numpy as np
|
| 12 |
-
import torch
|
| 13 |
-
|
| 14 |
-
ROOT = Path(__file__).resolve().parents[1]
|
| 15 |
-
if str(ROOT) not in sys.path:
|
| 16 |
-
sys.path.append(str(ROOT))
|
| 17 |
-
|
| 18 |
-
from scripts.collect_qwen_05b_measurements import EPS, benchmark_qwen_task
|
| 19 |
-
from scripts.qwen_05b_spec import QwenKernelTask, qwen_05b_tasks
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
TASK_BY_ID = {task.task_id: task for task in qwen_05b_tasks()}
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
def _bench_callable(fn, args: Tuple[Any, ...], repeats: int, warmup: int) -> float:
|
| 26 |
-
for _ in range(max(1, warmup)):
|
| 27 |
-
fn(*args)
|
| 28 |
-
torch.cuda.synchronize()
|
| 29 |
-
start = torch.cuda.Event(enable_timing=True)
|
| 30 |
-
end = torch.cuda.Event(enable_timing=True)
|
| 31 |
-
durations = []
|
| 32 |
-
for _ in range(max(1, repeats)):
|
| 33 |
-
torch.cuda.synchronize()
|
| 34 |
-
start.record()
|
| 35 |
-
fn(*args)
|
| 36 |
-
end.record()
|
| 37 |
-
end.synchronize()
|
| 38 |
-
durations.append(start.elapsed_time(end))
|
| 39 |
-
return float(np.median(np.asarray(durations, dtype=np.float32)))
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
def _build_qwen_callable(task: QwenKernelTask, seed: int):
|
| 43 |
-
torch.manual_seed(seed)
|
| 44 |
-
if task.family == "softmax":
|
| 45 |
-
x = torch.randn((task.m, task.n), device="cuda", dtype=torch.float16)
|
| 46 |
-
|
| 47 |
-
def fn(inp: torch.Tensor):
|
| 48 |
-
return torch.softmax(inp, dim=-1)
|
| 49 |
-
|
| 50 |
-
return fn, (x,)
|
| 51 |
-
|
| 52 |
-
if task.family == "rmsnorm":
|
| 53 |
-
x = torch.randn((task.m, task.n), device="cuda", dtype=torch.float16)
|
| 54 |
-
|
| 55 |
-
def fn(inp: torch.Tensor):
|
| 56 |
-
return inp.float() * torch.rsqrt(inp.float().pow(2).mean(dim=-1, keepdim=True) + EPS)
|
| 57 |
-
|
| 58 |
-
return fn, (x,)
|
| 59 |
-
|
| 60 |
-
if task.family == "gemm":
|
| 61 |
-
a = torch.randn((task.m, task.k), device="cuda", dtype=torch.float16)
|
| 62 |
-
b = torch.randn((task.k, task.n), device="cuda", dtype=torch.float16)
|
| 63 |
-
|
| 64 |
-
def fn(lhs: torch.Tensor, rhs: torch.Tensor):
|
| 65 |
-
return torch.matmul(lhs, rhs)
|
| 66 |
-
|
| 67 |
-
return fn, (a, b)
|
| 68 |
-
|
| 69 |
-
raise ValueError(f"Unsupported family: {task.family}")
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
def _benchmark_torch(task: QwenKernelTask, seed: int, repeats: int, warmup: int) -> Dict[str, float]:
|
| 73 |
-
eager_fn, args = _build_qwen_callable(task, seed)
|
| 74 |
-
eager_latency_ms = _bench_callable(eager_fn, args, repeats=repeats, warmup=warmup)
|
| 75 |
-
|
| 76 |
-
compiled_fn = torch.compile(eager_fn)
|
| 77 |
-
torch.cuda.synchronize()
|
| 78 |
-
start = time.perf_counter()
|
| 79 |
-
compiled_fn(*args)
|
| 80 |
-
torch.cuda.synchronize()
|
| 81 |
-
compile_plus_first_call_ms = float((time.perf_counter() - start) * 1000.0)
|
| 82 |
-
compiled_latency_ms = _bench_callable(compiled_fn, args, repeats=repeats, warmup=warmup)
|
| 83 |
-
return {
|
| 84 |
-
"eager_latency_ms": eager_latency_ms,
|
| 85 |
-
"compile_plus_first_call_ms": compile_plus_first_call_ms,
|
| 86 |
-
"compiled_latency_ms": compiled_latency_ms,
|
| 87 |
-
}
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
def _task_best_configs(eval_results: Dict[str, Any]) -> Dict[str, Dict[str, Dict[str, Any]]]:
|
| 91 |
-
task_map: Dict[str, Dict[str, Dict[str, Any]]] = {}
|
| 92 |
-
for section in eval_results["results"].values():
|
| 93 |
-
for method in ("random", "surrogate"):
|
| 94 |
-
for run in section["task_runs"][method]:
|
| 95 |
-
task_map.setdefault(run["task"], {})[method] = run["best_overall"]["config"]
|
| 96 |
-
return task_map
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
def main() -> None:
|
| 100 |
-
parser = argparse.ArgumentParser(description="Benchmark eager/torch.compile and best Triton configs for Qwen2.5-0.5B exact kernels.")
|
| 101 |
-
parser.add_argument("--generalization-results", type=Path, default=Path("outputs/qwen_05b_generalization_eval.json"))
|
| 102 |
-
parser.add_argument("--repeats", type=int, default=100)
|
| 103 |
-
parser.add_argument("--warmup", type=int, default=10)
|
| 104 |
-
parser.add_argument("--seed", type=int, default=123)
|
| 105 |
-
parser.add_argument("--output", type=Path, default=Path("outputs/qwen_05b_runtime_references.json"))
|
| 106 |
-
args = parser.parse_args()
|
| 107 |
-
|
| 108 |
-
generalization_results = json.loads(args.generalization_results.read_text(encoding="utf-8"))
|
| 109 |
-
task_configs = _task_best_configs(generalization_results)
|
| 110 |
-
|
| 111 |
-
results = {}
|
| 112 |
-
for idx, task_id in enumerate(sorted(task_configs.keys())):
|
| 113 |
-
task = TASK_BY_ID[task_id]
|
| 114 |
-
seed = args.seed + idx
|
| 115 |
-
torch_metrics = _benchmark_torch(task, seed=seed, repeats=args.repeats, warmup=args.warmup)
|
| 116 |
-
triton_results = {
|
| 117 |
-
method: benchmark_qwen_task(
|
| 118 |
-
task=task,
|
| 119 |
-
block_size=int(config["block_size"]),
|
| 120 |
-
num_warps=int(config["num_warps"]),
|
| 121 |
-
num_stages=int(config["num_stages"]),
|
| 122 |
-
repeats=args.repeats,
|
| 123 |
-
warmup=args.warmup,
|
| 124 |
-
seed=seed,
|
| 125 |
-
).__dict__
|
| 126 |
-
for method, config in task_configs[task_id].items()
|
| 127 |
-
}
|
| 128 |
-
results[task_id] = {
|
| 129 |
-
"family": task.family,
|
| 130 |
-
"role": task.role,
|
| 131 |
-
"mode": task.mode,
|
| 132 |
-
"torch": torch_metrics,
|
| 133 |
-
"triton": triton_results,
|
| 134 |
-
"speedups": {
|
| 135 |
-
method: {
|
| 136 |
-
"vs_eager": float(torch_metrics["eager_latency_ms"] / row["median_ms"]),
|
| 137 |
-
"vs_compiled": float(torch_metrics["compiled_latency_ms"] / row["median_ms"]),
|
| 138 |
-
}
|
| 139 |
-
for method, row in triton_results.items()
|
| 140 |
-
},
|
| 141 |
-
}
|
| 142 |
-
|
| 143 |
-
summary = {
|
| 144 |
-
"generalization_results": str(args.generalization_results),
|
| 145 |
-
"repeats": args.repeats,
|
| 146 |
-
"warmup": args.warmup,
|
| 147 |
-
"seed": args.seed,
|
| 148 |
-
"task_count": len(results),
|
| 149 |
-
"results": results,
|
| 150 |
-
}
|
| 151 |
-
args.output.parent.mkdir(parents=True, exist_ok=True)
|
| 152 |
-
with args.output.open("w", encoding="utf-8") as handle:
|
| 153 |
-
json.dump(summary, handle, indent=2)
|
| 154 |
-
print(json.dumps(summary, indent=2))
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
if __name__ == "__main__":
|
| 158 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
scripts/benchmark_runtime_references.py
DELETED
|
@@ -1,222 +0,0 @@
|
|
| 1 |
-
#!/usr/bin/env python3
|
| 2 |
-
from __future__ import annotations
|
| 3 |
-
|
| 4 |
-
import argparse
|
| 5 |
-
import json
|
| 6 |
-
import re
|
| 7 |
-
import sys
|
| 8 |
-
import time
|
| 9 |
-
from pathlib import Path
|
| 10 |
-
from typing import Any, Dict, Iterable, List, Tuple
|
| 11 |
-
|
| 12 |
-
import numpy as np
|
| 13 |
-
import torch
|
| 14 |
-
import torch.nn.functional as F
|
| 15 |
-
|
| 16 |
-
ROOT = Path(__file__).resolve().parents[1]
|
| 17 |
-
if str(ROOT) not in sys.path:
|
| 18 |
-
sys.path.append(str(ROOT))
|
| 19 |
-
|
| 20 |
-
from scripts.collect_measurements import benchmark_single_config as benchmark_softmax_config
|
| 21 |
-
from scripts.collect_multifamily_measurements import (
|
| 22 |
-
EPS,
|
| 23 |
-
GROUPED_GEMM_GROUP_COUNT,
|
| 24 |
-
GROUPED_GEMM_K,
|
| 25 |
-
benchmark_grouped_gemm_config,
|
| 26 |
-
benchmark_layernorm_config,
|
| 27 |
-
)
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
TASK_RE = re.compile(
|
| 31 |
-
r"^(?P<family>[a-z_]+?)(?:_g(?P<g>\d+)_k(?P<k>\d+))?_m(?P<m>\d+)_n(?P<n>\d+)$"
|
| 32 |
-
)
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
def _parse_task(task_id: str) -> Dict[str, int | str | None]:
|
| 36 |
-
match = TASK_RE.match(task_id)
|
| 37 |
-
if not match:
|
| 38 |
-
raise ValueError(f"Cannot parse task id: {task_id}")
|
| 39 |
-
data = match.groupdict()
|
| 40 |
-
return {
|
| 41 |
-
"family": data["family"],
|
| 42 |
-
"m": int(data["m"]),
|
| 43 |
-
"n": int(data["n"]),
|
| 44 |
-
"g": int(data["g"]) if data["g"] is not None else None,
|
| 45 |
-
"k": int(data["k"]) if data["k"] is not None else None,
|
| 46 |
-
}
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
def _bench_callable(fn, args: Tuple[Any, ...], repeats: int, warmup: int) -> float:
|
| 50 |
-
for _ in range(max(1, warmup)):
|
| 51 |
-
fn(*args)
|
| 52 |
-
torch.cuda.synchronize()
|
| 53 |
-
|
| 54 |
-
start = torch.cuda.Event(enable_timing=True)
|
| 55 |
-
end = torch.cuda.Event(enable_timing=True)
|
| 56 |
-
latencies_ms: List[float] = []
|
| 57 |
-
for _ in range(max(1, repeats)):
|
| 58 |
-
torch.cuda.synchronize()
|
| 59 |
-
start.record()
|
| 60 |
-
fn(*args)
|
| 61 |
-
end.record()
|
| 62 |
-
end.synchronize()
|
| 63 |
-
latencies_ms.append(start.elapsed_time(end))
|
| 64 |
-
return float(np.median(np.asarray(latencies_ms, dtype=np.float32)))
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
def _build_family_callable(task_meta: Dict[str, Any], seed: int) -> Tuple[Any, Tuple[Any, ...]]:
|
| 68 |
-
family = str(task_meta["family"])
|
| 69 |
-
m = int(task_meta["m"])
|
| 70 |
-
n = int(task_meta["n"])
|
| 71 |
-
torch.manual_seed(seed)
|
| 72 |
-
|
| 73 |
-
if family == "softmax":
|
| 74 |
-
x = torch.randn((m, n), device="cuda", dtype=torch.float16)
|
| 75 |
-
|
| 76 |
-
def fn(inp: torch.Tensor) -> torch.Tensor:
|
| 77 |
-
return torch.softmax(inp, dim=-1)
|
| 78 |
-
|
| 79 |
-
return fn, (x,)
|
| 80 |
-
|
| 81 |
-
if family == "layernorm":
|
| 82 |
-
x = torch.randn((m, n), device="cuda", dtype=torch.float16)
|
| 83 |
-
|
| 84 |
-
def fn(inp: torch.Tensor) -> torch.Tensor:
|
| 85 |
-
return F.layer_norm(inp, (inp.shape[-1],), eps=EPS)
|
| 86 |
-
|
| 87 |
-
return fn, (x,)
|
| 88 |
-
|
| 89 |
-
if family == "grouped_gemm":
|
| 90 |
-
group_count = int(task_meta.get("g") or GROUPED_GEMM_GROUP_COUNT)
|
| 91 |
-
k_dim = int(task_meta.get("k") or GROUPED_GEMM_K)
|
| 92 |
-
group_m = max(64, m // group_count)
|
| 93 |
-
a_groups = [torch.randn((group_m, k_dim), device="cuda", dtype=torch.float16) for _ in range(group_count)]
|
| 94 |
-
b_groups = [torch.randn((k_dim, n), device="cuda", dtype=torch.float16) for _ in range(group_count)]
|
| 95 |
-
|
| 96 |
-
def fn(*inputs: torch.Tensor) -> Tuple[torch.Tensor, ...]:
|
| 97 |
-
mid = len(inputs) // 2
|
| 98 |
-
a_list = inputs[:mid]
|
| 99 |
-
b_list = inputs[mid:]
|
| 100 |
-
return tuple(torch.matmul(a, b) for a, b in zip(a_list, b_list))
|
| 101 |
-
|
| 102 |
-
return fn, tuple(a_groups + b_groups)
|
| 103 |
-
|
| 104 |
-
raise ValueError(f"Unsupported family: {family}")
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
def _benchmark_torch_compile(task_meta: Dict[str, Any], seed: int, repeats: int, warmup: int) -> Dict[str, float]:
|
| 108 |
-
eager_fn, args = _build_family_callable(task_meta, seed=seed)
|
| 109 |
-
eager_latency_ms = _bench_callable(eager_fn, args, repeats=repeats, warmup=warmup)
|
| 110 |
-
|
| 111 |
-
compiled_fn = torch.compile(eager_fn)
|
| 112 |
-
torch.cuda.synchronize()
|
| 113 |
-
start = time.perf_counter()
|
| 114 |
-
compiled_fn(*args)
|
| 115 |
-
torch.cuda.synchronize()
|
| 116 |
-
compile_plus_first_call_ms = float((time.perf_counter() - start) * 1000.0)
|
| 117 |
-
compiled_latency_ms = _bench_callable(compiled_fn, args, repeats=repeats, warmup=warmup)
|
| 118 |
-
return {
|
| 119 |
-
"eager_latency_ms": eager_latency_ms,
|
| 120 |
-
"compile_plus_first_call_ms": compile_plus_first_call_ms,
|
| 121 |
-
"compiled_latency_ms": compiled_latency_ms,
|
| 122 |
-
}
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
def _benchmark_triton_best(task_id: str, config: Dict[str, Any], repeats: int, warmup: int, seed: int) -> Dict[str, Any]:
|
| 126 |
-
task = _parse_task(task_id)
|
| 127 |
-
family = str(task["family"])
|
| 128 |
-
common = {
|
| 129 |
-
"n": int(task["n"]),
|
| 130 |
-
"block_size": int(config["block_size"]),
|
| 131 |
-
"num_warps": int(config["num_warps"]),
|
| 132 |
-
"num_stages": int(config["num_stages"]),
|
| 133 |
-
"m": int(task["m"]),
|
| 134 |
-
"repeats": repeats,
|
| 135 |
-
"warmup": warmup,
|
| 136 |
-
"seed": seed,
|
| 137 |
-
}
|
| 138 |
-
if family == "softmax":
|
| 139 |
-
row = benchmark_softmax_config(**common)
|
| 140 |
-
return row.__dict__
|
| 141 |
-
if family == "layernorm":
|
| 142 |
-
row = benchmark_layernorm_config(**common)
|
| 143 |
-
return row.__dict__
|
| 144 |
-
if family == "grouped_gemm":
|
| 145 |
-
row = benchmark_grouped_gemm_config(**common)
|
| 146 |
-
return row.__dict__
|
| 147 |
-
raise ValueError(f"Unsupported family: {family}")
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
def _collect_task_best_configs(generalization_results: Dict[str, Any]) -> Dict[str, Dict[str, Dict[str, Any]]]:
|
| 151 |
-
task_map: Dict[str, Dict[str, Dict[str, Any]]] = {}
|
| 152 |
-
for section in generalization_results["results"].values():
|
| 153 |
-
for method in ("random", "surrogate"):
|
| 154 |
-
for run in section["task_runs"][method]:
|
| 155 |
-
task_map.setdefault(run["task"], {})[method] = run["best_overall"]["config"]
|
| 156 |
-
return task_map
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
def main() -> None:
|
| 160 |
-
parser = argparse.ArgumentParser(description="Benchmark eager/torch.compile and live Triton configs on held-out tasks.")
|
| 161 |
-
parser.add_argument(
|
| 162 |
-
"--generalization-results",
|
| 163 |
-
type=Path,
|
| 164 |
-
default=Path("outputs/generalization_eval.json"),
|
| 165 |
-
)
|
| 166 |
-
parser.add_argument("--repeats", type=int, default=100)
|
| 167 |
-
parser.add_argument("--warmup", type=int, default=10)
|
| 168 |
-
parser.add_argument("--seed", type=int, default=123)
|
| 169 |
-
parser.add_argument(
|
| 170 |
-
"--output",
|
| 171 |
-
type=Path,
|
| 172 |
-
default=Path("outputs/runtime_references.json"),
|
| 173 |
-
)
|
| 174 |
-
args = parser.parse_args()
|
| 175 |
-
|
| 176 |
-
generalization_results = json.loads(args.generalization_results.read_text(encoding="utf-8"))
|
| 177 |
-
task_best_configs = _collect_task_best_configs(generalization_results)
|
| 178 |
-
|
| 179 |
-
results: Dict[str, Any] = {}
|
| 180 |
-
for idx, task_id in enumerate(sorted(task_best_configs.keys())):
|
| 181 |
-
task_seed = args.seed + idx
|
| 182 |
-
task_meta = _parse_task(task_id)
|
| 183 |
-
torch_metrics = _benchmark_torch_compile(task_meta, seed=task_seed, repeats=args.repeats, warmup=args.warmup)
|
| 184 |
-
method_results = {
|
| 185 |
-
method: _benchmark_triton_best(
|
| 186 |
-
task_id=task_id,
|
| 187 |
-
config=config,
|
| 188 |
-
repeats=args.repeats,
|
| 189 |
-
warmup=args.warmup,
|
| 190 |
-
seed=task_seed,
|
| 191 |
-
)
|
| 192 |
-
for method, config in task_best_configs[task_id].items()
|
| 193 |
-
}
|
| 194 |
-
results[task_id] = {
|
| 195 |
-
"family": task_meta["family"],
|
| 196 |
-
"torch": torch_metrics,
|
| 197 |
-
"triton": method_results,
|
| 198 |
-
"speedups": {
|
| 199 |
-
method: {
|
| 200 |
-
"vs_eager": float(torch_metrics["eager_latency_ms"] / row["median_ms"]),
|
| 201 |
-
"vs_compiled": float(torch_metrics["compiled_latency_ms"] / row["median_ms"]),
|
| 202 |
-
}
|
| 203 |
-
for method, row in method_results.items()
|
| 204 |
-
},
|
| 205 |
-
}
|
| 206 |
-
|
| 207 |
-
summary = {
|
| 208 |
-
"generalization_results": str(args.generalization_results),
|
| 209 |
-
"repeats": args.repeats,
|
| 210 |
-
"warmup": args.warmup,
|
| 211 |
-
"seed": args.seed,
|
| 212 |
-
"task_count": len(results),
|
| 213 |
-
"results": results,
|
| 214 |
-
}
|
| 215 |
-
args.output.parent.mkdir(parents=True, exist_ok=True)
|
| 216 |
-
with args.output.open("w", encoding="utf-8") as handle:
|
| 217 |
-
json.dump(summary, handle, indent=2)
|
| 218 |
-
print(json.dumps(summary, indent=2))
|
| 219 |
-
|
| 220 |
-
|
| 221 |
-
if __name__ == "__main__":
|
| 222 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
scripts/benchmark_surrogate_vs_triton_autotune.py
DELETED
|
@@ -1,361 +0,0 @@
|
|
| 1 |
-
#!/usr/bin/env python3
|
| 2 |
-
from __future__ import annotations
|
| 3 |
-
|
| 4 |
-
import argparse
|
| 5 |
-
import json
|
| 6 |
-
import sys
|
| 7 |
-
import time
|
| 8 |
-
from pathlib import Path
|
| 9 |
-
from typing import Any, Dict, List
|
| 10 |
-
|
| 11 |
-
import numpy as np
|
| 12 |
-
import torch
|
| 13 |
-
|
| 14 |
-
ROOT = Path(__file__).resolve().parents[1]
|
| 15 |
-
if str(ROOT) not in sys.path:
|
| 16 |
-
sys.path.append(str(ROOT))
|
| 17 |
-
|
| 18 |
-
import triton
|
| 19 |
-
import triton.language as tl
|
| 20 |
-
import triton.testing as ttesting
|
| 21 |
-
|
| 22 |
-
from scripts.collect_measurements import BLOCK_SIZES, NUM_STAGES, NUM_WARPS, fused_rowwise_softmax_kernel
|
| 23 |
-
from server.softmax_surrogate_environment import SoftmaxSurrogateEnvironment
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
DEFAULT_TASKS = [
|
| 27 |
-
"softmax_m4096_n4096",
|
| 28 |
-
"softmax_m4096_n6144",
|
| 29 |
-
"softmax_m4096_n8192",
|
| 30 |
-
]
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
def _choose_surrogate_config_id(env: SoftmaxSurrogateEnvironment, acquisition: str, beta: float, xi: float) -> int:
|
| 34 |
-
seen = set(env.seen_config_ids())
|
| 35 |
-
best_config_id = -1
|
| 36 |
-
best_score = float("-inf")
|
| 37 |
-
for config_id in env.available_config_ids():
|
| 38 |
-
if config_id in seen and len(seen) < len(env.available_config_ids()):
|
| 39 |
-
continue
|
| 40 |
-
score = env.acquisition_score(config_id, strategy=acquisition, beta=beta, xi=xi)
|
| 41 |
-
if score > best_score:
|
| 42 |
-
best_score = score
|
| 43 |
-
best_config_id = config_id
|
| 44 |
-
if best_config_id < 0:
|
| 45 |
-
raise RuntimeError("Failed to select surrogate config.")
|
| 46 |
-
return best_config_id
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
def _valid_autotune_configs(n_cols: int) -> List[triton.Config]:
|
| 50 |
-
configs = []
|
| 51 |
-
for block_size in BLOCK_SIZES:
|
| 52 |
-
if block_size < n_cols:
|
| 53 |
-
continue
|
| 54 |
-
for num_warps in NUM_WARPS:
|
| 55 |
-
for num_stages in NUM_STAGES:
|
| 56 |
-
configs.append(
|
| 57 |
-
triton.Config(
|
| 58 |
-
{"BLOCK_SIZE": block_size},
|
| 59 |
-
num_warps=num_warps,
|
| 60 |
-
num_stages=num_stages,
|
| 61 |
-
)
|
| 62 |
-
)
|
| 63 |
-
return configs
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
def _compile_plus_first_call_fixed(sample: torch.Tensor, config: Dict[str, int]) -> float:
|
| 67 |
-
output = torch.empty_like(sample)
|
| 68 |
-
grid = (sample.shape[0],)
|
| 69 |
-
torch.cuda.synchronize()
|
| 70 |
-
start = time.perf_counter()
|
| 71 |
-
fused_rowwise_softmax_kernel[grid](
|
| 72 |
-
sample,
|
| 73 |
-
output,
|
| 74 |
-
sample.stride(0),
|
| 75 |
-
sample.stride(1),
|
| 76 |
-
output.stride(0),
|
| 77 |
-
output.stride(1),
|
| 78 |
-
sample.shape[1],
|
| 79 |
-
BLOCK_SIZE=int(config["block_size"]),
|
| 80 |
-
num_warps=int(config["num_warps"]),
|
| 81 |
-
num_stages=int(config["num_stages"]),
|
| 82 |
-
)
|
| 83 |
-
torch.cuda.synchronize()
|
| 84 |
-
return float((time.perf_counter() - start) * 1000.0)
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
def _steady_state_fixed(sample: torch.Tensor, config: Dict[str, int], repeats: int, warmup: int) -> float:
|
| 88 |
-
output = torch.empty_like(sample)
|
| 89 |
-
grid = (sample.shape[0],)
|
| 90 |
-
|
| 91 |
-
def launch() -> None:
|
| 92 |
-
fused_rowwise_softmax_kernel[grid](
|
| 93 |
-
sample,
|
| 94 |
-
output,
|
| 95 |
-
sample.stride(0),
|
| 96 |
-
sample.stride(1),
|
| 97 |
-
output.stride(0),
|
| 98 |
-
output.stride(1),
|
| 99 |
-
sample.shape[1],
|
| 100 |
-
BLOCK_SIZE=int(config["block_size"]),
|
| 101 |
-
num_warps=int(config["num_warps"]),
|
| 102 |
-
num_stages=int(config["num_stages"]),
|
| 103 |
-
)
|
| 104 |
-
|
| 105 |
-
return float(
|
| 106 |
-
ttesting.do_bench(
|
| 107 |
-
launch,
|
| 108 |
-
warmup=warmup,
|
| 109 |
-
rep=repeats,
|
| 110 |
-
quantiles=[0.5],
|
| 111 |
-
return_mode="median",
|
| 112 |
-
)
|
| 113 |
-
)
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
def _benchmark_triton_autotune(sample: torch.Tensor, repeats: int, warmup: int) -> Dict[str, float]:
|
| 117 |
-
output = torch.empty_like(sample)
|
| 118 |
-
n_cols = sample.shape[1]
|
| 119 |
-
configs = _valid_autotune_configs(n_cols)
|
| 120 |
-
|
| 121 |
-
@triton.autotune(configs=configs, key=["n_cols"])
|
| 122 |
-
@triton.jit
|
| 123 |
-
def autotuned_softmax_kernel(
|
| 124 |
-
X_ptr,
|
| 125 |
-
Y_ptr,
|
| 126 |
-
stride_xm,
|
| 127 |
-
stride_xn,
|
| 128 |
-
stride_ym,
|
| 129 |
-
stride_yn,
|
| 130 |
-
n_cols,
|
| 131 |
-
BLOCK_SIZE: tl.constexpr,
|
| 132 |
-
):
|
| 133 |
-
row_idx = tl.program_id(0)
|
| 134 |
-
col_offsets = tl.arange(0, BLOCK_SIZE)
|
| 135 |
-
x_ptr = X_ptr + row_idx * stride_xm + col_offsets
|
| 136 |
-
y_ptr = Y_ptr + row_idx * stride_ym + col_offsets
|
| 137 |
-
mask = col_offsets < n_cols
|
| 138 |
-
|
| 139 |
-
x = tl.load(x_ptr, mask=mask, other=-float("inf"))
|
| 140 |
-
x = x - tl.max(x, axis=0)
|
| 141 |
-
numerator = tl.exp(x)
|
| 142 |
-
denominator = tl.sum(numerator, axis=0)
|
| 143 |
-
y = numerator / denominator
|
| 144 |
-
tl.store(y_ptr, y, mask=mask)
|
| 145 |
-
|
| 146 |
-
grid = (sample.shape[0],)
|
| 147 |
-
torch.cuda.synchronize()
|
| 148 |
-
start = time.perf_counter()
|
| 149 |
-
autotuned_softmax_kernel[grid](
|
| 150 |
-
sample,
|
| 151 |
-
output,
|
| 152 |
-
sample.stride(0),
|
| 153 |
-
sample.stride(1),
|
| 154 |
-
output.stride(0),
|
| 155 |
-
output.stride(1),
|
| 156 |
-
n_cols,
|
| 157 |
-
)
|
| 158 |
-
torch.cuda.synchronize()
|
| 159 |
-
first_call_ms = float((time.perf_counter() - start) * 1000.0)
|
| 160 |
-
|
| 161 |
-
def launch() -> None:
|
| 162 |
-
autotuned_softmax_kernel[grid](
|
| 163 |
-
sample,
|
| 164 |
-
output,
|
| 165 |
-
sample.stride(0),
|
| 166 |
-
sample.stride(1),
|
| 167 |
-
output.stride(0),
|
| 168 |
-
output.stride(1),
|
| 169 |
-
n_cols,
|
| 170 |
-
)
|
| 171 |
-
|
| 172 |
-
steady_ms = float(
|
| 173 |
-
ttesting.do_bench(
|
| 174 |
-
launch,
|
| 175 |
-
warmup=warmup,
|
| 176 |
-
rep=repeats,
|
| 177 |
-
quantiles=[0.5],
|
| 178 |
-
return_mode="median",
|
| 179 |
-
)
|
| 180 |
-
)
|
| 181 |
-
return {
|
| 182 |
-
"autotune_first_call_ms": first_call_ms,
|
| 183 |
-
"autotune_steady_ms": steady_ms,
|
| 184 |
-
}
|
| 185 |
-
|
| 186 |
-
|
| 187 |
-
def _build_summary(
|
| 188 |
-
measurement_path: str,
|
| 189 |
-
tasks: List[str],
|
| 190 |
-
acquisition: str,
|
| 191 |
-
beta: float,
|
| 192 |
-
xi: float,
|
| 193 |
-
results: Dict[str, Any],
|
| 194 |
-
) -> Dict[str, Any]:
|
| 195 |
-
if results:
|
| 196 |
-
surrogate_first = [task["surrogate"]["compile_plus_first_call_ms"] for task in results.values()]
|
| 197 |
-
surrogate_steady = [task["surrogate"]["steady_ms"] for task in results.values()]
|
| 198 |
-
autotune_first = [task["triton_autotune"]["autotune_first_call_ms"] for task in results.values()]
|
| 199 |
-
autotune_steady = [task["triton_autotune"]["autotune_steady_ms"] for task in results.values()]
|
| 200 |
-
surrogate_speedup = [
|
| 201 |
-
task["triton_autotune"]["autotune_first_call_ms"] / max(task["surrogate"]["compile_plus_first_call_ms"], 1e-9)
|
| 202 |
-
for task in results.values()
|
| 203 |
-
]
|
| 204 |
-
summary = {
|
| 205 |
-
"mean_surrogate_compile_plus_first_call_ms": float(np.mean(surrogate_first)),
|
| 206 |
-
"mean_surrogate_steady_ms": float(np.mean(surrogate_steady)),
|
| 207 |
-
"mean_autotune_first_call_ms": float(np.mean(autotune_first)),
|
| 208 |
-
"mean_autotune_steady_ms": float(np.mean(autotune_steady)),
|
| 209 |
-
"mean_search_time_speedup_surrogate_vs_autotune": float(np.mean(surrogate_speedup)),
|
| 210 |
-
"completed_task_count": len(results),
|
| 211 |
-
}
|
| 212 |
-
else:
|
| 213 |
-
summary = {
|
| 214 |
-
"mean_surrogate_compile_plus_first_call_ms": None,
|
| 215 |
-
"mean_surrogate_steady_ms": None,
|
| 216 |
-
"mean_autotune_first_call_ms": None,
|
| 217 |
-
"mean_autotune_steady_ms": None,
|
| 218 |
-
"mean_search_time_speedup_surrogate_vs_autotune": None,
|
| 219 |
-
"completed_task_count": 0,
|
| 220 |
-
}
|
| 221 |
-
return {
|
| 222 |
-
"measurement_path": measurement_path,
|
| 223 |
-
"tasks": tasks,
|
| 224 |
-
"acquisition": acquisition,
|
| 225 |
-
"beta": beta,
|
| 226 |
-
"xi": xi,
|
| 227 |
-
"results": results,
|
| 228 |
-
"summary": summary,
|
| 229 |
-
}
|
| 230 |
-
|
| 231 |
-
|
| 232 |
-
def run_benchmark(
|
| 233 |
-
measurement_path: str,
|
| 234 |
-
tasks: List[str],
|
| 235 |
-
repeats: int,
|
| 236 |
-
warmup: int,
|
| 237 |
-
seed: int,
|
| 238 |
-
acquisition: str,
|
| 239 |
-
beta: float,
|
| 240 |
-
xi: float,
|
| 241 |
-
output_path: Path,
|
| 242 |
-
resume: bool,
|
| 243 |
-
) -> Dict[str, Any]:
|
| 244 |
-
env_probe = SoftmaxSurrogateEnvironment(measurement_path=measurement_path, budget=1, seed=seed)
|
| 245 |
-
available_tasks = [task for task in env_probe.available_tasks() if task.startswith("softmax_m4096_n")]
|
| 246 |
-
train_task_ids = [task for task in available_tasks if task not in tasks]
|
| 247 |
-
|
| 248 |
-
results: Dict[str, Any] = {}
|
| 249 |
-
if resume and output_path.exists():
|
| 250 |
-
try:
|
| 251 |
-
existing = json.loads(output_path.read_text(encoding="utf-8"))
|
| 252 |
-
results = dict(existing.get("results", {}))
|
| 253 |
-
except Exception:
|
| 254 |
-
results = {}
|
| 255 |
-
|
| 256 |
-
for index, task in enumerate(tasks):
|
| 257 |
-
if task in results:
|
| 258 |
-
print(f"[progress] skipping completed task {index + 1}/{len(tasks)}: {task}")
|
| 259 |
-
continue
|
| 260 |
-
task_seed = seed + index
|
| 261 |
-
print(f"[progress] starting task {index + 1}/{len(tasks)}: {task}")
|
| 262 |
-
env = SoftmaxSurrogateEnvironment(
|
| 263 |
-
measurement_path=measurement_path,
|
| 264 |
-
budget=6,
|
| 265 |
-
seed=task_seed,
|
| 266 |
-
train_task_ids=train_task_ids,
|
| 267 |
-
)
|
| 268 |
-
reset_out = env.reset(task=task, seed=task_seed)
|
| 269 |
-
decision_start = time.perf_counter()
|
| 270 |
-
surrogate_config_id = _choose_surrogate_config_id(env, acquisition=acquisition, beta=beta, xi=xi)
|
| 271 |
-
decision_ms = float((time.perf_counter() - decision_start) * 1000.0)
|
| 272 |
-
surrogate_config = env.config_info(surrogate_config_id)
|
| 273 |
-
sample = torch.randn((env._task_rows[0].m, env._task_rows[0].n), device="cuda", dtype=torch.float16)
|
| 274 |
-
|
| 275 |
-
surrogate_first_call_ms = _compile_plus_first_call_fixed(sample, surrogate_config)
|
| 276 |
-
surrogate_steady_ms = _steady_state_fixed(sample, surrogate_config, repeats=repeats, warmup=warmup)
|
| 277 |
-
autotune_metrics = _benchmark_triton_autotune(sample, repeats=repeats, warmup=warmup)
|
| 278 |
-
oracle_best = env.oracle_best()
|
| 279 |
-
|
| 280 |
-
results[task] = {
|
| 281 |
-
"seeded_config_ids": reset_out["observation"]["tried_config_ids"],
|
| 282 |
-
"train_task_count": len(train_task_ids),
|
| 283 |
-
"oracle_best_ms": oracle_best["median_ms"],
|
| 284 |
-
"surrogate": {
|
| 285 |
-
"config": surrogate_config,
|
| 286 |
-
"decision_ms": decision_ms,
|
| 287 |
-
"compile_plus_first_call_ms": surrogate_first_call_ms,
|
| 288 |
-
"steady_ms": surrogate_steady_ms,
|
| 289 |
-
"regret_vs_oracle": float(surrogate_steady_ms / oracle_best["median_ms"] - 1.0),
|
| 290 |
-
},
|
| 291 |
-
"triton_autotune": {
|
| 292 |
-
**autotune_metrics,
|
| 293 |
-
"regret_vs_oracle": float(autotune_metrics["autotune_steady_ms"] / oracle_best["median_ms"] - 1.0),
|
| 294 |
-
},
|
| 295 |
-
}
|
| 296 |
-
print(
|
| 297 |
-
"[progress] finished"
|
| 298 |
-
f" task={task}"
|
| 299 |
-
f" surrogate_first_ms={results[task]['surrogate']['compile_plus_first_call_ms']:.3f}"
|
| 300 |
-
f" surrogate_steady_ms={results[task]['surrogate']['steady_ms']:.3f}"
|
| 301 |
-
f" autotune_first_ms={results[task]['triton_autotune']['autotune_first_call_ms']:.3f}"
|
| 302 |
-
f" autotune_steady_ms={results[task]['triton_autotune']['autotune_steady_ms']:.3f}"
|
| 303 |
-
)
|
| 304 |
-
snapshot = _build_summary(
|
| 305 |
-
measurement_path=measurement_path,
|
| 306 |
-
tasks=tasks,
|
| 307 |
-
acquisition=acquisition,
|
| 308 |
-
beta=beta,
|
| 309 |
-
xi=xi,
|
| 310 |
-
results=results,
|
| 311 |
-
)
|
| 312 |
-
output_path.parent.mkdir(parents=True, exist_ok=True)
|
| 313 |
-
output_path.write_text(json.dumps(snapshot, indent=2), encoding="utf-8")
|
| 314 |
-
|
| 315 |
-
return _build_summary(
|
| 316 |
-
measurement_path=measurement_path,
|
| 317 |
-
tasks=tasks,
|
| 318 |
-
acquisition=acquisition,
|
| 319 |
-
beta=beta,
|
| 320 |
-
xi=xi,
|
| 321 |
-
results=results,
|
| 322 |
-
)
|
| 323 |
-
|
| 324 |
-
|
| 325 |
-
def parse_args() -> argparse.Namespace:
|
| 326 |
-
parser = argparse.ArgumentParser(description="Quick benchmark: surrogate-guided softmax config choice vs Triton autotune.")
|
| 327 |
-
parser.add_argument("--measurement-path", default="data/autotune_measurements.csv")
|
| 328 |
-
parser.add_argument("--tasks", nargs="+", default=DEFAULT_TASKS)
|
| 329 |
-
parser.add_argument("--repeats", type=int, default=50)
|
| 330 |
-
parser.add_argument("--warmup", type=int, default=10)
|
| 331 |
-
parser.add_argument("--seed", type=int, default=7)
|
| 332 |
-
parser.add_argument("--acquisition", choices=("mean", "ucb", "ei"), default="ucb")
|
| 333 |
-
parser.add_argument("--beta", type=float, default=2.0)
|
| 334 |
-
parser.add_argument("--xi", type=float, default=0.0)
|
| 335 |
-
parser.add_argument("--output", type=Path, default=Path("outputs/surrogate_vs_triton_autotune.json"))
|
| 336 |
-
parser.add_argument("--no-resume", action="store_true")
|
| 337 |
-
return parser.parse_args()
|
| 338 |
-
|
| 339 |
-
|
| 340 |
-
def main() -> None:
|
| 341 |
-
args = parse_args()
|
| 342 |
-
summary = run_benchmark(
|
| 343 |
-
measurement_path=args.measurement_path,
|
| 344 |
-
tasks=args.tasks,
|
| 345 |
-
repeats=args.repeats,
|
| 346 |
-
warmup=args.warmup,
|
| 347 |
-
seed=args.seed,
|
| 348 |
-
acquisition=args.acquisition,
|
| 349 |
-
beta=args.beta,
|
| 350 |
-
xi=args.xi,
|
| 351 |
-
output_path=args.output,
|
| 352 |
-
resume=not args.no_resume,
|
| 353 |
-
)
|
| 354 |
-
args.output.parent.mkdir(parents=True, exist_ok=True)
|
| 355 |
-
with args.output.open("w", encoding="utf-8") as handle:
|
| 356 |
-
json.dump(summary, handle, indent=2)
|
| 357 |
-
print(json.dumps(summary, indent=2))
|
| 358 |
-
|
| 359 |
-
|
| 360 |
-
if __name__ == "__main__":
|
| 361 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
scripts/build_benchmark_splits.py
DELETED
|
@@ -1,81 +0,0 @@
|
|
| 1 |
-
#!/usr/bin/env python3
|
| 2 |
-
from __future__ import annotations
|
| 3 |
-
|
| 4 |
-
import argparse
|
| 5 |
-
import csv
|
| 6 |
-
import json
|
| 7 |
-
from collections import defaultdict
|
| 8 |
-
from pathlib import Path
|
| 9 |
-
from typing import Dict, List
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
def _load_tasks(path: Path) -> Dict[str, List[dict]]:
|
| 13 |
-
grouped: Dict[str, List[dict]] = defaultdict(list)
|
| 14 |
-
with path.open("r", newline="", encoding="utf-8") as handle:
|
| 15 |
-
for row in csv.DictReader(handle):
|
| 16 |
-
grouped[row["task_id"]].append(row)
|
| 17 |
-
return grouped
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
def _task_n(task_id: str) -> int:
|
| 21 |
-
return int(task_id.split("_n")[-1])
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
def build_splits(grouped: Dict[str, List[dict]], heldout_family: str | None) -> Dict[str, object]:
|
| 25 |
-
tasks_by_family: Dict[str, List[str]] = defaultdict(list)
|
| 26 |
-
for task_id, rows in grouped.items():
|
| 27 |
-
tasks_by_family[rows[0]["family"]].append(task_id)
|
| 28 |
-
|
| 29 |
-
for family in tasks_by_family:
|
| 30 |
-
tasks_by_family[family].sort(key=_task_n)
|
| 31 |
-
|
| 32 |
-
families = sorted(tasks_by_family.keys())
|
| 33 |
-
if not families:
|
| 34 |
-
raise RuntimeError("No tasks found in measurement file.")
|
| 35 |
-
|
| 36 |
-
shape_train: List[str] = []
|
| 37 |
-
shape_test: List[str] = []
|
| 38 |
-
for family, tasks in tasks_by_family.items():
|
| 39 |
-
holdout_count = 2 if len(tasks) >= 4 else 1
|
| 40 |
-
split_idx = max(1, len(tasks) - holdout_count)
|
| 41 |
-
shape_train.extend(tasks[:split_idx])
|
| 42 |
-
shape_test.extend(tasks[split_idx:])
|
| 43 |
-
|
| 44 |
-
if heldout_family is None:
|
| 45 |
-
heldout_family = families[-1]
|
| 46 |
-
if heldout_family not in tasks_by_family:
|
| 47 |
-
raise ValueError(f"Held-out family {heldout_family} is not present.")
|
| 48 |
-
|
| 49 |
-
family_train = [task_id for family, tasks in tasks_by_family.items() if family != heldout_family for task_id in tasks]
|
| 50 |
-
family_test = list(tasks_by_family[heldout_family])
|
| 51 |
-
|
| 52 |
-
return {
|
| 53 |
-
"families_present": families,
|
| 54 |
-
"shape_generalization": {
|
| 55 |
-
"train_tasks": sorted(shape_train),
|
| 56 |
-
"test_tasks": sorted(shape_test),
|
| 57 |
-
},
|
| 58 |
-
"family_holdout": {
|
| 59 |
-
"heldout_family": heldout_family,
|
| 60 |
-
"train_tasks": sorted(family_train),
|
| 61 |
-
"test_tasks": sorted(family_test),
|
| 62 |
-
},
|
| 63 |
-
}
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
def main() -> None:
|
| 67 |
-
parser = argparse.ArgumentParser(description="Build train/test split manifests for the multi-family benchmark.")
|
| 68 |
-
parser.add_argument("--measurement-path", type=Path, default=Path("data/autotune_measurements.csv"))
|
| 69 |
-
parser.add_argument("--output", type=Path, default=Path("data/benchmark_splits.json"))
|
| 70 |
-
parser.add_argument("--heldout-family", type=str, default=None)
|
| 71 |
-
args = parser.parse_args()
|
| 72 |
-
|
| 73 |
-
splits = build_splits(_load_tasks(args.measurement_path), args.heldout_family)
|
| 74 |
-
args.output.parent.mkdir(parents=True, exist_ok=True)
|
| 75 |
-
with args.output.open("w", encoding="utf-8") as handle:
|
| 76 |
-
json.dump(splits, handle, indent=2)
|
| 77 |
-
print(json.dumps(splits, indent=2))
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
if __name__ == "__main__":
|
| 81 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
scripts/build_qwen_05b_splits.py
DELETED
|
@@ -1,56 +0,0 @@
|
|
| 1 |
-
#!/usr/bin/env python3
|
| 2 |
-
from __future__ import annotations
|
| 3 |
-
|
| 4 |
-
import argparse
|
| 5 |
-
import json
|
| 6 |
-
from pathlib import Path
|
| 7 |
-
|
| 8 |
-
from scripts.qwen_05b_spec import DECODE_CTX_LENS, PREFILL_SEQ_LENS, qwen_05b_tasks
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
def build_splits() -> dict:
|
| 12 |
-
tasks = qwen_05b_tasks()
|
| 13 |
-
long_prefill = max(PREFILL_SEQ_LENS)
|
| 14 |
-
long_decode = max(DECODE_CTX_LENS)
|
| 15 |
-
|
| 16 |
-
shape_train = []
|
| 17 |
-
shape_test = []
|
| 18 |
-
for task in tasks:
|
| 19 |
-
if task.mode == "prefill" and task.seq_len == long_prefill:
|
| 20 |
-
shape_test.append(task.task_id)
|
| 21 |
-
elif task.mode == "decode" and task.ctx_len == long_decode:
|
| 22 |
-
shape_test.append(task.task_id)
|
| 23 |
-
else:
|
| 24 |
-
shape_train.append(task.task_id)
|
| 25 |
-
|
| 26 |
-
family_holdout_train = [task.task_id for task in tasks if task.family != "gemm"]
|
| 27 |
-
family_holdout_test = [task.task_id for task in tasks if task.family == "gemm"]
|
| 28 |
-
|
| 29 |
-
return {
|
| 30 |
-
"model_id": "Qwen/Qwen2.5-0.5B",
|
| 31 |
-
"shape_generalization": {
|
| 32 |
-
"train_tasks": sorted(shape_train),
|
| 33 |
-
"test_tasks": sorted(shape_test),
|
| 34 |
-
},
|
| 35 |
-
"family_holdout": {
|
| 36 |
-
"heldout_family": "gemm",
|
| 37 |
-
"train_tasks": sorted(family_holdout_train),
|
| 38 |
-
"test_tasks": sorted(family_holdout_test),
|
| 39 |
-
},
|
| 40 |
-
}
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
def main() -> None:
|
| 44 |
-
parser = argparse.ArgumentParser(description="Build train/test split manifests for Qwen2.5-0.5B kernel tuning.")
|
| 45 |
-
parser.add_argument("--output", type=Path, default=Path("data/qwen_05b_splits.json"))
|
| 46 |
-
args = parser.parse_args()
|
| 47 |
-
|
| 48 |
-
splits = build_splits()
|
| 49 |
-
args.output.parent.mkdir(parents=True, exist_ok=True)
|
| 50 |
-
with args.output.open("w", encoding="utf-8") as handle:
|
| 51 |
-
json.dump(splits, handle, indent=2)
|
| 52 |
-
print(json.dumps(splits, indent=2))
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
if __name__ == "__main__":
|
| 56 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
scripts/check_torch_cuda.py
DELETED
|
@@ -1,18 +0,0 @@
|
|
| 1 |
-
#!/usr/bin/env python3
|
| 2 |
-
from __future__ import annotations
|
| 3 |
-
|
| 4 |
-
import torch
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
def main() -> None:
|
| 8 |
-
print(f"python: {__import__('sys').executable}")
|
| 9 |
-
print(f"torch: {torch.__version__}")
|
| 10 |
-
print(f"cuda_available: {torch.cuda.is_available()}")
|
| 11 |
-
if torch.cuda.is_available():
|
| 12 |
-
print(f"cuda_device_name: {torch.cuda.get_device_name(0)}")
|
| 13 |
-
print(f"cuda_capability: {torch.cuda.get_device_capability(0)}")
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
if __name__ == "__main__":
|
| 17 |
-
main()
|
| 18 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
scripts/collect_measurements.py
DELETED
|
@@ -1,398 +0,0 @@
|
|
| 1 |
-
#!/usr/bin/env python3
|
| 2 |
-
from __future__ import annotations
|
| 3 |
-
|
| 4 |
-
import argparse
|
| 5 |
-
import csv
|
| 6 |
-
import math
|
| 7 |
-
import statistics
|
| 8 |
-
from dataclasses import dataclass
|
| 9 |
-
from pathlib import Path
|
| 10 |
-
from typing import Iterable, List, Dict, Optional
|
| 11 |
-
|
| 12 |
-
import numpy as np
|
| 13 |
-
import torch
|
| 14 |
-
|
| 15 |
-
try:
|
| 16 |
-
import triton
|
| 17 |
-
import triton.language as tl
|
| 18 |
-
import triton.testing as ttesting
|
| 19 |
-
except Exception as err: # pragma: no cover
|
| 20 |
-
raise RuntimeError(
|
| 21 |
-
"Triton is required for this measurement collector. Install Triton and rerun."
|
| 22 |
-
) from err
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
BLOCK_SIZES = (256, 512, 1024, 2048, 4096, 8192)
|
| 26 |
-
NUM_WARPS = (1, 2, 4, 8)
|
| 27 |
-
NUM_STAGES = (1, 2, 3, 4)
|
| 28 |
-
N_VALUES = (256, 512, 1024, 1536, 2048, 3072, 4096, 6144, 8192)
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
@dataclass(frozen=True)
|
| 32 |
-
class MeasurementRow:
|
| 33 |
-
task_id: str
|
| 34 |
-
m: int
|
| 35 |
-
n: int
|
| 36 |
-
block_size: int
|
| 37 |
-
num_warps: int
|
| 38 |
-
num_stages: int
|
| 39 |
-
median_ms: float
|
| 40 |
-
effective_gbps: float
|
| 41 |
-
score: float
|
| 42 |
-
validation_error: float
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
@triton.jit
|
| 46 |
-
def fused_rowwise_softmax_kernel(
|
| 47 |
-
X_ptr,
|
| 48 |
-
Y_ptr,
|
| 49 |
-
stride_xm,
|
| 50 |
-
stride_xn,
|
| 51 |
-
stride_ym,
|
| 52 |
-
stride_yn,
|
| 53 |
-
n_cols,
|
| 54 |
-
BLOCK_SIZE: tl.constexpr,
|
| 55 |
-
):
|
| 56 |
-
row_idx = tl.program_id(0)
|
| 57 |
-
col_offsets = tl.arange(0, BLOCK_SIZE)
|
| 58 |
-
x_ptr = X_ptr + row_idx * stride_xm + col_offsets
|
| 59 |
-
y_ptr = Y_ptr + row_idx * stride_ym + col_offsets
|
| 60 |
-
mask = col_offsets < n_cols
|
| 61 |
-
|
| 62 |
-
x = tl.load(x_ptr, mask=mask, other=-float("inf"))
|
| 63 |
-
x = x - tl.max(x, axis=0)
|
| 64 |
-
numerator = tl.exp(x)
|
| 65 |
-
denominator = tl.sum(numerator, axis=0)
|
| 66 |
-
y = numerator / denominator
|
| 67 |
-
tl.store(y_ptr, y, mask=mask)
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
def _task_id(m: int, n: int) -> str:
|
| 71 |
-
return f"softmax_m{m}_n{n}"
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
def _effective_gbps(sample: torch.Tensor, median_ms: float) -> float:
|
| 75 |
-
bytes_processed = float(sample.numel() * sample.element_size() * 2)
|
| 76 |
-
if median_ms <= 0:
|
| 77 |
-
return 0.0
|
| 78 |
-
return bytes_processed / (median_ms / 1000.0) / 1e9
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
def _benchmark_config(
|
| 82 |
-
sample: torch.Tensor,
|
| 83 |
-
block_size: int,
|
| 84 |
-
num_warps: int,
|
| 85 |
-
num_stages: int,
|
| 86 |
-
repeats: int,
|
| 87 |
-
warmup: int,
|
| 88 |
-
) -> float:
|
| 89 |
-
output = torch.empty_like(sample)
|
| 90 |
-
m, n = sample.shape
|
| 91 |
-
grid = (m,)
|
| 92 |
-
|
| 93 |
-
def launch() -> None:
|
| 94 |
-
fused_rowwise_softmax_kernel[grid](
|
| 95 |
-
sample,
|
| 96 |
-
output,
|
| 97 |
-
sample.stride(0),
|
| 98 |
-
sample.stride(1),
|
| 99 |
-
output.stride(0),
|
| 100 |
-
output.stride(1),
|
| 101 |
-
n,
|
| 102 |
-
BLOCK_SIZE=block_size,
|
| 103 |
-
num_warps=num_warps,
|
| 104 |
-
num_stages=num_stages,
|
| 105 |
-
)
|
| 106 |
-
|
| 107 |
-
# warmup to compile and stabilize caches / clocks.
|
| 108 |
-
for _ in range(max(1, warmup)):
|
| 109 |
-
launch()
|
| 110 |
-
torch.cuda.synchronize()
|
| 111 |
-
|
| 112 |
-
if ttesting is not None:
|
| 113 |
-
result = ttesting.do_bench(
|
| 114 |
-
launch,
|
| 115 |
-
warmup=0,
|
| 116 |
-
rep=repeats,
|
| 117 |
-
quantiles=[0.5],
|
| 118 |
-
return_mode="median",
|
| 119 |
-
)
|
| 120 |
-
if isinstance(result, dict):
|
| 121 |
-
return float(result.get("median", 0.0))
|
| 122 |
-
return float(result)
|
| 123 |
-
|
| 124 |
-
start = torch.cuda.Event(enable_timing=True)
|
| 125 |
-
end = torch.cuda.Event(enable_timing=True)
|
| 126 |
-
durations_ms: List[float] = []
|
| 127 |
-
for _ in range(max(1, repeats)):
|
| 128 |
-
torch.cuda.synchronize()
|
| 129 |
-
start.record()
|
| 130 |
-
launch()
|
| 131 |
-
end.record()
|
| 132 |
-
end.synchronize()
|
| 133 |
-
durations_ms.append(start.elapsed_time(end))
|
| 134 |
-
return float(statistics.median(durations_ms))
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
def _validate_correctness(sample: torch.Tensor, block_size: int, num_warps: int, num_stages: int) -> float:
|
| 138 |
-
ref = torch.softmax(sample.float(), dim=-1).to(sample.dtype)
|
| 139 |
-
out = torch.empty_like(sample)
|
| 140 |
-
|
| 141 |
-
fused_rowwise_softmax_kernel[(sample.shape[0],)](
|
| 142 |
-
sample,
|
| 143 |
-
out,
|
| 144 |
-
sample.stride(0),
|
| 145 |
-
sample.stride(1),
|
| 146 |
-
out.stride(0),
|
| 147 |
-
out.stride(1),
|
| 148 |
-
sample.shape[1],
|
| 149 |
-
BLOCK_SIZE=block_size,
|
| 150 |
-
num_warps=num_warps,
|
| 151 |
-
num_stages=num_stages,
|
| 152 |
-
)
|
| 153 |
-
torch.cuda.synchronize()
|
| 154 |
-
|
| 155 |
-
err = (out.to(torch.float32) - ref).abs().max().item()
|
| 156 |
-
return float(err)
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
def benchmark_single_config(
|
| 160 |
-
n: int,
|
| 161 |
-
block_size: int,
|
| 162 |
-
num_warps: int,
|
| 163 |
-
num_stages: int,
|
| 164 |
-
m: int = 4096,
|
| 165 |
-
repeats: int = 200,
|
| 166 |
-
warmup: int = 25,
|
| 167 |
-
seed: int = 0,
|
| 168 |
-
) -> MeasurementRow:
|
| 169 |
-
if block_size < n:
|
| 170 |
-
raise ValueError(f"Invalid config: BLOCK_SIZE {block_size} < N {n}")
|
| 171 |
-
|
| 172 |
-
torch.manual_seed(seed)
|
| 173 |
-
sample = torch.randn((m, n), device="cuda", dtype=torch.float16)
|
| 174 |
-
|
| 175 |
-
if num_warps not in NUM_WARPS:
|
| 176 |
-
raise ValueError(f"Unsupported num_warps={num_warps}")
|
| 177 |
-
if num_stages not in NUM_STAGES:
|
| 178 |
-
raise ValueError(f"Unsupported num_stages={num_stages}")
|
| 179 |
-
if block_size not in BLOCK_SIZES:
|
| 180 |
-
raise ValueError(f"Unsupported BLOCK_SIZE={block_size}")
|
| 181 |
-
|
| 182 |
-
median_ms = _benchmark_config(
|
| 183 |
-
sample=sample,
|
| 184 |
-
block_size=block_size,
|
| 185 |
-
num_warps=num_warps,
|
| 186 |
-
num_stages=num_stages,
|
| 187 |
-
repeats=repeats,
|
| 188 |
-
warmup=warmup,
|
| 189 |
-
)
|
| 190 |
-
|
| 191 |
-
val_err = _validate_correctness(
|
| 192 |
-
sample=sample,
|
| 193 |
-
block_size=block_size,
|
| 194 |
-
num_warps=num_warps,
|
| 195 |
-
num_stages=num_stages,
|
| 196 |
-
)
|
| 197 |
-
|
| 198 |
-
gbps = _effective_gbps(sample, median_ms)
|
| 199 |
-
score = -math.log(max(median_ms, np.finfo(float).tiny))
|
| 200 |
-
return MeasurementRow(
|
| 201 |
-
task_id=_task_id(m, n),
|
| 202 |
-
m=m,
|
| 203 |
-
n=n,
|
| 204 |
-
block_size=block_size,
|
| 205 |
-
num_warps=num_warps,
|
| 206 |
-
num_stages=num_stages,
|
| 207 |
-
median_ms=float(median_ms),
|
| 208 |
-
effective_gbps=float(gbps),
|
| 209 |
-
score=float(score),
|
| 210 |
-
validation_error=float(val_err),
|
| 211 |
-
)
|
| 212 |
-
|
| 213 |
-
|
| 214 |
-
def collect_measurements(
|
| 215 |
-
output_path: Path,
|
| 216 |
-
n_values: Iterable[int],
|
| 217 |
-
repeats: int,
|
| 218 |
-
warmup: int,
|
| 219 |
-
seed: int,
|
| 220 |
-
m: int = 4096,
|
| 221 |
-
append: bool = True,
|
| 222 |
-
) -> List[MeasurementRow]:
|
| 223 |
-
if not torch.cuda.is_available():
|
| 224 |
-
raise RuntimeError("CUDA is required to run measurements on H100.")
|
| 225 |
-
if not torch.cuda.get_device_name(0):
|
| 226 |
-
raise RuntimeError("No CUDA device found.")
|
| 227 |
-
|
| 228 |
-
output_path.parent.mkdir(parents=True, exist_ok=True)
|
| 229 |
-
|
| 230 |
-
existing: Dict[str, MeasurementRow] = {}
|
| 231 |
-
if output_path.exists():
|
| 232 |
-
with output_path.open("r", newline="", encoding="utf-8") as f:
|
| 233 |
-
for row in csv.DictReader(f):
|
| 234 |
-
key = _row_key(
|
| 235 |
-
int(row["m"]),
|
| 236 |
-
int(row["n"]),
|
| 237 |
-
int(row["block_size"]),
|
| 238 |
-
int(row["num_warps"]),
|
| 239 |
-
int(row["num_stages"]),
|
| 240 |
-
)
|
| 241 |
-
existing[key] = MeasurementRow(
|
| 242 |
-
task_id=row["task_id"],
|
| 243 |
-
m=int(row["m"]),
|
| 244 |
-
n=int(row["n"]),
|
| 245 |
-
block_size=int(row["block_size"]),
|
| 246 |
-
num_warps=int(row["num_warps"]),
|
| 247 |
-
num_stages=int(row["num_stages"]),
|
| 248 |
-
median_ms=float(row["median_ms"]),
|
| 249 |
-
effective_gbps=float(row["effective_gbps"]),
|
| 250 |
-
score=float(row["score"]),
|
| 251 |
-
validation_error=float(row["validation_error"]),
|
| 252 |
-
)
|
| 253 |
-
|
| 254 |
-
results: List[MeasurementRow] = []
|
| 255 |
-
for n in n_values:
|
| 256 |
-
if n < 0:
|
| 257 |
-
raise ValueError(f"Invalid n value: {n}")
|
| 258 |
-
for block in BLOCK_SIZES:
|
| 259 |
-
if block < n:
|
| 260 |
-
continue
|
| 261 |
-
for num_warps in NUM_WARPS:
|
| 262 |
-
for num_stages in NUM_STAGES:
|
| 263 |
-
key = _row_key(m, n, block, num_warps, num_stages)
|
| 264 |
-
if append and key in existing:
|
| 265 |
-
continue
|
| 266 |
-
|
| 267 |
-
torch.cuda.synchronize()
|
| 268 |
-
row = benchmark_single_config(
|
| 269 |
-
n=n,
|
| 270 |
-
block_size=block,
|
| 271 |
-
num_warps=num_warps,
|
| 272 |
-
num_stages=num_stages,
|
| 273 |
-
m=m,
|
| 274 |
-
repeats=repeats,
|
| 275 |
-
warmup=warmup,
|
| 276 |
-
seed=seed,
|
| 277 |
-
)
|
| 278 |
-
results.append(row)
|
| 279 |
-
|
| 280 |
-
if results:
|
| 281 |
-
write_mode = "a" if output_path.exists() and append else "w"
|
| 282 |
-
with output_path.open(write_mode, newline="", encoding="utf-8") as f:
|
| 283 |
-
writer = csv.writer(f)
|
| 284 |
-
if write_mode == "w" or (output_path.stat().st_size == 0):
|
| 285 |
-
writer.writerow(
|
| 286 |
-
[
|
| 287 |
-
"task_id",
|
| 288 |
-
"m",
|
| 289 |
-
"n",
|
| 290 |
-
"block_size",
|
| 291 |
-
"num_warps",
|
| 292 |
-
"num_stages",
|
| 293 |
-
"median_ms",
|
| 294 |
-
"effective_gbps",
|
| 295 |
-
"score",
|
| 296 |
-
"validation_error",
|
| 297 |
-
]
|
| 298 |
-
)
|
| 299 |
-
for r in results:
|
| 300 |
-
writer.writerow(
|
| 301 |
-
[
|
| 302 |
-
r.task_id,
|
| 303 |
-
r.m,
|
| 304 |
-
r.n,
|
| 305 |
-
r.block_size,
|
| 306 |
-
r.num_warps,
|
| 307 |
-
r.num_stages,
|
| 308 |
-
r.median_ms,
|
| 309 |
-
r.effective_gbps,
|
| 310 |
-
r.score,
|
| 311 |
-
r.validation_error,
|
| 312 |
-
]
|
| 313 |
-
)
|
| 314 |
-
return results
|
| 315 |
-
|
| 316 |
-
|
| 317 |
-
def _row_key(m: int, n: int, block_size: int, num_warps: int, num_stages: int) -> str:
|
| 318 |
-
return f"{m}|{n}|{block_size}|{num_warps}|{num_stages}"
|
| 319 |
-
|
| 320 |
-
|
| 321 |
-
def parse_args() -> argparse.Namespace:
|
| 322 |
-
parser = argparse.ArgumentParser(
|
| 323 |
-
description="Collect real H100 measurements for Triton row-wise fused softmax."
|
| 324 |
-
)
|
| 325 |
-
parser.add_argument(
|
| 326 |
-
"--output",
|
| 327 |
-
type=Path,
|
| 328 |
-
default=Path("data/softmax_measurements.csv"),
|
| 329 |
-
help="CSV output path.",
|
| 330 |
-
)
|
| 331 |
-
parser.add_argument(
|
| 332 |
-
"--n-cols",
|
| 333 |
-
type=int,
|
| 334 |
-
nargs="+",
|
| 335 |
-
default=N_VALUES,
|
| 336 |
-
help="Softmax inner dimension N values to benchmark.",
|
| 337 |
-
)
|
| 338 |
-
parser.add_argument("--m", type=int, default=4096, help="Outer dimension M.")
|
| 339 |
-
parser.add_argument("--repeats", type=int, default=200, help="Benchmark repeats.")
|
| 340 |
-
parser.add_argument("--warmup", type=int, default=25, help="Benchmark warmup runs.")
|
| 341 |
-
parser.add_argument("--seed", type=int, default=0, help="Torch/random seed.")
|
| 342 |
-
parser.add_argument(
|
| 343 |
-
"--single-run",
|
| 344 |
-
action="store_true",
|
| 345 |
-
help="Run one specific config and print JSON-like output.",
|
| 346 |
-
)
|
| 347 |
-
parser.add_argument("--block-size", type=int, default=1024)
|
| 348 |
-
parser.add_argument("--num-warps", type=int, default=4)
|
| 349 |
-
parser.add_argument("--num-stages", type=int, default=2)
|
| 350 |
-
parser.add_argument(
|
| 351 |
-
"--append",
|
| 352 |
-
action="store_true",
|
| 353 |
-
help="Append to existing CSV file (default). If false, overwrite.",
|
| 354 |
-
)
|
| 355 |
-
parser.add_argument(
|
| 356 |
-
"--overwrite",
|
| 357 |
-
action="store_true",
|
| 358 |
-
help="Overwrite existing CSV data instead of appending.",
|
| 359 |
-
)
|
| 360 |
-
return parser.parse_args()
|
| 361 |
-
|
| 362 |
-
|
| 363 |
-
def main() -> None:
|
| 364 |
-
args = parse_args()
|
| 365 |
-
|
| 366 |
-
if not torch.cuda.is_available():
|
| 367 |
-
raise RuntimeError("CUDA is not available.")
|
| 368 |
-
if not torch.cuda.get_device_name(0):
|
| 369 |
-
raise RuntimeError("No CUDA device was found.")
|
| 370 |
-
|
| 371 |
-
append = args.append and not args.overwrite
|
| 372 |
-
if args.single_run:
|
| 373 |
-
row = benchmark_single_config(
|
| 374 |
-
n=args.n_cols[0],
|
| 375 |
-
block_size=args.block_size,
|
| 376 |
-
num_warps=args.num_warps,
|
| 377 |
-
num_stages=args.num_stages,
|
| 378 |
-
m=args.m,
|
| 379 |
-
repeats=args.repeats,
|
| 380 |
-
warmup=args.warmup,
|
| 381 |
-
seed=args.seed,
|
| 382 |
-
)
|
| 383 |
-
print(row.__dict__)
|
| 384 |
-
return
|
| 385 |
-
|
| 386 |
-
collect_measurements(
|
| 387 |
-
output_path=args.output,
|
| 388 |
-
n_values=args.n_cols,
|
| 389 |
-
repeats=args.repeats,
|
| 390 |
-
warmup=args.warmup,
|
| 391 |
-
seed=args.seed,
|
| 392 |
-
m=args.m,
|
| 393 |
-
append=append,
|
| 394 |
-
)
|
| 395 |
-
|
| 396 |
-
|
| 397 |
-
if __name__ == "__main__":
|
| 398 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
scripts/collect_multifamily_measurements.py
DELETED
|
@@ -1,659 +0,0 @@
|
|
| 1 |
-
#!/usr/bin/env python3
|
| 2 |
-
from __future__ import annotations
|
| 3 |
-
|
| 4 |
-
import argparse
|
| 5 |
-
import csv
|
| 6 |
-
import json
|
| 7 |
-
import math
|
| 8 |
-
import sys
|
| 9 |
-
from dataclasses import dataclass
|
| 10 |
-
from pathlib import Path
|
| 11 |
-
from typing import Dict, Iterable, List, Sequence, Tuple
|
| 12 |
-
|
| 13 |
-
import numpy as np
|
| 14 |
-
import torch
|
| 15 |
-
|
| 16 |
-
ROOT = Path(__file__).resolve().parents[1]
|
| 17 |
-
if str(ROOT) not in sys.path:
|
| 18 |
-
sys.path.append(str(ROOT))
|
| 19 |
-
|
| 20 |
-
try:
|
| 21 |
-
import triton
|
| 22 |
-
import triton.language as tl
|
| 23 |
-
import triton.testing as ttesting
|
| 24 |
-
except Exception as err: # pragma: no cover
|
| 25 |
-
raise RuntimeError(
|
| 26 |
-
"Triton is required for the multi-family measurement collector."
|
| 27 |
-
) from err
|
| 28 |
-
|
| 29 |
-
from scripts.collect_measurements import (
|
| 30 |
-
BLOCK_SIZES,
|
| 31 |
-
NUM_STAGES,
|
| 32 |
-
NUM_WARPS,
|
| 33 |
-
N_VALUES,
|
| 34 |
-
benchmark_single_config as benchmark_softmax_config,
|
| 35 |
-
)
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
EPS = 1e-5
|
| 39 |
-
GROUPED_GEMM_GROUP_COUNT = 4
|
| 40 |
-
GROUPED_GEMM_K = 512
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
@dataclass(frozen=True)
|
| 44 |
-
class MultiFamilyMeasurementRow:
|
| 45 |
-
family_group: str
|
| 46 |
-
family: str
|
| 47 |
-
task_id: str
|
| 48 |
-
m: int
|
| 49 |
-
n: int
|
| 50 |
-
config_id: int
|
| 51 |
-
block_size: int
|
| 52 |
-
num_warps: int
|
| 53 |
-
num_stages: int
|
| 54 |
-
shape_json: str
|
| 55 |
-
config_json: str
|
| 56 |
-
median_ms: float
|
| 57 |
-
effective_gbps: float
|
| 58 |
-
score: float
|
| 59 |
-
validation_error: float
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
@triton.jit
|
| 63 |
-
def fused_rowwise_layernorm_kernel(
|
| 64 |
-
X_ptr,
|
| 65 |
-
Y_ptr,
|
| 66 |
-
stride_xm,
|
| 67 |
-
stride_xn,
|
| 68 |
-
stride_ym,
|
| 69 |
-
stride_yn,
|
| 70 |
-
n_cols,
|
| 71 |
-
eps,
|
| 72 |
-
BLOCK_SIZE: tl.constexpr,
|
| 73 |
-
):
|
| 74 |
-
row_idx = tl.program_id(0)
|
| 75 |
-
col_offsets = tl.arange(0, BLOCK_SIZE)
|
| 76 |
-
x_ptr = X_ptr + row_idx * stride_xm + col_offsets
|
| 77 |
-
y_ptr = Y_ptr + row_idx * stride_ym + col_offsets
|
| 78 |
-
mask = col_offsets < n_cols
|
| 79 |
-
|
| 80 |
-
x = tl.load(x_ptr, mask=mask, other=0.0).to(tl.float32)
|
| 81 |
-
mean = tl.sum(x, axis=0) / n_cols
|
| 82 |
-
x_centered = x - mean
|
| 83 |
-
var = tl.sum(x_centered * x_centered, axis=0) / n_cols
|
| 84 |
-
inv_std = tl.rsqrt(var + eps)
|
| 85 |
-
y = x_centered * inv_std
|
| 86 |
-
tl.store(y_ptr, y.to(tl.float16), mask=mask)
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
@triton.jit
|
| 90 |
-
def matmul_kernel(
|
| 91 |
-
A_ptr,
|
| 92 |
-
B_ptr,
|
| 93 |
-
C_ptr,
|
| 94 |
-
M,
|
| 95 |
-
N,
|
| 96 |
-
K,
|
| 97 |
-
stride_am,
|
| 98 |
-
stride_ak,
|
| 99 |
-
stride_bk,
|
| 100 |
-
stride_bn,
|
| 101 |
-
stride_cm,
|
| 102 |
-
stride_cn,
|
| 103 |
-
BLOCK_M: tl.constexpr,
|
| 104 |
-
BLOCK_N: tl.constexpr,
|
| 105 |
-
BLOCK_K: tl.constexpr,
|
| 106 |
-
):
|
| 107 |
-
pid = tl.program_id(axis=0)
|
| 108 |
-
num_pid_m = tl.cdiv(M, BLOCK_M)
|
| 109 |
-
num_pid_n = tl.cdiv(N, BLOCK_N)
|
| 110 |
-
pid_m = pid // num_pid_n
|
| 111 |
-
pid_n = pid % num_pid_n
|
| 112 |
-
|
| 113 |
-
offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
| 114 |
-
offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
|
| 115 |
-
offs_k = tl.arange(0, BLOCK_K)
|
| 116 |
-
|
| 117 |
-
a_ptrs = A_ptr + offs_m[:, None] * stride_am + offs_k[None, :] * stride_ak
|
| 118 |
-
b_ptrs = B_ptr + offs_k[:, None] * stride_bk + offs_n[None, :] * stride_bn
|
| 119 |
-
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
|
| 120 |
-
for _ in range(0, tl.cdiv(K, BLOCK_K)):
|
| 121 |
-
a = tl.load(a_ptrs, mask=(offs_m[:, None] < M) & (offs_k[None, :] < K), other=0.0)
|
| 122 |
-
b = tl.load(b_ptrs, mask=(offs_k[:, None] < K) & (offs_n[None, :] < N), other=0.0)
|
| 123 |
-
acc += tl.dot(a, b)
|
| 124 |
-
a_ptrs += BLOCK_K * stride_ak
|
| 125 |
-
b_ptrs += BLOCK_K * stride_bk
|
| 126 |
-
offs_k += BLOCK_K
|
| 127 |
-
|
| 128 |
-
c = acc.to(tl.float16)
|
| 129 |
-
c_ptrs = C_ptr + offs_m[:, None] * stride_cm + offs_n[None, :] * stride_cn
|
| 130 |
-
c_mask = (offs_m[:, None] < M) & (offs_n[None, :] < N)
|
| 131 |
-
tl.store(c_ptrs, c, mask=c_mask)
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
def _task_id(family: str, m: int, n: int) -> str:
|
| 135 |
-
if family == "grouped_gemm":
|
| 136 |
-
return f"{family}_g{GROUPED_GEMM_GROUP_COUNT}_k{GROUPED_GEMM_K}_m{m}_n{n}"
|
| 137 |
-
return f"{family}_m{m}_n{n}"
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
def _valid_configs(family: str, n: int) -> List[Tuple[int, int, int]]:
|
| 141 |
-
configs: List[Tuple[int, int, int]] = []
|
| 142 |
-
if family == "grouped_gemm":
|
| 143 |
-
candidate_blocks = tuple(block for block in BLOCK_SIZES if block <= 1024)
|
| 144 |
-
else:
|
| 145 |
-
candidate_blocks = BLOCK_SIZES
|
| 146 |
-
for block_size in candidate_blocks:
|
| 147 |
-
if family != "grouped_gemm" and block_size < n:
|
| 148 |
-
continue
|
| 149 |
-
for num_warps in NUM_WARPS:
|
| 150 |
-
for num_stages in NUM_STAGES:
|
| 151 |
-
configs.append((block_size, num_warps, num_stages))
|
| 152 |
-
return configs
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
def _effective_gbps(sample: torch.Tensor, median_ms: float) -> float:
|
| 156 |
-
bytes_processed = float(sample.numel() * sample.element_size() * 2)
|
| 157 |
-
if median_ms <= 0:
|
| 158 |
-
return 0.0
|
| 159 |
-
return bytes_processed / (median_ms / 1000.0) / 1e9
|
| 160 |
-
|
| 161 |
-
|
| 162 |
-
def _benchmark_layernorm_config(
|
| 163 |
-
sample: torch.Tensor,
|
| 164 |
-
block_size: int,
|
| 165 |
-
num_warps: int,
|
| 166 |
-
num_stages: int,
|
| 167 |
-
repeats: int,
|
| 168 |
-
warmup: int,
|
| 169 |
-
) -> float:
|
| 170 |
-
output = torch.empty_like(sample)
|
| 171 |
-
grid = (sample.shape[0],)
|
| 172 |
-
|
| 173 |
-
def launch() -> None:
|
| 174 |
-
fused_rowwise_layernorm_kernel[grid](
|
| 175 |
-
sample,
|
| 176 |
-
output,
|
| 177 |
-
sample.stride(0),
|
| 178 |
-
sample.stride(1),
|
| 179 |
-
output.stride(0),
|
| 180 |
-
output.stride(1),
|
| 181 |
-
sample.shape[1],
|
| 182 |
-
EPS,
|
| 183 |
-
BLOCK_SIZE=block_size,
|
| 184 |
-
num_warps=num_warps,
|
| 185 |
-
num_stages=num_stages,
|
| 186 |
-
)
|
| 187 |
-
|
| 188 |
-
for _ in range(max(1, warmup)):
|
| 189 |
-
launch()
|
| 190 |
-
torch.cuda.synchronize()
|
| 191 |
-
|
| 192 |
-
if ttesting is not None:
|
| 193 |
-
result = ttesting.do_bench(
|
| 194 |
-
launch,
|
| 195 |
-
warmup=0,
|
| 196 |
-
rep=repeats,
|
| 197 |
-
quantiles=[0.5],
|
| 198 |
-
return_mode="median",
|
| 199 |
-
)
|
| 200 |
-
if isinstance(result, dict):
|
| 201 |
-
return float(result.get("median", 0.0))
|
| 202 |
-
return float(result)
|
| 203 |
-
|
| 204 |
-
events: List[float] = []
|
| 205 |
-
start = torch.cuda.Event(enable_timing=True)
|
| 206 |
-
end = torch.cuda.Event(enable_timing=True)
|
| 207 |
-
for _ in range(max(1, repeats)):
|
| 208 |
-
torch.cuda.synchronize()
|
| 209 |
-
start.record()
|
| 210 |
-
launch()
|
| 211 |
-
end.record()
|
| 212 |
-
end.synchronize()
|
| 213 |
-
events.append(start.elapsed_time(end))
|
| 214 |
-
return float(np.median(np.asarray(events, dtype=np.float32)))
|
| 215 |
-
|
| 216 |
-
|
| 217 |
-
def _validate_layernorm(sample: torch.Tensor, block_size: int, num_warps: int, num_stages: int) -> float:
|
| 218 |
-
ref = torch.nn.functional.layer_norm(sample.float(), (sample.shape[1],), eps=EPS).to(sample.dtype)
|
| 219 |
-
out = torch.empty_like(sample)
|
| 220 |
-
fused_rowwise_layernorm_kernel[(sample.shape[0],)](
|
| 221 |
-
sample,
|
| 222 |
-
out,
|
| 223 |
-
sample.stride(0),
|
| 224 |
-
sample.stride(1),
|
| 225 |
-
out.stride(0),
|
| 226 |
-
out.stride(1),
|
| 227 |
-
sample.shape[1],
|
| 228 |
-
EPS,
|
| 229 |
-
BLOCK_SIZE=block_size,
|
| 230 |
-
num_warps=num_warps,
|
| 231 |
-
num_stages=num_stages,
|
| 232 |
-
)
|
| 233 |
-
torch.cuda.synchronize()
|
| 234 |
-
return float((out.to(torch.float32) - ref).abs().max().item())
|
| 235 |
-
|
| 236 |
-
|
| 237 |
-
def benchmark_layernorm_config(
|
| 238 |
-
n: int,
|
| 239 |
-
block_size: int,
|
| 240 |
-
num_warps: int,
|
| 241 |
-
num_stages: int,
|
| 242 |
-
m: int = 4096,
|
| 243 |
-
repeats: int = 200,
|
| 244 |
-
warmup: int = 25,
|
| 245 |
-
seed: int = 0,
|
| 246 |
-
) -> MultiFamilyMeasurementRow:
|
| 247 |
-
if block_size < n:
|
| 248 |
-
raise ValueError(f"Invalid config: BLOCK_SIZE {block_size} < N {n}")
|
| 249 |
-
|
| 250 |
-
torch.manual_seed(seed)
|
| 251 |
-
sample = torch.randn((m, n), device="cuda", dtype=torch.float16)
|
| 252 |
-
median_ms = _benchmark_layernorm_config(
|
| 253 |
-
sample=sample,
|
| 254 |
-
block_size=block_size,
|
| 255 |
-
num_warps=num_warps,
|
| 256 |
-
num_stages=num_stages,
|
| 257 |
-
repeats=repeats,
|
| 258 |
-
warmup=warmup,
|
| 259 |
-
)
|
| 260 |
-
val_err = _validate_layernorm(
|
| 261 |
-
sample=sample,
|
| 262 |
-
block_size=block_size,
|
| 263 |
-
num_warps=num_warps,
|
| 264 |
-
num_stages=num_stages,
|
| 265 |
-
)
|
| 266 |
-
gbps = _effective_gbps(sample, median_ms)
|
| 267 |
-
return _make_row(
|
| 268 |
-
family="layernorm",
|
| 269 |
-
m=m,
|
| 270 |
-
n=n,
|
| 271 |
-
block_size=block_size,
|
| 272 |
-
num_warps=num_warps,
|
| 273 |
-
num_stages=num_stages,
|
| 274 |
-
median_ms=median_ms,
|
| 275 |
-
effective_gbps=gbps,
|
| 276 |
-
validation_error=val_err,
|
| 277 |
-
)
|
| 278 |
-
|
| 279 |
-
|
| 280 |
-
def _grouped_gemm_shapes(m: int, n: int) -> Tuple[int, int, int]:
|
| 281 |
-
group_m = max(64, m // GROUPED_GEMM_GROUP_COUNT)
|
| 282 |
-
return GROUPED_GEMM_GROUP_COUNT, group_m, GROUPED_GEMM_K
|
| 283 |
-
|
| 284 |
-
|
| 285 |
-
def _matmul_meta_from_block(block_size: int) -> Tuple[int, int, int]:
|
| 286 |
-
block_m = max(32, min(block_size // 4, 256))
|
| 287 |
-
block_n = max(32, min(block_size // 4, 256))
|
| 288 |
-
block_k = 32
|
| 289 |
-
return block_m, block_n, block_k
|
| 290 |
-
|
| 291 |
-
|
| 292 |
-
def _benchmark_grouped_gemm_config(
|
| 293 |
-
a_groups: Sequence[torch.Tensor],
|
| 294 |
-
b_groups: Sequence[torch.Tensor],
|
| 295 |
-
block_size: int,
|
| 296 |
-
num_warps: int,
|
| 297 |
-
num_stages: int,
|
| 298 |
-
repeats: int,
|
| 299 |
-
warmup: int,
|
| 300 |
-
) -> float:
|
| 301 |
-
c_groups = [torch.empty((a.shape[0], b.shape[1]), device=a.device, dtype=a.dtype) for a, b in zip(a_groups, b_groups)]
|
| 302 |
-
block_m, block_n, block_k = _matmul_meta_from_block(block_size)
|
| 303 |
-
|
| 304 |
-
def launch() -> None:
|
| 305 |
-
for a, b, c in zip(a_groups, b_groups, c_groups):
|
| 306 |
-
grid = (triton.cdiv(a.shape[0], block_m) * triton.cdiv(b.shape[1], block_n),)
|
| 307 |
-
matmul_kernel[grid](
|
| 308 |
-
a,
|
| 309 |
-
b,
|
| 310 |
-
c,
|
| 311 |
-
a.shape[0],
|
| 312 |
-
b.shape[1],
|
| 313 |
-
a.shape[1],
|
| 314 |
-
a.stride(0),
|
| 315 |
-
a.stride(1),
|
| 316 |
-
b.stride(0),
|
| 317 |
-
b.stride(1),
|
| 318 |
-
c.stride(0),
|
| 319 |
-
c.stride(1),
|
| 320 |
-
BLOCK_M=block_m,
|
| 321 |
-
BLOCK_N=block_n,
|
| 322 |
-
BLOCK_K=block_k,
|
| 323 |
-
num_warps=num_warps,
|
| 324 |
-
num_stages=num_stages,
|
| 325 |
-
)
|
| 326 |
-
|
| 327 |
-
for _ in range(max(1, warmup)):
|
| 328 |
-
launch()
|
| 329 |
-
torch.cuda.synchronize()
|
| 330 |
-
|
| 331 |
-
if ttesting is not None:
|
| 332 |
-
result = ttesting.do_bench(
|
| 333 |
-
launch,
|
| 334 |
-
warmup=0,
|
| 335 |
-
rep=repeats,
|
| 336 |
-
quantiles=[0.5],
|
| 337 |
-
return_mode="median",
|
| 338 |
-
)
|
| 339 |
-
if isinstance(result, dict):
|
| 340 |
-
return float(result.get("median", 0.0))
|
| 341 |
-
return float(result)
|
| 342 |
-
|
| 343 |
-
durations_ms: List[float] = []
|
| 344 |
-
start = torch.cuda.Event(enable_timing=True)
|
| 345 |
-
end = torch.cuda.Event(enable_timing=True)
|
| 346 |
-
for _ in range(max(1, repeats)):
|
| 347 |
-
torch.cuda.synchronize()
|
| 348 |
-
start.record()
|
| 349 |
-
launch()
|
| 350 |
-
end.record()
|
| 351 |
-
end.synchronize()
|
| 352 |
-
durations_ms.append(start.elapsed_time(end))
|
| 353 |
-
return float(np.median(np.asarray(durations_ms, dtype=np.float32)))
|
| 354 |
-
|
| 355 |
-
|
| 356 |
-
def _validate_grouped_gemm(
|
| 357 |
-
a_groups: Sequence[torch.Tensor],
|
| 358 |
-
b_groups: Sequence[torch.Tensor],
|
| 359 |
-
block_size: int,
|
| 360 |
-
num_warps: int,
|
| 361 |
-
num_stages: int,
|
| 362 |
-
) -> float:
|
| 363 |
-
c_groups = [torch.empty((a.shape[0], b.shape[1]), device=a.device, dtype=a.dtype) for a, b in zip(a_groups, b_groups)]
|
| 364 |
-
block_m, block_n, block_k = _matmul_meta_from_block(block_size)
|
| 365 |
-
for a, b, c in zip(a_groups, b_groups, c_groups):
|
| 366 |
-
grid = (triton.cdiv(a.shape[0], block_m) * triton.cdiv(b.shape[1], block_n),)
|
| 367 |
-
matmul_kernel[grid](
|
| 368 |
-
a,
|
| 369 |
-
b,
|
| 370 |
-
c,
|
| 371 |
-
a.shape[0],
|
| 372 |
-
b.shape[1],
|
| 373 |
-
a.shape[1],
|
| 374 |
-
a.stride(0),
|
| 375 |
-
a.stride(1),
|
| 376 |
-
b.stride(0),
|
| 377 |
-
b.stride(1),
|
| 378 |
-
c.stride(0),
|
| 379 |
-
c.stride(1),
|
| 380 |
-
BLOCK_M=block_m,
|
| 381 |
-
BLOCK_N=block_n,
|
| 382 |
-
BLOCK_K=block_k,
|
| 383 |
-
num_warps=num_warps,
|
| 384 |
-
num_stages=num_stages,
|
| 385 |
-
)
|
| 386 |
-
torch.cuda.synchronize()
|
| 387 |
-
|
| 388 |
-
max_err = 0.0
|
| 389 |
-
for a, b, c in zip(a_groups, b_groups, c_groups):
|
| 390 |
-
ref = torch.matmul(a.float(), b.float()).to(torch.float16)
|
| 391 |
-
max_err = max(max_err, float((c.float() - ref.float()).abs().max().item()))
|
| 392 |
-
return max_err
|
| 393 |
-
|
| 394 |
-
|
| 395 |
-
def benchmark_grouped_gemm_config(
|
| 396 |
-
n: int,
|
| 397 |
-
block_size: int,
|
| 398 |
-
num_warps: int,
|
| 399 |
-
num_stages: int,
|
| 400 |
-
m: int = 4096,
|
| 401 |
-
repeats: int = 200,
|
| 402 |
-
warmup: int = 25,
|
| 403 |
-
seed: int = 0,
|
| 404 |
-
) -> MultiFamilyMeasurementRow:
|
| 405 |
-
group_count, group_m, k_dim = _grouped_gemm_shapes(m, n)
|
| 406 |
-
torch.manual_seed(seed)
|
| 407 |
-
a_groups = [torch.randn((group_m, k_dim), device="cuda", dtype=torch.float16) for _ in range(group_count)]
|
| 408 |
-
b_groups = [torch.randn((k_dim, n), device="cuda", dtype=torch.float16) for _ in range(group_count)]
|
| 409 |
-
|
| 410 |
-
median_ms = _benchmark_grouped_gemm_config(
|
| 411 |
-
a_groups=a_groups,
|
| 412 |
-
b_groups=b_groups,
|
| 413 |
-
block_size=block_size,
|
| 414 |
-
num_warps=num_warps,
|
| 415 |
-
num_stages=num_stages,
|
| 416 |
-
repeats=repeats,
|
| 417 |
-
warmup=warmup,
|
| 418 |
-
)
|
| 419 |
-
val_err = _validate_grouped_gemm(
|
| 420 |
-
a_groups=a_groups,
|
| 421 |
-
b_groups=b_groups,
|
| 422 |
-
block_size=block_size,
|
| 423 |
-
num_warps=num_warps,
|
| 424 |
-
num_stages=num_stages,
|
| 425 |
-
)
|
| 426 |
-
bytes_processed = group_count * (
|
| 427 |
-
a_groups[0].numel() * a_groups[0].element_size()
|
| 428 |
-
+ b_groups[0].numel() * b_groups[0].element_size()
|
| 429 |
-
+ group_m * n * a_groups[0].element_size()
|
| 430 |
-
)
|
| 431 |
-
effective_gbps = bytes_processed / (median_ms / 1000.0) / 1e9 if median_ms > 0 else 0.0
|
| 432 |
-
return _make_row(
|
| 433 |
-
family="grouped_gemm",
|
| 434 |
-
m=m,
|
| 435 |
-
n=n,
|
| 436 |
-
block_size=block_size,
|
| 437 |
-
num_warps=num_warps,
|
| 438 |
-
num_stages=num_stages,
|
| 439 |
-
median_ms=median_ms,
|
| 440 |
-
effective_gbps=effective_gbps,
|
| 441 |
-
validation_error=val_err,
|
| 442 |
-
)
|
| 443 |
-
|
| 444 |
-
|
| 445 |
-
def _make_row(
|
| 446 |
-
family: str,
|
| 447 |
-
m: int,
|
| 448 |
-
n: int,
|
| 449 |
-
block_size: int,
|
| 450 |
-
num_warps: int,
|
| 451 |
-
num_stages: int,
|
| 452 |
-
median_ms: float,
|
| 453 |
-
effective_gbps: float,
|
| 454 |
-
validation_error: float,
|
| 455 |
-
) -> MultiFamilyMeasurementRow:
|
| 456 |
-
configs = _valid_configs(family, n)
|
| 457 |
-
config_id = configs.index((block_size, num_warps, num_stages))
|
| 458 |
-
return MultiFamilyMeasurementRow(
|
| 459 |
-
family_group="A" if family in {"softmax", "layernorm"} else "B",
|
| 460 |
-
family=family,
|
| 461 |
-
task_id=_task_id(family, m, n),
|
| 462 |
-
m=m,
|
| 463 |
-
n=n,
|
| 464 |
-
config_id=config_id,
|
| 465 |
-
block_size=block_size,
|
| 466 |
-
num_warps=num_warps,
|
| 467 |
-
num_stages=num_stages,
|
| 468 |
-
shape_json=json.dumps(
|
| 469 |
-
{
|
| 470 |
-
"family": family,
|
| 471 |
-
"m": m,
|
| 472 |
-
"n": n,
|
| 473 |
-
"group_count": GROUPED_GEMM_GROUP_COUNT if family == "grouped_gemm" else None,
|
| 474 |
-
"k": GROUPED_GEMM_K if family == "grouped_gemm" else None,
|
| 475 |
-
},
|
| 476 |
-
sort_keys=True,
|
| 477 |
-
),
|
| 478 |
-
config_json=json.dumps(
|
| 479 |
-
{
|
| 480 |
-
"block_size": block_size,
|
| 481 |
-
"num_warps": num_warps,
|
| 482 |
-
"num_stages": num_stages,
|
| 483 |
-
},
|
| 484 |
-
sort_keys=True,
|
| 485 |
-
),
|
| 486 |
-
median_ms=float(median_ms),
|
| 487 |
-
effective_gbps=float(effective_gbps),
|
| 488 |
-
score=float(-math.log(max(median_ms, np.finfo(float).tiny))),
|
| 489 |
-
validation_error=float(validation_error),
|
| 490 |
-
)
|
| 491 |
-
|
| 492 |
-
|
| 493 |
-
def _softmax_row_to_multi(row: object) -> MultiFamilyMeasurementRow:
|
| 494 |
-
return _make_row(
|
| 495 |
-
family="softmax",
|
| 496 |
-
m=int(row.m),
|
| 497 |
-
n=int(row.n),
|
| 498 |
-
block_size=int(row.block_size),
|
| 499 |
-
num_warps=int(row.num_warps),
|
| 500 |
-
num_stages=int(row.num_stages),
|
| 501 |
-
median_ms=float(row.median_ms),
|
| 502 |
-
effective_gbps=float(row.effective_gbps),
|
| 503 |
-
validation_error=float(row.validation_error),
|
| 504 |
-
)
|
| 505 |
-
|
| 506 |
-
|
| 507 |
-
def collect_multifamily_measurements(
|
| 508 |
-
output_path: Path,
|
| 509 |
-
families: Sequence[str],
|
| 510 |
-
n_values: Iterable[int],
|
| 511 |
-
repeats: int,
|
| 512 |
-
warmup: int,
|
| 513 |
-
seed: int,
|
| 514 |
-
m: int = 4096,
|
| 515 |
-
append: bool = True,
|
| 516 |
-
) -> List[MultiFamilyMeasurementRow]:
|
| 517 |
-
if not torch.cuda.is_available():
|
| 518 |
-
raise RuntimeError("CUDA is required to run measurements on GPU.")
|
| 519 |
-
|
| 520 |
-
output_path.parent.mkdir(parents=True, exist_ok=True)
|
| 521 |
-
existing = set()
|
| 522 |
-
if output_path.exists():
|
| 523 |
-
with output_path.open("r", newline="", encoding="utf-8") as handle:
|
| 524 |
-
for row in csv.DictReader(handle):
|
| 525 |
-
existing.add(_row_key(row["family"], int(row["m"]), int(row["n"]), int(row["config_id"])))
|
| 526 |
-
|
| 527 |
-
results: List[MultiFamilyMeasurementRow] = []
|
| 528 |
-
for family in families:
|
| 529 |
-
if family not in {"softmax", "layernorm", "grouped_gemm"}:
|
| 530 |
-
raise ValueError(f"Unsupported family: {family}")
|
| 531 |
-
for n in n_values:
|
| 532 |
-
for config_id, (block_size, num_warps, num_stages) in enumerate(_valid_configs(family, n)):
|
| 533 |
-
key = _row_key(family, m, n, config_id)
|
| 534 |
-
if append and key in existing:
|
| 535 |
-
continue
|
| 536 |
-
if family == "softmax":
|
| 537 |
-
row = _softmax_row_to_multi(
|
| 538 |
-
benchmark_softmax_config(
|
| 539 |
-
n=n,
|
| 540 |
-
block_size=block_size,
|
| 541 |
-
num_warps=num_warps,
|
| 542 |
-
num_stages=num_stages,
|
| 543 |
-
m=m,
|
| 544 |
-
repeats=repeats,
|
| 545 |
-
warmup=warmup,
|
| 546 |
-
seed=seed,
|
| 547 |
-
)
|
| 548 |
-
)
|
| 549 |
-
elif family == "layernorm":
|
| 550 |
-
row = benchmark_layernorm_config(
|
| 551 |
-
n=n,
|
| 552 |
-
block_size=block_size,
|
| 553 |
-
num_warps=num_warps,
|
| 554 |
-
num_stages=num_stages,
|
| 555 |
-
m=m,
|
| 556 |
-
repeats=repeats,
|
| 557 |
-
warmup=warmup,
|
| 558 |
-
seed=seed,
|
| 559 |
-
)
|
| 560 |
-
else:
|
| 561 |
-
row = benchmark_grouped_gemm_config(
|
| 562 |
-
n=n,
|
| 563 |
-
block_size=block_size,
|
| 564 |
-
num_warps=num_warps,
|
| 565 |
-
num_stages=num_stages,
|
| 566 |
-
m=m,
|
| 567 |
-
repeats=repeats,
|
| 568 |
-
warmup=warmup,
|
| 569 |
-
seed=seed,
|
| 570 |
-
)
|
| 571 |
-
results.append(row)
|
| 572 |
-
|
| 573 |
-
if results:
|
| 574 |
-
write_mode = "a" if output_path.exists() and append else "w"
|
| 575 |
-
with output_path.open(write_mode, newline="", encoding="utf-8") as handle:
|
| 576 |
-
writer = csv.writer(handle)
|
| 577 |
-
if write_mode == "w" or output_path.stat().st_size == 0:
|
| 578 |
-
writer.writerow(
|
| 579 |
-
[
|
| 580 |
-
"family_group",
|
| 581 |
-
"family",
|
| 582 |
-
"task_id",
|
| 583 |
-
"m",
|
| 584 |
-
"n",
|
| 585 |
-
"config_id",
|
| 586 |
-
"block_size",
|
| 587 |
-
"num_warps",
|
| 588 |
-
"num_stages",
|
| 589 |
-
"shape_json",
|
| 590 |
-
"config_json",
|
| 591 |
-
"median_ms",
|
| 592 |
-
"effective_gbps",
|
| 593 |
-
"score",
|
| 594 |
-
"validation_error",
|
| 595 |
-
]
|
| 596 |
-
)
|
| 597 |
-
for row in results:
|
| 598 |
-
writer.writerow(
|
| 599 |
-
[
|
| 600 |
-
row.family_group,
|
| 601 |
-
row.family,
|
| 602 |
-
row.task_id,
|
| 603 |
-
row.m,
|
| 604 |
-
row.n,
|
| 605 |
-
row.config_id,
|
| 606 |
-
row.block_size,
|
| 607 |
-
row.num_warps,
|
| 608 |
-
row.num_stages,
|
| 609 |
-
row.shape_json,
|
| 610 |
-
row.config_json,
|
| 611 |
-
row.median_ms,
|
| 612 |
-
row.effective_gbps,
|
| 613 |
-
row.score,
|
| 614 |
-
row.validation_error,
|
| 615 |
-
]
|
| 616 |
-
)
|
| 617 |
-
return results
|
| 618 |
-
|
| 619 |
-
|
| 620 |
-
def _row_key(family: str, m: int, n: int, config_id: int) -> str:
|
| 621 |
-
return f"{family}|{m}|{n}|{config_id}"
|
| 622 |
-
|
| 623 |
-
|
| 624 |
-
def parse_args() -> argparse.Namespace:
|
| 625 |
-
parser = argparse.ArgumentParser(description="Collect multi-family GPU measurements.")
|
| 626 |
-
parser.add_argument("--output", type=Path, default=Path("data/autotune_measurements.csv"))
|
| 627 |
-
parser.add_argument(
|
| 628 |
-
"--families",
|
| 629 |
-
nargs="+",
|
| 630 |
-
default=("softmax", "layernorm", "grouped_gemm"),
|
| 631 |
-
choices=("softmax", "layernorm", "grouped_gemm"),
|
| 632 |
-
)
|
| 633 |
-
parser.add_argument("--n-cols", type=int, nargs="+", default=N_VALUES)
|
| 634 |
-
parser.add_argument("--m", type=int, default=4096)
|
| 635 |
-
parser.add_argument("--repeats", type=int, default=200)
|
| 636 |
-
parser.add_argument("--warmup", type=int, default=25)
|
| 637 |
-
parser.add_argument("--seed", type=int, default=0)
|
| 638 |
-
parser.add_argument("--append", action="store_true")
|
| 639 |
-
parser.add_argument("--overwrite", action="store_true")
|
| 640 |
-
return parser.parse_args()
|
| 641 |
-
|
| 642 |
-
|
| 643 |
-
def main() -> None:
|
| 644 |
-
args = parse_args()
|
| 645 |
-
append = args.append and not args.overwrite
|
| 646 |
-
collect_multifamily_measurements(
|
| 647 |
-
output_path=args.output,
|
| 648 |
-
families=args.families,
|
| 649 |
-
n_values=args.n_cols,
|
| 650 |
-
repeats=args.repeats,
|
| 651 |
-
warmup=args.warmup,
|
| 652 |
-
seed=args.seed,
|
| 653 |
-
m=args.m,
|
| 654 |
-
append=append,
|
| 655 |
-
)
|
| 656 |
-
|
| 657 |
-
|
| 658 |
-
if __name__ == "__main__":
|
| 659 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
scripts/collect_qwen_05b_measurements.py
DELETED
|
@@ -1,449 +0,0 @@
|
|
| 1 |
-
#!/usr/bin/env python3
|
| 2 |
-
from __future__ import annotations
|
| 3 |
-
|
| 4 |
-
import argparse
|
| 5 |
-
import csv
|
| 6 |
-
import json
|
| 7 |
-
import math
|
| 8 |
-
import sys
|
| 9 |
-
from dataclasses import dataclass
|
| 10 |
-
from pathlib import Path
|
| 11 |
-
from typing import Dict, List, Sequence
|
| 12 |
-
|
| 13 |
-
import numpy as np
|
| 14 |
-
import torch
|
| 15 |
-
|
| 16 |
-
ROOT = Path(__file__).resolve().parents[1]
|
| 17 |
-
if str(ROOT) not in sys.path:
|
| 18 |
-
sys.path.append(str(ROOT))
|
| 19 |
-
|
| 20 |
-
try:
|
| 21 |
-
import triton
|
| 22 |
-
import triton.language as tl
|
| 23 |
-
import triton.testing as ttesting
|
| 24 |
-
except Exception as err: # pragma: no cover
|
| 25 |
-
raise RuntimeError("Triton is required for Qwen kernel measurement.") from err
|
| 26 |
-
|
| 27 |
-
from scripts.collect_measurements import BLOCK_SIZES, NUM_STAGES, NUM_WARPS, benchmark_single_config
|
| 28 |
-
from scripts.qwen_05b_spec import QwenKernelTask, qwen_05b_tasks
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
EPS = 1e-5
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
@dataclass(frozen=True)
|
| 35 |
-
class QwenMeasurementRow:
|
| 36 |
-
family_group: str
|
| 37 |
-
family: str
|
| 38 |
-
task_id: str
|
| 39 |
-
m: int
|
| 40 |
-
n: int
|
| 41 |
-
k: int
|
| 42 |
-
config_id: int
|
| 43 |
-
block_size: int
|
| 44 |
-
num_warps: int
|
| 45 |
-
num_stages: int
|
| 46 |
-
shape_json: str
|
| 47 |
-
config_json: str
|
| 48 |
-
median_ms: float
|
| 49 |
-
effective_gbps: float
|
| 50 |
-
score: float
|
| 51 |
-
validation_error: float
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
@triton.jit
|
| 55 |
-
def fused_rowwise_rmsnorm_kernel(
|
| 56 |
-
X_ptr,
|
| 57 |
-
Y_ptr,
|
| 58 |
-
stride_xm,
|
| 59 |
-
stride_xn,
|
| 60 |
-
stride_ym,
|
| 61 |
-
stride_yn,
|
| 62 |
-
n_cols,
|
| 63 |
-
eps,
|
| 64 |
-
BLOCK_SIZE: tl.constexpr,
|
| 65 |
-
):
|
| 66 |
-
row_idx = tl.program_id(0)
|
| 67 |
-
col_offsets = tl.arange(0, BLOCK_SIZE)
|
| 68 |
-
mask = col_offsets < n_cols
|
| 69 |
-
x_ptr = X_ptr + row_idx * stride_xm + col_offsets
|
| 70 |
-
y_ptr = Y_ptr + row_idx * stride_ym + col_offsets
|
| 71 |
-
|
| 72 |
-
x = tl.load(x_ptr, mask=mask, other=0.0).to(tl.float32)
|
| 73 |
-
mean_sq = tl.sum(x * x, axis=0) / n_cols
|
| 74 |
-
inv_rms = tl.rsqrt(mean_sq + eps)
|
| 75 |
-
y = x * inv_rms
|
| 76 |
-
tl.store(y_ptr, y.to(tl.float16), mask=mask)
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
@triton.jit
|
| 80 |
-
def matmul_kernel(
|
| 81 |
-
A_ptr,
|
| 82 |
-
B_ptr,
|
| 83 |
-
C_ptr,
|
| 84 |
-
M,
|
| 85 |
-
N,
|
| 86 |
-
K,
|
| 87 |
-
stride_am,
|
| 88 |
-
stride_ak,
|
| 89 |
-
stride_bk,
|
| 90 |
-
stride_bn,
|
| 91 |
-
stride_cm,
|
| 92 |
-
stride_cn,
|
| 93 |
-
BLOCK_M: tl.constexpr,
|
| 94 |
-
BLOCK_N: tl.constexpr,
|
| 95 |
-
BLOCK_K: tl.constexpr,
|
| 96 |
-
):
|
| 97 |
-
pid = tl.program_id(axis=0)
|
| 98 |
-
num_pid_m = tl.cdiv(M, BLOCK_M)
|
| 99 |
-
num_pid_n = tl.cdiv(N, BLOCK_N)
|
| 100 |
-
pid_m = pid // num_pid_n
|
| 101 |
-
pid_n = pid % num_pid_n
|
| 102 |
-
|
| 103 |
-
offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
| 104 |
-
offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
|
| 105 |
-
offs_k = tl.arange(0, BLOCK_K)
|
| 106 |
-
|
| 107 |
-
a_ptrs = A_ptr + offs_m[:, None] * stride_am + offs_k[None, :] * stride_ak
|
| 108 |
-
b_ptrs = B_ptr + offs_k[:, None] * stride_bk + offs_n[None, :] * stride_bn
|
| 109 |
-
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
|
| 110 |
-
for _ in range(0, tl.cdiv(K, BLOCK_K)):
|
| 111 |
-
a = tl.load(a_ptrs, mask=(offs_m[:, None] < M) & (offs_k[None, :] < K), other=0.0)
|
| 112 |
-
b = tl.load(b_ptrs, mask=(offs_k[:, None] < K) & (offs_n[None, :] < N), other=0.0)
|
| 113 |
-
acc += tl.dot(a, b)
|
| 114 |
-
a_ptrs += BLOCK_K * stride_ak
|
| 115 |
-
b_ptrs += BLOCK_K * stride_bk
|
| 116 |
-
offs_k += BLOCK_K
|
| 117 |
-
|
| 118 |
-
c = acc.to(tl.float16)
|
| 119 |
-
c_ptrs = C_ptr + offs_m[:, None] * stride_cm + offs_n[None, :] * stride_cn
|
| 120 |
-
c_mask = (offs_m[:, None] < M) & (offs_n[None, :] < N)
|
| 121 |
-
tl.store(c_ptrs, c, mask=c_mask)
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
def _gemm_blocks(block_size: int) -> tuple[int, int, int]:
|
| 125 |
-
block_m = max(16, min(block_size // 4, 128))
|
| 126 |
-
block_n = max(16, min(block_size // 4, 128))
|
| 127 |
-
block_k = 32
|
| 128 |
-
return block_m, block_n, block_k
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
def _effective_gbps(bytes_processed: int, median_ms: float) -> float:
|
| 132 |
-
if median_ms <= 0:
|
| 133 |
-
return 0.0
|
| 134 |
-
return float(bytes_processed) / (median_ms / 1000.0) / 1e9
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
def _score(ms: float) -> float:
|
| 138 |
-
return float(-math.log(max(ms, np.finfo(float).tiny)))
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
def _config_json(block_size: int, num_warps: int, num_stages: int) -> str:
|
| 142 |
-
return json.dumps(
|
| 143 |
-
{"block_size": block_size, "num_warps": num_warps, "num_stages": num_stages},
|
| 144 |
-
sort_keys=True,
|
| 145 |
-
)
|
| 146 |
-
|
| 147 |
-
|
| 148 |
-
def _valid_configs(task: QwenKernelTask) -> List[tuple[int, int, int]]:
|
| 149 |
-
configs: List[tuple[int, int, int]] = []
|
| 150 |
-
for block_size in BLOCK_SIZES:
|
| 151 |
-
if task.family in {"softmax", "rmsnorm"} and block_size < task.n:
|
| 152 |
-
continue
|
| 153 |
-
if task.family == "gemm" and block_size > 1024:
|
| 154 |
-
continue
|
| 155 |
-
for num_warps in NUM_WARPS:
|
| 156 |
-
for num_stages in NUM_STAGES:
|
| 157 |
-
configs.append((block_size, num_warps, num_stages))
|
| 158 |
-
return configs
|
| 159 |
-
|
| 160 |
-
|
| 161 |
-
def _benchmark_rmsnorm(task: QwenKernelTask, block_size: int, num_warps: int, num_stages: int, repeats: int, warmup: int, seed: int) -> QwenMeasurementRow:
|
| 162 |
-
torch.manual_seed(seed)
|
| 163 |
-
sample = torch.randn((task.m, task.n), device="cuda", dtype=torch.float16)
|
| 164 |
-
output = torch.empty_like(sample)
|
| 165 |
-
grid = (sample.shape[0],)
|
| 166 |
-
|
| 167 |
-
def launch() -> None:
|
| 168 |
-
fused_rowwise_rmsnorm_kernel[grid](
|
| 169 |
-
sample,
|
| 170 |
-
output,
|
| 171 |
-
sample.stride(0),
|
| 172 |
-
sample.stride(1),
|
| 173 |
-
output.stride(0),
|
| 174 |
-
output.stride(1),
|
| 175 |
-
sample.shape[1],
|
| 176 |
-
EPS,
|
| 177 |
-
BLOCK_SIZE=block_size,
|
| 178 |
-
num_warps=num_warps,
|
| 179 |
-
num_stages=num_stages,
|
| 180 |
-
)
|
| 181 |
-
|
| 182 |
-
for _ in range(max(1, warmup)):
|
| 183 |
-
launch()
|
| 184 |
-
torch.cuda.synchronize()
|
| 185 |
-
if ttesting is not None:
|
| 186 |
-
result = ttesting.do_bench(launch, warmup=0, rep=repeats, quantiles=[0.5], return_mode="median")
|
| 187 |
-
median_ms = float(result.get("median", 0.0) if isinstance(result, dict) else result)
|
| 188 |
-
else:
|
| 189 |
-
start = torch.cuda.Event(enable_timing=True)
|
| 190 |
-
end = torch.cuda.Event(enable_timing=True)
|
| 191 |
-
durations: List[float] = []
|
| 192 |
-
for _ in range(max(1, repeats)):
|
| 193 |
-
torch.cuda.synchronize()
|
| 194 |
-
start.record()
|
| 195 |
-
launch()
|
| 196 |
-
end.record()
|
| 197 |
-
end.synchronize()
|
| 198 |
-
durations.append(start.elapsed_time(end))
|
| 199 |
-
median_ms = float(np.median(np.asarray(durations, dtype=np.float32)))
|
| 200 |
-
|
| 201 |
-
ref = sample.float() * torch.rsqrt(sample.float().pow(2).mean(dim=-1, keepdim=True) + EPS)
|
| 202 |
-
fused_rowwise_rmsnorm_kernel[grid](
|
| 203 |
-
sample,
|
| 204 |
-
output,
|
| 205 |
-
sample.stride(0),
|
| 206 |
-
sample.stride(1),
|
| 207 |
-
output.stride(0),
|
| 208 |
-
output.stride(1),
|
| 209 |
-
sample.shape[1],
|
| 210 |
-
EPS,
|
| 211 |
-
BLOCK_SIZE=block_size,
|
| 212 |
-
num_warps=num_warps,
|
| 213 |
-
num_stages=num_stages,
|
| 214 |
-
)
|
| 215 |
-
torch.cuda.synchronize()
|
| 216 |
-
max_err = float((output.float() - ref.float()).abs().max().item())
|
| 217 |
-
config_id = _valid_configs(task).index((block_size, num_warps, num_stages))
|
| 218 |
-
return QwenMeasurementRow(
|
| 219 |
-
family_group=task.family_group,
|
| 220 |
-
family=task.family,
|
| 221 |
-
task_id=task.task_id,
|
| 222 |
-
m=task.m,
|
| 223 |
-
n=task.n,
|
| 224 |
-
k=task.k,
|
| 225 |
-
config_id=config_id,
|
| 226 |
-
block_size=block_size,
|
| 227 |
-
num_warps=num_warps,
|
| 228 |
-
num_stages=num_stages,
|
| 229 |
-
shape_json=json.dumps(task.shape_fields(), sort_keys=True),
|
| 230 |
-
config_json=_config_json(block_size, num_warps, num_stages),
|
| 231 |
-
median_ms=median_ms,
|
| 232 |
-
effective_gbps=_effective_gbps(sample.numel() * sample.element_size() * 2, median_ms),
|
| 233 |
-
score=_score(median_ms),
|
| 234 |
-
validation_error=max_err,
|
| 235 |
-
)
|
| 236 |
-
|
| 237 |
-
|
| 238 |
-
def _benchmark_gemm(task: QwenKernelTask, block_size: int, num_warps: int, num_stages: int, repeats: int, warmup: int, seed: int) -> QwenMeasurementRow:
|
| 239 |
-
torch.manual_seed(seed)
|
| 240 |
-
a = torch.randn((task.m, task.k), device="cuda", dtype=torch.float16)
|
| 241 |
-
b = torch.randn((task.k, task.n), device="cuda", dtype=torch.float16)
|
| 242 |
-
c = torch.empty((task.m, task.n), device="cuda", dtype=torch.float16)
|
| 243 |
-
block_m, block_n, block_k = _gemm_blocks(block_size)
|
| 244 |
-
|
| 245 |
-
def launch() -> None:
|
| 246 |
-
grid = (triton.cdiv(task.m, block_m) * triton.cdiv(task.n, block_n),)
|
| 247 |
-
matmul_kernel[grid](
|
| 248 |
-
a,
|
| 249 |
-
b,
|
| 250 |
-
c,
|
| 251 |
-
task.m,
|
| 252 |
-
task.n,
|
| 253 |
-
task.k,
|
| 254 |
-
a.stride(0),
|
| 255 |
-
a.stride(1),
|
| 256 |
-
b.stride(0),
|
| 257 |
-
b.stride(1),
|
| 258 |
-
c.stride(0),
|
| 259 |
-
c.stride(1),
|
| 260 |
-
BLOCK_M=block_m,
|
| 261 |
-
BLOCK_N=block_n,
|
| 262 |
-
BLOCK_K=block_k,
|
| 263 |
-
num_warps=num_warps,
|
| 264 |
-
num_stages=num_stages,
|
| 265 |
-
)
|
| 266 |
-
|
| 267 |
-
for _ in range(max(1, warmup)):
|
| 268 |
-
launch()
|
| 269 |
-
torch.cuda.synchronize()
|
| 270 |
-
if ttesting is not None:
|
| 271 |
-
result = ttesting.do_bench(launch, warmup=0, rep=repeats, quantiles=[0.5], return_mode="median")
|
| 272 |
-
median_ms = float(result.get("median", 0.0) if isinstance(result, dict) else result)
|
| 273 |
-
else:
|
| 274 |
-
start = torch.cuda.Event(enable_timing=True)
|
| 275 |
-
end = torch.cuda.Event(enable_timing=True)
|
| 276 |
-
durations: List[float] = []
|
| 277 |
-
for _ in range(max(1, repeats)):
|
| 278 |
-
torch.cuda.synchronize()
|
| 279 |
-
start.record()
|
| 280 |
-
launch()
|
| 281 |
-
end.record()
|
| 282 |
-
end.synchronize()
|
| 283 |
-
durations.append(start.elapsed_time(end))
|
| 284 |
-
median_ms = float(np.median(np.asarray(durations, dtype=np.float32)))
|
| 285 |
-
|
| 286 |
-
ref = torch.matmul(a.float(), b.float()).to(torch.float16)
|
| 287 |
-
launch()
|
| 288 |
-
torch.cuda.synchronize()
|
| 289 |
-
max_err = float((c.float() - ref.float()).abs().max().item())
|
| 290 |
-
config_id = _valid_configs(task).index((block_size, num_warps, num_stages))
|
| 291 |
-
bytes_processed = a.numel() * a.element_size() + b.numel() * b.element_size() + c.numel() * c.element_size()
|
| 292 |
-
return QwenMeasurementRow(
|
| 293 |
-
family_group=task.family_group,
|
| 294 |
-
family=task.family,
|
| 295 |
-
task_id=task.task_id,
|
| 296 |
-
m=task.m,
|
| 297 |
-
n=task.n,
|
| 298 |
-
k=task.k,
|
| 299 |
-
config_id=config_id,
|
| 300 |
-
block_size=block_size,
|
| 301 |
-
num_warps=num_warps,
|
| 302 |
-
num_stages=num_stages,
|
| 303 |
-
shape_json=json.dumps(task.shape_fields(), sort_keys=True),
|
| 304 |
-
config_json=_config_json(block_size, num_warps, num_stages),
|
| 305 |
-
median_ms=median_ms,
|
| 306 |
-
effective_gbps=_effective_gbps(bytes_processed, median_ms),
|
| 307 |
-
score=_score(median_ms),
|
| 308 |
-
validation_error=max_err,
|
| 309 |
-
)
|
| 310 |
-
|
| 311 |
-
|
| 312 |
-
def benchmark_qwen_task(task: QwenKernelTask, block_size: int, num_warps: int, num_stages: int, repeats: int, warmup: int, seed: int) -> QwenMeasurementRow:
|
| 313 |
-
if task.family == "softmax":
|
| 314 |
-
row = benchmark_single_config(
|
| 315 |
-
n=task.n,
|
| 316 |
-
block_size=block_size,
|
| 317 |
-
num_warps=num_warps,
|
| 318 |
-
num_stages=num_stages,
|
| 319 |
-
m=task.m,
|
| 320 |
-
repeats=repeats,
|
| 321 |
-
warmup=warmup,
|
| 322 |
-
seed=seed,
|
| 323 |
-
)
|
| 324 |
-
config_id = _valid_configs(task).index((block_size, num_warps, num_stages))
|
| 325 |
-
return QwenMeasurementRow(
|
| 326 |
-
family_group=task.family_group,
|
| 327 |
-
family=task.family,
|
| 328 |
-
task_id=task.task_id,
|
| 329 |
-
m=task.m,
|
| 330 |
-
n=task.n,
|
| 331 |
-
k=0,
|
| 332 |
-
config_id=config_id,
|
| 333 |
-
block_size=block_size,
|
| 334 |
-
num_warps=num_warps,
|
| 335 |
-
num_stages=num_stages,
|
| 336 |
-
shape_json=json.dumps(task.shape_fields(), sort_keys=True),
|
| 337 |
-
config_json=_config_json(block_size, num_warps, num_stages),
|
| 338 |
-
median_ms=float(row.median_ms),
|
| 339 |
-
effective_gbps=float(row.effective_gbps),
|
| 340 |
-
score=float(row.score),
|
| 341 |
-
validation_error=float(row.validation_error),
|
| 342 |
-
)
|
| 343 |
-
if task.family == "rmsnorm":
|
| 344 |
-
return _benchmark_rmsnorm(task, block_size, num_warps, num_stages, repeats, warmup, seed)
|
| 345 |
-
if task.family == "gemm":
|
| 346 |
-
return _benchmark_gemm(task, block_size, num_warps, num_stages, repeats, warmup, seed)
|
| 347 |
-
raise ValueError(f"Unsupported family: {task.family}")
|
| 348 |
-
|
| 349 |
-
|
| 350 |
-
def collect_qwen_measurements(output_path: Path, repeats: int, warmup: int, seed: int, append: bool = True) -> List[QwenMeasurementRow]:
|
| 351 |
-
output_path.parent.mkdir(parents=True, exist_ok=True)
|
| 352 |
-
existing = set()
|
| 353 |
-
if output_path.exists():
|
| 354 |
-
with output_path.open("r", newline="", encoding="utf-8") as handle:
|
| 355 |
-
for row in csv.DictReader(handle):
|
| 356 |
-
existing.add((row["task_id"], int(row["config_id"])))
|
| 357 |
-
|
| 358 |
-
results: List[QwenMeasurementRow] = []
|
| 359 |
-
for task in qwen_05b_tasks():
|
| 360 |
-
for config_id, (block_size, num_warps, num_stages) in enumerate(_valid_configs(task)):
|
| 361 |
-
key = (task.task_id, config_id)
|
| 362 |
-
if append and key in existing:
|
| 363 |
-
continue
|
| 364 |
-
results.append(
|
| 365 |
-
benchmark_qwen_task(
|
| 366 |
-
task=task,
|
| 367 |
-
block_size=block_size,
|
| 368 |
-
num_warps=num_warps,
|
| 369 |
-
num_stages=num_stages,
|
| 370 |
-
repeats=repeats,
|
| 371 |
-
warmup=warmup,
|
| 372 |
-
seed=seed,
|
| 373 |
-
)
|
| 374 |
-
)
|
| 375 |
-
|
| 376 |
-
if results:
|
| 377 |
-
write_mode = "a" if output_path.exists() and append else "w"
|
| 378 |
-
with output_path.open(write_mode, newline="", encoding="utf-8") as handle:
|
| 379 |
-
writer = csv.writer(handle)
|
| 380 |
-
if write_mode == "w" or output_path.stat().st_size == 0:
|
| 381 |
-
writer.writerow(
|
| 382 |
-
[
|
| 383 |
-
"family_group",
|
| 384 |
-
"family",
|
| 385 |
-
"task_id",
|
| 386 |
-
"m",
|
| 387 |
-
"n",
|
| 388 |
-
"k",
|
| 389 |
-
"config_id",
|
| 390 |
-
"block_size",
|
| 391 |
-
"num_warps",
|
| 392 |
-
"num_stages",
|
| 393 |
-
"shape_json",
|
| 394 |
-
"config_json",
|
| 395 |
-
"median_ms",
|
| 396 |
-
"effective_gbps",
|
| 397 |
-
"score",
|
| 398 |
-
"validation_error",
|
| 399 |
-
]
|
| 400 |
-
)
|
| 401 |
-
for row in results:
|
| 402 |
-
writer.writerow(
|
| 403 |
-
[
|
| 404 |
-
row.family_group,
|
| 405 |
-
row.family,
|
| 406 |
-
row.task_id,
|
| 407 |
-
row.m,
|
| 408 |
-
row.n,
|
| 409 |
-
row.k,
|
| 410 |
-
row.config_id,
|
| 411 |
-
row.block_size,
|
| 412 |
-
row.num_warps,
|
| 413 |
-
row.num_stages,
|
| 414 |
-
row.shape_json,
|
| 415 |
-
row.config_json,
|
| 416 |
-
row.median_ms,
|
| 417 |
-
row.effective_gbps,
|
| 418 |
-
row.score,
|
| 419 |
-
row.validation_error,
|
| 420 |
-
]
|
| 421 |
-
)
|
| 422 |
-
return results
|
| 423 |
-
|
| 424 |
-
|
| 425 |
-
def parse_args() -> argparse.Namespace:
|
| 426 |
-
parser = argparse.ArgumentParser(description="Collect exact-kernel measurements for Qwen2.5-0.5B.")
|
| 427 |
-
parser.add_argument("--output", type=Path, default=Path("data/qwen_05b_measurements.csv"))
|
| 428 |
-
parser.add_argument("--repeats", type=int, default=200)
|
| 429 |
-
parser.add_argument("--warmup", type=int, default=25)
|
| 430 |
-
parser.add_argument("--seed", type=int, default=0)
|
| 431 |
-
parser.add_argument("--append", action="store_true")
|
| 432 |
-
parser.add_argument("--overwrite", action="store_true")
|
| 433 |
-
return parser.parse_args()
|
| 434 |
-
|
| 435 |
-
|
| 436 |
-
def main() -> None:
|
| 437 |
-
args = parse_args()
|
| 438 |
-
append = args.append and not args.overwrite
|
| 439 |
-
collect_qwen_measurements(
|
| 440 |
-
output_path=args.output,
|
| 441 |
-
repeats=args.repeats,
|
| 442 |
-
warmup=args.warmup,
|
| 443 |
-
seed=args.seed,
|
| 444 |
-
append=append,
|
| 445 |
-
)
|
| 446 |
-
|
| 447 |
-
|
| 448 |
-
if __name__ == "__main__":
|
| 449 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
scripts/demo_compare.py
DELETED
|
@@ -1,200 +0,0 @@
|
|
| 1 |
-
#!/usr/bin/env python3
|
| 2 |
-
from __future__ import annotations
|
| 3 |
-
|
| 4 |
-
import argparse
|
| 5 |
-
import json
|
| 6 |
-
from dataclasses import dataclass
|
| 7 |
-
import sys
|
| 8 |
-
from pathlib import Path
|
| 9 |
-
from typing import Any, Dict, List
|
| 10 |
-
|
| 11 |
-
ROOT = Path(__file__).resolve().parents[1]
|
| 12 |
-
if str(ROOT) not in sys.path:
|
| 13 |
-
sys.path.append(str(ROOT))
|
| 14 |
-
|
| 15 |
-
from server.softmax_surrogate_environment import SoftmaxSurrogateEnvironment
|
| 16 |
-
from scripts.collect_measurements import benchmark_single_config
|
| 17 |
-
from scripts.run_random_baseline import run_random_baseline
|
| 18 |
-
from scripts.run_surrogate_baseline import run_surrogate_baseline
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
@dataclass
|
| 22 |
-
class BaselineResult:
|
| 23 |
-
method: str
|
| 24 |
-
latency_ms: float
|
| 25 |
-
config: Dict[str, int]
|
| 26 |
-
regret: float
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
def _search_metric_key(summary: Dict[str, Any], oracle_best_ms: float) -> tuple[float, float, float]:
|
| 30 |
-
metrics = summary.get("aggregate_metrics", {})
|
| 31 |
-
mean_auc_regret = float(metrics.get("mean_auc_regret", float("inf")))
|
| 32 |
-
oracle_hit_rate_final = float(metrics.get("oracle_hit_rate_final", 0.0))
|
| 33 |
-
best_latency_ms = float(summary["best_overall"]["latency_ms"])
|
| 34 |
-
latency_regret = best_latency_ms / oracle_best_ms - 1.0
|
| 35 |
-
return (mean_auc_regret, -oracle_hit_rate_final, latency_regret)
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
def _heuristic_for_task(task_id: str, task_rows: List[Dict[str, Any]], env: SoftmaxSurrogateEnvironment) -> BaselineResult:
|
| 39 |
-
n = int(task_id.split("_n")[-1])
|
| 40 |
-
block = min(row["block_size"] for row in task_rows if row["block_size"] >= n)
|
| 41 |
-
warp = 4 if 4 in {row["num_warps"] for row in task_rows} else 2
|
| 42 |
-
stage = 2 if 2 in {row["num_stages"] for row in task_rows} else 1
|
| 43 |
-
|
| 44 |
-
candidate = None
|
| 45 |
-
for row in task_rows:
|
| 46 |
-
if row["block_size"] == block and row["num_warps"] == warp and row["num_stages"] == stage:
|
| 47 |
-
candidate = row
|
| 48 |
-
break
|
| 49 |
-
if candidate is None:
|
| 50 |
-
candidate = min(
|
| 51 |
-
task_rows,
|
| 52 |
-
key=lambda row: abs(row["block_size"] - block) + 10 * abs(row["num_warps"] - warp),
|
| 53 |
-
)
|
| 54 |
-
|
| 55 |
-
latency_ms = env.measured_latency_ms(candidate["config_id"])
|
| 56 |
-
oracle_best_ms = env.oracle_best()["median_ms"]
|
| 57 |
-
return BaselineResult(
|
| 58 |
-
method="heuristic",
|
| 59 |
-
latency_ms=float(latency_ms),
|
| 60 |
-
config=candidate,
|
| 61 |
-
regret=float(latency_ms / oracle_best_ms - 1.0),
|
| 62 |
-
)
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
def _pick_task(task_arg: str | None, measurement_path: str, budget: int) -> str:
|
| 66 |
-
env = SoftmaxSurrogateEnvironment(measurement_path=measurement_path, budget=budget, seed=0)
|
| 67 |
-
if task_arg:
|
| 68 |
-
env.reset(task=task_arg)
|
| 69 |
-
else:
|
| 70 |
-
env.reset()
|
| 71 |
-
return env.state()["task_id"]
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
def _run_all(
|
| 75 |
-
task: str,
|
| 76 |
-
budget: int,
|
| 77 |
-
episodes: int,
|
| 78 |
-
seed: int,
|
| 79 |
-
measurement_path: str,
|
| 80 |
-
acquisition: str,
|
| 81 |
-
beta: float,
|
| 82 |
-
xi: float,
|
| 83 |
-
) -> Dict[str, Any]:
|
| 84 |
-
env = SoftmaxSurrogateEnvironment(measurement_path=measurement_path, budget=budget, seed=seed)
|
| 85 |
-
env.reset(task=task)
|
| 86 |
-
task_rows = env.available_configs()
|
| 87 |
-
oracle_best = env.oracle_best()
|
| 88 |
-
|
| 89 |
-
heuristic = _heuristic_for_task(task, task_rows, env)
|
| 90 |
-
random_summary = run_random_baseline(task, episodes=episodes, budget=budget, seed=seed, measurement_path=measurement_path)
|
| 91 |
-
surrogate_summary = run_surrogate_baseline(
|
| 92 |
-
task,
|
| 93 |
-
episodes=episodes,
|
| 94 |
-
budget=budget,
|
| 95 |
-
seed=seed,
|
| 96 |
-
measurement_path=measurement_path,
|
| 97 |
-
acquisition=acquisition,
|
| 98 |
-
beta=beta,
|
| 99 |
-
xi=xi,
|
| 100 |
-
)
|
| 101 |
-
|
| 102 |
-
search_summaries = {
|
| 103 |
-
"random": random_summary,
|
| 104 |
-
"surrogate": surrogate_summary,
|
| 105 |
-
}
|
| 106 |
-
winner_method, winner_summary = min(
|
| 107 |
-
search_summaries.items(),
|
| 108 |
-
key=lambda item: _search_metric_key(item[1], oracle_best["median_ms"]),
|
| 109 |
-
)
|
| 110 |
-
winner_cfg = winner_summary["best_overall"]["config"]
|
| 111 |
-
winner_regret = float(winner_summary["best_overall"]["latency_ms"] / oracle_best["median_ms"] - 1.0)
|
| 112 |
-
n = int(task.split("_n")[-1])
|
| 113 |
-
|
| 114 |
-
live = benchmark_single_config(
|
| 115 |
-
n=n,
|
| 116 |
-
block_size=winner_cfg["block_size"],
|
| 117 |
-
num_warps=winner_cfg["num_warps"],
|
| 118 |
-
num_stages=winner_cfg["num_stages"],
|
| 119 |
-
repeats=max(200, budget * 20),
|
| 120 |
-
warmup=25,
|
| 121 |
-
seed=seed + 999,
|
| 122 |
-
)
|
| 123 |
-
|
| 124 |
-
return {
|
| 125 |
-
"task": task,
|
| 126 |
-
"seed": seed,
|
| 127 |
-
"budget": budget,
|
| 128 |
-
"episodes": episodes,
|
| 129 |
-
"acquisition": acquisition,
|
| 130 |
-
"beta": beta,
|
| 131 |
-
"xi": xi,
|
| 132 |
-
"oracle_best": oracle_best,
|
| 133 |
-
"heuristic": heuristic.__dict__,
|
| 134 |
-
"random": random_summary["best_overall"],
|
| 135 |
-
"random_aggregate_metrics": random_summary.get("aggregate_metrics", {}),
|
| 136 |
-
"surrogate": surrogate_summary["best_overall"],
|
| 137 |
-
"surrogate_aggregate_metrics": surrogate_summary.get("aggregate_metrics", {}),
|
| 138 |
-
"winner": {
|
| 139 |
-
"method": winner_method,
|
| 140 |
-
"selection_metric": "min(mean_auc_regret), tie-break max(oracle_hit_rate_final), then best latency",
|
| 141 |
-
"latency_ms": winner_summary["best_overall"]["latency_ms"],
|
| 142 |
-
"config": winner_cfg,
|
| 143 |
-
"regret": winner_regret,
|
| 144 |
-
"live_rerun": live.__dict__,
|
| 145 |
-
},
|
| 146 |
-
}
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
def parse_args() -> argparse.Namespace:
|
| 150 |
-
parser = argparse.ArgumentParser(description="Compare heuristic/random/surrogate baselines.")
|
| 151 |
-
parser.add_argument(
|
| 152 |
-
"--task",
|
| 153 |
-
default="softmax_m4096_n2048",
|
| 154 |
-
help="Task ID (e.g., softmax_m4096_n2048)",
|
| 155 |
-
)
|
| 156 |
-
parser.add_argument("--budget", type=int, default=6)
|
| 157 |
-
parser.add_argument("--episodes", type=int, default=20)
|
| 158 |
-
parser.add_argument("--seed", type=int, default=2)
|
| 159 |
-
parser.add_argument(
|
| 160 |
-
"--acquisition",
|
| 161 |
-
type=str,
|
| 162 |
-
choices=("mean", "ucb", "ei"),
|
| 163 |
-
default="ucb",
|
| 164 |
-
)
|
| 165 |
-
parser.add_argument("--beta", type=float, default=2.0)
|
| 166 |
-
parser.add_argument("--xi", type=float, default=0.0)
|
| 167 |
-
parser.add_argument(
|
| 168 |
-
"--measurement-path",
|
| 169 |
-
type=str,
|
| 170 |
-
default="data/autotune_measurements.csv",
|
| 171 |
-
)
|
| 172 |
-
parser.add_argument(
|
| 173 |
-
"--output",
|
| 174 |
-
type=Path,
|
| 175 |
-
default=Path("outputs/demo_compare.json"),
|
| 176 |
-
)
|
| 177 |
-
return parser.parse_args()
|
| 178 |
-
|
| 179 |
-
|
| 180 |
-
def main() -> None:
|
| 181 |
-
args = parse_args()
|
| 182 |
-
task = _pick_task(args.task, args.measurement_path, args.budget)
|
| 183 |
-
summary = _run_all(
|
| 184 |
-
task=task,
|
| 185 |
-
budget=args.budget,
|
| 186 |
-
episodes=args.episodes,
|
| 187 |
-
seed=args.seed,
|
| 188 |
-
measurement_path=args.measurement_path,
|
| 189 |
-
acquisition=args.acquisition,
|
| 190 |
-
beta=args.beta,
|
| 191 |
-
xi=args.xi,
|
| 192 |
-
)
|
| 193 |
-
args.output.parent.mkdir(parents=True, exist_ok=True)
|
| 194 |
-
with args.output.open("w", encoding="utf-8") as f:
|
| 195 |
-
json.dump(summary, f, indent=2)
|
| 196 |
-
print(json.dumps(summary, indent=2))
|
| 197 |
-
|
| 198 |
-
|
| 199 |
-
if __name__ == "__main__":
|
| 200 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
scripts/eval_generalization.py
DELETED
|
@@ -1,147 +0,0 @@
|
|
| 1 |
-
#!/usr/bin/env python3
|
| 2 |
-
from __future__ import annotations
|
| 3 |
-
|
| 4 |
-
import argparse
|
| 5 |
-
import json
|
| 6 |
-
import sys
|
| 7 |
-
from pathlib import Path
|
| 8 |
-
from typing import Any, Dict, List
|
| 9 |
-
|
| 10 |
-
import numpy as np
|
| 11 |
-
|
| 12 |
-
ROOT = Path(__file__).resolve().parents[1]
|
| 13 |
-
if str(ROOT) not in sys.path:
|
| 14 |
-
sys.path.append(str(ROOT))
|
| 15 |
-
|
| 16 |
-
from scripts.run_random_baseline import run_random_baseline
|
| 17 |
-
from scripts.run_surrogate_baseline import run_surrogate_baseline
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
def _average_metric_dict(records: List[Dict[str, float]]) -> Dict[str, float]:
|
| 21 |
-
if not records:
|
| 22 |
-
return {}
|
| 23 |
-
keys = sorted({key for record in records for key in record.keys()}, key=lambda value: int(value))
|
| 24 |
-
return {
|
| 25 |
-
key: float(np.mean(np.asarray([record[key] for record in records if key in record], dtype=np.float32)))
|
| 26 |
-
for key in keys
|
| 27 |
-
}
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
def _summarize_runs(runs: List[Dict[str, Any]]) -> Dict[str, Any]:
|
| 31 |
-
mean_regret_records = [run["aggregate_metrics"].get("mean_regret_at", {}) for run in runs]
|
| 32 |
-
median_regret_records = [run["aggregate_metrics"].get("median_regret_at", {}) for run in runs]
|
| 33 |
-
auc_values = [run["aggregate_metrics"].get("mean_auc_regret") for run in runs]
|
| 34 |
-
oracle_hit_values = [run["aggregate_metrics"].get("oracle_hit_rate_final") for run in runs]
|
| 35 |
-
return {
|
| 36 |
-
"mean_regret_at": _average_metric_dict(mean_regret_records),
|
| 37 |
-
"median_regret_at": _average_metric_dict(median_regret_records),
|
| 38 |
-
"mean_best_so_far_auc": float(np.mean(np.asarray(auc_values, dtype=np.float32))) if auc_values else None,
|
| 39 |
-
"mean_oracle_hit_rate_final": float(np.mean(np.asarray(oracle_hit_values, dtype=np.float32))) if oracle_hit_values else None,
|
| 40 |
-
}
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
def _evaluate_section(
|
| 44 |
-
section_name: str,
|
| 45 |
-
split: Dict[str, Any],
|
| 46 |
-
measurement_path: str,
|
| 47 |
-
episodes: int,
|
| 48 |
-
budget: int,
|
| 49 |
-
seed: int,
|
| 50 |
-
acquisition: str,
|
| 51 |
-
beta: float,
|
| 52 |
-
xi: float,
|
| 53 |
-
) -> Dict[str, Any]:
|
| 54 |
-
train_tasks = split["train_tasks"]
|
| 55 |
-
test_tasks = split["test_tasks"]
|
| 56 |
-
random_runs: List[Dict[str, Any]] = []
|
| 57 |
-
surrogate_runs: List[Dict[str, Any]] = []
|
| 58 |
-
|
| 59 |
-
for idx, task in enumerate(test_tasks):
|
| 60 |
-
task_seed = seed + idx * 1000
|
| 61 |
-
random_runs.append(
|
| 62 |
-
run_random_baseline(
|
| 63 |
-
task=task,
|
| 64 |
-
episodes=episodes,
|
| 65 |
-
budget=budget,
|
| 66 |
-
seed=task_seed,
|
| 67 |
-
measurement_path=measurement_path,
|
| 68 |
-
)
|
| 69 |
-
)
|
| 70 |
-
surrogate_runs.append(
|
| 71 |
-
run_surrogate_baseline(
|
| 72 |
-
task=task,
|
| 73 |
-
episodes=episodes,
|
| 74 |
-
budget=budget,
|
| 75 |
-
seed=task_seed,
|
| 76 |
-
measurement_path=measurement_path,
|
| 77 |
-
train_task_ids=train_tasks,
|
| 78 |
-
acquisition=acquisition,
|
| 79 |
-
beta=beta,
|
| 80 |
-
xi=xi,
|
| 81 |
-
)
|
| 82 |
-
)
|
| 83 |
-
|
| 84 |
-
return {
|
| 85 |
-
"section": section_name,
|
| 86 |
-
"train_tasks": train_tasks,
|
| 87 |
-
"test_tasks": test_tasks,
|
| 88 |
-
"random_summary": _summarize_runs(random_runs),
|
| 89 |
-
"surrogate_summary": _summarize_runs(surrogate_runs),
|
| 90 |
-
"task_runs": {
|
| 91 |
-
"random": random_runs,
|
| 92 |
-
"surrogate": surrogate_runs,
|
| 93 |
-
},
|
| 94 |
-
}
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
def main() -> None:
|
| 98 |
-
parser = argparse.ArgumentParser(description="Evaluate random vs surrogate on shape and family holdout splits.")
|
| 99 |
-
parser.add_argument("--measurement-path", type=str, default="data/autotune_measurements.csv")
|
| 100 |
-
parser.add_argument("--splits", type=Path, default=Path("data/benchmark_splits.json"))
|
| 101 |
-
parser.add_argument("--episodes", type=int, default=20)
|
| 102 |
-
parser.add_argument("--budget", type=int, default=6)
|
| 103 |
-
parser.add_argument("--seed", type=int, default=2)
|
| 104 |
-
parser.add_argument("--acquisition", choices=("mean", "ucb", "ei"), default="ucb")
|
| 105 |
-
parser.add_argument("--beta", type=float, default=2.0)
|
| 106 |
-
parser.add_argument("--xi", type=float, default=0.0)
|
| 107 |
-
parser.add_argument("--output", type=Path, default=Path("outputs/generalization_eval.json"))
|
| 108 |
-
args = parser.parse_args()
|
| 109 |
-
|
| 110 |
-
splits = json.loads(args.splits.read_text(encoding="utf-8"))
|
| 111 |
-
sections = {
|
| 112 |
-
"shape_generalization": splits["shape_generalization"],
|
| 113 |
-
"family_holdout": splits["family_holdout"],
|
| 114 |
-
}
|
| 115 |
-
results = {
|
| 116 |
-
name: _evaluate_section(
|
| 117 |
-
section_name=name,
|
| 118 |
-
split=section,
|
| 119 |
-
measurement_path=args.measurement_path,
|
| 120 |
-
episodes=args.episodes,
|
| 121 |
-
budget=args.budget,
|
| 122 |
-
seed=args.seed,
|
| 123 |
-
acquisition=args.acquisition,
|
| 124 |
-
beta=args.beta,
|
| 125 |
-
xi=args.xi,
|
| 126 |
-
)
|
| 127 |
-
for name, section in sections.items()
|
| 128 |
-
}
|
| 129 |
-
|
| 130 |
-
summary = {
|
| 131 |
-
"measurement_path": args.measurement_path,
|
| 132 |
-
"splits_path": str(args.splits),
|
| 133 |
-
"episodes": args.episodes,
|
| 134 |
-
"budget": args.budget,
|
| 135 |
-
"acquisition": args.acquisition,
|
| 136 |
-
"beta": args.beta,
|
| 137 |
-
"xi": args.xi,
|
| 138 |
-
"results": results,
|
| 139 |
-
}
|
| 140 |
-
args.output.parent.mkdir(parents=True, exist_ok=True)
|
| 141 |
-
with args.output.open("w", encoding="utf-8") as handle:
|
| 142 |
-
json.dump(summary, handle, indent=2)
|
| 143 |
-
print(json.dumps(summary, indent=2))
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
if __name__ == "__main__":
|
| 147 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
scripts/qwen_05b_spec.py
DELETED
|
@@ -1,115 +0,0 @@
|
|
| 1 |
-
#!/usr/bin/env python3
|
| 2 |
-
from __future__ import annotations
|
| 3 |
-
|
| 4 |
-
from dataclasses import dataclass
|
| 5 |
-
from typing import Dict, List
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
MODEL_ID = "Qwen/Qwen2.5-0.5B"
|
| 9 |
-
HIDDEN_SIZE = 896
|
| 10 |
-
INTERMEDIATE_SIZE = 4864
|
| 11 |
-
NUM_HIDDEN_LAYERS = 24
|
| 12 |
-
NUM_ATTENTION_HEADS = 14
|
| 13 |
-
NUM_KEY_VALUE_HEADS = 2
|
| 14 |
-
HEAD_DIM = HIDDEN_SIZE // NUM_ATTENTION_HEADS
|
| 15 |
-
KV_PROJ_SIZE = NUM_KEY_VALUE_HEADS * HEAD_DIM
|
| 16 |
-
DTYPE = "bfloat16"
|
| 17 |
-
BATCH_SIZE = 1
|
| 18 |
-
|
| 19 |
-
PREFILL_SEQ_LENS = (128, 512, 2048)
|
| 20 |
-
DECODE_CTX_LENS = (128, 512, 2048, 8192)
|
| 21 |
-
|
| 22 |
-
MODE_IDS = {
|
| 23 |
-
"prefill": 1,
|
| 24 |
-
"decode": 2,
|
| 25 |
-
}
|
| 26 |
-
|
| 27 |
-
ROLE_IDS = {
|
| 28 |
-
"rmsnorm": 1,
|
| 29 |
-
"attn_softmax": 2,
|
| 30 |
-
"q_proj": 3,
|
| 31 |
-
"k_proj": 4,
|
| 32 |
-
"v_proj": 5,
|
| 33 |
-
"o_proj": 6,
|
| 34 |
-
"gate_proj": 7,
|
| 35 |
-
"up_proj": 8,
|
| 36 |
-
"down_proj": 9,
|
| 37 |
-
}
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
@dataclass(frozen=True)
|
| 41 |
-
class QwenKernelTask:
|
| 42 |
-
family_group: str
|
| 43 |
-
family: str
|
| 44 |
-
task_id: str
|
| 45 |
-
role: str
|
| 46 |
-
mode: str
|
| 47 |
-
m: int
|
| 48 |
-
n: int
|
| 49 |
-
k: int
|
| 50 |
-
seq_len: int
|
| 51 |
-
ctx_len: int
|
| 52 |
-
|
| 53 |
-
def shape_fields(self) -> Dict[str, int | str]:
|
| 54 |
-
return {
|
| 55 |
-
"family_group": self.family_group,
|
| 56 |
-
"family": self.family,
|
| 57 |
-
"role": self.role,
|
| 58 |
-
"mode": self.mode,
|
| 59 |
-
"role_id": ROLE_IDS[self.role],
|
| 60 |
-
"mode_id": MODE_IDS[self.mode],
|
| 61 |
-
"m": self.m,
|
| 62 |
-
"n": self.n,
|
| 63 |
-
"k": self.k,
|
| 64 |
-
"seq_len": self.seq_len,
|
| 65 |
-
"ctx_len": self.ctx_len,
|
| 66 |
-
"hidden_size": HIDDEN_SIZE,
|
| 67 |
-
"intermediate_size": INTERMEDIATE_SIZE,
|
| 68 |
-
"num_attention_heads": NUM_ATTENTION_HEADS,
|
| 69 |
-
"num_key_value_heads": NUM_KEY_VALUE_HEADS,
|
| 70 |
-
"head_dim": HEAD_DIM,
|
| 71 |
-
"dtype": DTYPE,
|
| 72 |
-
"model_id": MODEL_ID,
|
| 73 |
-
}
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
def _task_id(role: str, mode: str, m: int, n: int, k: int, seq_len: int, ctx_len: int) -> str:
|
| 77 |
-
extra = f"_k{k}" if k > 0 else ""
|
| 78 |
-
ctx = f"_ctx{ctx_len}" if ctx_len > 0 else ""
|
| 79 |
-
return f"qwen05b_{role}_{mode}_m{m}_n{n}{extra}_seq{seq_len}{ctx}"
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
def qwen_05b_tasks() -> List[QwenKernelTask]:
|
| 83 |
-
tasks: List[QwenKernelTask] = []
|
| 84 |
-
|
| 85 |
-
for seq_len in PREFILL_SEQ_LENS:
|
| 86 |
-
tasks.extend(
|
| 87 |
-
[
|
| 88 |
-
QwenKernelTask("QWEN", "rmsnorm", _task_id("rmsnorm", "prefill", seq_len, HIDDEN_SIZE, 0, seq_len, seq_len), "rmsnorm", "prefill", seq_len, HIDDEN_SIZE, 0, seq_len, seq_len),
|
| 89 |
-
QwenKernelTask("QWEN", "softmax", _task_id("attn_softmax", "prefill", NUM_ATTENTION_HEADS * seq_len, seq_len, 0, seq_len, seq_len), "attn_softmax", "prefill", NUM_ATTENTION_HEADS * seq_len, seq_len, 0, seq_len, seq_len),
|
| 90 |
-
QwenKernelTask("QWEN", "gemm", _task_id("q_proj", "prefill", seq_len, HIDDEN_SIZE, HIDDEN_SIZE, seq_len, seq_len), "q_proj", "prefill", seq_len, HIDDEN_SIZE, HIDDEN_SIZE, seq_len, seq_len),
|
| 91 |
-
QwenKernelTask("QWEN", "gemm", _task_id("k_proj", "prefill", seq_len, KV_PROJ_SIZE, HIDDEN_SIZE, seq_len, seq_len), "k_proj", "prefill", seq_len, KV_PROJ_SIZE, HIDDEN_SIZE, seq_len, seq_len),
|
| 92 |
-
QwenKernelTask("QWEN", "gemm", _task_id("v_proj", "prefill", seq_len, KV_PROJ_SIZE, HIDDEN_SIZE, seq_len, seq_len), "v_proj", "prefill", seq_len, KV_PROJ_SIZE, HIDDEN_SIZE, seq_len, seq_len),
|
| 93 |
-
QwenKernelTask("QWEN", "gemm", _task_id("o_proj", "prefill", seq_len, HIDDEN_SIZE, HIDDEN_SIZE, seq_len, seq_len), "o_proj", "prefill", seq_len, HIDDEN_SIZE, HIDDEN_SIZE, seq_len, seq_len),
|
| 94 |
-
QwenKernelTask("QWEN", "gemm", _task_id("gate_proj", "prefill", seq_len, INTERMEDIATE_SIZE, HIDDEN_SIZE, seq_len, seq_len), "gate_proj", "prefill", seq_len, INTERMEDIATE_SIZE, HIDDEN_SIZE, seq_len, seq_len),
|
| 95 |
-
QwenKernelTask("QWEN", "gemm", _task_id("up_proj", "prefill", seq_len, INTERMEDIATE_SIZE, HIDDEN_SIZE, seq_len, seq_len), "up_proj", "prefill", seq_len, INTERMEDIATE_SIZE, HIDDEN_SIZE, seq_len, seq_len),
|
| 96 |
-
QwenKernelTask("QWEN", "gemm", _task_id("down_proj", "prefill", seq_len, HIDDEN_SIZE, INTERMEDIATE_SIZE, seq_len, seq_len), "down_proj", "prefill", seq_len, HIDDEN_SIZE, INTERMEDIATE_SIZE, seq_len, seq_len),
|
| 97 |
-
]
|
| 98 |
-
)
|
| 99 |
-
|
| 100 |
-
for ctx_len in DECODE_CTX_LENS:
|
| 101 |
-
tasks.extend(
|
| 102 |
-
[
|
| 103 |
-
QwenKernelTask("QWEN", "rmsnorm", _task_id("rmsnorm", "decode", 1, HIDDEN_SIZE, 0, 1, ctx_len), "rmsnorm", "decode", 1, HIDDEN_SIZE, 0, 1, ctx_len),
|
| 104 |
-
QwenKernelTask("QWEN", "softmax", _task_id("attn_softmax", "decode", NUM_ATTENTION_HEADS, ctx_len, 0, 1, ctx_len), "attn_softmax", "decode", NUM_ATTENTION_HEADS, ctx_len, 0, 1, ctx_len),
|
| 105 |
-
QwenKernelTask("QWEN", "gemm", _task_id("q_proj", "decode", 1, HIDDEN_SIZE, HIDDEN_SIZE, 1, ctx_len), "q_proj", "decode", 1, HIDDEN_SIZE, HIDDEN_SIZE, 1, ctx_len),
|
| 106 |
-
QwenKernelTask("QWEN", "gemm", _task_id("k_proj", "decode", 1, KV_PROJ_SIZE, HIDDEN_SIZE, 1, ctx_len), "k_proj", "decode", 1, KV_PROJ_SIZE, HIDDEN_SIZE, 1, ctx_len),
|
| 107 |
-
QwenKernelTask("QWEN", "gemm", _task_id("v_proj", "decode", 1, KV_PROJ_SIZE, HIDDEN_SIZE, 1, ctx_len), "v_proj", "decode", 1, KV_PROJ_SIZE, HIDDEN_SIZE, 1, ctx_len),
|
| 108 |
-
QwenKernelTask("QWEN", "gemm", _task_id("o_proj", "decode", 1, HIDDEN_SIZE, HIDDEN_SIZE, 1, ctx_len), "o_proj", "decode", 1, HIDDEN_SIZE, HIDDEN_SIZE, 1, ctx_len),
|
| 109 |
-
QwenKernelTask("QWEN", "gemm", _task_id("gate_proj", "decode", 1, INTERMEDIATE_SIZE, HIDDEN_SIZE, 1, ctx_len), "gate_proj", "decode", 1, INTERMEDIATE_SIZE, HIDDEN_SIZE, 1, ctx_len),
|
| 110 |
-
QwenKernelTask("QWEN", "gemm", _task_id("up_proj", "decode", 1, INTERMEDIATE_SIZE, HIDDEN_SIZE, 1, ctx_len), "up_proj", "decode", 1, INTERMEDIATE_SIZE, HIDDEN_SIZE, 1, ctx_len),
|
| 111 |
-
QwenKernelTask("QWEN", "gemm", _task_id("down_proj", "decode", 1, HIDDEN_SIZE, INTERMEDIATE_SIZE, 1, ctx_len), "down_proj", "decode", 1, HIDDEN_SIZE, INTERMEDIATE_SIZE, 1, ctx_len),
|
| 112 |
-
]
|
| 113 |
-
)
|
| 114 |
-
|
| 115 |
-
return tasks
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
scripts/report_task_hardness.py
DELETED
|
@@ -1,42 +0,0 @@
|
|
| 1 |
-
#!/usr/bin/env python3
|
| 2 |
-
from __future__ import annotations
|
| 3 |
-
|
| 4 |
-
import argparse
|
| 5 |
-
import csv
|
| 6 |
-
from collections import defaultdict
|
| 7 |
-
from pathlib import Path
|
| 8 |
-
from typing import Dict, List
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
def _load_rows(path: Path) -> Dict[str, List[float]]:
|
| 12 |
-
grouped: Dict[str, List[float]] = defaultdict(list)
|
| 13 |
-
with path.open("r", newline="", encoding="utf-8") as f:
|
| 14 |
-
for row in csv.DictReader(f):
|
| 15 |
-
grouped[row["task_id"]].append(float(row["median_ms"]))
|
| 16 |
-
return grouped
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
def main() -> None:
|
| 20 |
-
parser = argparse.ArgumentParser(description="Report task hardness from measured latency table.")
|
| 21 |
-
parser.add_argument("--measurement-path", type=Path, default=Path("data/autotune_measurements.csv"))
|
| 22 |
-
parser.add_argument("--budget", type=int, default=6)
|
| 23 |
-
args = parser.parse_args()
|
| 24 |
-
|
| 25 |
-
grouped = _load_rows(args.measurement_path)
|
| 26 |
-
for task_id, vals in sorted(grouped.items()):
|
| 27 |
-
vals = sorted(vals)
|
| 28 |
-
best = vals[0]
|
| 29 |
-
ncfg = len(vals)
|
| 30 |
-
within1 = sum(v <= best * 1.01 for v in vals)
|
| 31 |
-
within2 = sum(v <= best * 1.02 for v in vals)
|
| 32 |
-
within5 = sum(v <= best * 1.05 for v in vals)
|
| 33 |
-
hit_best = 1.0 - (1.0 - 1.0 / ncfg) ** args.budget
|
| 34 |
-
print(
|
| 35 |
-
f"{task_id} ncfg={ncfg} best_ms={best:.9f} "
|
| 36 |
-
f"within1={within1} within2={within2} within5={within5} "
|
| 37 |
-
f"random_hit_best@{args.budget}={hit_best:.4f}"
|
| 38 |
-
)
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
if __name__ == "__main__":
|
| 42 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
scripts/run_full_pipeline.sh
DELETED
|
@@ -1,84 +0,0 @@
|
|
| 1 |
-
#!/usr/bin/env bash
|
| 2 |
-
set -euo pipefail
|
| 3 |
-
|
| 4 |
-
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
| 5 |
-
REPO_ROOT="$(cd "${SCRIPT_DIR}/.." && pwd)"
|
| 6 |
-
cd "${REPO_ROOT}"
|
| 7 |
-
|
| 8 |
-
MEASUREMENTS="${REPO_ROOT}/data/autotune_measurements.csv"
|
| 9 |
-
SPLITS="${REPO_ROOT}/data/benchmark_splits.json"
|
| 10 |
-
OUTPUT_DIR="${REPO_ROOT}/outputs"
|
| 11 |
-
mkdir -p "${OUTPUT_DIR}"
|
| 12 |
-
EPISODES=20
|
| 13 |
-
BUDGET=6
|
| 14 |
-
BETA=2.0
|
| 15 |
-
|
| 16 |
-
echo "[step] 0) checking workspace"
|
| 17 |
-
python3 - <<'PY'
|
| 18 |
-
import torch
|
| 19 |
-
|
| 20 |
-
print(f"python={__import__('sys').executable}")
|
| 21 |
-
print(f"torch={getattr(__import__('torch'), '__version__', 'missing')}")
|
| 22 |
-
print(f"torch.cuda.is_available={torch.cuda.is_available()}")
|
| 23 |
-
if not torch.cuda.is_available():
|
| 24 |
-
raise SystemExit("CUDA is not available. Run this pipeline on a GPU machine.")
|
| 25 |
-
print("ready")
|
| 26 |
-
PY
|
| 27 |
-
|
| 28 |
-
echo "[step] 1) collect multi-family measurements"
|
| 29 |
-
python3 scripts/collect_multifamily_measurements.py \
|
| 30 |
-
--output "${MEASUREMENTS}" \
|
| 31 |
-
--families softmax layernorm grouped_gemm \
|
| 32 |
-
--n-cols 256 512 1024 1536 2048 3072 4096 6144 8192 \
|
| 33 |
-
--m 4096 \
|
| 34 |
-
--repeats 200 \
|
| 35 |
-
--warmup 25 \
|
| 36 |
-
--seed 0 \
|
| 37 |
-
--append
|
| 38 |
-
|
| 39 |
-
echo "[step] 2) build train/test splits"
|
| 40 |
-
python3 scripts/build_benchmark_splits.py \
|
| 41 |
-
--measurement-path "${MEASUREMENTS}" \
|
| 42 |
-
--output "${SPLITS}" \
|
| 43 |
-
--heldout-family grouped_gemm
|
| 44 |
-
|
| 45 |
-
echo "[step] 3) local smoke test"
|
| 46 |
-
python3 scripts/smoke_test_client.py | tee "${OUTPUT_DIR}/smoke_test_client.json"
|
| 47 |
-
|
| 48 |
-
echo "[step] 4) evaluate held-out shapes and held-out family"
|
| 49 |
-
python3 scripts/eval_generalization.py \
|
| 50 |
-
--measurement-path "${MEASUREMENTS}" \
|
| 51 |
-
--splits "${SPLITS}" \
|
| 52 |
-
--episodes "${EPISODES}" \
|
| 53 |
-
--budget "${BUDGET}" \
|
| 54 |
-
--seed 2 \
|
| 55 |
-
--acquisition ucb \
|
| 56 |
-
--beta "${BETA}" \
|
| 57 |
-
--output "${OUTPUT_DIR}/generalization_eval.json"
|
| 58 |
-
|
| 59 |
-
echo "[step] 5) benchmark eager vs torch.compile vs best Triton configs"
|
| 60 |
-
python3 scripts/benchmark_runtime_references.py \
|
| 61 |
-
--generalization-results "${OUTPUT_DIR}/generalization_eval.json" \
|
| 62 |
-
--repeats 100 \
|
| 63 |
-
--warmup 10 \
|
| 64 |
-
--seed 123 \
|
| 65 |
-
--output "${OUTPUT_DIR}/runtime_references.json"
|
| 66 |
-
|
| 67 |
-
python3 - <<'PY'
|
| 68 |
-
import json
|
| 69 |
-
from pathlib import Path
|
| 70 |
-
|
| 71 |
-
summary = json.loads(Path("outputs/generalization_eval.json").read_text(encoding="utf-8"))
|
| 72 |
-
for section_name, section in summary["results"].items():
|
| 73 |
-
print(section_name)
|
| 74 |
-
print(" random:", section["random_summary"])
|
| 75 |
-
print(" surrogate:", section["surrogate_summary"])
|
| 76 |
-
|
| 77 |
-
runtime = json.loads(Path("outputs/runtime_references.json").read_text(encoding="utf-8"))
|
| 78 |
-
for task_id, task in runtime["results"].items():
|
| 79 |
-
print(task_id)
|
| 80 |
-
print(" torch:", task["torch"])
|
| 81 |
-
print(" speedups:", task["speedups"])
|
| 82 |
-
PY
|
| 83 |
-
|
| 84 |
-
echo "[done] outputs in outputs/"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
scripts/run_qwen_05b_pipeline.sh
DELETED
|
@@ -1,90 +0,0 @@
|
|
| 1 |
-
#!/usr/bin/env bash
|
| 2 |
-
set -euo pipefail
|
| 3 |
-
|
| 4 |
-
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
| 5 |
-
REPO_ROOT="$(cd "${SCRIPT_DIR}/.." && pwd)"
|
| 6 |
-
cd "${REPO_ROOT}"
|
| 7 |
-
|
| 8 |
-
MEASUREMENTS="${REPO_ROOT}/data/qwen_05b_measurements.csv"
|
| 9 |
-
SPLITS="${REPO_ROOT}/data/qwen_05b_splits.json"
|
| 10 |
-
OUTPUT_DIR="${REPO_ROOT}/outputs"
|
| 11 |
-
mkdir -p "${OUTPUT_DIR}"
|
| 12 |
-
EPISODES=20
|
| 13 |
-
BUDGET=6
|
| 14 |
-
BETA=2.0
|
| 15 |
-
|
| 16 |
-
echo "[step] 0) checking workspace"
|
| 17 |
-
python3 - <<'PY'
|
| 18 |
-
import torch
|
| 19 |
-
|
| 20 |
-
print(f"python={__import__('sys').executable}")
|
| 21 |
-
print(f"torch={getattr(__import__('torch'), '__version__', 'missing')}")
|
| 22 |
-
print(f"torch.cuda.is_available={torch.cuda.is_available()}")
|
| 23 |
-
if not torch.cuda.is_available():
|
| 24 |
-
raise SystemExit("CUDA is not available. Run this pipeline on a GPU machine.")
|
| 25 |
-
print("ready")
|
| 26 |
-
PY
|
| 27 |
-
|
| 28 |
-
echo "[step] 1) collect exact Qwen2.5-0.5B kernel measurements"
|
| 29 |
-
python3 scripts/collect_qwen_05b_measurements.py \
|
| 30 |
-
--output "${MEASUREMENTS}" \
|
| 31 |
-
--repeats 200 \
|
| 32 |
-
--warmup 25 \
|
| 33 |
-
--seed 0 \
|
| 34 |
-
--append
|
| 35 |
-
|
| 36 |
-
echo "[step] 2) build Qwen splits"
|
| 37 |
-
python3 scripts/build_qwen_05b_splits.py --output "${SPLITS}"
|
| 38 |
-
|
| 39 |
-
echo "[step] 3) local smoke test against Qwen measurement cache"
|
| 40 |
-
python3 - <<'PY'
|
| 41 |
-
import json
|
| 42 |
-
from pathlib import Path
|
| 43 |
-
from client import SoftmaxSurrogateEnvClient
|
| 44 |
-
|
| 45 |
-
client = SoftmaxSurrogateEnvClient(measurement_path="data/qwen_05b_measurements.csv")
|
| 46 |
-
reset_out = client.reset()
|
| 47 |
-
first_config = reset_out["observation"]["tried_config_ids"][0]
|
| 48 |
-
step_out = client.step({"config_id": first_config})
|
| 49 |
-
summary = {"reset": reset_out, "step": step_out}
|
| 50 |
-
Path("outputs/qwen_05b_smoke_test.json").write_text(json.dumps(summary, indent=2), encoding="utf-8")
|
| 51 |
-
print(json.dumps(summary, indent=2))
|
| 52 |
-
PY
|
| 53 |
-
|
| 54 |
-
echo "[step] 4) evaluate held-out shapes and held-out family"
|
| 55 |
-
python3 scripts/eval_generalization.py \
|
| 56 |
-
--measurement-path "${MEASUREMENTS}" \
|
| 57 |
-
--splits "${SPLITS}" \
|
| 58 |
-
--episodes "${EPISODES}" \
|
| 59 |
-
--budget "${BUDGET}" \
|
| 60 |
-
--seed 2 \
|
| 61 |
-
--acquisition ucb \
|
| 62 |
-
--beta "${BETA}" \
|
| 63 |
-
--output "${OUTPUT_DIR}/qwen_05b_generalization_eval.json"
|
| 64 |
-
|
| 65 |
-
echo "[step] 5) benchmark eager vs torch.compile vs best Triton configs"
|
| 66 |
-
python3 scripts/benchmark_qwen_05b_runtime.py \
|
| 67 |
-
--generalization-results "${OUTPUT_DIR}/qwen_05b_generalization_eval.json" \
|
| 68 |
-
--repeats 100 \
|
| 69 |
-
--warmup 10 \
|
| 70 |
-
--seed 123 \
|
| 71 |
-
--output "${OUTPUT_DIR}/qwen_05b_runtime_references.json"
|
| 72 |
-
|
| 73 |
-
python3 - <<'PY'
|
| 74 |
-
import json
|
| 75 |
-
from pathlib import Path
|
| 76 |
-
|
| 77 |
-
eval_summary = json.loads(Path("outputs/qwen_05b_generalization_eval.json").read_text(encoding="utf-8"))
|
| 78 |
-
for section_name, section in eval_summary["results"].items():
|
| 79 |
-
print(section_name)
|
| 80 |
-
print(" random:", section["random_summary"])
|
| 81 |
-
print(" surrogate:", section["surrogate_summary"])
|
| 82 |
-
|
| 83 |
-
runtime = json.loads(Path("outputs/qwen_05b_runtime_references.json").read_text(encoding="utf-8"))
|
| 84 |
-
for task_id, task in runtime["results"].items():
|
| 85 |
-
print(task_id)
|
| 86 |
-
print(" torch:", task["torch"])
|
| 87 |
-
print(" speedups:", task["speedups"])
|
| 88 |
-
PY
|
| 89 |
-
|
| 90 |
-
echo "[done] Qwen outputs in outputs/"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
scripts/run_random_baseline.py
DELETED
|
@@ -1,187 +0,0 @@
|
|
| 1 |
-
#!/usr/bin/env python3
|
| 2 |
-
from __future__ import annotations
|
| 3 |
-
|
| 4 |
-
import argparse
|
| 5 |
-
import json
|
| 6 |
-
from dataclasses import dataclass
|
| 7 |
-
import sys
|
| 8 |
-
from pathlib import Path
|
| 9 |
-
from typing import Any, Dict, List
|
| 10 |
-
|
| 11 |
-
import numpy as np
|
| 12 |
-
|
| 13 |
-
ROOT = Path(__file__).resolve().parents[1]
|
| 14 |
-
if str(ROOT) not in sys.path:
|
| 15 |
-
sys.path.append(str(ROOT))
|
| 16 |
-
|
| 17 |
-
from server.softmax_surrogate_environment import SoftmaxSurrogateEnvironment
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
@dataclass
|
| 21 |
-
class RunRecord:
|
| 22 |
-
task_id: str
|
| 23 |
-
episode: int
|
| 24 |
-
best_latency_ms: float
|
| 25 |
-
best_config: Dict[str, int]
|
| 26 |
-
final_validation_mse: float
|
| 27 |
-
final_state: Dict[str, Any]
|
| 28 |
-
final_regret: float
|
| 29 |
-
history: List[Dict[str, Any]]
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
def _aggregate_metrics(episode_records: List[Dict[str, Any]], budget: int) -> Dict[str, Any]:
|
| 33 |
-
ks = sorted(set(k for k in (1, 3, 5, budget) if k <= budget))
|
| 34 |
-
regrets_by_k: Dict[int, List[float]] = {k: [] for k in ks}
|
| 35 |
-
auc_regrets: List[float] = []
|
| 36 |
-
|
| 37 |
-
for episode in episode_records:
|
| 38 |
-
regrets = [float(step["regret"]) for step in episode["history"]]
|
| 39 |
-
if regrets:
|
| 40 |
-
auc_regrets.append(float(sum(regrets) / len(regrets)))
|
| 41 |
-
for k in ks:
|
| 42 |
-
if len(regrets) >= k:
|
| 43 |
-
regrets_by_k[k].append(regrets[k - 1])
|
| 44 |
-
|
| 45 |
-
return {
|
| 46 |
-
"mean_regret_at": {
|
| 47 |
-
str(k): float(sum(vals) / len(vals)) for k, vals in regrets_by_k.items() if vals
|
| 48 |
-
},
|
| 49 |
-
"median_regret_at": {
|
| 50 |
-
str(k): float(np.median(np.asarray(vals, dtype=np.float32))) for k, vals in regrets_by_k.items() if vals
|
| 51 |
-
},
|
| 52 |
-
"mean_auc_regret": float(sum(auc_regrets) / len(auc_regrets)) if auc_regrets else None,
|
| 53 |
-
"oracle_hit_rate_final": float(
|
| 54 |
-
sum(1 for episode in episode_records if float(episode["final_regret"]) == 0.0) / len(episode_records)
|
| 55 |
-
) if episode_records else None,
|
| 56 |
-
}
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
def _pick_task_from_input(args: argparse.Namespace) -> str:
|
| 60 |
-
if args.task:
|
| 61 |
-
return args.task
|
| 62 |
-
env = SoftmaxSurrogateEnvironment(
|
| 63 |
-
measurement_path=args.measurement_path,
|
| 64 |
-
budget=args.budget,
|
| 65 |
-
seed=args.seed,
|
| 66 |
-
)
|
| 67 |
-
return env.reset()["observation"]["task_id"]
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
def run_random_baseline(
|
| 71 |
-
task: str,
|
| 72 |
-
episodes: int,
|
| 73 |
-
budget: int,
|
| 74 |
-
seed: int,
|
| 75 |
-
measurement_path: str,
|
| 76 |
-
) -> Dict[str, Any]:
|
| 77 |
-
rng = np.random.default_rng(seed)
|
| 78 |
-
best_overall: Dict[str, Any] = {"latency_ms": float("inf"), "config": None, "task_id": task}
|
| 79 |
-
episode_records: List[Dict[str, Any]] = []
|
| 80 |
-
|
| 81 |
-
env = SoftmaxSurrogateEnvironment(
|
| 82 |
-
measurement_path=measurement_path,
|
| 83 |
-
budget=budget,
|
| 84 |
-
seed=seed,
|
| 85 |
-
)
|
| 86 |
-
|
| 87 |
-
for episode in range(episodes):
|
| 88 |
-
env.reset(task=task, seed=seed + episode)
|
| 89 |
-
done = False
|
| 90 |
-
episode_best = float("inf")
|
| 91 |
-
episode_best_cfg: Dict[str, int] | None = None
|
| 92 |
-
history: List[Dict[str, Any]] = []
|
| 93 |
-
|
| 94 |
-
while not done:
|
| 95 |
-
unseen = [config_id for config_id in env.available_config_ids() if config_id not in env.seen_config_ids()]
|
| 96 |
-
choice_pool = unseen if unseen else env.available_config_ids()
|
| 97 |
-
config_id = int(rng.choice(choice_pool))
|
| 98 |
-
step_out = env.step({"config_id": config_id})
|
| 99 |
-
obs = step_out["observation"]
|
| 100 |
-
trial = obs["last_trial"]
|
| 101 |
-
history.append(
|
| 102 |
-
{
|
| 103 |
-
"config_id": config_id,
|
| 104 |
-
"latency_ms": trial["latency_ms"],
|
| 105 |
-
"config": trial["config"],
|
| 106 |
-
"reward": step_out["reward"],
|
| 107 |
-
"regret": step_out["info"]["current_regret"],
|
| 108 |
-
"validation_mse": step_out["info"]["validation_mse"],
|
| 109 |
-
}
|
| 110 |
-
)
|
| 111 |
-
if obs["best_so_far_ms"] < episode_best:
|
| 112 |
-
episode_best = obs["best_so_far_ms"]
|
| 113 |
-
best_id = env.seen_config_ids()[int(np.argmin([env.measured_latency_ms(cid) for cid in env.seen_config_ids()]))]
|
| 114 |
-
episode_best_cfg = env.config_info(best_id)
|
| 115 |
-
done = bool(step_out["done"])
|
| 116 |
-
|
| 117 |
-
if episode_best < best_overall["latency_ms"]:
|
| 118 |
-
best_overall = {
|
| 119 |
-
"latency_ms": float(episode_best),
|
| 120 |
-
"config": episode_best_cfg,
|
| 121 |
-
"task_id": task,
|
| 122 |
-
}
|
| 123 |
-
|
| 124 |
-
diagnostics = env.diagnostics()
|
| 125 |
-
episode_records.append(
|
| 126 |
-
RunRecord(
|
| 127 |
-
task_id=task,
|
| 128 |
-
episode=episode,
|
| 129 |
-
best_latency_ms=float(episode_best),
|
| 130 |
-
best_config=episode_best_cfg or {},
|
| 131 |
-
final_validation_mse=float(diagnostics["validation_mse"]),
|
| 132 |
-
final_state=env.state(),
|
| 133 |
-
final_regret=float(diagnostics["current_regret"]),
|
| 134 |
-
history=history,
|
| 135 |
-
).__dict__
|
| 136 |
-
)
|
| 137 |
-
|
| 138 |
-
return {
|
| 139 |
-
"task": task,
|
| 140 |
-
"method": "random",
|
| 141 |
-
"episodes": episodes,
|
| 142 |
-
"budget": budget,
|
| 143 |
-
"seed": seed,
|
| 144 |
-
"oracle_best_ms": env.oracle_best()["median_ms"],
|
| 145 |
-
"best_overall": best_overall,
|
| 146 |
-
"aggregate_metrics": _aggregate_metrics(episode_records, budget),
|
| 147 |
-
"episodes_summary": episode_records,
|
| 148 |
-
}
|
| 149 |
-
|
| 150 |
-
|
| 151 |
-
def parse_args() -> argparse.Namespace:
|
| 152 |
-
parser = argparse.ArgumentParser(description="Random baseline for surrogate environment.")
|
| 153 |
-
parser.add_argument("--task", default=None, help="Task ID (e.g., softmax_m4096_n2048)")
|
| 154 |
-
parser.add_argument("--episodes", type=int, default=20)
|
| 155 |
-
parser.add_argument("--budget", type=int, default=6)
|
| 156 |
-
parser.add_argument("--seed", type=int, default=0)
|
| 157 |
-
parser.add_argument(
|
| 158 |
-
"--measurement-path",
|
| 159 |
-
type=str,
|
| 160 |
-
default="data/autotune_measurements.csv",
|
| 161 |
-
)
|
| 162 |
-
parser.add_argument(
|
| 163 |
-
"--output",
|
| 164 |
-
type=Path,
|
| 165 |
-
default=Path("outputs/random_baseline.json"),
|
| 166 |
-
)
|
| 167 |
-
return parser.parse_args()
|
| 168 |
-
|
| 169 |
-
|
| 170 |
-
def main() -> None:
|
| 171 |
-
args = parse_args()
|
| 172 |
-
task = _pick_task_from_input(args)
|
| 173 |
-
summary = run_random_baseline(
|
| 174 |
-
task=task,
|
| 175 |
-
episodes=args.episodes,
|
| 176 |
-
budget=args.budget,
|
| 177 |
-
seed=args.seed,
|
| 178 |
-
measurement_path=args.measurement_path,
|
| 179 |
-
)
|
| 180 |
-
args.output.parent.mkdir(parents=True, exist_ok=True)
|
| 181 |
-
with args.output.open("w", encoding="utf-8") as f:
|
| 182 |
-
json.dump(summary, f, indent=2)
|
| 183 |
-
print(json.dumps(summary, indent=2))
|
| 184 |
-
|
| 185 |
-
|
| 186 |
-
if __name__ == "__main__":
|
| 187 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
scripts/run_surrogate_baseline.py
DELETED
|
@@ -1,217 +0,0 @@
|
|
| 1 |
-
#!/usr/bin/env python3
|
| 2 |
-
from __future__ import annotations
|
| 3 |
-
|
| 4 |
-
import argparse
|
| 5 |
-
import json
|
| 6 |
-
import sys
|
| 7 |
-
from pathlib import Path
|
| 8 |
-
from typing import Any, Dict, List
|
| 9 |
-
|
| 10 |
-
import numpy as np
|
| 11 |
-
|
| 12 |
-
ROOT = Path(__file__).resolve().parents[1]
|
| 13 |
-
if str(ROOT) not in sys.path:
|
| 14 |
-
sys.path.append(str(ROOT))
|
| 15 |
-
|
| 16 |
-
from server.softmax_surrogate_environment import SoftmaxSurrogateEnvironment
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
def _choose_surrogate_action(
|
| 20 |
-
env: SoftmaxSurrogateEnvironment,
|
| 21 |
-
acquisition: str,
|
| 22 |
-
beta: float,
|
| 23 |
-
xi: float,
|
| 24 |
-
) -> int:
|
| 25 |
-
seen = set(env.seen_config_ids())
|
| 26 |
-
best_config_id = -1
|
| 27 |
-
best_score = float("-inf")
|
| 28 |
-
|
| 29 |
-
for config_id in env.available_config_ids():
|
| 30 |
-
if config_id in seen and len(seen) < len(env.available_config_ids()):
|
| 31 |
-
continue
|
| 32 |
-
score = env.acquisition_score(config_id, strategy=acquisition, beta=beta, xi=xi)
|
| 33 |
-
if score > best_score:
|
| 34 |
-
best_score = score
|
| 35 |
-
best_config_id = config_id
|
| 36 |
-
|
| 37 |
-
if best_config_id < 0:
|
| 38 |
-
raise RuntimeError("Failed to choose a surrogate action.")
|
| 39 |
-
return best_config_id
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
def _aggregate_metrics(episode_records: List[Dict[str, Any]], budget: int) -> Dict[str, Any]:
|
| 43 |
-
ks = sorted(set(k for k in (1, 3, 5, budget) if k <= budget))
|
| 44 |
-
regrets_by_k: Dict[int, List[float]] = {k: [] for k in ks}
|
| 45 |
-
auc_regrets: List[float] = []
|
| 46 |
-
|
| 47 |
-
for episode in episode_records:
|
| 48 |
-
regrets = [float(step["regret"]) for step in episode["history"]]
|
| 49 |
-
if regrets:
|
| 50 |
-
auc_regrets.append(float(sum(regrets) / len(regrets)))
|
| 51 |
-
for k in ks:
|
| 52 |
-
if len(regrets) >= k:
|
| 53 |
-
regrets_by_k[k].append(regrets[k - 1])
|
| 54 |
-
|
| 55 |
-
return {
|
| 56 |
-
"mean_regret_at": {
|
| 57 |
-
str(k): float(sum(vals) / len(vals)) for k, vals in regrets_by_k.items() if vals
|
| 58 |
-
},
|
| 59 |
-
"median_regret_at": {
|
| 60 |
-
str(k): float(np.median(np.asarray(vals, dtype=np.float32))) for k, vals in regrets_by_k.items() if vals
|
| 61 |
-
},
|
| 62 |
-
"mean_auc_regret": float(sum(auc_regrets) / len(auc_regrets)) if auc_regrets else None,
|
| 63 |
-
"oracle_hit_rate_final": float(
|
| 64 |
-
sum(1 for episode in episode_records if float(episode["final_regret"]) == 0.0) / len(episode_records)
|
| 65 |
-
) if episode_records else None,
|
| 66 |
-
}
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
def run_surrogate_baseline(
|
| 70 |
-
task: str,
|
| 71 |
-
episodes: int,
|
| 72 |
-
budget: int,
|
| 73 |
-
seed: int,
|
| 74 |
-
measurement_path: str,
|
| 75 |
-
train_task_ids: List[str] | None = None,
|
| 76 |
-
acquisition: str = "ucb",
|
| 77 |
-
beta: float = 1.5,
|
| 78 |
-
xi: float = 0.0,
|
| 79 |
-
) -> Dict[str, Any]:
|
| 80 |
-
env = SoftmaxSurrogateEnvironment(
|
| 81 |
-
measurement_path=measurement_path,
|
| 82 |
-
budget=budget,
|
| 83 |
-
seed=seed,
|
| 84 |
-
train_task_ids=train_task_ids,
|
| 85 |
-
)
|
| 86 |
-
|
| 87 |
-
best_overall = {"latency_ms": float("inf"), "config": None, "task_id": task}
|
| 88 |
-
episode_records: List[Dict[str, Any]] = []
|
| 89 |
-
|
| 90 |
-
for episode in range(episodes):
|
| 91 |
-
env.reset(task=task, seed=seed + episode)
|
| 92 |
-
done = False
|
| 93 |
-
episode_best = float("inf")
|
| 94 |
-
episode_best_cfg: Dict[str, int] | None = None
|
| 95 |
-
history: List[Dict[str, Any]] = []
|
| 96 |
-
|
| 97 |
-
while not done:
|
| 98 |
-
config_id = _choose_surrogate_action(env, acquisition=acquisition, beta=beta, xi=xi)
|
| 99 |
-
out = env.step({"config_id": config_id})
|
| 100 |
-
obs = out["observation"]
|
| 101 |
-
trial = obs["last_trial"]
|
| 102 |
-
history.append(
|
| 103 |
-
{
|
| 104 |
-
"config_id": config_id,
|
| 105 |
-
"latency_ms": trial["latency_ms"],
|
| 106 |
-
"config": trial["config"],
|
| 107 |
-
"reward": out["reward"],
|
| 108 |
-
"regret": out["info"]["current_regret"],
|
| 109 |
-
"validation_mse": out["info"]["validation_mse"],
|
| 110 |
-
}
|
| 111 |
-
)
|
| 112 |
-
if obs["best_so_far_ms"] < episode_best:
|
| 113 |
-
episode_best = obs["best_so_far_ms"]
|
| 114 |
-
best_seen = min(env.seen_config_ids(), key=env.measured_latency_ms)
|
| 115 |
-
episode_best_cfg = env.config_info(best_seen)
|
| 116 |
-
done = bool(out["done"])
|
| 117 |
-
|
| 118 |
-
if episode_best < best_overall["latency_ms"]:
|
| 119 |
-
best_overall = {
|
| 120 |
-
"latency_ms": float(episode_best),
|
| 121 |
-
"config": episode_best_cfg,
|
| 122 |
-
"task_id": task,
|
| 123 |
-
}
|
| 124 |
-
|
| 125 |
-
diagnostics = env.diagnostics()
|
| 126 |
-
episode_records.append(
|
| 127 |
-
{
|
| 128 |
-
"task_id": task,
|
| 129 |
-
"episode": episode,
|
| 130 |
-
"best_latency_ms": episode_best,
|
| 131 |
-
"best_config": episode_best_cfg or {},
|
| 132 |
-
"final_validation_mse": diagnostics["validation_mse"],
|
| 133 |
-
"final_regret": diagnostics["current_regret"],
|
| 134 |
-
"history": history,
|
| 135 |
-
}
|
| 136 |
-
)
|
| 137 |
-
|
| 138 |
-
return {
|
| 139 |
-
"task": task,
|
| 140 |
-
"method": "surrogate",
|
| 141 |
-
"episodes": episodes,
|
| 142 |
-
"budget": budget,
|
| 143 |
-
"seed": seed,
|
| 144 |
-
"train_task_ids": list(train_task_ids or []),
|
| 145 |
-
"acquisition": acquisition,
|
| 146 |
-
"beta": beta,
|
| 147 |
-
"xi": xi,
|
| 148 |
-
"oracle_best_ms": env.oracle_best()["median_ms"],
|
| 149 |
-
"best_overall": best_overall,
|
| 150 |
-
"aggregate_metrics": _aggregate_metrics(episode_records, budget),
|
| 151 |
-
"episodes_summary": episode_records,
|
| 152 |
-
}
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
def parse_args() -> argparse.Namespace:
|
| 156 |
-
parser = argparse.ArgumentParser(description="Surrogate-guided baseline.")
|
| 157 |
-
parser.add_argument("--task", default=None, help="Task ID (e.g., softmax_m4096_n2048)")
|
| 158 |
-
parser.add_argument("--episodes", type=int, default=20)
|
| 159 |
-
parser.add_argument("--budget", type=int, default=6)
|
| 160 |
-
parser.add_argument("--seed", type=int, default=0)
|
| 161 |
-
parser.add_argument(
|
| 162 |
-
"--acquisition",
|
| 163 |
-
type=str,
|
| 164 |
-
choices=("mean", "ucb", "ei"),
|
| 165 |
-
default="ucb",
|
| 166 |
-
help="Candidate selection mode: mean, ucb, or ei.",
|
| 167 |
-
)
|
| 168 |
-
parser.add_argument("--beta", type=float, default=1.5, help="UCB exploration strength.")
|
| 169 |
-
parser.add_argument("--xi", type=float, default=0.0, help="Expected-improvement margin.")
|
| 170 |
-
parser.add_argument(
|
| 171 |
-
"--train-tasks-file",
|
| 172 |
-
type=Path,
|
| 173 |
-
default=None,
|
| 174 |
-
help="Optional JSON file containing a list of train task ids.",
|
| 175 |
-
)
|
| 176 |
-
parser.add_argument(
|
| 177 |
-
"--measurement-path",
|
| 178 |
-
type=str,
|
| 179 |
-
default="data/autotune_measurements.csv",
|
| 180 |
-
)
|
| 181 |
-
parser.add_argument(
|
| 182 |
-
"--output",
|
| 183 |
-
type=Path,
|
| 184 |
-
default=Path("outputs/surrogate_baseline.json"),
|
| 185 |
-
)
|
| 186 |
-
return parser.parse_args()
|
| 187 |
-
|
| 188 |
-
|
| 189 |
-
def main() -> None:
|
| 190 |
-
args = parse_args()
|
| 191 |
-
if not args.task:
|
| 192 |
-
env = SoftmaxSurrogateEnvironment(measurement_path=args.measurement_path, budget=args.budget, seed=args.seed)
|
| 193 |
-
args.task = env.reset()["observation"]["task_id"]
|
| 194 |
-
|
| 195 |
-
train_task_ids = None
|
| 196 |
-
if args.train_tasks_file is not None:
|
| 197 |
-
train_task_ids = json.loads(args.train_tasks_file.read_text(encoding="utf-8"))
|
| 198 |
-
|
| 199 |
-
summary = run_surrogate_baseline(
|
| 200 |
-
task=args.task,
|
| 201 |
-
episodes=args.episodes,
|
| 202 |
-
budget=args.budget,
|
| 203 |
-
seed=args.seed,
|
| 204 |
-
measurement_path=args.measurement_path,
|
| 205 |
-
train_task_ids=train_task_ids,
|
| 206 |
-
acquisition=args.acquisition,
|
| 207 |
-
beta=args.beta,
|
| 208 |
-
xi=args.xi,
|
| 209 |
-
)
|
| 210 |
-
args.output.parent.mkdir(parents=True, exist_ok=True)
|
| 211 |
-
with args.output.open("w", encoding="utf-8") as f:
|
| 212 |
-
json.dump(summary, f, indent=2)
|
| 213 |
-
print(json.dumps(summary, indent=2))
|
| 214 |
-
|
| 215 |
-
|
| 216 |
-
if __name__ == "__main__":
|
| 217 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
scripts/smoke_test_client.py
DELETED
|
@@ -1,28 +0,0 @@
|
|
| 1 |
-
#!/usr/bin/env python3
|
| 2 |
-
from __future__ import annotations
|
| 3 |
-
|
| 4 |
-
import json
|
| 5 |
-
import sys
|
| 6 |
-
from pathlib import Path
|
| 7 |
-
|
| 8 |
-
ROOT = Path(__file__).resolve().parents[1]
|
| 9 |
-
if str(ROOT) not in sys.path:
|
| 10 |
-
sys.path.append(str(ROOT))
|
| 11 |
-
|
| 12 |
-
from client import SoftmaxSurrogateEnvClient
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
def main() -> None:
|
| 16 |
-
client = SoftmaxSurrogateEnvClient()
|
| 17 |
-
reset_out = client.reset()
|
| 18 |
-
step_out = client.step({"config_id": 0})
|
| 19 |
-
summary = {"reset": reset_out, "step": step_out}
|
| 20 |
-
out = Path("outputs/smoke_test_client.json")
|
| 21 |
-
out.parent.mkdir(parents=True, exist_ok=True)
|
| 22 |
-
with out.open("w", encoding="utf-8") as f:
|
| 23 |
-
json.dump(summary, f, indent=2)
|
| 24 |
-
print(json.dumps(summary, indent=2))
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
if __name__ == "__main__":
|
| 28 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
server/__init__.py
DELETED
|
@@ -1 +0,0 @@
|
|
| 1 |
-
# Package marker for importing server modules directly from workspace scripts.
|
|
|
|
|
|
server/app.py
DELETED
|
@@ -1,72 +0,0 @@
|
|
| 1 |
-
from __future__ import annotations
|
| 2 |
-
|
| 3 |
-
import argparse
|
| 4 |
-
import sys
|
| 5 |
-
from pathlib import Path
|
| 6 |
-
from typing import Any, Dict, Optional
|
| 7 |
-
|
| 8 |
-
ROOT = Path(__file__).resolve().parents[1]
|
| 9 |
-
if str(ROOT) not in sys.path:
|
| 10 |
-
sys.path.append(str(ROOT))
|
| 11 |
-
|
| 12 |
-
from fastapi import FastAPI
|
| 13 |
-
from fastapi import HTTPException
|
| 14 |
-
from pydantic import BaseModel
|
| 15 |
-
|
| 16 |
-
from models import ResetResult, StepResult
|
| 17 |
-
from server.softmax_surrogate_environment import SoftmaxSurrogateEnvironment
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
app = FastAPI(title="Autotune Benchmark OpenEnv Server")
|
| 21 |
-
env = SoftmaxSurrogateEnvironment()
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
class ResetRequest(BaseModel):
|
| 25 |
-
task: Optional[str] = None
|
| 26 |
-
seed: Optional[int] = None
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
class StepRequest(BaseModel):
|
| 30 |
-
config_id: Optional[int] = None
|
| 31 |
-
x: Optional[list[float]] = None
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
@app.get("/health")
|
| 35 |
-
def health() -> Dict[str, str]:
|
| 36 |
-
return {"ok": "true"}
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
@app.post("/reset")
|
| 40 |
-
def reset(payload: ResetRequest) -> Dict[str, Any]:
|
| 41 |
-
result = env.reset(task=payload.task, seed=payload.seed)
|
| 42 |
-
return result
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
@app.post("/step")
|
| 46 |
-
def step(payload: StepRequest) -> Dict[str, Any]:
|
| 47 |
-
if payload.config_id is not None:
|
| 48 |
-
result = env.step({"config_id": payload.config_id})
|
| 49 |
-
return result
|
| 50 |
-
if payload.x is not None:
|
| 51 |
-
result = env.step({"x": payload.x})
|
| 52 |
-
return result
|
| 53 |
-
raise HTTPException(status_code=400, detail="Missing config_id.")
|
| 54 |
-
return result
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
@app.get("/state")
|
| 58 |
-
def state() -> Dict[str, Any]:
|
| 59 |
-
return env.state()
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
if __name__ == "__main__":
|
| 63 |
-
parser = argparse.ArgumentParser(description="Run softmax surrogate environment server.")
|
| 64 |
-
parser.add_argument("--host", default="0.0.0.0")
|
| 65 |
-
parser.add_argument("--port", type=int, default=8000)
|
| 66 |
-
args = parser.parse_args()
|
| 67 |
-
try:
|
| 68 |
-
import uvicorn
|
| 69 |
-
|
| 70 |
-
uvicorn.run("app:app", host=args.host, port=args.port, reload=False)
|
| 71 |
-
except Exception as err: # pragma: no cover
|
| 72 |
-
raise RuntimeError("uvicorn not available") from err
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
server/softmax_surrogate_environment.py
DELETED
|
@@ -1,570 +0,0 @@
|
|
| 1 |
-
from __future__ import annotations
|
| 2 |
-
|
| 3 |
-
import csv
|
| 4 |
-
import json
|
| 5 |
-
import math
|
| 6 |
-
import random
|
| 7 |
-
import sys
|
| 8 |
-
from dataclasses import dataclass
|
| 9 |
-
from pathlib import Path
|
| 10 |
-
from typing import Any, Dict, List, Optional, Sequence, Tuple
|
| 11 |
-
|
| 12 |
-
import numpy as np
|
| 13 |
-
|
| 14 |
-
ROOT = Path(__file__).resolve().parents[1]
|
| 15 |
-
if str(ROOT) not in sys.path:
|
| 16 |
-
sys.path.append(str(ROOT))
|
| 17 |
-
|
| 18 |
-
from scripts.collect_measurements import BLOCK_SIZES, NUM_STAGES, NUM_WARPS
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
DEFAULT_MEASUREMENT_PATH = "data/autotune_measurements.csv"
|
| 22 |
-
DEFAULT_BUDGET = 6
|
| 23 |
-
INITIAL_DATASET_SIZE = 2
|
| 24 |
-
DUPLICATE_PENALTY = -1e-4
|
| 25 |
-
FAMILY_INDEX = {
|
| 26 |
-
"softmax": 0,
|
| 27 |
-
"layernorm": 1,
|
| 28 |
-
"grouped_gemm": 2,
|
| 29 |
-
"rmsnorm": 3,
|
| 30 |
-
"gemm": 4,
|
| 31 |
-
}
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
@dataclass(frozen=True)
|
| 35 |
-
class Measurement:
|
| 36 |
-
family_group: str
|
| 37 |
-
family: str
|
| 38 |
-
task_id: str
|
| 39 |
-
m: int
|
| 40 |
-
n: int
|
| 41 |
-
k: int
|
| 42 |
-
config_id: int
|
| 43 |
-
block_size: int
|
| 44 |
-
num_warps: int
|
| 45 |
-
num_stages: int
|
| 46 |
-
shape_json: str
|
| 47 |
-
config_json: str
|
| 48 |
-
median_ms: float
|
| 49 |
-
effective_gbps: float
|
| 50 |
-
score: float
|
| 51 |
-
validation_error: float
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
def _normalize_discrete(values: Sequence[int], value: int) -> float:
|
| 55 |
-
idx = list(values).index(int(value))
|
| 56 |
-
if len(values) == 1:
|
| 57 |
-
return 0.0
|
| 58 |
-
return 2.0 * (idx / (len(values) - 1)) - 1.0
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
class SoftmaxSurrogateEnvironment:
|
| 62 |
-
"""
|
| 63 |
-
Generic discrete-action autotuning environment backed by measured GPU data.
|
| 64 |
-
|
| 65 |
-
The class name is kept for compatibility with the existing local server and
|
| 66 |
-
baseline scripts, but the task space is now multi-family.
|
| 67 |
-
"""
|
| 68 |
-
|
| 69 |
-
def __init__(
|
| 70 |
-
self,
|
| 71 |
-
measurement_path: str = DEFAULT_MEASUREMENT_PATH,
|
| 72 |
-
budget: int = DEFAULT_BUDGET,
|
| 73 |
-
seed: int = 0,
|
| 74 |
-
initial_samples: int = INITIAL_DATASET_SIZE,
|
| 75 |
-
train_task_ids: Optional[Sequence[str]] = None,
|
| 76 |
-
) -> None:
|
| 77 |
-
self.measurement_path = Path(measurement_path)
|
| 78 |
-
self.budget = int(budget)
|
| 79 |
-
self.seed = int(seed)
|
| 80 |
-
self.initial_samples = max(1, int(initial_samples))
|
| 81 |
-
self.train_task_ids = set(train_task_ids or [])
|
| 82 |
-
|
| 83 |
-
self._measurements = self._load_measurements()
|
| 84 |
-
self._task_ids = sorted(self._measurements.keys())
|
| 85 |
-
if not self._task_ids:
|
| 86 |
-
raise RuntimeError(
|
| 87 |
-
"No measurement data found. Run the measurement collectors first."
|
| 88 |
-
)
|
| 89 |
-
|
| 90 |
-
self._rng = random.Random(self.seed)
|
| 91 |
-
self._episode_counter = 0
|
| 92 |
-
|
| 93 |
-
self._task_id: Optional[str] = None
|
| 94 |
-
self._family: Optional[str] = None
|
| 95 |
-
self._episode_id: Optional[str] = None
|
| 96 |
-
self._task_rows: List[Measurement] = []
|
| 97 |
-
self._prior_rows: List[Measurement] = []
|
| 98 |
-
self._config_by_id: Dict[int, Measurement] = {}
|
| 99 |
-
self._observed_ids: List[int] = []
|
| 100 |
-
self._observed_id_set = set()
|
| 101 |
-
self._observed_rows: List[Measurement] = []
|
| 102 |
-
self._observed_latencies: List[float] = []
|
| 103 |
-
self._steps_taken = 0
|
| 104 |
-
self._steps_remaining = 0
|
| 105 |
-
self._best_latency_ms = float("inf")
|
| 106 |
-
self._best_config_id: Optional[int] = None
|
| 107 |
-
self._validation_mse = float("inf")
|
| 108 |
-
|
| 109 |
-
self._surrogate_version = 0
|
| 110 |
-
self._surrogate_fitted_version = -1
|
| 111 |
-
self._surrogate_x: Optional[np.ndarray] = None
|
| 112 |
-
self._surrogate_y: Optional[np.ndarray] = None
|
| 113 |
-
self._surrogate_alpha: Optional[np.ndarray] = None
|
| 114 |
-
self._surrogate_k: Optional[np.ndarray] = None
|
| 115 |
-
self._surrogate_length_scale: float = 0.5
|
| 116 |
-
|
| 117 |
-
def reset(self, task: Optional[str] = None, seed: Optional[int] = None) -> Dict[str, Any]:
|
| 118 |
-
if seed is not None:
|
| 119 |
-
self._rng = random.Random(int(seed))
|
| 120 |
-
|
| 121 |
-
if task is None:
|
| 122 |
-
task = self._rng.choice(self._task_ids)
|
| 123 |
-
if task not in self._measurements:
|
| 124 |
-
raise ValueError(f"Unknown task: {task}")
|
| 125 |
-
|
| 126 |
-
rows = self._measurements[task]
|
| 127 |
-
self._task_id = task
|
| 128 |
-
self._family = rows[0].family
|
| 129 |
-
self._task_rows = rows
|
| 130 |
-
self._config_by_id = {row.config_id: row for row in rows}
|
| 131 |
-
self._prior_rows = self._build_prior_rows(task)
|
| 132 |
-
self._observed_ids = []
|
| 133 |
-
self._observed_id_set = set()
|
| 134 |
-
self._observed_rows = []
|
| 135 |
-
self._observed_latencies = []
|
| 136 |
-
self._steps_taken = 0
|
| 137 |
-
self._steps_remaining = self.budget
|
| 138 |
-
self._best_latency_ms = float("inf")
|
| 139 |
-
self._best_config_id = None
|
| 140 |
-
self._episode_counter += 1
|
| 141 |
-
self._episode_id = f"{task}:{self.seed}:{self._episode_counter}"
|
| 142 |
-
|
| 143 |
-
sample_count = min(self.initial_samples, len(rows))
|
| 144 |
-
for config_id in self._rng.sample(list(self._config_by_id.keys()), k=sample_count):
|
| 145 |
-
self._observe_config(config_id)
|
| 146 |
-
|
| 147 |
-
self._invalidate_surrogate()
|
| 148 |
-
self._validation_mse = self._compute_validation_mse()
|
| 149 |
-
|
| 150 |
-
return self._format_step_output(
|
| 151 |
-
observation=self._observation_payload(kind="reset"),
|
| 152 |
-
reward=0.0,
|
| 153 |
-
done=False,
|
| 154 |
-
info=self.diagnostics(),
|
| 155 |
-
)
|
| 156 |
-
|
| 157 |
-
def step(self, action: Any) -> Dict[str, Any]:
|
| 158 |
-
if self._task_id is None:
|
| 159 |
-
raise RuntimeError("Call reset() before step().")
|
| 160 |
-
if self._steps_remaining <= 0:
|
| 161 |
-
return self._format_step_output(
|
| 162 |
-
observation=self._observation_payload(kind="done"),
|
| 163 |
-
reward=0.0,
|
| 164 |
-
done=True,
|
| 165 |
-
info=self.diagnostics(),
|
| 166 |
-
)
|
| 167 |
-
|
| 168 |
-
config_id = self._extract_config_id(action)
|
| 169 |
-
row = self._row_for_id(config_id)
|
| 170 |
-
prev_best = self._best_latency_ms
|
| 171 |
-
duplicate = config_id in self._observed_id_set
|
| 172 |
-
|
| 173 |
-
if not duplicate:
|
| 174 |
-
self._observe_config(config_id)
|
| 175 |
-
self._surrogate_version += 1
|
| 176 |
-
|
| 177 |
-
self._steps_taken += 1
|
| 178 |
-
self._steps_remaining -= 1
|
| 179 |
-
self._validation_mse = self._compute_validation_mse()
|
| 180 |
-
|
| 181 |
-
reward = DUPLICATE_PENALTY if duplicate else max(0.0, math.log(prev_best) - math.log(self._best_latency_ms))
|
| 182 |
-
observation = self._observation_payload(
|
| 183 |
-
kind="step",
|
| 184 |
-
last_trial={
|
| 185 |
-
"config_id": config_id,
|
| 186 |
-
"config": self.config_info(config_id),
|
| 187 |
-
"latency_ms": row.median_ms,
|
| 188 |
-
"score": row.score,
|
| 189 |
-
"duplicate": duplicate,
|
| 190 |
-
},
|
| 191 |
-
)
|
| 192 |
-
return self._format_step_output(
|
| 193 |
-
observation=observation,
|
| 194 |
-
reward=reward,
|
| 195 |
-
done=self._steps_remaining <= 0,
|
| 196 |
-
info=self.diagnostics(),
|
| 197 |
-
)
|
| 198 |
-
|
| 199 |
-
def state(self) -> Dict[str, Any]:
|
| 200 |
-
if self._task_id is None:
|
| 201 |
-
return {"status": "uninitialized"}
|
| 202 |
-
return {
|
| 203 |
-
"episode_id": self._episode_id,
|
| 204 |
-
"step_count": self._steps_taken,
|
| 205 |
-
"task_id": self._task_id,
|
| 206 |
-
"family": self._family,
|
| 207 |
-
"tried_config_ids": list(self._observed_ids),
|
| 208 |
-
}
|
| 209 |
-
|
| 210 |
-
def diagnostics(self) -> Dict[str, Any]:
|
| 211 |
-
if self._task_id is None:
|
| 212 |
-
return {"status": "uninitialized"}
|
| 213 |
-
oracle_best_ms = self.oracle_best()["median_ms"]
|
| 214 |
-
regret = self._best_latency_ms / oracle_best_ms - 1.0
|
| 215 |
-
return {
|
| 216 |
-
"validation_mse": self._validation_mse,
|
| 217 |
-
"best_so_far_ms": self._best_latency_ms,
|
| 218 |
-
"oracle_best_ms": oracle_best_ms,
|
| 219 |
-
"current_regret": regret,
|
| 220 |
-
"observed_count": len(self._observed_ids),
|
| 221 |
-
"prior_count": len(self._prior_rows),
|
| 222 |
-
}
|
| 223 |
-
|
| 224 |
-
def available_tasks(self) -> List[str]:
|
| 225 |
-
return list(self._task_ids)
|
| 226 |
-
|
| 227 |
-
def available_config_ids(self) -> List[int]:
|
| 228 |
-
if self._task_id is None:
|
| 229 |
-
raise RuntimeError("Call reset() before accessing config ids.")
|
| 230 |
-
return sorted(self._config_by_id.keys())
|
| 231 |
-
|
| 232 |
-
def available_configs(self) -> List[Dict[str, Any]]:
|
| 233 |
-
return [self.config_info(config_id) for config_id in self.available_config_ids()]
|
| 234 |
-
|
| 235 |
-
def config_info(self, config_id: int) -> Dict[str, Any]:
|
| 236 |
-
row = self._row_for_id(config_id)
|
| 237 |
-
return {
|
| 238 |
-
"config_id": int(config_id),
|
| 239 |
-
"family": row.family,
|
| 240 |
-
"task_id": row.task_id,
|
| 241 |
-
"block_size": row.block_size,
|
| 242 |
-
"num_warps": row.num_warps,
|
| 243 |
-
"num_stages": row.num_stages,
|
| 244 |
-
}
|
| 245 |
-
|
| 246 |
-
def measured_latency_ms(self, config_id: int) -> float:
|
| 247 |
-
return self._row_for_id(config_id).median_ms
|
| 248 |
-
|
| 249 |
-
def oracle_best(self) -> Dict[str, Any]:
|
| 250 |
-
if self._task_id is None:
|
| 251 |
-
raise RuntimeError("Call reset() before querying oracle_best().")
|
| 252 |
-
best = min(self._task_rows, key=lambda row: row.median_ms)
|
| 253 |
-
return {
|
| 254 |
-
"config_id": best.config_id,
|
| 255 |
-
"family": best.family,
|
| 256 |
-
"task_id": best.task_id,
|
| 257 |
-
"block_size": best.block_size,
|
| 258 |
-
"num_warps": best.num_warps,
|
| 259 |
-
"num_stages": best.num_stages,
|
| 260 |
-
"median_ms": best.median_ms,
|
| 261 |
-
"score": best.score,
|
| 262 |
-
}
|
| 263 |
-
|
| 264 |
-
def predict_score(self, config_id: int) -> float:
|
| 265 |
-
return float(self._predict_with_uncertainty(config_id)[0])
|
| 266 |
-
|
| 267 |
-
def acquisition_score(
|
| 268 |
-
self,
|
| 269 |
-
config_id: int,
|
| 270 |
-
strategy: str = "ucb",
|
| 271 |
-
beta: float = 1.0,
|
| 272 |
-
xi: float = 0.0,
|
| 273 |
-
) -> float:
|
| 274 |
-
mean, sigma = self._predict_with_uncertainty(config_id)
|
| 275 |
-
if strategy == "mean":
|
| 276 |
-
return float(mean)
|
| 277 |
-
if strategy == "ucb":
|
| 278 |
-
return float(mean + float(beta) * sigma)
|
| 279 |
-
if strategy == "ei":
|
| 280 |
-
best_observed = max(row.score for row in self._observed_rows) if self._observed_rows else mean
|
| 281 |
-
delta = mean - best_observed - float(xi)
|
| 282 |
-
if sigma <= 0.0:
|
| 283 |
-
return float(max(delta, 0.0))
|
| 284 |
-
z = delta / sigma
|
| 285 |
-
return float(max(delta * _normal_cdf(z) + sigma * _normal_pdf(z), 0.0))
|
| 286 |
-
raise ValueError(f"Unknown acquisition strategy: {strategy}")
|
| 287 |
-
|
| 288 |
-
def seen_config_ids(self) -> List[int]:
|
| 289 |
-
return list(self._observed_ids)
|
| 290 |
-
|
| 291 |
-
def _build_prior_rows(self, current_task: str) -> List[Measurement]:
|
| 292 |
-
if not self.train_task_ids:
|
| 293 |
-
return []
|
| 294 |
-
prior_rows: List[Measurement] = []
|
| 295 |
-
for task_id in sorted(self.train_task_ids):
|
| 296 |
-
if task_id == current_task or task_id not in self._measurements:
|
| 297 |
-
continue
|
| 298 |
-
prior_rows.extend(self._measurements[task_id])
|
| 299 |
-
return prior_rows
|
| 300 |
-
|
| 301 |
-
def _predict_with_uncertainty(self, config_id: int) -> Tuple[float, float]:
|
| 302 |
-
if not self._observed_rows and not self._prior_rows:
|
| 303 |
-
raise RuntimeError("No surrogate data available.")
|
| 304 |
-
self._fit_surrogate()
|
| 305 |
-
if self._surrogate_x is None or self._surrogate_y is None:
|
| 306 |
-
raise RuntimeError("Surrogate model unavailable.")
|
| 307 |
-
if self._surrogate_x.shape[0] == 1:
|
| 308 |
-
return float(self._surrogate_y[0]), 0.0
|
| 309 |
-
|
| 310 |
-
cfg = _config_to_vector(self._row_for_id(config_id)).reshape(1, -1)
|
| 311 |
-
if self._surrogate_k is None or self._surrogate_alpha is None:
|
| 312 |
-
raise RuntimeError("Surrogate model unavailable.")
|
| 313 |
-
k = _rbf_kernel(self._surrogate_x, cfg, self._surrogate_length_scale).reshape(-1)
|
| 314 |
-
pred = float(k @ self._surrogate_alpha)
|
| 315 |
-
solve = np.linalg.solve(self._surrogate_k, k)
|
| 316 |
-
var = max(0.0, float(1.0 - k @ solve))
|
| 317 |
-
return pred, float(math.sqrt(max(var, 1e-12)))
|
| 318 |
-
|
| 319 |
-
def _fit_surrogate(self) -> None:
|
| 320 |
-
if self._surrogate_fitted_version == self._surrogate_version:
|
| 321 |
-
return
|
| 322 |
-
rows = self._prior_rows + self._observed_rows
|
| 323 |
-
if not rows:
|
| 324 |
-
self._surrogate_x = None
|
| 325 |
-
self._surrogate_y = None
|
| 326 |
-
self._surrogate_alpha = None
|
| 327 |
-
self._surrogate_k = None
|
| 328 |
-
self._surrogate_fitted_version = self._surrogate_version
|
| 329 |
-
return
|
| 330 |
-
|
| 331 |
-
self._surrogate_x = np.array([_config_to_vector(row) for row in rows], dtype=np.float32)
|
| 332 |
-
self._surrogate_y = np.array([row.score for row in rows], dtype=np.float32)
|
| 333 |
-
if self._surrogate_x.shape[0] == 1:
|
| 334 |
-
self._surrogate_alpha = self._surrogate_y.copy()
|
| 335 |
-
self._surrogate_k = None
|
| 336 |
-
self._surrogate_fitted_version = self._surrogate_version
|
| 337 |
-
return
|
| 338 |
-
|
| 339 |
-
pairwise = _pairwise_sq_dists(self._surrogate_x)
|
| 340 |
-
triu = pairwise[np.triu_indices(self._surrogate_x.shape[0], k=1)]
|
| 341 |
-
med_dist = float(np.median(np.sqrt(triu))) if triu.size else 0.5
|
| 342 |
-
self._surrogate_length_scale = max(0.15, med_dist)
|
| 343 |
-
k = _rbf_kernel(self._surrogate_x, self._surrogate_x, self._surrogate_length_scale)
|
| 344 |
-
k[np.diag_indices_from(k)] += 1e-3
|
| 345 |
-
self._surrogate_k = k
|
| 346 |
-
self._surrogate_alpha = np.linalg.solve(k, self._surrogate_y)
|
| 347 |
-
self._surrogate_fitted_version = self._surrogate_version
|
| 348 |
-
|
| 349 |
-
def _compute_validation_mse(self) -> float:
|
| 350 |
-
if not self._task_rows:
|
| 351 |
-
return float("inf")
|
| 352 |
-
preds = np.array(
|
| 353 |
-
[self._predict_with_uncertainty(config_id)[0] for config_id in self.available_config_ids()],
|
| 354 |
-
dtype=np.float32,
|
| 355 |
-
)
|
| 356 |
-
target = np.array([self._row_for_id(config_id).score for config_id in self.available_config_ids()], dtype=np.float32)
|
| 357 |
-
return float(np.mean((preds - target) ** 2))
|
| 358 |
-
|
| 359 |
-
def _observe_config(self, config_id: int) -> None:
|
| 360 |
-
row = self._row_for_id(config_id)
|
| 361 |
-
self._observed_ids.append(config_id)
|
| 362 |
-
self._observed_id_set.add(config_id)
|
| 363 |
-
self._observed_rows.append(row)
|
| 364 |
-
self._observed_latencies.append(row.median_ms)
|
| 365 |
-
if row.median_ms < self._best_latency_ms:
|
| 366 |
-
self._best_latency_ms = row.median_ms
|
| 367 |
-
self._best_config_id = config_id
|
| 368 |
-
|
| 369 |
-
def _observation_payload(
|
| 370 |
-
self,
|
| 371 |
-
kind: str,
|
| 372 |
-
last_trial: Optional[Dict[str, Any]] = None,
|
| 373 |
-
) -> Dict[str, Any]:
|
| 374 |
-
payload = {
|
| 375 |
-
"type": kind,
|
| 376 |
-
"task_id": self._task_id,
|
| 377 |
-
"family": self._family,
|
| 378 |
-
"M": self._task_rows[0].m if self._task_rows else None,
|
| 379 |
-
"N": self._task_rows[0].n if self._task_rows else None,
|
| 380 |
-
"dtype": "fp16",
|
| 381 |
-
"tried_config_ids": list(self._observed_ids),
|
| 382 |
-
"tried_latencies_ms": list(self._observed_latencies),
|
| 383 |
-
"best_so_far_ms": self._best_latency_ms,
|
| 384 |
-
"steps_remaining": self._steps_remaining,
|
| 385 |
-
}
|
| 386 |
-
if last_trial is not None:
|
| 387 |
-
payload["last_trial"] = last_trial
|
| 388 |
-
return payload
|
| 389 |
-
|
| 390 |
-
def _extract_config_id(self, action: Any) -> int:
|
| 391 |
-
if isinstance(action, (str, bytes)):
|
| 392 |
-
action = json.loads(action)
|
| 393 |
-
if isinstance(action, dict):
|
| 394 |
-
if "config_id" in action:
|
| 395 |
-
return int(action["config_id"])
|
| 396 |
-
if "x" in action:
|
| 397 |
-
normalized = self._extract_legacy_action(action["x"])
|
| 398 |
-
config = self._map_legacy_action_to_config(normalized)
|
| 399 |
-
return config
|
| 400 |
-
if isinstance(action, (int, np.integer)):
|
| 401 |
-
return int(action)
|
| 402 |
-
raise TypeError("Action must be an int config_id or dict with config_id.")
|
| 403 |
-
|
| 404 |
-
def _extract_legacy_action(self, action: Any) -> List[float]:
|
| 405 |
-
arr = np.clip(np.asarray(action, dtype=float), -1.0, 1.0)
|
| 406 |
-
if arr.shape != (3,):
|
| 407 |
-
raise ValueError("Legacy action vector must have 3 values.")
|
| 408 |
-
return arr.tolist()
|
| 409 |
-
|
| 410 |
-
def _map_legacy_action_to_config(self, action: Sequence[float]) -> int:
|
| 411 |
-
base = (
|
| 412 |
-
_de_norm(float(action[0]), BLOCK_SIZES),
|
| 413 |
-
_de_norm(float(action[1]), NUM_WARPS),
|
| 414 |
-
_de_norm(float(action[2]), NUM_STAGES),
|
| 415 |
-
)
|
| 416 |
-
best_id = min(
|
| 417 |
-
self.available_config_ids(),
|
| 418 |
-
key=lambda config_id: (
|
| 419 |
-
self._row_for_id(config_id).block_size - base[0]
|
| 420 |
-
) ** 2
|
| 421 |
-
+ (self._row_for_id(config_id).num_warps - base[1]) ** 2
|
| 422 |
-
+ (self._row_for_id(config_id).num_stages - base[2]) ** 2,
|
| 423 |
-
)
|
| 424 |
-
return int(best_id)
|
| 425 |
-
|
| 426 |
-
def _row_for_id(self, config_id: int) -> Measurement:
|
| 427 |
-
if config_id not in self._config_by_id:
|
| 428 |
-
raise ValueError(f"Unknown config_id={config_id}")
|
| 429 |
-
return self._config_by_id[int(config_id)]
|
| 430 |
-
|
| 431 |
-
def _invalidate_surrogate(self) -> None:
|
| 432 |
-
self._surrogate_version += 1
|
| 433 |
-
self._surrogate_fitted_version = -1
|
| 434 |
-
self._surrogate_x = None
|
| 435 |
-
self._surrogate_y = None
|
| 436 |
-
self._surrogate_alpha = None
|
| 437 |
-
self._surrogate_k = None
|
| 438 |
-
|
| 439 |
-
def _format_step_output(
|
| 440 |
-
self,
|
| 441 |
-
observation: Dict[str, Any],
|
| 442 |
-
reward: float,
|
| 443 |
-
done: bool,
|
| 444 |
-
info: Optional[Dict[str, Any]] = None,
|
| 445 |
-
) -> Dict[str, Any]:
|
| 446 |
-
return {
|
| 447 |
-
"observation": observation,
|
| 448 |
-
"reward": float(reward),
|
| 449 |
-
"done": bool(done),
|
| 450 |
-
"state": self.state(),
|
| 451 |
-
"info": info or {},
|
| 452 |
-
}
|
| 453 |
-
|
| 454 |
-
def _load_measurements(self) -> Dict[str, List[Measurement]]:
|
| 455 |
-
if not self.measurement_path.exists():
|
| 456 |
-
raise FileNotFoundError(
|
| 457 |
-
f"Missing measurement file at {self.measurement_path}. "
|
| 458 |
-
"Run the measurement collectors first."
|
| 459 |
-
)
|
| 460 |
-
|
| 461 |
-
grouped: Dict[str, List[Measurement]] = {}
|
| 462 |
-
with self.measurement_path.open("r", newline="", encoding="utf-8") as handle:
|
| 463 |
-
reader = csv.DictReader(handle)
|
| 464 |
-
fallback_config_ids: Dict[str, int] = {}
|
| 465 |
-
for row in reader:
|
| 466 |
-
family = row.get("family", "softmax")
|
| 467 |
-
family_group = row.get("family_group", "A" if family in {"softmax", "layernorm"} else "B")
|
| 468 |
-
task_id = row["task_id"]
|
| 469 |
-
block_size = int(row["block_size"])
|
| 470 |
-
num_warps = int(row["num_warps"])
|
| 471 |
-
num_stages = int(row["num_stages"])
|
| 472 |
-
config_id_raw = row.get("config_id")
|
| 473 |
-
if config_id_raw in (None, ""):
|
| 474 |
-
key = f"{task_id}|{block_size}|{num_warps}|{num_stages}"
|
| 475 |
-
if key not in fallback_config_ids:
|
| 476 |
-
fallback_config_ids[key] = len([k for k in fallback_config_ids if k.startswith(f"{task_id}|")])
|
| 477 |
-
config_id = fallback_config_ids[key]
|
| 478 |
-
else:
|
| 479 |
-
config_id = int(config_id_raw)
|
| 480 |
-
|
| 481 |
-
measurement = Measurement(
|
| 482 |
-
family_group=family_group,
|
| 483 |
-
family=family,
|
| 484 |
-
task_id=task_id,
|
| 485 |
-
m=int(row["m"]),
|
| 486 |
-
n=int(row["n"]),
|
| 487 |
-
k=int(row.get("k", 0) or 0),
|
| 488 |
-
config_id=config_id,
|
| 489 |
-
block_size=block_size,
|
| 490 |
-
num_warps=num_warps,
|
| 491 |
-
num_stages=num_stages,
|
| 492 |
-
shape_json=row.get("shape_json", "{}"),
|
| 493 |
-
config_json=row.get("config_json", "{}"),
|
| 494 |
-
median_ms=float(row["median_ms"]),
|
| 495 |
-
effective_gbps=float(row["effective_gbps"]),
|
| 496 |
-
score=float(row["score"]),
|
| 497 |
-
validation_error=float(row["validation_error"]),
|
| 498 |
-
)
|
| 499 |
-
grouped.setdefault(task_id, []).append(measurement)
|
| 500 |
-
|
| 501 |
-
for task_id in grouped:
|
| 502 |
-
grouped[task_id].sort(key=lambda row: row.config_id)
|
| 503 |
-
return grouped
|
| 504 |
-
|
| 505 |
-
|
| 506 |
-
def _config_to_vector(row: Measurement) -> np.ndarray:
|
| 507 |
-
family_vec = np.zeros(len(FAMILY_INDEX), dtype=np.float32)
|
| 508 |
-
if row.family in FAMILY_INDEX:
|
| 509 |
-
family_vec[FAMILY_INDEX[row.family]] = 1.0
|
| 510 |
-
shape_fields = {}
|
| 511 |
-
try:
|
| 512 |
-
shape_fields = json.loads(row.shape_json) if row.shape_json else {}
|
| 513 |
-
except Exception:
|
| 514 |
-
shape_fields = {}
|
| 515 |
-
mode_val = _shape_scalar(shape_fields.get("mode_id"))
|
| 516 |
-
role_val = _shape_scalar(shape_fields.get("role_id"))
|
| 517 |
-
seq_val = _shape_scalar(shape_fields.get("seq_len"))
|
| 518 |
-
ctx_val = _shape_scalar(shape_fields.get("ctx_len"))
|
| 519 |
-
group_val = _shape_scalar(shape_fields.get("group_count"))
|
| 520 |
-
numeric = np.array(
|
| 521 |
-
[
|
| 522 |
-
math.log2(max(row.m, 1)) / 16.0,
|
| 523 |
-
math.log2(max(row.n, 1)) / 16.0,
|
| 524 |
-
math.log2(max(row.k, 1)) / 16.0 if row.k > 0 else 0.0,
|
| 525 |
-
math.log2(max(seq_val, 1.0)) / 16.0 if seq_val > 0 else 0.0,
|
| 526 |
-
math.log2(max(ctx_val, 1.0)) / 16.0 if ctx_val > 0 else 0.0,
|
| 527 |
-
math.log2(max(group_val, 1.0)) / 8.0 if group_val > 0 else 0.0,
|
| 528 |
-
mode_val / 8.0,
|
| 529 |
-
role_val / 16.0,
|
| 530 |
-
_normalize_discrete(BLOCK_SIZES, row.block_size),
|
| 531 |
-
_normalize_discrete(NUM_WARPS, row.num_warps),
|
| 532 |
-
_normalize_discrete(NUM_STAGES, row.num_stages),
|
| 533 |
-
],
|
| 534 |
-
dtype=np.float32,
|
| 535 |
-
)
|
| 536 |
-
return np.concatenate([family_vec, numeric], axis=0)
|
| 537 |
-
|
| 538 |
-
|
| 539 |
-
def _pairwise_sq_dists(X: np.ndarray) -> np.ndarray:
|
| 540 |
-
diff = X[:, None, :] - X[None, :, :]
|
| 541 |
-
return np.sum(diff * diff, axis=2)
|
| 542 |
-
|
| 543 |
-
|
| 544 |
-
def _rbf_kernel(X: np.ndarray, Y: np.ndarray, length_scale: float) -> np.ndarray:
|
| 545 |
-
sigma2 = float(length_scale * length_scale)
|
| 546 |
-
if sigma2 <= 0:
|
| 547 |
-
sigma2 = 1e-6
|
| 548 |
-
xy = X @ Y.T
|
| 549 |
-
x2 = np.sum(X * X, axis=1)[:, None]
|
| 550 |
-
y2 = np.sum(Y * Y, axis=1)[None, :]
|
| 551 |
-
d2 = np.maximum(x2 - 2.0 * xy + y2, 0.0)
|
| 552 |
-
return np.exp(-0.5 * d2 / sigma2).astype(np.float32)
|
| 553 |
-
|
| 554 |
-
|
| 555 |
-
def _normal_pdf(z: float) -> float:
|
| 556 |
-
inv_sqrt_2pi = 1.0 / math.sqrt(2.0 * math.pi)
|
| 557 |
-
return float(inv_sqrt_2pi * math.exp(-0.5 * z * z))
|
| 558 |
-
|
| 559 |
-
|
| 560 |
-
def _normal_cdf(z: float) -> float:
|
| 561 |
-
return 0.5 * (1.0 + math.erf(z / math.sqrt(2.0)))
|
| 562 |
-
|
| 563 |
-
|
| 564 |
-
def _shape_scalar(value: Any) -> float:
|
| 565 |
-
if value is None:
|
| 566 |
-
return 0.0
|
| 567 |
-
try:
|
| 568 |
-
return float(value)
|
| 569 |
-
except (TypeError, ValueError):
|
| 570 |
-
return 0.0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|