Spaces:
Configuration error
Configuration error
| #!/usr/bin/env python3 | |
| from __future__ import annotations | |
| import argparse | |
| import json | |
| import sys | |
| import time | |
| from pathlib import Path | |
| from typing import Any, Dict, Tuple | |
| import numpy as np | |
| import torch | |
| ROOT = Path(__file__).resolve().parents[1] | |
| if str(ROOT) not in sys.path: | |
| sys.path.append(str(ROOT)) | |
| from scripts.collect_qwen_05b_measurements import EPS, benchmark_qwen_task | |
| from scripts.qwen_05b_spec import QwenKernelTask, qwen_05b_tasks | |
| TASK_BY_ID = {task.task_id: task for task in qwen_05b_tasks()} | |
| def _bench_callable(fn, args: Tuple[Any, ...], repeats: int, warmup: int) -> float: | |
| for _ in range(max(1, warmup)): | |
| fn(*args) | |
| torch.cuda.synchronize() | |
| start = torch.cuda.Event(enable_timing=True) | |
| end = torch.cuda.Event(enable_timing=True) | |
| durations = [] | |
| for _ in range(max(1, repeats)): | |
| torch.cuda.synchronize() | |
| start.record() | |
| fn(*args) | |
| end.record() | |
| end.synchronize() | |
| durations.append(start.elapsed_time(end)) | |
| return float(np.median(np.asarray(durations, dtype=np.float32))) | |
| def _build_qwen_callable(task: QwenKernelTask, seed: int): | |
| torch.manual_seed(seed) | |
| if task.family == "softmax": | |
| x = torch.randn((task.m, task.n), device="cuda", dtype=torch.float16) | |
| def fn(inp: torch.Tensor): | |
| return torch.softmax(inp, dim=-1) | |
| return fn, (x,) | |
| if task.family == "rmsnorm": | |
| x = torch.randn((task.m, task.n), device="cuda", dtype=torch.float16) | |
| def fn(inp: torch.Tensor): | |
| return inp.float() * torch.rsqrt(inp.float().pow(2).mean(dim=-1, keepdim=True) + EPS) | |
| return fn, (x,) | |
| if task.family == "gemm": | |
| a = torch.randn((task.m, task.k), device="cuda", dtype=torch.float16) | |
| b = torch.randn((task.k, task.n), device="cuda", dtype=torch.float16) | |
| def fn(lhs: torch.Tensor, rhs: torch.Tensor): | |
| return torch.matmul(lhs, rhs) | |
| return fn, (a, b) | |
| raise ValueError(f"Unsupported family: {task.family}") | |
| def _benchmark_torch(task: QwenKernelTask, seed: int, repeats: int, warmup: int) -> Dict[str, float]: | |
| eager_fn, args = _build_qwen_callable(task, seed) | |
| eager_latency_ms = _bench_callable(eager_fn, args, repeats=repeats, warmup=warmup) | |
| compiled_fn = torch.compile(eager_fn) | |
| torch.cuda.synchronize() | |
| start = time.perf_counter() | |
| compiled_fn(*args) | |
| torch.cuda.synchronize() | |
| compile_plus_first_call_ms = float((time.perf_counter() - start) * 1000.0) | |
| compiled_latency_ms = _bench_callable(compiled_fn, args, repeats=repeats, warmup=warmup) | |
| return { | |
| "eager_latency_ms": eager_latency_ms, | |
| "compile_plus_first_call_ms": compile_plus_first_call_ms, | |
| "compiled_latency_ms": compiled_latency_ms, | |
| } | |
| def _task_best_configs(eval_results: Dict[str, Any]) -> Dict[str, Dict[str, Dict[str, Any]]]: | |
| task_map: Dict[str, Dict[str, Dict[str, Any]]] = {} | |
| for section in eval_results["results"].values(): | |
| for method in ("random", "surrogate"): | |
| for run in section["task_runs"][method]: | |
| task_map.setdefault(run["task"], {})[method] = run["best_overall"]["config"] | |
| return task_map | |
| def main() -> None: | |
| parser = argparse.ArgumentParser(description="Benchmark eager/torch.compile and best Triton configs for Qwen2.5-0.5B exact kernels.") | |
| parser.add_argument("--generalization-results", type=Path, default=Path("outputs/qwen_05b_generalization_eval.json")) | |
| parser.add_argument("--repeats", type=int, default=100) | |
| parser.add_argument("--warmup", type=int, default=10) | |
| parser.add_argument("--seed", type=int, default=123) | |
| parser.add_argument("--output", type=Path, default=Path("outputs/qwen_05b_runtime_references.json")) | |
| args = parser.parse_args() | |
| generalization_results = json.loads(args.generalization_results.read_text(encoding="utf-8")) | |
| task_configs = _task_best_configs(generalization_results) | |
| results = {} | |
| for idx, task_id in enumerate(sorted(task_configs.keys())): | |
| task = TASK_BY_ID[task_id] | |
| seed = args.seed + idx | |
| torch_metrics = _benchmark_torch(task, seed=seed, repeats=args.repeats, warmup=args.warmup) | |
| triton_results = { | |
| method: benchmark_qwen_task( | |
| task=task, | |
| block_size=int(config["block_size"]), | |
| num_warps=int(config["num_warps"]), | |
| num_stages=int(config["num_stages"]), | |
| repeats=args.repeats, | |
| warmup=args.warmup, | |
| seed=seed, | |
| ).__dict__ | |
| for method, config in task_configs[task_id].items() | |
| } | |
| results[task_id] = { | |
| "family": task.family, | |
| "role": task.role, | |
| "mode": task.mode, | |
| "torch": torch_metrics, | |
| "triton": triton_results, | |
| "speedups": { | |
| method: { | |
| "vs_eager": float(torch_metrics["eager_latency_ms"] / row["median_ms"]), | |
| "vs_compiled": float(torch_metrics["compiled_latency_ms"] / row["median_ms"]), | |
| } | |
| for method, row in triton_results.items() | |
| }, | |
| } | |
| summary = { | |
| "generalization_results": str(args.generalization_results), | |
| "repeats": args.repeats, | |
| "warmup": args.warmup, | |
| "seed": args.seed, | |
| "task_count": len(results), | |
| "results": results, | |
| } | |
| args.output.parent.mkdir(parents=True, exist_ok=True) | |
| with args.output.open("w", encoding="utf-8") as handle: | |
| json.dump(summary, handle, indent=2) | |
| print(json.dumps(summary, indent=2)) | |
| if __name__ == "__main__": | |
| main() | |