Molbap's picture
Molbap HF Staff
Upload folder using huggingface_hub
e199518 verified
Raw
History Blame Contribute Delete
14.3 kB
"""Benchmark resize+normalize: separable / fused triton vs torchvision vs the real processor.
PYTHONPATH=../torch-ext python benchmark.py --processor google/siglip-so400m-patch14-384
PYTHONPATH=../torch-ext python benchmark.py --n 32 --out 384 384 --interp bicubic --antialias
Prints parity (vs torchvision-float) per backend, then ms/iter for each path. Needs CUDA.
"""
import argparse
import sys
import time
# Hide `kernels` from transformers: this worktree builds kernels.LayerRepository without a version,
# which newer `kernels` rejects at import. Preprocessing needs no hub layer kernels.
sys.modules["kernels"] = None
import torch
import torchvision.transforms.v2.functional as tvF
from torchvision.io import ImageReadMode, decode_jpeg, encode_jpeg
from torchvision.transforms import InterpolationMode
from kernel_image_resize import resize_normalize
from kernel_image_resize._pack import PIL_RESAMPLE_TO_INTERP, max_taps
_TV_INTERP = {"bilinear": InterpolationMode.BILINEAR, "bicubic": InterpolationMode.BICUBIC}
def make_ragged_images(n, device, min_res, max_res, seed=0):
g = torch.Generator(device="cpu").manual_seed(seed)
images = []
for _ in range(n):
h = int(torch.randint(min_res, max_res + 1, (1,), generator=g).item())
w = int(torch.randint(min_res, max_res + 1, (1,), generator=g).item())
images.append(torch.randint(0, 256, (3, h, w), generator=g, dtype=torch.uint8).to(device))
return images
def torchvision_reference(images, out_h, out_w, mean, std, rescale, interp, antialias):
mode = _TV_INTERP[interp]
mean_t = torch.tensor(mean, device=images[0].device).view(3, 1, 1)
std_t = torch.tensor(std, device=images[0].device).view(3, 1, 1)
outs = []
for img in images:
r = tvF.resize(img.float(), [out_h, out_w], interpolation=mode, antialias=antialias)
outs.append((r * rescale - mean_t) / std_t)
return torch.stack(outs)
def build_compiled_reference(out_h, out_w, mean, std, rescale, interp, antialias, device):
"""torch.compile(dynamic=True) of a per-image float resize+normalize."""
import torch.nn.functional as F
mean_t = torch.tensor(mean, device=device).view(3, 1, 1)
std_t = torch.tensor(std, device=device).view(3, 1, 1)
mode = "bicubic" if interp == "bicubic" else "bilinear"
def _one(img):
r = F.interpolate(img.unsqueeze(0).float(), size=(out_h, out_w), mode=mode, antialias=antialias, align_corners=False)
return (r.squeeze(0) * rescale - mean_t) / std_t
compiled = torch.compile(_one, dynamic=True)
def run(images):
return torch.stack([compiled(img) for img in images])
return run
def pad_stack(images):
"""Pad ragged CHW images to the batch-max H/W and stack into (N, C, Hmax, Wmax)."""
c = images[0].shape[0]
max_h = max(img.shape[1] for img in images)
max_w = max(img.shape[2] for img in images)
out = torch.zeros(len(images), c, max_h, max_w, dtype=images[0].dtype, device=images[0].device)
for i, img in enumerate(images):
out[i, :, : img.shape[1], : img.shape[2]] = img
return out
def build_packed_compiled_reference(out_h, out_w, mean, std, rescale, interp, antialias, device):
"""torch.compile of a single batched resize+normalize over a stacked (N, C, H, W) tensor."""
import torch.nn.functional as F
mean_t = torch.tensor(mean, device=device).view(1, 3, 1, 1)
std_t = torch.tensor(std, device=device).view(1, 3, 1, 1)
mode = "bicubic" if interp == "bicubic" else "bilinear"
def _batch(stacked):
r = F.interpolate(stacked.float(), size=(out_h, out_w), mode=mode, antialias=antialias, align_corners=False)
return (r * rescale - mean_t) / std_t
return torch.compile(_batch, dynamic=True)
def run_inference(model_id, images, block, iters, device):
"""End-to-end: preprocess (processor / separable / fused / compiled) -> vision features (bf16 forward).
Checks each kernel feeds the model with no feature drift and times the full pipeline."""
from transformers import AutoModel
proc, (out_h, out_w, mean, std, rescale, interp, antialias) = load_processor_config(model_id)
model = AutoModel.from_pretrained(model_id).to(device=device, dtype=torch.bfloat16).eval()
vision = model.vision_model
kk = dict(size=(out_h, out_w), image_mean=mean, image_std=std, rescale_factor=rescale,
resample=interp, antialias=antialias, block=block)
@torch.no_grad()
def features(pixel_values):
out = vision(pixel_values=pixel_values.to(model.dtype))
pooled = getattr(out, "pooler_output", None)
return pooled if pooled is not None else out.last_hidden_state
compiled_one = build_compiled_reference(out_h, out_w, mean, std, rescale, interp, antialias, device)
methods = {
"processor": lambda: proc(images, return_tensors="pt", device=device)["pixel_values"],
"separable": lambda: resize_normalize(images, backend="separable", **kk),
"fused": lambda: resize_normalize(images, backend="fused", **kk),
"compiled": lambda: compiled_one(images),
}
methods["compiled"]() # warmup the compiled artifact
methods["compiled"]()
torch.cuda.synchronize()
print(f"\n[infer] {model_id} out={out_h}x{out_w} forward dtype=bfloat16")
base = features(methods["processor"]())
base_scale = base.abs().max().item()
for name in ("separable", "fused", "compiled"):
d = (features(methods[name]()) - base).abs().max().item()
print(f"[infer parity] features {name} vs processor: max|Δ| = {d:.2e} ({d / base_scale:.1%} of feature max)")
# forward is timed on a FIXED precomputed tensor, so it is method-independent by construction;
# if it varies across rows, the preprocessor's output (dtype/contiguity) is hurting the model.
print("[infer] ms/iter: preprocess forward(fixed input) preprocess+forward")
for name, preprocess in methods.items():
pixel_values = preprocess()
pre = _time(preprocess, iters, device)
fwd = _time(lambda pixel_values=pixel_values: features(pixel_values), iters, device)
e2e = _time(lambda preprocess=preprocess: features(preprocess()), iters, device)
print(f" {name:10s} {pre:8.3f} {fwd:8.3f} {e2e:8.3f}")
def run_decode(images_cpu, out_h, out_w, mean, std, rescale, interp, antialias, block, iters, device):
"""Data-path table from JPEG bytes: CPU decode (libjpeg) vs GPU decode (nvJPEG) + the kernel.
decoders differ at the pixel level (nvJPEG vs libjpeg), so this measures wall-clock, not parity.
"""
jpeg = [encode_jpeg(img, quality=95) for img in images_cpu]
avg_kb = sum(b.numel() for b in jpeg) / len(jpeg) / 1024
kk = dict(size=(out_h, out_w), image_mean=mean, image_std=std, rescale_factor=rescale,
resample=interp, antialias=antialias, block=block)
def cpu_decode_kernel():
imgs = [decode_jpeg(b, mode=ImageReadMode.RGB).to(device) for b in jpeg]
return resize_normalize(imgs, backend="separable", **kk)
def gpu_decode_kernel():
imgs = decode_jpeg(jpeg, mode=ImageReadMode.RGB, device=device)
return resize_normalize(imgs, backend="separable", **kk)
def gpu_decode_torchvision():
imgs = decode_jpeg(jpeg, mode=ImageReadMode.RGB, device=device)
return torchvision_reference(imgs, out_h, out_w, mean, std, rescale, interp, antialias)
def cpu_decode_torchvision():
imgs = [decode_jpeg(b, mode=ImageReadMode.RGB).to(device) for b in jpeg]
return torchvision_reference(imgs, out_h, out_w, mean, std, rescale, interp, antialias)
print(f"\n[decode] N={len(jpeg)} avg={avg_kb:.0f} KB/img out={out_h}x{out_w} (from JPEG bytes, ms/iter)")
print(f" CPU decode + torchvision resize : {_time(cpu_decode_torchvision, iters, device):8.3f} [status quo data path]")
print(f" CPU decode + separable kernel : {_time(cpu_decode_kernel, iters, device):8.3f}")
print(f" GPU decode (nvJPEG) + tv resize : {_time(gpu_decode_torchvision, iters, device):8.3f} [GPU pipeline, tv resize]")
print(f" GPU decode (nvJPEG) + kernel : {_time(gpu_decode_kernel, iters, device):8.3f} [GPU pipeline, kernel resize]")
def load_processor_config(name):
from transformers import AutoImageProcessor
proc = AutoImageProcessor.from_pretrained(name, backend="torchvision")
size = proc.size
if "height" not in size or "width" not in size:
raise ValueError(f"{name}: size={size} is not a fixed (height, width)")
out_h, out_w = size["height"], size["width"]
interp = PIL_RESAMPLE_TO_INTERP.get(int(proc.resample))
rescale = float(proc.rescale_factor) if getattr(proc, "do_rescale", True) else 1.0
antialias = bool(getattr(proc, "antialias", True))
return proc, (out_h, out_w, list(proc.image_mean), list(proc.image_std), rescale, interp, antialias)
def _time(fn, iters, device):
for _ in range(3):
fn()
if device.type == "cuda":
torch.cuda.synchronize()
start, end = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True)
start.record()
for _ in range(iters):
fn()
end.record()
torch.cuda.synchronize()
return start.elapsed_time(end) / iters
t0 = time.perf_counter()
for _ in range(iters):
fn()
return (time.perf_counter() - t0) / iters * 1e3
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--processor", default=None)
parser.add_argument("--n", type=int, default=32)
parser.add_argument("--out", type=int, nargs=2, default=[384, 384], metavar=("H", "W"))
parser.add_argument("--interp", choices=["bilinear", "bicubic"], default="bicubic")
parser.add_argument("--antialias", action="store_true")
parser.add_argument("--min-res", type=int, default=384)
parser.add_argument("--max-res", type=int, default=1024)
parser.add_argument("--iters", type=int, default=50)
parser.add_argument("--block", type=int, default=256)
parser.add_argument("--tol", type=float, default=3e-3)
parser.add_argument("--infer", action="store_true", help="end-to-end Siglip2 inference comparison (bf16 forward)")
parser.add_argument("--model", default="google/siglip2-base-patch16-224", help="model for --infer")
parser.add_argument("--decode", action="store_true", help="JPEG-decode data-path table (CPU vs GPU/nvJPEG) and stop")
args = parser.parse_args()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if device.type != "cuda":
print("benchmark needs CUDA.")
return
proc = None
if args.processor:
proc, (out_h, out_w, mean, std, rescale, interp, antialias) = load_processor_config(args.processor)
print(f"processor={args.processor} -> out={out_h}x{out_w} interp={interp} antialias={antialias}")
else:
out_h, out_w = args.out
mean = [0.48145466, 0.4578275, 0.40821073]
std = [0.26862954, 0.26130258, 0.27577711]
rescale = 1.0 / 255.0
interp, antialias = args.interp, args.antialias
images = make_ragged_images(args.n, device, args.min_res, args.max_res)
taps = (max_taps(images, out_h, 1, interp, antialias), max_taps(images, out_w, 2, interp, antialias))
print(f"N={args.n} in∈[{args.min_res},{args.max_res}]² ragged out={out_h}x{out_w} "
f"interp={interp} antialias={antialias} max_taps={taps} iters={args.iters}\n")
if args.decode:
images_cpu = make_ragged_images(args.n, torch.device("cpu"), args.min_res, args.max_res)
run_decode(images_cpu, out_h, out_w, mean, std, rescale, interp, antialias, args.block, args.iters, device)
return
ref = torchvision_reference(images, out_h, out_w, mean, std, rescale, interp, antialias)
common = dict(size=(out_h, out_w), image_mean=mean, image_std=std, rescale_factor=rescale,
resample=interp, antialias=antialias, block=args.block)
for backend in ("fused", "separable"):
got = resize_normalize(images, backend=backend, **common)
d = (got - ref).abs().max().item()
print(f"[parity] {backend:9s} vs torchvision(float): max|Δ| = {d:.2e} "
f"({'PASS' if d < args.tol else 'FAIL'} @ tol={args.tol})")
print()
compiled_run = build_compiled_reference(out_h, out_w, mean, std, rescale, interp, antialias, device)
packed = pad_stack(images)
packed_compiled_run = build_packed_compiled_reference(out_h, out_w, mean, std, rescale, interp, antialias, device)
t0 = time.perf_counter()
compiled_run(images)
compiled_run(images)
packed_compiled_run(packed)
packed_compiled_run(packed)
torch.cuda.synchronize()
t_warmup = (time.perf_counter() - t0) * 1e3
t_eager = _time(lambda: torchvision_reference(images, out_h, out_w, mean, std, rescale, interp, antialias), args.iters, device)
t_comp = _time(lambda: compiled_run(images), args.iters, device)
t_comp_packed = _time(lambda: packed_compiled_run(packed), args.iters, device)
t_fused = _time(lambda: resize_normalize(images, backend="fused", **common), args.iters, device)
t_sep = _time(lambda: resize_normalize(images, backend="separable", **common), args.iters, device)
print("Resize+normalize only (no decode/H2D), ms/iter:")
print(f" torchvision eager loop : {t_eager:8.3f} [per-image float loop]")
print(f" torchvision compiled : {t_comp:8.3f} [torch.compile dynamic per-image; warmup {t_warmup:.0f} ms excluded]")
print(f" torchvision compiled pkt: {t_comp_packed:8.3f} [one graph over padded (N,C,Hmax,Wmax) stack; timing only, padding alters output]")
print(f" fused triton (2D) : {t_fused:8.3f} [taps*taps]")
print(f" separable triton (uint8): {t_sep:8.3f} [taps+taps]")
if proc is not None:
t_pr = _time(lambda: proc(images, return_tensors="pt", device=device)["pixel_values"], args.iters, device)
print(f"\n {args.processor} : {t_pr:8.3f} ms/iter")
print(f" -> separable is {t_sep / t_pr:.2f}x the real processor")
if args.infer:
run_inference(args.model, images, args.block, args.iters, device)
if __name__ == "__main__":
main()