TaoNet-mini-T2 / code /Taotern_SSM /scripts /profile_dplr_frequency_path.py
StarMist0012's picture
Add files using upload-large-folder tool
388fd6e verified
"""Profile the DPLR convolutional frequency path.
This is a small remote-friendly profiler for choosing TileLang/Triton kernel
targets. It focuses on S4TernaryDPLRSSM rather than the older Gamma fallback
because this is the SSM core used by the TaoNet comparison work.
"""
from __future__ import annotations
import argparse
import json
import sys
import time
from contextlib import nullcontext
from pathlib import Path
from typing import Any
import torch
REPO_ROOT = Path(__file__).resolve().parents[1]
if str(REPO_ROOT) not in sys.path:
sys.path.insert(0, str(REPO_ROOT))
from gamma_space_model import S4TernaryDPLRSSM
DTYPES = {
"fp32": torch.float32,
"float32": torch.float32,
"bf16": torch.bfloat16,
"bfloat16": torch.bfloat16,
"fp16": torch.float16,
"float16": torch.float16,
}
def synchronize(device: torch.device) -> None:
if device.type == "cuda":
torch.cuda.synchronize(device)
def memory_stats(device: torch.device) -> dict[str, float | None]:
if device.type != "cuda":
return {"peak_allocated_mb": None, "peak_reserved_mb": None}
return {
"peak_allocated_mb": torch.cuda.max_memory_allocated(device) / (1024**2),
"peak_reserved_mb": torch.cuda.max_memory_reserved(device) / (1024**2),
}
def run_timed(fn, *, device: torch.device, warmup: int, repeats: int) -> dict[str, float]:
for _ in range(warmup):
fn()
synchronize(device)
latencies = []
for _ in range(repeats):
if device.type == "cuda":
torch.cuda.reset_peak_memory_stats(device)
synchronize(device)
start = time.perf_counter()
fn()
synchronize(device)
latencies.append(time.perf_counter() - start)
return {
"mean_ms": sum(latencies) / len(latencies) * 1000.0,
"min_ms": min(latencies) * 1000.0,
}
def profiler_table(prof: torch.profiler.profile, row_limit: int) -> list[dict[str, Any]]:
rows = []
for event in prof.key_averages().table(
sort_by="cuda_time_total",
row_limit=row_limit,
).splitlines():
rows.append({"row": event})
return rows
def main() -> int:
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu")
parser.add_argument("--dtype", choices=sorted(DTYPES), default="bf16")
parser.add_argument("--batch-size", type=int, default=4)
parser.add_argument("--seq-len", type=int, default=512)
parser.add_argument("--d-model", type=int, default=64)
parser.add_argument("--hidden-dim", type=int, default=256)
parser.add_argument("--rank", type=int, default=1)
parser.add_argument("--warmup", type=int, default=2)
parser.add_argument("--repeats", type=int, default=5)
parser.add_argument("--profile", action="store_true")
parser.add_argument("--row-limit", type=int, default=20)
parser.add_argument("--method", choices=["forward", "direct", "transfer"], default="forward")
parser.add_argument("--output", type=Path, default=None)
args = parser.parse_args()
device = torch.device(args.device)
dtype = DTYPES[args.dtype]
model = S4TernaryDPLRSSM(
state_dim=args.d_model,
hidden_dim=args.hidden_dim,
rank=args.rank,
kernel_mode="conv",
kernel_threshold=1,
).to(device=device)
model.train()
x = torch.randn(args.batch_size, args.seq_len, args.d_model, device=device, dtype=dtype)
autocast_enabled = device.type == "cuda" and dtype in {torch.float16, torch.bfloat16}
def autocast_context():
if not autocast_enabled:
return nullcontext()
return torch.autocast(device_type=device.type, dtype=dtype, enabled=True)
def apply_model() -> torch.Tensor:
if args.method == "forward":
y, _ = model(x, return_state=False)
return y
fft_dtype = torch.float32 if x.dtype in {torch.float16, torch.bfloat16} else x.dtype
fft_len = 1 << max(1, (2 * args.seq_len - 1).bit_length())
with torch.autocast(device_type=device.type, enabled=False):
u_channels = x.transpose(1, 2).to(dtype=fft_dtype)
u_f = torch.fft.rfft(u_channels, n=fft_len)
if args.method == "direct":
y_f = model._apply_frequency_response(
u_f=u_f,
seq_len=args.seq_len,
fft_len=fft_len,
dtype=fft_dtype,
device=device,
)
else:
transfer = model._compute_frequency_response(
seq_len=args.seq_len,
fft_len=fft_len,
dtype=fft_dtype,
device=device,
use_cache=False,
)
y_f = torch.einsum("foi,bif->bof", transfer, u_f)
y = torch.fft.irfft(y_f, n=fft_len)[..., : args.seq_len]
return y.transpose(1, 2).to(dtype=x.dtype)
def forward_only() -> None:
with torch.no_grad():
with autocast_context():
y = apply_model()
y.sum().item()
def forward_backward() -> None:
model.zero_grad(set_to_none=True)
with autocast_context():
y = apply_model()
loss = y.square().mean()
loss.backward()
forward_stats = run_timed(
forward_only,
device=device,
warmup=args.warmup,
repeats=args.repeats,
)
forward_backward_stats = run_timed(
forward_backward,
device=device,
warmup=args.warmup,
repeats=args.repeats,
)
tokens = args.batch_size * args.seq_len
report: dict[str, Any] = {
"config": vars(args) | {"device": str(device), "dtype": str(dtype).replace("torch.", "")},
"forward": {
**forward_stats,
"tokens_per_s": tokens / max(forward_stats["mean_ms"] / 1000.0, 1e-12),
},
"forward_backward": {
**forward_backward_stats,
"tokens_per_s": tokens / max(forward_backward_stats["mean_ms"] / 1000.0, 1e-12),
**memory_stats(device),
},
"frequency_grid_cache_entries": len(model._frequency_grid_cache),
}
if args.profile:
activities = [torch.profiler.ProfilerActivity.CPU]
if device.type == "cuda":
activities.append(torch.profiler.ProfilerActivity.CUDA)
with torch.profiler.profile(activities=activities, record_shapes=True) as prof:
forward_backward()
report["profiler_table"] = profiler_table(prof, args.row_limit)
text = json.dumps(report, indent=2, sort_keys=True, default=str)
print(text)
if args.output is not None:
args.output.parent.mkdir(parents=True, exist_ok=True)
args.output.write_text(text, encoding="utf-8")
return 0
if __name__ == "__main__":
raise SystemExit(main())