| """Benchmark resize+normalize: separable / fused triton vs torchvision vs the real processor. |
| |
| PYTHONPATH=../torch-ext python benchmark.py --processor google/siglip-so400m-patch14-384 |
| PYTHONPATH=../torch-ext python benchmark.py --n 32 --out 384 384 --interp bicubic --antialias |
| |
| Prints parity (vs torchvision-float) per backend, then ms/iter for each path. Needs CUDA. |
| """ |
|
|
| import argparse |
| import sys |
| import time |
|
|
| |
| |
| sys.modules["kernels"] = None |
|
|
| import torch |
| import torchvision.transforms.v2.functional as tvF |
| from torchvision.io import ImageReadMode, decode_jpeg, encode_jpeg |
| from torchvision.transforms import InterpolationMode |
|
|
| from kernel_image_resize import resize_normalize |
| from kernel_image_resize._pack import PIL_RESAMPLE_TO_INTERP, max_taps |
|
|
|
|
| _TV_INTERP = {"bilinear": InterpolationMode.BILINEAR, "bicubic": InterpolationMode.BICUBIC} |
|
|
|
|
| def make_ragged_images(n, device, min_res, max_res, seed=0): |
| g = torch.Generator(device="cpu").manual_seed(seed) |
| images = [] |
| for _ in range(n): |
| h = int(torch.randint(min_res, max_res + 1, (1,), generator=g).item()) |
| w = int(torch.randint(min_res, max_res + 1, (1,), generator=g).item()) |
| images.append(torch.randint(0, 256, (3, h, w), generator=g, dtype=torch.uint8).to(device)) |
| return images |
|
|
|
|
| def torchvision_reference(images, out_h, out_w, mean, std, rescale, interp, antialias): |
| mode = _TV_INTERP[interp] |
| mean_t = torch.tensor(mean, device=images[0].device).view(3, 1, 1) |
| std_t = torch.tensor(std, device=images[0].device).view(3, 1, 1) |
| outs = [] |
| for img in images: |
| r = tvF.resize(img.float(), [out_h, out_w], interpolation=mode, antialias=antialias) |
| outs.append((r * rescale - mean_t) / std_t) |
| return torch.stack(outs) |
|
|
|
|
| def build_compiled_reference(out_h, out_w, mean, std, rescale, interp, antialias, device): |
| """torch.compile(dynamic=True) of a per-image float resize+normalize.""" |
| import torch.nn.functional as F |
|
|
| mean_t = torch.tensor(mean, device=device).view(3, 1, 1) |
| std_t = torch.tensor(std, device=device).view(3, 1, 1) |
| mode = "bicubic" if interp == "bicubic" else "bilinear" |
|
|
| def _one(img): |
| r = F.interpolate(img.unsqueeze(0).float(), size=(out_h, out_w), mode=mode, antialias=antialias, align_corners=False) |
| return (r.squeeze(0) * rescale - mean_t) / std_t |
|
|
| compiled = torch.compile(_one, dynamic=True) |
|
|
| def run(images): |
| return torch.stack([compiled(img) for img in images]) |
|
|
| return run |
|
|
|
|
| def pad_stack(images): |
| """Pad ragged CHW images to the batch-max H/W and stack into (N, C, Hmax, Wmax).""" |
| c = images[0].shape[0] |
| max_h = max(img.shape[1] for img in images) |
| max_w = max(img.shape[2] for img in images) |
| out = torch.zeros(len(images), c, max_h, max_w, dtype=images[0].dtype, device=images[0].device) |
| for i, img in enumerate(images): |
| out[i, :, : img.shape[1], : img.shape[2]] = img |
| return out |
|
|
|
|
| def build_packed_compiled_reference(out_h, out_w, mean, std, rescale, interp, antialias, device): |
| """torch.compile of a single batched resize+normalize over a stacked (N, C, H, W) tensor.""" |
| import torch.nn.functional as F |
|
|
| mean_t = torch.tensor(mean, device=device).view(1, 3, 1, 1) |
| std_t = torch.tensor(std, device=device).view(1, 3, 1, 1) |
| mode = "bicubic" if interp == "bicubic" else "bilinear" |
|
|
| def _batch(stacked): |
| r = F.interpolate(stacked.float(), size=(out_h, out_w), mode=mode, antialias=antialias, align_corners=False) |
| return (r * rescale - mean_t) / std_t |
|
|
| return torch.compile(_batch, dynamic=True) |
|
|
|
|
| def run_inference(model_id, images, block, iters, device): |
| """End-to-end: preprocess (processor / separable / fused / compiled) -> vision features (bf16 forward). |
| Checks each kernel feeds the model with no feature drift and times the full pipeline.""" |
| from transformers import AutoModel |
|
|
| proc, (out_h, out_w, mean, std, rescale, interp, antialias) = load_processor_config(model_id) |
| model = AutoModel.from_pretrained(model_id).to(device=device, dtype=torch.bfloat16).eval() |
| vision = model.vision_model |
| kk = dict(size=(out_h, out_w), image_mean=mean, image_std=std, rescale_factor=rescale, |
| resample=interp, antialias=antialias, block=block) |
|
|
| @torch.no_grad() |
| def features(pixel_values): |
| out = vision(pixel_values=pixel_values.to(model.dtype)) |
| pooled = getattr(out, "pooler_output", None) |
| return pooled if pooled is not None else out.last_hidden_state |
|
|
| compiled_one = build_compiled_reference(out_h, out_w, mean, std, rescale, interp, antialias, device) |
| methods = { |
| "processor": lambda: proc(images, return_tensors="pt", device=device)["pixel_values"], |
| "separable": lambda: resize_normalize(images, backend="separable", **kk), |
| "fused": lambda: resize_normalize(images, backend="fused", **kk), |
| "compiled": lambda: compiled_one(images), |
| } |
| methods["compiled"]() |
| methods["compiled"]() |
| torch.cuda.synchronize() |
|
|
| print(f"\n[infer] {model_id} out={out_h}x{out_w} forward dtype=bfloat16") |
| base = features(methods["processor"]()) |
| base_scale = base.abs().max().item() |
| for name in ("separable", "fused", "compiled"): |
| d = (features(methods[name]()) - base).abs().max().item() |
| print(f"[infer parity] features {name} vs processor: max|Δ| = {d:.2e} ({d / base_scale:.1%} of feature max)") |
|
|
| |
| |
| print("[infer] ms/iter: preprocess forward(fixed input) preprocess+forward") |
| for name, preprocess in methods.items(): |
| pixel_values = preprocess() |
| pre = _time(preprocess, iters, device) |
| fwd = _time(lambda pixel_values=pixel_values: features(pixel_values), iters, device) |
| e2e = _time(lambda preprocess=preprocess: features(preprocess()), iters, device) |
| print(f" {name:10s} {pre:8.3f} {fwd:8.3f} {e2e:8.3f}") |
|
|
|
|
| def run_decode(images_cpu, out_h, out_w, mean, std, rescale, interp, antialias, block, iters, device): |
| """Data-path table from JPEG bytes: CPU decode (libjpeg) vs GPU decode (nvJPEG) + the kernel. |
| |
| decoders differ at the pixel level (nvJPEG vs libjpeg), so this measures wall-clock, not parity. |
| """ |
| jpeg = [encode_jpeg(img, quality=95) for img in images_cpu] |
| avg_kb = sum(b.numel() for b in jpeg) / len(jpeg) / 1024 |
| kk = dict(size=(out_h, out_w), image_mean=mean, image_std=std, rescale_factor=rescale, |
| resample=interp, antialias=antialias, block=block) |
|
|
| def cpu_decode_kernel(): |
| imgs = [decode_jpeg(b, mode=ImageReadMode.RGB).to(device) for b in jpeg] |
| return resize_normalize(imgs, backend="separable", **kk) |
|
|
| def gpu_decode_kernel(): |
| imgs = decode_jpeg(jpeg, mode=ImageReadMode.RGB, device=device) |
| return resize_normalize(imgs, backend="separable", **kk) |
|
|
| def gpu_decode_torchvision(): |
| imgs = decode_jpeg(jpeg, mode=ImageReadMode.RGB, device=device) |
| return torchvision_reference(imgs, out_h, out_w, mean, std, rescale, interp, antialias) |
|
|
| def cpu_decode_torchvision(): |
| imgs = [decode_jpeg(b, mode=ImageReadMode.RGB).to(device) for b in jpeg] |
| return torchvision_reference(imgs, out_h, out_w, mean, std, rescale, interp, antialias) |
|
|
| print(f"\n[decode] N={len(jpeg)} avg={avg_kb:.0f} KB/img out={out_h}x{out_w} (from JPEG bytes, ms/iter)") |
| print(f" CPU decode + torchvision resize : {_time(cpu_decode_torchvision, iters, device):8.3f} [status quo data path]") |
| print(f" CPU decode + separable kernel : {_time(cpu_decode_kernel, iters, device):8.3f}") |
| print(f" GPU decode (nvJPEG) + tv resize : {_time(gpu_decode_torchvision, iters, device):8.3f} [GPU pipeline, tv resize]") |
| print(f" GPU decode (nvJPEG) + kernel : {_time(gpu_decode_kernel, iters, device):8.3f} [GPU pipeline, kernel resize]") |
|
|
|
|
| def load_processor_config(name): |
| from transformers import AutoImageProcessor |
|
|
| proc = AutoImageProcessor.from_pretrained(name, backend="torchvision") |
| size = proc.size |
| if "height" not in size or "width" not in size: |
| raise ValueError(f"{name}: size={size} is not a fixed (height, width)") |
| out_h, out_w = size["height"], size["width"] |
| interp = PIL_RESAMPLE_TO_INTERP.get(int(proc.resample)) |
| rescale = float(proc.rescale_factor) if getattr(proc, "do_rescale", True) else 1.0 |
| antialias = bool(getattr(proc, "antialias", True)) |
| return proc, (out_h, out_w, list(proc.image_mean), list(proc.image_std), rescale, interp, antialias) |
|
|
|
|
| def _time(fn, iters, device): |
| for _ in range(3): |
| fn() |
| if device.type == "cuda": |
| torch.cuda.synchronize() |
| start, end = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True) |
| start.record() |
| for _ in range(iters): |
| fn() |
| end.record() |
| torch.cuda.synchronize() |
| return start.elapsed_time(end) / iters |
| t0 = time.perf_counter() |
| for _ in range(iters): |
| fn() |
| return (time.perf_counter() - t0) / iters * 1e3 |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser() |
| parser.add_argument("--processor", default=None) |
| parser.add_argument("--n", type=int, default=32) |
| parser.add_argument("--out", type=int, nargs=2, default=[384, 384], metavar=("H", "W")) |
| parser.add_argument("--interp", choices=["bilinear", "bicubic"], default="bicubic") |
| parser.add_argument("--antialias", action="store_true") |
| parser.add_argument("--min-res", type=int, default=384) |
| parser.add_argument("--max-res", type=int, default=1024) |
| parser.add_argument("--iters", type=int, default=50) |
| parser.add_argument("--block", type=int, default=256) |
| parser.add_argument("--tol", type=float, default=3e-3) |
| parser.add_argument("--infer", action="store_true", help="end-to-end Siglip2 inference comparison (bf16 forward)") |
| parser.add_argument("--model", default="google/siglip2-base-patch16-224", help="model for --infer") |
| parser.add_argument("--decode", action="store_true", help="JPEG-decode data-path table (CPU vs GPU/nvJPEG) and stop") |
| args = parser.parse_args() |
|
|
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| if device.type != "cuda": |
| print("benchmark needs CUDA.") |
| return |
|
|
| proc = None |
| if args.processor: |
| proc, (out_h, out_w, mean, std, rescale, interp, antialias) = load_processor_config(args.processor) |
| print(f"processor={args.processor} -> out={out_h}x{out_w} interp={interp} antialias={antialias}") |
| else: |
| out_h, out_w = args.out |
| mean = [0.48145466, 0.4578275, 0.40821073] |
| std = [0.26862954, 0.26130258, 0.27577711] |
| rescale = 1.0 / 255.0 |
| interp, antialias = args.interp, args.antialias |
|
|
| images = make_ragged_images(args.n, device, args.min_res, args.max_res) |
| taps = (max_taps(images, out_h, 1, interp, antialias), max_taps(images, out_w, 2, interp, antialias)) |
| print(f"N={args.n} in∈[{args.min_res},{args.max_res}]² ragged out={out_h}x{out_w} " |
| f"interp={interp} antialias={antialias} max_taps={taps} iters={args.iters}\n") |
|
|
| if args.decode: |
| images_cpu = make_ragged_images(args.n, torch.device("cpu"), args.min_res, args.max_res) |
| run_decode(images_cpu, out_h, out_w, mean, std, rescale, interp, antialias, args.block, args.iters, device) |
| return |
|
|
| ref = torchvision_reference(images, out_h, out_w, mean, std, rescale, interp, antialias) |
| common = dict(size=(out_h, out_w), image_mean=mean, image_std=std, rescale_factor=rescale, |
| resample=interp, antialias=antialias, block=args.block) |
| for backend in ("fused", "separable"): |
| got = resize_normalize(images, backend=backend, **common) |
| d = (got - ref).abs().max().item() |
| print(f"[parity] {backend:9s} vs torchvision(float): max|Δ| = {d:.2e} " |
| f"({'PASS' if d < args.tol else 'FAIL'} @ tol={args.tol})") |
| print() |
|
|
| compiled_run = build_compiled_reference(out_h, out_w, mean, std, rescale, interp, antialias, device) |
| packed = pad_stack(images) |
| packed_compiled_run = build_packed_compiled_reference(out_h, out_w, mean, std, rescale, interp, antialias, device) |
| t0 = time.perf_counter() |
| compiled_run(images) |
| compiled_run(images) |
| packed_compiled_run(packed) |
| packed_compiled_run(packed) |
| torch.cuda.synchronize() |
| t_warmup = (time.perf_counter() - t0) * 1e3 |
|
|
| t_eager = _time(lambda: torchvision_reference(images, out_h, out_w, mean, std, rescale, interp, antialias), args.iters, device) |
| t_comp = _time(lambda: compiled_run(images), args.iters, device) |
| t_comp_packed = _time(lambda: packed_compiled_run(packed), args.iters, device) |
| t_fused = _time(lambda: resize_normalize(images, backend="fused", **common), args.iters, device) |
| t_sep = _time(lambda: resize_normalize(images, backend="separable", **common), args.iters, device) |
| print("Resize+normalize only (no decode/H2D), ms/iter:") |
| print(f" torchvision eager loop : {t_eager:8.3f} [per-image float loop]") |
| print(f" torchvision compiled : {t_comp:8.3f} [torch.compile dynamic per-image; warmup {t_warmup:.0f} ms excluded]") |
| print(f" torchvision compiled pkt: {t_comp_packed:8.3f} [one graph over padded (N,C,Hmax,Wmax) stack; timing only, padding alters output]") |
| print(f" fused triton (2D) : {t_fused:8.3f} [taps*taps]") |
| print(f" separable triton (uint8): {t_sep:8.3f} [taps+taps]") |
|
|
| if proc is not None: |
| t_pr = _time(lambda: proc(images, return_tensors="pt", device=device)["pixel_values"], args.iters, device) |
| print(f"\n {args.processor} : {t_pr:8.3f} ms/iter") |
| print(f" -> separable is {t_sep / t_pr:.2f}x the real processor") |
|
|
| if args.infer: |
| run_inference(args.model, images, args.block, args.iters, device) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|