TaoNet-mini-T2 / code /Taotern_SSM /scripts /diagnose_tilelang_acceleration.py
StarMist0012's picture
Add files using upload-large-folder tool
388fd6e verified
"""Diagnose TileLang/Triton acceleration availability for Gamma SSM.
The current csrc.tilelang package includes PyTorch fallback code. This script
separates "module import works" from "real accelerated backend is active" so
remote benchmark logs do not accidentally treat fallback execution as TileLang
hardware acceleration.
"""
from __future__ import annotations
import argparse
import importlib.util
import json
import sys
import time
from pathlib import Path
from typing import Any
import torch
REPO_ROOT = Path(__file__).resolve().parents[1]
if str(REPO_ROOT) not in sys.path:
sys.path.insert(0, str(REPO_ROOT))
def synchronize(device: torch.device) -> None:
if device.type == "cuda":
torch.cuda.synchronize(device)
def package_available(name: str) -> bool:
return importlib.util.find_spec(name) is not None
def time_gamma_forward(
*,
batch_size: int,
seq_len: int,
d_model: int,
hidden_dim: int,
dtype: torch.dtype,
device: torch.device,
repeats: int,
warmup: int,
) -> dict[str, Any]:
from gamma_space_model import SSMGamma
model = SSMGamma(state_dim=d_model, hidden_dim=hidden_dim).to(device=device)
x = torch.randn(batch_size, seq_len, d_model, device=device, dtype=dtype)
def run() -> None:
y, _ = model(x)
y.sum().item()
for _ in range(warmup):
run()
synchronize(device)
latencies = []
for _ in range(repeats):
synchronize(device)
start = time.perf_counter()
run()
synchronize(device)
latencies.append(time.perf_counter() - start)
mean_s = sum(latencies) / len(latencies)
tokens = batch_size * seq_len
return {
"mean_ms": mean_s * 1000.0,
"tokens_per_s": tokens / max(mean_s, 1e-12),
}
def main() -> int:
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu")
parser.add_argument("--dtype", choices=["fp32", "bf16", "fp16"], default="bf16")
parser.add_argument("--batch-size", type=int, default=4)
parser.add_argument("--seq-len", type=int, default=512)
parser.add_argument("--d-model", type=int, default=128)
parser.add_argument("--hidden-dim", type=int, default=256)
parser.add_argument("--warmup", type=int, default=1)
parser.add_argument("--repeats", type=int, default=3)
args = parser.parse_args()
dtype_map = {
"fp32": torch.float32,
"bf16": torch.bfloat16,
"fp16": torch.float16,
}
device = torch.device(args.device)
dtype = dtype_map[args.dtype]
import gamma_space_model
from gamma_space_model import HAS_TILELANG_OPS, TILELANG_BACKEND
try:
import csrc.tilelang as csrc_tilelang
csrc_flags = {
"module_imported": True,
"has_triton_import": bool(getattr(csrc_tilelang, "HAS_TRITON", False)),
"has_tilelang_import": bool(getattr(csrc_tilelang, "HAS_TILELANG", False)),
"has_tilelang_acceleration": bool(
getattr(csrc_tilelang, "HAS_TILELANG_ACCELERATION", False)
),
"backend": getattr(csrc_tilelang, "TILELANG_BACKEND", "unknown"),
}
except ImportError as exc:
csrc_flags = {
"module_imported": False,
"import_error": str(exc),
}
report: dict[str, Any] = {
"torch": {
"version": torch.__version__,
"cuda_available": torch.cuda.is_available(),
"device": str(device),
"cuda_device_name": torch.cuda.get_device_name(device) if device.type == "cuda" else None,
},
"packages": {
"triton_available": package_available("triton"),
"tilelang_available": package_available("tilelang"),
},
"gamma_space_model": {
"version": getattr(gamma_space_model, "__version__", None),
"has_tilelang_ops": bool(HAS_TILELANG_OPS),
"tilelang_backend": TILELANG_BACKEND,
},
"csrc_tilelang": csrc_flags,
}
if device.type == "cuda" and not torch.cuda.is_available():
report["benchmark_error"] = "CUDA requested but torch.cuda.is_available() is false."
else:
report["gamma_forward_benchmark"] = time_gamma_forward(
batch_size=args.batch_size,
seq_len=args.seq_len,
d_model=args.d_model,
hidden_dim=args.hidden_dim,
dtype=dtype,
device=device,
repeats=args.repeats,
warmup=args.warmup,
)
print(json.dumps(report, indent=2, sort_keys=True))
return 0
if __name__ == "__main__":
raise SystemExit(main())