TaoNet-mini-T2 / code /Taotern_SSM /scripts /profile_dplr_direct_steps.py
StarMist0012's picture
Add files using upload-large-folder tool
388fd6e verified
"""Break down the DPLR direct frequency path into timed forward stages.
The whole-path profiler tells us whether the direct convolution path is fast,
but not which internal tensor operation should become the next TileLang/Triton
target. This script mirrors ``S4TernaryDPLRSSM._apply_frequency_response`` and
records per-stage timings without changing model behavior.
"""
from __future__ import annotations
import argparse
import json
import math
import statistics
import sys
import time
from pathlib import Path
from typing import Any, Callable
import torch
REPO_ROOT = Path(__file__).resolve().parents[1]
if str(REPO_ROOT) not in sys.path:
sys.path.insert(0, str(REPO_ROOT))
from gamma_space_model import S4TernaryDPLRSSM
DTYPES = {
"fp32": torch.float32,
"float32": torch.float32,
"bf16": torch.bfloat16,
"bfloat16": torch.bfloat16,
"fp16": torch.float16,
"float16": torch.float16,
}
def synchronize(device: torch.device) -> None:
if device.type == "cuda":
torch.cuda.synchronize(device)
def summarize(values: list[float]) -> dict[str, float]:
return {
"mean_ms": statistics.fmean(values),
"min_ms": min(values),
"max_ms": max(values),
"stdev_ms": statistics.pstdev(values) if len(values) > 1 else 0.0,
}
class StageRecorder:
def __init__(self, device: torch.device) -> None:
self.device = device
self.cuda = device.type == "cuda"
self.events: list[tuple[str, torch.cuda.Event, torch.cuda.Event]] = []
self.cpu_times: list[tuple[str, float]] = []
def measure(self, name: str, fn: Callable[[], Any]) -> Any:
if self.cuda:
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
start.record()
value = fn()
end.record()
self.events.append((name, start, end))
return value
start_time = time.perf_counter()
value = fn()
self.cpu_times.append((name, (time.perf_counter() - start_time) * 1000.0))
return value
def results(self) -> dict[str, float]:
if self.cuda:
torch.cuda.synchronize(self.device)
return {name: start.elapsed_time(end) for name, start, end in self.events}
return dict(self.cpu_times)
def run_profiled_direct(
model: S4TernaryDPLRSSM,
x: torch.Tensor,
*,
seq_len: int,
fft_len: int,
target_dtype: torch.dtype,
device: torch.device,
) -> tuple[torch.Tensor, dict[str, float]]:
recorder = StageRecorder(device)
def input_fft() -> tuple[torch.Tensor, torch.Tensor]:
u_channels = x.transpose(1, 2).to(dtype=target_dtype)
return u_channels, torch.fft.rfft(u_channels, n=fft_len)
u_channels, u_f = recorder.measure("input_fft", input_fft)
diag, U, V, B_disc = recorder.measure(
"discrete_params",
lambda: model._discrete_params(dtype=target_dtype, device=device),
)
def matrix_power() -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
A_dense = model._dense_discrete_A_from_params(diag, U, V)
A_power = torch.linalg.matrix_power(A_dense, seq_len)
C = model.C.to(device=device, dtype=target_dtype)
D = model.D.to(device=device, dtype=target_dtype)
return A_power, C, D
A_power, C, D = recorder.measure("dense_A_power_C_D", matrix_power)
complex_dtype = torch.complex64 if target_dtype != torch.float64 else torch.complex128
freq_count = fft_len // 2 + 1
def roots_and_casts() -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
roots, roots_power = model._frequency_roots(seq_len, fft_len, target_dtype, device)
return (
roots,
roots_power,
diag.to(dtype=complex_dtype),
U.to(dtype=complex_dtype),
V.to(dtype=complex_dtype),
B_disc.to(dtype=complex_dtype),
C.to(dtype=complex_dtype),
)
(
roots,
roots_power,
diag_complex,
U_complex,
V_complex,
B_complex,
C_complex,
) = recorder.measure("roots_and_complex_casts", roots_and_casts)
def diagonal_input_solve() -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
u_freq = u_f.permute(2, 0, 1).to(dtype=complex_dtype)
denom = 1.0 - roots[:, None] * diag_complex[None, :]
inv_diag = denom.reciprocal()
input_term = torch.einsum("nd,fbd->fbn", B_complex, u_freq)
inv_input = inv_diag[:, None, :] * input_term
return u_freq, inv_diag, inv_input
u_freq, inv_diag, inv_input = recorder.measure("diagonal_input_solve", diagonal_input_solve)
def low_rank_solve() -> torch.Tensor:
omega_u = roots[:, None, None] * U_complex[None, :, :]
inv_u = inv_diag[:, :, None] * omega_u
vt_inv_u = torch.einsum("nr,fns->frs", V_complex, inv_u)
vt_inv_input = torch.einsum("nr,fbn->fbr", V_complex, inv_input)
if model.rank == 1:
middle = (1.0 + vt_inv_u[:, 0, 0]).reciprocal()
correction = (
inv_u[:, None, :, 0]
* middle.view(freq_count, 1, 1)
* vt_inv_input[:, :, 0].unsqueeze(-1)
)
else:
rank_eye = torch.eye(model.rank, device=device, dtype=complex_dtype).expand(freq_count, -1, -1)
middle = torch.linalg.inv(rank_eye + vt_inv_u)
correction = torch.einsum("fns,frs,fbr->fbn", inv_u, middle, vt_inv_input)
return inv_input - correction
response = recorder.measure("low_rank_solve", low_rank_solve)
def powered_readout() -> torch.Tensor:
A_power_complex = A_power.to(dtype=complex_dtype)
return torch.matmul(C_complex, A_power_complex)
C_power = recorder.measure("powered_readout", powered_readout)
def output_projection() -> torch.Tensor:
y_freq = torch.einsum("on,fbn->fbo", C_complex, response)
y_freq = y_freq - (
roots_power.view(freq_count, 1, 1)
* torch.einsum("on,fbn->fbo", C_power, response)
)
return y_freq + u_freq * D.to(dtype=complex_dtype).view(1, 1, -1)
y_freq = recorder.measure("output_projection_and_skip", output_projection)
def inverse_fft() -> torch.Tensor:
y = torch.fft.irfft(y_freq.permute(1, 2, 0), n=fft_len)[..., :seq_len]
return y.transpose(1, 2).to(dtype=x.dtype)
y = recorder.measure("inverse_fft", inverse_fft)
del u_channels
return y, recorder.results()
def main() -> int:
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu")
parser.add_argument("--dtype", choices=sorted(DTYPES), default="bf16")
parser.add_argument("--batch-size", type=int, default=4)
parser.add_argument("--seq-len", type=int, default=512)
parser.add_argument("--d-model", type=int, default=64)
parser.add_argument("--hidden-dim", type=int, default=256)
parser.add_argument("--rank", type=int, default=1)
parser.add_argument("--warmup", type=int, default=3)
parser.add_argument("--repeats", type=int, default=10)
parser.add_argument("--output", type=Path, default=None)
args = parser.parse_args()
device = torch.device(args.device)
dtype = DTYPES[args.dtype]
model = S4TernaryDPLRSSM(
state_dim=args.d_model,
hidden_dim=args.hidden_dim,
rank=args.rank,
kernel_mode="conv",
kernel_threshold=1,
).to(device=device)
model.train()
x = torch.randn(args.batch_size, args.seq_len, args.d_model, device=device, dtype=dtype)
target_dtype = torch.float32 if x.dtype in {torch.float16, torch.bfloat16} else x.dtype
fft_len = 1 << max(1, (2 * args.seq_len - 1).bit_length())
with torch.no_grad(), torch.autocast(device_type=device.type, enabled=False):
for _ in range(args.warmup):
run_profiled_direct(
model,
x,
seq_len=args.seq_len,
fft_len=fft_len,
target_dtype=target_dtype,
device=device,
)
synchronize(device)
stage_runs: dict[str, list[float]] = {}
total_ms: list[float] = []
profiled_y: torch.Tensor | None = None
for _ in range(args.repeats):
synchronize(device)
start = time.perf_counter()
profiled_y, stages = run_profiled_direct(
model,
x,
seq_len=args.seq_len,
fft_len=fft_len,
target_dtype=target_dtype,
device=device,
)
synchronize(device)
total_ms.append((time.perf_counter() - start) * 1000.0)
for name, value in stages.items():
stage_runs.setdefault(name, []).append(value)
reference_y, _ = model._forward_convolutional(x, return_state=False)
max_abs_diff = (profiled_y - reference_y).abs().max().item() if profiled_y is not None else math.nan
stage_summary = {name: summarize(values) for name, values in stage_runs.items()}
stage_total_mean = sum(item["mean_ms"] for item in stage_summary.values())
report: dict[str, Any] = {
"config": vars(args) | {"device": str(device), "dtype": str(dtype).replace("torch.", "")},
"fft_len": fft_len,
"target_dtype": str(target_dtype).replace("torch.", ""),
"total_wall": summarize(total_ms),
"stage_total_mean_ms": stage_total_mean,
"stages": stage_summary,
"validation": {"max_abs_diff_vs_forward_convolutional": max_abs_diff},
"frequency_grid_cache_entries": len(model._frequency_grid_cache),
}
text = json.dumps(report, indent=2, sort_keys=True, default=str)
print(text)
if args.output is not None:
args.output.parent.mkdir(parents=True, exist_ok=True)
args.output.write_text(text, encoding="utf-8")
return 0
if __name__ == "__main__":
raise SystemExit(main())