"""Profile the DPLR convolutional frequency path. This is a small remote-friendly profiler for choosing TileLang/Triton kernel targets. It focuses on S4TernaryDPLRSSM rather than the older Gamma fallback because this is the SSM core used by the TaoNet comparison work. """ from __future__ import annotations import argparse import json import sys import time from contextlib import nullcontext from pathlib import Path from typing import Any import torch REPO_ROOT = Path(__file__).resolve().parents[1] if str(REPO_ROOT) not in sys.path: sys.path.insert(0, str(REPO_ROOT)) from gamma_space_model import S4TernaryDPLRSSM DTYPES = { "fp32": torch.float32, "float32": torch.float32, "bf16": torch.bfloat16, "bfloat16": torch.bfloat16, "fp16": torch.float16, "float16": torch.float16, } def synchronize(device: torch.device) -> None: if device.type == "cuda": torch.cuda.synchronize(device) def memory_stats(device: torch.device) -> dict[str, float | None]: if device.type != "cuda": return {"peak_allocated_mb": None, "peak_reserved_mb": None} return { "peak_allocated_mb": torch.cuda.max_memory_allocated(device) / (1024**2), "peak_reserved_mb": torch.cuda.max_memory_reserved(device) / (1024**2), } def run_timed(fn, *, device: torch.device, warmup: int, repeats: int) -> dict[str, float]: for _ in range(warmup): fn() synchronize(device) latencies = [] for _ in range(repeats): if device.type == "cuda": torch.cuda.reset_peak_memory_stats(device) synchronize(device) start = time.perf_counter() fn() synchronize(device) latencies.append(time.perf_counter() - start) return { "mean_ms": sum(latencies) / len(latencies) * 1000.0, "min_ms": min(latencies) * 1000.0, } def profiler_table(prof: torch.profiler.profile, row_limit: int) -> list[dict[str, Any]]: rows = [] for event in prof.key_averages().table( sort_by="cuda_time_total", row_limit=row_limit, ).splitlines(): rows.append({"row": event}) return rows def main() -> int: parser = argparse.ArgumentParser(description=__doc__) parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu") parser.add_argument("--dtype", choices=sorted(DTYPES), default="bf16") parser.add_argument("--batch-size", type=int, default=4) parser.add_argument("--seq-len", type=int, default=512) parser.add_argument("--d-model", type=int, default=64) parser.add_argument("--hidden-dim", type=int, default=256) parser.add_argument("--rank", type=int, default=1) parser.add_argument("--warmup", type=int, default=2) parser.add_argument("--repeats", type=int, default=5) parser.add_argument("--profile", action="store_true") parser.add_argument("--row-limit", type=int, default=20) parser.add_argument("--method", choices=["forward", "direct", "transfer"], default="forward") parser.add_argument("--output", type=Path, default=None) args = parser.parse_args() device = torch.device(args.device) dtype = DTYPES[args.dtype] model = S4TernaryDPLRSSM( state_dim=args.d_model, hidden_dim=args.hidden_dim, rank=args.rank, kernel_mode="conv", kernel_threshold=1, ).to(device=device) model.train() x = torch.randn(args.batch_size, args.seq_len, args.d_model, device=device, dtype=dtype) autocast_enabled = device.type == "cuda" and dtype in {torch.float16, torch.bfloat16} def autocast_context(): if not autocast_enabled: return nullcontext() return torch.autocast(device_type=device.type, dtype=dtype, enabled=True) def apply_model() -> torch.Tensor: if args.method == "forward": y, _ = model(x, return_state=False) return y fft_dtype = torch.float32 if x.dtype in {torch.float16, torch.bfloat16} else x.dtype fft_len = 1 << max(1, (2 * args.seq_len - 1).bit_length()) with torch.autocast(device_type=device.type, enabled=False): u_channels = x.transpose(1, 2).to(dtype=fft_dtype) u_f = torch.fft.rfft(u_channels, n=fft_len) if args.method == "direct": y_f = model._apply_frequency_response( u_f=u_f, seq_len=args.seq_len, fft_len=fft_len, dtype=fft_dtype, device=device, ) else: transfer = model._compute_frequency_response( seq_len=args.seq_len, fft_len=fft_len, dtype=fft_dtype, device=device, use_cache=False, ) y_f = torch.einsum("foi,bif->bof", transfer, u_f) y = torch.fft.irfft(y_f, n=fft_len)[..., : args.seq_len] return y.transpose(1, 2).to(dtype=x.dtype) def forward_only() -> None: with torch.no_grad(): with autocast_context(): y = apply_model() y.sum().item() def forward_backward() -> None: model.zero_grad(set_to_none=True) with autocast_context(): y = apply_model() loss = y.square().mean() loss.backward() forward_stats = run_timed( forward_only, device=device, warmup=args.warmup, repeats=args.repeats, ) forward_backward_stats = run_timed( forward_backward, device=device, warmup=args.warmup, repeats=args.repeats, ) tokens = args.batch_size * args.seq_len report: dict[str, Any] = { "config": vars(args) | {"device": str(device), "dtype": str(dtype).replace("torch.", "")}, "forward": { **forward_stats, "tokens_per_s": tokens / max(forward_stats["mean_ms"] / 1000.0, 1e-12), }, "forward_backward": { **forward_backward_stats, "tokens_per_s": tokens / max(forward_backward_stats["mean_ms"] / 1000.0, 1e-12), **memory_stats(device), }, "frequency_grid_cache_entries": len(model._frequency_grid_cache), } if args.profile: activities = [torch.profiler.ProfilerActivity.CPU] if device.type == "cuda": activities.append(torch.profiler.ProfilerActivity.CUDA) with torch.profiler.profile(activities=activities, record_shapes=True) as prof: forward_backward() report["profiler_table"] = profiler_table(prof, args.row_limit) text = json.dumps(report, indent=2, sort_keys=True, default=str) print(text) if args.output is not None: args.output.parent.mkdir(parents=True, exist_ok=True) args.output.write_text(text, encoding="utf-8") return 0 if __name__ == "__main__": raise SystemExit(main())