TaoNet-mini-T2 / code /Taotern_SSM /scripts /benchmark_ssm_variants.py
StarMist0012's picture
Add files using upload-large-folder tool
388fd6e verified
"""Lightweight script benchmarks for Gamma SSM variants.
This replaces the notebook-only timing loop for quick local/remote feedback.
It focuses on the model kernels themselves: full-sequence forward, optional
forward+backward, and optional recurrent decode.
"""
from __future__ import annotations
import argparse
from contextlib import nullcontext
import csv
import json
import os
import platform
import subprocess
import sys
import time
from pathlib import Path
from typing import Any, Iterable
import torch
import torch.nn as nn
REPO_ROOT = Path(__file__).resolve().parents[1]
if str(REPO_ROOT) not in sys.path:
sys.path.insert(0, str(REPO_ROOT))
from gamma_space_model import S4TernaryDPLRSSM, SSMGamma, SSMGammaS4
DTYPES = {
"float32": torch.float32,
"fp32": torch.float32,
"float16": torch.float16,
"fp16": torch.float16,
"bfloat16": torch.bfloat16,
"bf16": torch.bfloat16,
}
def parse_int_list(value: str) -> list[int]:
return [int(item.strip()) for item in value.split(",") if item.strip()]
def synchronize(device: torch.device) -> None:
if device.type == "cuda":
torch.cuda.synchronize(device)
def cuda_memory(device: torch.device) -> dict[str, float | None]:
if device.type != "cuda":
return {
"peak_allocated_mb": None,
"peak_reserved_mb": None,
}
return {
"peak_allocated_mb": torch.cuda.max_memory_allocated(device) / (1024**2),
"peak_reserved_mb": torch.cuda.max_memory_reserved(device) / (1024**2),
}
def reset_cuda_memory(device: torch.device) -> None:
if device.type == "cuda":
torch.cuda.reset_peak_memory_stats(device)
def nvidia_smi_snapshot() -> str | None:
try:
completed = subprocess.run(
[
"nvidia-smi",
"--query-gpu=name,memory.used,memory.total,utilization.gpu,utilization.memory,power.draw,temperature.gpu",
"--format=csv,noheader,nounits",
],
check=False,
capture_output=True,
text=True,
timeout=5,
)
except (OSError, subprocess.TimeoutExpired):
return None
if completed.returncode != 0:
return None
return completed.stdout.strip()
def make_model(name: str, d_model: int, hidden_dim: int, rank: int, kernel_mode: str) -> nn.Module:
if name == "baseline":
return SSMGamma(state_dim=d_model, hidden_dim=hidden_dim)
if name == "gamma_s4":
return SSMGammaS4(
state_dim=d_model,
hidden_dim=hidden_dim,
discretization="bilinear",
kernel_mode=kernel_mode,
kernel_threshold=1,
)
if name == "dplr":
return S4TernaryDPLRSSM(
state_dim=d_model,
hidden_dim=hidden_dim,
rank=rank,
kernel_mode=kernel_mode,
kernel_threshold=1,
)
raise ValueError(f"Unknown model '{name}'.")
def model_forward(model: nn.Module, x: torch.Tensor, return_state: bool = False) -> torch.Tensor:
if isinstance(model, SSMGamma):
y, _ = model(x)
else:
y, _ = model(x, return_state=return_state)
return y
def init_state(model: nn.Module, batch_size: int, device: torch.device, dtype: torch.dtype) -> torch.Tensor:
return model.init_state(batch_size=batch_size, device=device, dtype=dtype)
def allocate_cache(
model: nn.Module,
batch_size: int,
seq_len: int,
device: torch.device,
dtype: torch.dtype,
) -> dict[str, torch.Tensor] | None:
if hasattr(model, "allocate_inference_cache"):
return model.allocate_inference_cache(
batch_size=batch_size,
seq_len=seq_len,
device=device,
dtype=dtype,
)
return None
def recurrent_forward(model: nn.Module, x: torch.Tensor) -> torch.Tensor:
batch_size, seq_len, _ = x.shape
state = init_state(model, batch_size=batch_size, device=x.device, dtype=x.dtype)
cache = allocate_cache(model, batch_size=batch_size, seq_len=seq_len, device=x.device, dtype=x.dtype)
outputs = []
for step in range(seq_len):
token = x[:, step, :]
if cache is None:
y, state = model.step(token, state)
else:
y, state = model.step(token, state, cache=cache)
outputs.append(y)
return torch.stack(outputs, dim=1)
def timed_repeats(
fn,
*,
device: torch.device,
warmup: int,
repeats: int,
) -> tuple[float, float]:
for _ in range(warmup):
fn()
synchronize(device)
latencies = []
for _ in range(repeats):
reset_cuda_memory(device)
synchronize(device)
start = time.perf_counter()
fn()
synchronize(device)
latencies.append(time.perf_counter() - start)
mean_s = sum(latencies) / len(latencies)
min_s = min(latencies)
return mean_s, min_s
def benchmark_case(
*,
model_name: str,
batch_size: int,
seq_len: int,
d_model: int,
hidden_dim: int,
rank: int,
kernel_mode: str,
dtype: torch.dtype,
device: torch.device,
warmup: int,
repeats: int,
run_backward: bool,
run_recurrent: bool,
) -> list[dict[str, Any]]:
model = make_model(model_name, d_model=d_model, hidden_dim=hidden_dim, rank=rank, kernel_mode=kernel_mode)
model = model.to(device=device)
model.train()
x = torch.randn(batch_size, seq_len, d_model, device=device, dtype=dtype)
tokens = batch_size * seq_len
autocast_enabled = device.type == "cuda" and dtype in {torch.float16, torch.bfloat16}
autocast_dtype = dtype if dtype in {torch.float16, torch.bfloat16} else torch.float32
rows: list[dict[str, Any]] = []
def autocast_context():
if not autocast_enabled:
return nullcontext()
return torch.autocast(device_type=device.type, dtype=autocast_dtype, enabled=True)
def add_row(mode: str, mean_s: float, min_s: float) -> None:
mem = cuda_memory(device)
rows.append(
{
"model": model_name,
"mode": mode,
"batch_size": batch_size,
"seq_len": seq_len,
"d_model": d_model,
"hidden_dim": hidden_dim,
"rank": rank if model_name == "dplr" else None,
"kernel_mode": kernel_mode,
"dtype": str(dtype).replace("torch.", ""),
"device": str(device),
"mean_ms": mean_s * 1000.0,
"min_ms": min_s * 1000.0,
"tokens_per_s_mean": tokens / max(mean_s, 1e-12),
"tokens_per_s_best": tokens / max(min_s, 1e-12),
**mem,
}
)
def forward_only() -> None:
with torch.no_grad():
with autocast_context():
y = model_forward(model, x, return_state=False)
_ = y.sum()
mean_s, min_s = timed_repeats(forward_only, device=device, warmup=warmup, repeats=repeats)
add_row("forward", mean_s, min_s)
if run_backward:
def forward_backward() -> None:
model.zero_grad(set_to_none=True)
with autocast_context():
y = model_forward(model, x, return_state=False)
loss = y.square().mean()
loss.backward()
mean_s, min_s = timed_repeats(forward_backward, device=device, warmup=warmup, repeats=repeats)
add_row("forward_backward", mean_s, min_s)
if run_recurrent:
model.eval()
def recurrent() -> None:
with torch.no_grad():
y = recurrent_forward(model, x)
_ = y.sum()
mean_s, min_s = timed_repeats(recurrent, device=device, warmup=max(1, warmup // 2), repeats=repeats)
add_row("recurrent", mean_s, min_s)
return rows
def write_outputs(rows: list[dict[str, Any]], output_dir: Path, metadata: dict[str, Any]) -> None:
output_dir.mkdir(parents=True, exist_ok=True)
json_path = output_dir / "ssm_variant_benchmark.json"
csv_path = output_dir / "ssm_variant_benchmark.csv"
json_path.write_text(
json.dumps(
{
"metadata": metadata,
"results": rows,
},
indent=2,
),
encoding="utf-8",
)
fieldnames = list(rows[0].keys()) if rows else []
with csv_path.open("w", newline="", encoding="utf-8") as handle:
writer = csv.DictWriter(handle, fieldnames=fieldnames)
writer.writeheader()
writer.writerows(rows)
print(f"Wrote {json_path}")
print(f"Wrote {csv_path}")
def print_table(rows: Iterable[dict[str, Any]]) -> None:
columns = [
"model",
"mode",
"batch_size",
"seq_len",
"mean_ms",
"tokens_per_s_mean",
"peak_allocated_mb",
]
print("\t".join(columns))
for row in rows:
values = []
for column in columns:
value = row[column]
if isinstance(value, float):
values.append(f"{value:.3f}")
else:
values.append(str(value))
print("\t".join(values))
def main() -> None:
parser = argparse.ArgumentParser(description="Benchmark Gamma SSM variants.")
parser.add_argument("--models", default="dplr,gamma_s4,baseline")
parser.add_argument("--batch-sizes", default="1,4")
parser.add_argument("--seq-lens", default="128,512")
parser.add_argument("--d-model", type=int, default=128)
parser.add_argument("--hidden-dim", type=int, default=256)
parser.add_argument("--rank", type=int, default=1)
parser.add_argument("--kernel-mode", choices=["auto", "conv", "recurrent"], default="conv")
parser.add_argument("--dtype", choices=sorted(DTYPES), default="bf16")
parser.add_argument("--device", default="auto")
parser.add_argument("--warmup", type=int, default=2)
parser.add_argument("--repeats", type=int, default=5)
parser.add_argument("--backward", action="store_true")
parser.add_argument("--recurrent", action="store_true")
parser.add_argument("--output-dir", default=os.environ.get("REPOBRIDGE_OUTPUT_DIR", "output/benchmarks"))
args = parser.parse_args()
if args.device == "auto":
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
else:
device = torch.device(args.device)
dtype = DTYPES[args.dtype]
if device.type != "cuda" and dtype == torch.float16:
raise ValueError("float16 benchmark requires CUDA.")
if device.type == "cuda":
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
models = [item.strip() for item in args.models.split(",") if item.strip()]
batch_sizes = parse_int_list(args.batch_sizes)
seq_lens = parse_int_list(args.seq_lens)
metadata = {
"python": platform.python_version(),
"platform": platform.platform(),
"torch": torch.__version__,
"cuda_available": torch.cuda.is_available(),
"cuda_device": torch.cuda.get_device_name(device) if device.type == "cuda" else None,
"nvidia_smi_before": nvidia_smi_snapshot(),
"args": vars(args),
}
rows: list[dict[str, Any]] = []
for model_name in models:
for batch_size in batch_sizes:
for seq_len in seq_lens:
print(f"Benchmarking model={model_name} batch={batch_size} seq={seq_len}")
rows.extend(
benchmark_case(
model_name=model_name,
batch_size=batch_size,
seq_len=seq_len,
d_model=args.d_model,
hidden_dim=args.hidden_dim,
rank=args.rank,
kernel_mode=args.kernel_mode,
dtype=dtype,
device=device,
warmup=args.warmup,
repeats=args.repeats,
run_backward=args.backward,
run_recurrent=args.recurrent,
)
)
metadata["nvidia_smi_after"] = nvidia_smi_snapshot()
print_table(rows)
write_outputs(rows, Path(args.output_dir), metadata)
if __name__ == "__main__":
main()