File size: 5,928 Bytes
8e4e948 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 | #!/usr/bin/env python3
"""Latency and peak-VRAM benchmark for the SRT-Adapter v8a.
Measures forward-pass latency and peak GPU memory for:
1. Backbone-only (Qwen2.5-7B, no adapter)
2. Backbone + SRT-Adapter (full forward including all 4 readouts)
Reports tokens/sec and peak GiB at sequence lengths 64, 256, 512, with batch=1.
Usage:
python scripts/benchmark_latency.py
python scripts/benchmark_latency.py --warmup 5 --iters 20
"""
from __future__ import annotations
import argparse
import json
import sys
import time
from pathlib import Path
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
HERE = Path(__file__).resolve().parent
ROOT = HERE.parent
sys.path.insert(0, str((ROOT / "src").resolve()))
from srt.adapter import SRTAdapter # noqa: E402
from srt.config import ( # noqa: E402
SRTConfig, MAHConfig, RRMConfig, BENConfig, CommunityConfig, LossConfig,
)
def build_config(p: Path) -> SRTConfig:
raw = json.loads(p.read_text())
return SRTConfig(
backbone_id=raw["backbone_id"],
backbone_dtype=raw["backbone_dtype"],
mah_layer_indices=list(raw["mah_layer_indices"]),
rrm_inject_indices=list(raw["rrm_inject_indices"]),
community_layer_idx=raw["community_layer_idx"],
num_mah_layers=raw["num_mah_layers"],
mah=MAHConfig(**raw["mah"]),
rrm=RRMConfig(**raw["rrm"]),
ben=BENConfig(**raw["ben"]),
community=CommunityConfig(**raw["community"]),
loss=LossConfig(**{k: v for k, v in raw["loss"].items() if k in LossConfig.__dataclass_fields__}),
)
def time_forward(fn, warmup: int, iters: int) -> tuple[float, float]:
"""Return (mean_sec, peak_bytes) over `iters` after `warmup`."""
for _ in range(warmup):
fn()
torch.cuda.synchronize()
torch.cuda.reset_peak_memory_stats()
torch.cuda.synchronize()
t0 = time.perf_counter()
for _ in range(iters):
fn()
torch.cuda.synchronize()
elapsed = (time.perf_counter() - t0) / iters
peak = torch.cuda.max_memory_allocated()
return elapsed, peak
def fmt_gib(b: int) -> str:
return f"{b / (1024 ** 3):.2f} GiB"
def main() -> None:
ap = argparse.ArgumentParser()
ap.add_argument("--config", default=str(ROOT / "config.json"))
ap.add_argument("--adapter", default=str(ROOT / "adapter.safetensors"))
ap.add_argument("--seq-lens", type=int, nargs="+", default=[64, 256, 512])
ap.add_argument("--warmup", type=int, default=3)
ap.add_argument("--iters", type=int, default=10)
args = ap.parse_args()
if not torch.cuda.is_available():
print("CUDA required.", file=sys.stderr)
sys.exit(1)
device = "cuda"
gpu_name = torch.cuda.get_device_name(0)
print(f"GPU: {gpu_name}")
print(f"torch={torch.__version__}, dtype=bfloat16\n")
config = build_config(Path(args.config))
tok = AutoTokenizer.from_pretrained(config.backbone_id)
# Build a synthetic input pool we can slice
long_text = "The quick brown fox jumps over the lazy dog. " * 200
enc_full = tok(long_text, return_tensors="pt").to(device)
rows = []
# 1. Backbone-only baseline
print("=== Backbone-only (Qwen2.5-7B, bfloat16) ===")
backbone = AutoModelForCausalLM.from_pretrained(
config.backbone_id, torch_dtype=torch.bfloat16
).to(device)
backbone.eval()
for T in args.seq_lens:
ids = enc_full.input_ids[:, :T]
mask = enc_full.attention_mask[:, :T]
torch.cuda.empty_cache()
def fn():
with torch.no_grad():
backbone(input_ids=ids, attention_mask=mask)
sec, peak = time_forward(fn, args.warmup, args.iters)
tps = T / sec
print(f" T={T:4d} {sec*1000:7.2f} ms/fwd {tps:8.1f} tok/s peak={fmt_gib(peak)}")
rows.append({"variant": "backbone_only", "seq_len": T,
"ms_per_forward": sec * 1000, "tokens_per_sec": tps,
"peak_vram_gib": peak / (1024 ** 3)})
del backbone
torch.cuda.empty_cache()
# 2. Backbone + adapter
print("\n=== Backbone + SRT-Adapter v8a ===")
model = SRTAdapter(config).to(device)
if args.adapter.endswith(".safetensors"):
from safetensors.torch import load_file
state = load_file(args.adapter, device=device)
else:
state = torch.load(args.adapter, map_location=device)
model.load_state_dict(state, strict=False)
model.eval()
for T in args.seq_lens:
ids = enc_full.input_ids[:, :T]
mask = enc_full.attention_mask[:, :T]
torch.cuda.empty_cache()
def fn():
with torch.no_grad():
model(input_ids=ids, attention_mask=mask)
sec, peak = time_forward(fn, args.warmup, args.iters)
tps = T / sec
print(f" T={T:4d} {sec*1000:7.2f} ms/fwd {tps:8.1f} tok/s peak={fmt_gib(peak)}")
rows.append({"variant": "backbone_plus_adapter", "seq_len": T,
"ms_per_forward": sec * 1000, "tokens_per_sec": tps,
"peak_vram_gib": peak / (1024 ** 3)})
# Adapter overhead summary
print("\n=== Adapter overhead ===")
by = {(r["variant"], r["seq_len"]): r for r in rows}
for T in args.seq_lens:
b = by[("backbone_only", T)]
a = by[("backbone_plus_adapter", T)]
latency_overhead = (a["ms_per_forward"] / b["ms_per_forward"]) - 1.0
vram_overhead = a["peak_vram_gib"] - b["peak_vram_gib"]
print(f" T={T:4d} latency +{latency_overhead*100:5.1f}% "
f"vram +{vram_overhead:.2f} GiB")
out = ROOT / "benchmarks" / "latency_vram.json"
out.write_text(json.dumps({
"gpu": gpu_name,
"torch_version": torch.__version__,
"warmup": args.warmup,
"iters": args.iters,
"rows": rows,
}, indent=2))
print(f"\nwrote {out}")
if __name__ == "__main__":
main()
|