"""Profile the DPLR direct path with and without the finite-tail correction. This diagnostic does not change model behavior. It answers whether the exact finite convolution term C @ response - z^L (C @ A^L) @ response is a promising speed target or a mathematically important part we should keep. """ from __future__ import annotations import argparse import json import sys import time from pathlib import Path from typing import Any import torch REPO_ROOT = Path(__file__).resolve().parents[1] if str(REPO_ROOT) not in sys.path: sys.path.insert(0, str(REPO_ROOT)) from gamma_space_model import S4TernaryDPLRSSM DTYPES = { "fp32": torch.float32, "float32": torch.float32, "bf16": torch.bfloat16, "bfloat16": torch.bfloat16, "fp16": torch.float16, "float16": torch.float16, } def synchronize(device: torch.device) -> None: if device.type == "cuda": torch.cuda.synchronize(device) def summarize(latencies: list[float], tokens: int) -> dict[str, float]: mean_s = sum(latencies) / len(latencies) return { "mean_ms": mean_s * 1000.0, "min_ms": min(latencies) * 1000.0, "tokens_per_s": tokens / max(mean_s, 1e-12), } def dplr_direct( model: S4TernaryDPLRSSM, x: torch.Tensor, *, finite_tail: bool, ) -> torch.Tensor: batch, seq_len, _ = x.shape del batch original_dtype = x.dtype target_dtype = torch.float32 if x.dtype in {torch.float16, torch.bfloat16} else x.dtype fft_len = 1 << max(1, (2 * seq_len - 1).bit_length()) device = x.device with torch.autocast(device_type=device.type, enabled=False): u_channels = x.transpose(1, 2).to(dtype=target_dtype) u_f = torch.fft.rfft(u_channels, n=fft_len) diag, U, V, B_disc = model._discrete_params(dtype=target_dtype, device=device) A_dense = model._dense_discrete_A_from_params(diag, U, V) C = model.C.to(device=device, dtype=target_dtype) D = model.D.to(device=device, dtype=target_dtype) A_power = torch.linalg.matrix_power(A_dense, seq_len) if finite_tail else None complex_dtype = torch.complex64 if target_dtype != torch.float64 else torch.complex128 freq_count = fft_len // 2 + 1 roots, roots_power = model._frequency_roots(seq_len, fft_len, target_dtype, device) diag_complex = diag.to(dtype=complex_dtype) U_complex = U.to(dtype=complex_dtype) V_complex = V.to(dtype=complex_dtype) B_complex = B_disc.to(dtype=complex_dtype) C_complex = C.to(dtype=complex_dtype) u_freq = u_f.permute(2, 0, 1).to(dtype=complex_dtype) denom = 1.0 - roots[:, None] * diag_complex[None, :] inv_diag = denom.reciprocal() input_term = torch.einsum("nd,fbd->fbn", B_complex, u_freq) inv_input = inv_diag[:, None, :] * input_term omega_u = roots[:, None, None] * U_complex[None, :, :] inv_u = inv_diag[:, :, None] * omega_u vt_inv_u = torch.einsum("nr,fns->frs", V_complex, inv_u) vt_inv_input = torch.einsum("nr,fbn->fbr", V_complex, inv_input) if model.rank == 1: middle = (1.0 + vt_inv_u[:, 0, 0]).reciprocal() correction = ( inv_u[:, None, :, 0] * middle.view(freq_count, 1, 1) * vt_inv_input[:, :, 0].unsqueeze(-1) ) else: rank_eye = torch.eye(model.rank, device=device, dtype=complex_dtype).expand(freq_count, -1, -1) middle = torch.linalg.inv(rank_eye + vt_inv_u) correction = torch.einsum("fns,frs,fbr->fbn", inv_u, middle, vt_inv_input) response = inv_input - correction y_freq = torch.einsum("on,fbn->fbo", C_complex, response) if finite_tail: assert A_power is not None A_power_complex = A_power.to(dtype=complex_dtype) powered_readout = torch.matmul(C_complex, A_power_complex) y_freq = y_freq - ( roots_power.view(freq_count, 1, 1) * torch.einsum("on,fbn->fbo", powered_readout, response) ) y_freq = y_freq + u_freq * D.to(dtype=complex_dtype).view(1, 1, -1) y = torch.fft.irfft(y_freq.permute(1, 2, 0), n=fft_len)[..., :seq_len] return y.transpose(1, 2).to(dtype=original_dtype) def time_variant( fn, *, device: torch.device, warmup: int, repeats: int, tokens: int, ) -> dict[str, float]: for _ in range(warmup): fn() synchronize(device) latencies = [] for _ in range(repeats): synchronize(device) start = time.perf_counter() fn() synchronize(device) latencies.append(time.perf_counter() - start) return summarize(latencies, tokens) def main() -> int: parser = argparse.ArgumentParser(description=__doc__) parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu") parser.add_argument("--dtype", choices=sorted(DTYPES), default="bf16") parser.add_argument("--batch-size", type=int, default=32) parser.add_argument("--seq-len", type=int, default=512) parser.add_argument("--d-model", type=int, default=64) parser.add_argument("--hidden-dim", type=int, default=256) parser.add_argument("--rank", type=int, default=1) parser.add_argument("--warmup", type=int, default=3) parser.add_argument("--repeats", type=int, default=10) parser.add_argument("--output", type=Path, default=None) args = parser.parse_args() device = torch.device(args.device) dtype = DTYPES[args.dtype] model = S4TernaryDPLRSSM( state_dim=args.d_model, hidden_dim=args.hidden_dim, rank=args.rank, kernel_mode="conv", kernel_threshold=1, ).to(device=device) model.train() x = torch.randn(args.batch_size, args.seq_len, args.d_model, device=device, dtype=dtype) tokens = args.batch_size * args.seq_len def exact_forward() -> torch.Tensor: return dplr_direct(model, x, finite_tail=True) def ablated_forward() -> torch.Tensor: return dplr_direct(model, x, finite_tail=False) def exact_backward() -> None: model.zero_grad(set_to_none=True) y = exact_forward() y.square().mean().backward() def ablated_backward() -> None: model.zero_grad(set_to_none=True) y = ablated_forward() y.square().mean().backward() with torch.no_grad(): y_exact = exact_forward() y_ablated = ablated_forward() y_reference, _ = model._forward_convolutional(x, return_state=False) diff = (y_exact.float() - y_ablated.float()).abs() reference_diff = (y_exact.float() - y_reference.float()).abs() exact_norm = y_exact.float().norm().item() diff_norm = diff.norm().item() report: dict[str, Any] = { "config": vars(args) | {"device": str(device), "dtype": str(dtype).replace("torch.", "")}, "forward": { "exact": time_variant( exact_forward, device=device, warmup=args.warmup, repeats=args.repeats, tokens=tokens, ), "finite_tail_ablated": time_variant( ablated_forward, device=device, warmup=args.warmup, repeats=args.repeats, tokens=tokens, ), }, "forward_backward": { "exact": time_variant( exact_backward, device=device, warmup=args.warmup, repeats=args.repeats, tokens=tokens, ), "finite_tail_ablated": time_variant( ablated_backward, device=device, warmup=args.warmup, repeats=args.repeats, tokens=tokens, ), }, "difference": { "max_abs": diff.max().item(), "mean_abs": diff.mean().item(), "exact_norm": exact_norm, "diff_norm": diff_norm, "relative_l2": diff_norm / max(exact_norm, 1e-12), "exact_vs_production_max_abs": reference_diff.max().item(), }, "frequency_grid_cache_entries": len(model._frequency_grid_cache), } text = json.dumps(report, indent=2, sort_keys=True, default=str) print(text) if args.output is not None: args.output.parent.mkdir(parents=True, exist_ok=True) args.output.write_text(text, encoding="utf-8") return 0 if __name__ == "__main__": raise SystemExit(main())