"""Lightweight script benchmarks for Gamma SSM variants. This replaces the notebook-only timing loop for quick local/remote feedback. It focuses on the model kernels themselves: full-sequence forward, optional forward+backward, and optional recurrent decode. """ from __future__ import annotations import argparse from contextlib import nullcontext import csv import json import os import platform import subprocess import sys import time from pathlib import Path from typing import Any, Iterable import torch import torch.nn as nn REPO_ROOT = Path(__file__).resolve().parents[1] if str(REPO_ROOT) not in sys.path: sys.path.insert(0, str(REPO_ROOT)) from gamma_space_model import S4TernaryDPLRSSM, SSMGamma, SSMGammaS4 DTYPES = { "float32": torch.float32, "fp32": torch.float32, "float16": torch.float16, "fp16": torch.float16, "bfloat16": torch.bfloat16, "bf16": torch.bfloat16, } def parse_int_list(value: str) -> list[int]: return [int(item.strip()) for item in value.split(",") if item.strip()] def synchronize(device: torch.device) -> None: if device.type == "cuda": torch.cuda.synchronize(device) def cuda_memory(device: torch.device) -> dict[str, float | None]: if device.type != "cuda": return { "peak_allocated_mb": None, "peak_reserved_mb": None, } return { "peak_allocated_mb": torch.cuda.max_memory_allocated(device) / (1024**2), "peak_reserved_mb": torch.cuda.max_memory_reserved(device) / (1024**2), } def reset_cuda_memory(device: torch.device) -> None: if device.type == "cuda": torch.cuda.reset_peak_memory_stats(device) def nvidia_smi_snapshot() -> str | None: try: completed = subprocess.run( [ "nvidia-smi", "--query-gpu=name,memory.used,memory.total,utilization.gpu,utilization.memory,power.draw,temperature.gpu", "--format=csv,noheader,nounits", ], check=False, capture_output=True, text=True, timeout=5, ) except (OSError, subprocess.TimeoutExpired): return None if completed.returncode != 0: return None return completed.stdout.strip() def make_model(name: str, d_model: int, hidden_dim: int, rank: int, kernel_mode: str) -> nn.Module: if name == "baseline": return SSMGamma(state_dim=d_model, hidden_dim=hidden_dim) if name == "gamma_s4": return SSMGammaS4( state_dim=d_model, hidden_dim=hidden_dim, discretization="bilinear", kernel_mode=kernel_mode, kernel_threshold=1, ) if name == "dplr": return S4TernaryDPLRSSM( state_dim=d_model, hidden_dim=hidden_dim, rank=rank, kernel_mode=kernel_mode, kernel_threshold=1, ) raise ValueError(f"Unknown model '{name}'.") def model_forward(model: nn.Module, x: torch.Tensor, return_state: bool = False) -> torch.Tensor: if isinstance(model, SSMGamma): y, _ = model(x) else: y, _ = model(x, return_state=return_state) return y def init_state(model: nn.Module, batch_size: int, device: torch.device, dtype: torch.dtype) -> torch.Tensor: return model.init_state(batch_size=batch_size, device=device, dtype=dtype) def allocate_cache( model: nn.Module, batch_size: int, seq_len: int, device: torch.device, dtype: torch.dtype, ) -> dict[str, torch.Tensor] | None: if hasattr(model, "allocate_inference_cache"): return model.allocate_inference_cache( batch_size=batch_size, seq_len=seq_len, device=device, dtype=dtype, ) return None def recurrent_forward(model: nn.Module, x: torch.Tensor) -> torch.Tensor: batch_size, seq_len, _ = x.shape state = init_state(model, batch_size=batch_size, device=x.device, dtype=x.dtype) cache = allocate_cache(model, batch_size=batch_size, seq_len=seq_len, device=x.device, dtype=x.dtype) outputs = [] for step in range(seq_len): token = x[:, step, :] if cache is None: y, state = model.step(token, state) else: y, state = model.step(token, state, cache=cache) outputs.append(y) return torch.stack(outputs, dim=1) def timed_repeats( fn, *, device: torch.device, warmup: int, repeats: int, ) -> tuple[float, float]: for _ in range(warmup): fn() synchronize(device) latencies = [] for _ in range(repeats): reset_cuda_memory(device) synchronize(device) start = time.perf_counter() fn() synchronize(device) latencies.append(time.perf_counter() - start) mean_s = sum(latencies) / len(latencies) min_s = min(latencies) return mean_s, min_s def benchmark_case( *, model_name: str, batch_size: int, seq_len: int, d_model: int, hidden_dim: int, rank: int, kernel_mode: str, dtype: torch.dtype, device: torch.device, warmup: int, repeats: int, run_backward: bool, run_recurrent: bool, ) -> list[dict[str, Any]]: model = make_model(model_name, d_model=d_model, hidden_dim=hidden_dim, rank=rank, kernel_mode=kernel_mode) model = model.to(device=device) model.train() x = torch.randn(batch_size, seq_len, d_model, device=device, dtype=dtype) tokens = batch_size * seq_len autocast_enabled = device.type == "cuda" and dtype in {torch.float16, torch.bfloat16} autocast_dtype = dtype if dtype in {torch.float16, torch.bfloat16} else torch.float32 rows: list[dict[str, Any]] = [] def autocast_context(): if not autocast_enabled: return nullcontext() return torch.autocast(device_type=device.type, dtype=autocast_dtype, enabled=True) def add_row(mode: str, mean_s: float, min_s: float) -> None: mem = cuda_memory(device) rows.append( { "model": model_name, "mode": mode, "batch_size": batch_size, "seq_len": seq_len, "d_model": d_model, "hidden_dim": hidden_dim, "rank": rank if model_name == "dplr" else None, "kernel_mode": kernel_mode, "dtype": str(dtype).replace("torch.", ""), "device": str(device), "mean_ms": mean_s * 1000.0, "min_ms": min_s * 1000.0, "tokens_per_s_mean": tokens / max(mean_s, 1e-12), "tokens_per_s_best": tokens / max(min_s, 1e-12), **mem, } ) def forward_only() -> None: with torch.no_grad(): with autocast_context(): y = model_forward(model, x, return_state=False) _ = y.sum() mean_s, min_s = timed_repeats(forward_only, device=device, warmup=warmup, repeats=repeats) add_row("forward", mean_s, min_s) if run_backward: def forward_backward() -> None: model.zero_grad(set_to_none=True) with autocast_context(): y = model_forward(model, x, return_state=False) loss = y.square().mean() loss.backward() mean_s, min_s = timed_repeats(forward_backward, device=device, warmup=warmup, repeats=repeats) add_row("forward_backward", mean_s, min_s) if run_recurrent: model.eval() def recurrent() -> None: with torch.no_grad(): y = recurrent_forward(model, x) _ = y.sum() mean_s, min_s = timed_repeats(recurrent, device=device, warmup=max(1, warmup // 2), repeats=repeats) add_row("recurrent", mean_s, min_s) return rows def write_outputs(rows: list[dict[str, Any]], output_dir: Path, metadata: dict[str, Any]) -> None: output_dir.mkdir(parents=True, exist_ok=True) json_path = output_dir / "ssm_variant_benchmark.json" csv_path = output_dir / "ssm_variant_benchmark.csv" json_path.write_text( json.dumps( { "metadata": metadata, "results": rows, }, indent=2, ), encoding="utf-8", ) fieldnames = list(rows[0].keys()) if rows else [] with csv_path.open("w", newline="", encoding="utf-8") as handle: writer = csv.DictWriter(handle, fieldnames=fieldnames) writer.writeheader() writer.writerows(rows) print(f"Wrote {json_path}") print(f"Wrote {csv_path}") def print_table(rows: Iterable[dict[str, Any]]) -> None: columns = [ "model", "mode", "batch_size", "seq_len", "mean_ms", "tokens_per_s_mean", "peak_allocated_mb", ] print("\t".join(columns)) for row in rows: values = [] for column in columns: value = row[column] if isinstance(value, float): values.append(f"{value:.3f}") else: values.append(str(value)) print("\t".join(values)) def main() -> None: parser = argparse.ArgumentParser(description="Benchmark Gamma SSM variants.") parser.add_argument("--models", default="dplr,gamma_s4,baseline") parser.add_argument("--batch-sizes", default="1,4") parser.add_argument("--seq-lens", default="128,512") parser.add_argument("--d-model", type=int, default=128) parser.add_argument("--hidden-dim", type=int, default=256) parser.add_argument("--rank", type=int, default=1) parser.add_argument("--kernel-mode", choices=["auto", "conv", "recurrent"], default="conv") parser.add_argument("--dtype", choices=sorted(DTYPES), default="bf16") parser.add_argument("--device", default="auto") parser.add_argument("--warmup", type=int, default=2) parser.add_argument("--repeats", type=int, default=5) parser.add_argument("--backward", action="store_true") parser.add_argument("--recurrent", action="store_true") parser.add_argument("--output-dir", default=os.environ.get("REPOBRIDGE_OUTPUT_DIR", "output/benchmarks")) args = parser.parse_args() if args.device == "auto": device = torch.device("cuda" if torch.cuda.is_available() else "cpu") else: device = torch.device(args.device) dtype = DTYPES[args.dtype] if device.type != "cuda" and dtype == torch.float16: raise ValueError("float16 benchmark requires CUDA.") if device.type == "cuda": torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True models = [item.strip() for item in args.models.split(",") if item.strip()] batch_sizes = parse_int_list(args.batch_sizes) seq_lens = parse_int_list(args.seq_lens) metadata = { "python": platform.python_version(), "platform": platform.platform(), "torch": torch.__version__, "cuda_available": torch.cuda.is_available(), "cuda_device": torch.cuda.get_device_name(device) if device.type == "cuda" else None, "nvidia_smi_before": nvidia_smi_snapshot(), "args": vars(args), } rows: list[dict[str, Any]] = [] for model_name in models: for batch_size in batch_sizes: for seq_len in seq_lens: print(f"Benchmarking model={model_name} batch={batch_size} seq={seq_len}") rows.extend( benchmark_case( model_name=model_name, batch_size=batch_size, seq_len=seq_len, d_model=args.d_model, hidden_dim=args.hidden_dim, rank=args.rank, kernel_mode=args.kernel_mode, dtype=dtype, device=device, warmup=args.warmup, repeats=args.repeats, run_backward=args.backward, run_recurrent=args.recurrent, ) ) metadata["nvidia_smi_after"] = nvidia_smi_snapshot() print_table(rows) write_outputs(rows, Path(args.output_dir), metadata) if __name__ == "__main__": main()