TaoNet-mini-T2 / code /Taotern_SSM /scripts /profile_dplr_finite_tail_ablation.py
StarMist0012's picture
Add files using upload-large-folder tool
388fd6e verified
"""Profile the DPLR direct path with and without the finite-tail correction.
This diagnostic does not change model behavior. It answers whether the exact
finite convolution term
C @ response - z^L (C @ A^L) @ response
is a promising speed target or a mathematically important part we should keep.
"""
from __future__ import annotations
import argparse
import json
import sys
import time
from pathlib import Path
from typing import Any
import torch
REPO_ROOT = Path(__file__).resolve().parents[1]
if str(REPO_ROOT) not in sys.path:
sys.path.insert(0, str(REPO_ROOT))
from gamma_space_model import S4TernaryDPLRSSM
DTYPES = {
"fp32": torch.float32,
"float32": torch.float32,
"bf16": torch.bfloat16,
"bfloat16": torch.bfloat16,
"fp16": torch.float16,
"float16": torch.float16,
}
def synchronize(device: torch.device) -> None:
if device.type == "cuda":
torch.cuda.synchronize(device)
def summarize(latencies: list[float], tokens: int) -> dict[str, float]:
mean_s = sum(latencies) / len(latencies)
return {
"mean_ms": mean_s * 1000.0,
"min_ms": min(latencies) * 1000.0,
"tokens_per_s": tokens / max(mean_s, 1e-12),
}
def dplr_direct(
model: S4TernaryDPLRSSM,
x: torch.Tensor,
*,
finite_tail: bool,
) -> torch.Tensor:
batch, seq_len, _ = x.shape
del batch
original_dtype = x.dtype
target_dtype = torch.float32 if x.dtype in {torch.float16, torch.bfloat16} else x.dtype
fft_len = 1 << max(1, (2 * seq_len - 1).bit_length())
device = x.device
with torch.autocast(device_type=device.type, enabled=False):
u_channels = x.transpose(1, 2).to(dtype=target_dtype)
u_f = torch.fft.rfft(u_channels, n=fft_len)
diag, U, V, B_disc = model._discrete_params(dtype=target_dtype, device=device)
A_dense = model._dense_discrete_A_from_params(diag, U, V)
C = model.C.to(device=device, dtype=target_dtype)
D = model.D.to(device=device, dtype=target_dtype)
A_power = torch.linalg.matrix_power(A_dense, seq_len) if finite_tail else None
complex_dtype = torch.complex64 if target_dtype != torch.float64 else torch.complex128
freq_count = fft_len // 2 + 1
roots, roots_power = model._frequency_roots(seq_len, fft_len, target_dtype, device)
diag_complex = diag.to(dtype=complex_dtype)
U_complex = U.to(dtype=complex_dtype)
V_complex = V.to(dtype=complex_dtype)
B_complex = B_disc.to(dtype=complex_dtype)
C_complex = C.to(dtype=complex_dtype)
u_freq = u_f.permute(2, 0, 1).to(dtype=complex_dtype)
denom = 1.0 - roots[:, None] * diag_complex[None, :]
inv_diag = denom.reciprocal()
input_term = torch.einsum("nd,fbd->fbn", B_complex, u_freq)
inv_input = inv_diag[:, None, :] * input_term
omega_u = roots[:, None, None] * U_complex[None, :, :]
inv_u = inv_diag[:, :, None] * omega_u
vt_inv_u = torch.einsum("nr,fns->frs", V_complex, inv_u)
vt_inv_input = torch.einsum("nr,fbn->fbr", V_complex, inv_input)
if model.rank == 1:
middle = (1.0 + vt_inv_u[:, 0, 0]).reciprocal()
correction = (
inv_u[:, None, :, 0]
* middle.view(freq_count, 1, 1)
* vt_inv_input[:, :, 0].unsqueeze(-1)
)
else:
rank_eye = torch.eye(model.rank, device=device, dtype=complex_dtype).expand(freq_count, -1, -1)
middle = torch.linalg.inv(rank_eye + vt_inv_u)
correction = torch.einsum("fns,frs,fbr->fbn", inv_u, middle, vt_inv_input)
response = inv_input - correction
y_freq = torch.einsum("on,fbn->fbo", C_complex, response)
if finite_tail:
assert A_power is not None
A_power_complex = A_power.to(dtype=complex_dtype)
powered_readout = torch.matmul(C_complex, A_power_complex)
y_freq = y_freq - (
roots_power.view(freq_count, 1, 1)
* torch.einsum("on,fbn->fbo", powered_readout, response)
)
y_freq = y_freq + u_freq * D.to(dtype=complex_dtype).view(1, 1, -1)
y = torch.fft.irfft(y_freq.permute(1, 2, 0), n=fft_len)[..., :seq_len]
return y.transpose(1, 2).to(dtype=original_dtype)
def time_variant(
fn,
*,
device: torch.device,
warmup: int,
repeats: int,
tokens: int,
) -> dict[str, float]:
for _ in range(warmup):
fn()
synchronize(device)
latencies = []
for _ in range(repeats):
synchronize(device)
start = time.perf_counter()
fn()
synchronize(device)
latencies.append(time.perf_counter() - start)
return summarize(latencies, tokens)
def main() -> int:
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu")
parser.add_argument("--dtype", choices=sorted(DTYPES), default="bf16")
parser.add_argument("--batch-size", type=int, default=32)
parser.add_argument("--seq-len", type=int, default=512)
parser.add_argument("--d-model", type=int, default=64)
parser.add_argument("--hidden-dim", type=int, default=256)
parser.add_argument("--rank", type=int, default=1)
parser.add_argument("--warmup", type=int, default=3)
parser.add_argument("--repeats", type=int, default=10)
parser.add_argument("--output", type=Path, default=None)
args = parser.parse_args()
device = torch.device(args.device)
dtype = DTYPES[args.dtype]
model = S4TernaryDPLRSSM(
state_dim=args.d_model,
hidden_dim=args.hidden_dim,
rank=args.rank,
kernel_mode="conv",
kernel_threshold=1,
).to(device=device)
model.train()
x = torch.randn(args.batch_size, args.seq_len, args.d_model, device=device, dtype=dtype)
tokens = args.batch_size * args.seq_len
def exact_forward() -> torch.Tensor:
return dplr_direct(model, x, finite_tail=True)
def ablated_forward() -> torch.Tensor:
return dplr_direct(model, x, finite_tail=False)
def exact_backward() -> None:
model.zero_grad(set_to_none=True)
y = exact_forward()
y.square().mean().backward()
def ablated_backward() -> None:
model.zero_grad(set_to_none=True)
y = ablated_forward()
y.square().mean().backward()
with torch.no_grad():
y_exact = exact_forward()
y_ablated = ablated_forward()
y_reference, _ = model._forward_convolutional(x, return_state=False)
diff = (y_exact.float() - y_ablated.float()).abs()
reference_diff = (y_exact.float() - y_reference.float()).abs()
exact_norm = y_exact.float().norm().item()
diff_norm = diff.norm().item()
report: dict[str, Any] = {
"config": vars(args) | {"device": str(device), "dtype": str(dtype).replace("torch.", "")},
"forward": {
"exact": time_variant(
exact_forward,
device=device,
warmup=args.warmup,
repeats=args.repeats,
tokens=tokens,
),
"finite_tail_ablated": time_variant(
ablated_forward,
device=device,
warmup=args.warmup,
repeats=args.repeats,
tokens=tokens,
),
},
"forward_backward": {
"exact": time_variant(
exact_backward,
device=device,
warmup=args.warmup,
repeats=args.repeats,
tokens=tokens,
),
"finite_tail_ablated": time_variant(
ablated_backward,
device=device,
warmup=args.warmup,
repeats=args.repeats,
tokens=tokens,
),
},
"difference": {
"max_abs": diff.max().item(),
"mean_abs": diff.mean().item(),
"exact_norm": exact_norm,
"diff_norm": diff_norm,
"relative_l2": diff_norm / max(exact_norm, 1e-12),
"exact_vs_production_max_abs": reference_diff.max().item(),
},
"frequency_grid_cache_entries": len(model._frequency_grid_cache),
}
text = json.dumps(report, indent=2, sort_keys=True, default=str)
print(text)
if args.output is not None:
args.output.parent.mkdir(parents=True, exist_ok=True)
args.output.write_text(text, encoding="utf-8")
return 0
if __name__ == "__main__":
raise SystemExit(main())