"""Break down the DPLR direct frequency path into timed forward stages. The whole-path profiler tells us whether the direct convolution path is fast, but not which internal tensor operation should become the next TileLang/Triton target. This script mirrors ``S4TernaryDPLRSSM._apply_frequency_response`` and records per-stage timings without changing model behavior. """ from __future__ import annotations import argparse import json import math import statistics import sys import time from pathlib import Path from typing import Any, Callable import torch REPO_ROOT = Path(__file__).resolve().parents[1] if str(REPO_ROOT) not in sys.path: sys.path.insert(0, str(REPO_ROOT)) from gamma_space_model import S4TernaryDPLRSSM DTYPES = { "fp32": torch.float32, "float32": torch.float32, "bf16": torch.bfloat16, "bfloat16": torch.bfloat16, "fp16": torch.float16, "float16": torch.float16, } def synchronize(device: torch.device) -> None: if device.type == "cuda": torch.cuda.synchronize(device) def summarize(values: list[float]) -> dict[str, float]: return { "mean_ms": statistics.fmean(values), "min_ms": min(values), "max_ms": max(values), "stdev_ms": statistics.pstdev(values) if len(values) > 1 else 0.0, } class StageRecorder: def __init__(self, device: torch.device) -> None: self.device = device self.cuda = device.type == "cuda" self.events: list[tuple[str, torch.cuda.Event, torch.cuda.Event]] = [] self.cpu_times: list[tuple[str, float]] = [] def measure(self, name: str, fn: Callable[[], Any]) -> Any: if self.cuda: start = torch.cuda.Event(enable_timing=True) end = torch.cuda.Event(enable_timing=True) start.record() value = fn() end.record() self.events.append((name, start, end)) return value start_time = time.perf_counter() value = fn() self.cpu_times.append((name, (time.perf_counter() - start_time) * 1000.0)) return value def results(self) -> dict[str, float]: if self.cuda: torch.cuda.synchronize(self.device) return {name: start.elapsed_time(end) for name, start, end in self.events} return dict(self.cpu_times) def run_profiled_direct( model: S4TernaryDPLRSSM, x: torch.Tensor, *, seq_len: int, fft_len: int, target_dtype: torch.dtype, device: torch.device, ) -> tuple[torch.Tensor, dict[str, float]]: recorder = StageRecorder(device) def input_fft() -> tuple[torch.Tensor, torch.Tensor]: u_channels = x.transpose(1, 2).to(dtype=target_dtype) return u_channels, torch.fft.rfft(u_channels, n=fft_len) u_channels, u_f = recorder.measure("input_fft", input_fft) diag, U, V, B_disc = recorder.measure( "discrete_params", lambda: model._discrete_params(dtype=target_dtype, device=device), ) def matrix_power() -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: A_dense = model._dense_discrete_A_from_params(diag, U, V) A_power = torch.linalg.matrix_power(A_dense, seq_len) C = model.C.to(device=device, dtype=target_dtype) D = model.D.to(device=device, dtype=target_dtype) return A_power, C, D A_power, C, D = recorder.measure("dense_A_power_C_D", matrix_power) complex_dtype = torch.complex64 if target_dtype != torch.float64 else torch.complex128 freq_count = fft_len // 2 + 1 def roots_and_casts() -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: roots, roots_power = model._frequency_roots(seq_len, fft_len, target_dtype, device) return ( roots, roots_power, diag.to(dtype=complex_dtype), U.to(dtype=complex_dtype), V.to(dtype=complex_dtype), B_disc.to(dtype=complex_dtype), C.to(dtype=complex_dtype), ) ( roots, roots_power, diag_complex, U_complex, V_complex, B_complex, C_complex, ) = recorder.measure("roots_and_complex_casts", roots_and_casts) def diagonal_input_solve() -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: u_freq = u_f.permute(2, 0, 1).to(dtype=complex_dtype) denom = 1.0 - roots[:, None] * diag_complex[None, :] inv_diag = denom.reciprocal() input_term = torch.einsum("nd,fbd->fbn", B_complex, u_freq) inv_input = inv_diag[:, None, :] * input_term return u_freq, inv_diag, inv_input u_freq, inv_diag, inv_input = recorder.measure("diagonal_input_solve", diagonal_input_solve) def low_rank_solve() -> torch.Tensor: omega_u = roots[:, None, None] * U_complex[None, :, :] inv_u = inv_diag[:, :, None] * omega_u vt_inv_u = torch.einsum("nr,fns->frs", V_complex, inv_u) vt_inv_input = torch.einsum("nr,fbn->fbr", V_complex, inv_input) if model.rank == 1: middle = (1.0 + vt_inv_u[:, 0, 0]).reciprocal() correction = ( inv_u[:, None, :, 0] * middle.view(freq_count, 1, 1) * vt_inv_input[:, :, 0].unsqueeze(-1) ) else: rank_eye = torch.eye(model.rank, device=device, dtype=complex_dtype).expand(freq_count, -1, -1) middle = torch.linalg.inv(rank_eye + vt_inv_u) correction = torch.einsum("fns,frs,fbr->fbn", inv_u, middle, vt_inv_input) return inv_input - correction response = recorder.measure("low_rank_solve", low_rank_solve) def powered_readout() -> torch.Tensor: A_power_complex = A_power.to(dtype=complex_dtype) return torch.matmul(C_complex, A_power_complex) C_power = recorder.measure("powered_readout", powered_readout) def output_projection() -> torch.Tensor: y_freq = torch.einsum("on,fbn->fbo", C_complex, response) y_freq = y_freq - ( roots_power.view(freq_count, 1, 1) * torch.einsum("on,fbn->fbo", C_power, response) ) return y_freq + u_freq * D.to(dtype=complex_dtype).view(1, 1, -1) y_freq = recorder.measure("output_projection_and_skip", output_projection) def inverse_fft() -> torch.Tensor: y = torch.fft.irfft(y_freq.permute(1, 2, 0), n=fft_len)[..., :seq_len] return y.transpose(1, 2).to(dtype=x.dtype) y = recorder.measure("inverse_fft", inverse_fft) del u_channels return y, recorder.results() def main() -> int: parser = argparse.ArgumentParser(description=__doc__) parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu") parser.add_argument("--dtype", choices=sorted(DTYPES), default="bf16") parser.add_argument("--batch-size", type=int, default=4) parser.add_argument("--seq-len", type=int, default=512) parser.add_argument("--d-model", type=int, default=64) parser.add_argument("--hidden-dim", type=int, default=256) parser.add_argument("--rank", type=int, default=1) parser.add_argument("--warmup", type=int, default=3) parser.add_argument("--repeats", type=int, default=10) parser.add_argument("--output", type=Path, default=None) args = parser.parse_args() device = torch.device(args.device) dtype = DTYPES[args.dtype] model = S4TernaryDPLRSSM( state_dim=args.d_model, hidden_dim=args.hidden_dim, rank=args.rank, kernel_mode="conv", kernel_threshold=1, ).to(device=device) model.train() x = torch.randn(args.batch_size, args.seq_len, args.d_model, device=device, dtype=dtype) target_dtype = torch.float32 if x.dtype in {torch.float16, torch.bfloat16} else x.dtype fft_len = 1 << max(1, (2 * args.seq_len - 1).bit_length()) with torch.no_grad(), torch.autocast(device_type=device.type, enabled=False): for _ in range(args.warmup): run_profiled_direct( model, x, seq_len=args.seq_len, fft_len=fft_len, target_dtype=target_dtype, device=device, ) synchronize(device) stage_runs: dict[str, list[float]] = {} total_ms: list[float] = [] profiled_y: torch.Tensor | None = None for _ in range(args.repeats): synchronize(device) start = time.perf_counter() profiled_y, stages = run_profiled_direct( model, x, seq_len=args.seq_len, fft_len=fft_len, target_dtype=target_dtype, device=device, ) synchronize(device) total_ms.append((time.perf_counter() - start) * 1000.0) for name, value in stages.items(): stage_runs.setdefault(name, []).append(value) reference_y, _ = model._forward_convolutional(x, return_state=False) max_abs_diff = (profiled_y - reference_y).abs().max().item() if profiled_y is not None else math.nan stage_summary = {name: summarize(values) for name, values in stage_runs.items()} stage_total_mean = sum(item["mean_ms"] for item in stage_summary.values()) report: dict[str, Any] = { "config": vars(args) | {"device": str(device), "dtype": str(dtype).replace("torch.", "")}, "fft_len": fft_len, "target_dtype": str(target_dtype).replace("torch.", ""), "total_wall": summarize(total_ms), "stage_total_mean_ms": stage_total_mean, "stages": stage_summary, "validation": {"max_abs_diff_vs_forward_convolutional": max_abs_diff}, "frequency_grid_cache_entries": len(model._frequency_grid_cache), } text = json.dumps(report, indent=2, sort_keys=True, default=str) print(text) if args.output is not None: args.output.parent.mkdir(parents=True, exist_ok=True) args.output.write_text(text, encoding="utf-8") return 0 if __name__ == "__main__": raise SystemExit(main())