"""Diagnose TileLang/Triton acceleration availability for Gamma SSM. The current csrc.tilelang package includes PyTorch fallback code. This script separates "module import works" from "real accelerated backend is active" so remote benchmark logs do not accidentally treat fallback execution as TileLang hardware acceleration. """ from __future__ import annotations import argparse import importlib.util import json import sys import time from pathlib import Path from typing import Any import torch REPO_ROOT = Path(__file__).resolve().parents[1] if str(REPO_ROOT) not in sys.path: sys.path.insert(0, str(REPO_ROOT)) def synchronize(device: torch.device) -> None: if device.type == "cuda": torch.cuda.synchronize(device) def package_available(name: str) -> bool: return importlib.util.find_spec(name) is not None def time_gamma_forward( *, batch_size: int, seq_len: int, d_model: int, hidden_dim: int, dtype: torch.dtype, device: torch.device, repeats: int, warmup: int, ) -> dict[str, Any]: from gamma_space_model import SSMGamma model = SSMGamma(state_dim=d_model, hidden_dim=hidden_dim).to(device=device) x = torch.randn(batch_size, seq_len, d_model, device=device, dtype=dtype) def run() -> None: y, _ = model(x) y.sum().item() for _ in range(warmup): run() synchronize(device) latencies = [] for _ in range(repeats): synchronize(device) start = time.perf_counter() run() synchronize(device) latencies.append(time.perf_counter() - start) mean_s = sum(latencies) / len(latencies) tokens = batch_size * seq_len return { "mean_ms": mean_s * 1000.0, "tokens_per_s": tokens / max(mean_s, 1e-12), } def main() -> int: parser = argparse.ArgumentParser(description=__doc__) parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu") parser.add_argument("--dtype", choices=["fp32", "bf16", "fp16"], default="bf16") parser.add_argument("--batch-size", type=int, default=4) parser.add_argument("--seq-len", type=int, default=512) parser.add_argument("--d-model", type=int, default=128) parser.add_argument("--hidden-dim", type=int, default=256) parser.add_argument("--warmup", type=int, default=1) parser.add_argument("--repeats", type=int, default=3) args = parser.parse_args() dtype_map = { "fp32": torch.float32, "bf16": torch.bfloat16, "fp16": torch.float16, } device = torch.device(args.device) dtype = dtype_map[args.dtype] import gamma_space_model from gamma_space_model import HAS_TILELANG_OPS, TILELANG_BACKEND try: import csrc.tilelang as csrc_tilelang csrc_flags = { "module_imported": True, "has_triton_import": bool(getattr(csrc_tilelang, "HAS_TRITON", False)), "has_tilelang_import": bool(getattr(csrc_tilelang, "HAS_TILELANG", False)), "has_tilelang_acceleration": bool( getattr(csrc_tilelang, "HAS_TILELANG_ACCELERATION", False) ), "backend": getattr(csrc_tilelang, "TILELANG_BACKEND", "unknown"), } except ImportError as exc: csrc_flags = { "module_imported": False, "import_error": str(exc), } report: dict[str, Any] = { "torch": { "version": torch.__version__, "cuda_available": torch.cuda.is_available(), "device": str(device), "cuda_device_name": torch.cuda.get_device_name(device) if device.type == "cuda" else None, }, "packages": { "triton_available": package_available("triton"), "tilelang_available": package_available("tilelang"), }, "gamma_space_model": { "version": getattr(gamma_space_model, "__version__", None), "has_tilelang_ops": bool(HAS_TILELANG_OPS), "tilelang_backend": TILELANG_BACKEND, }, "csrc_tilelang": csrc_flags, } if device.type == "cuda" and not torch.cuda.is_available(): report["benchmark_error"] = "CUDA requested but torch.cuda.is_available() is false." else: report["gamma_forward_benchmark"] = time_gamma_forward( batch_size=args.batch_size, seq_len=args.seq_len, d_model=args.d_model, hidden_dim=args.hidden_dim, dtype=dtype, device=device, repeats=args.repeats, warmup=args.warmup, ) print(json.dumps(report, indent=2, sort_keys=True)) return 0 if __name__ == "__main__": raise SystemExit(main())