"""Benchmark resize+normalize: separable / fused triton vs torchvision vs the real processor. PYTHONPATH=../torch-ext python benchmark.py --processor google/siglip-so400m-patch14-384 PYTHONPATH=../torch-ext python benchmark.py --n 32 --out 384 384 --interp bicubic --antialias Prints parity (vs torchvision-float) per backend, then ms/iter for each path. Needs CUDA. """ import argparse import sys import time # Hide `kernels` from transformers: this worktree builds kernels.LayerRepository without a version, # which newer `kernels` rejects at import. Preprocessing needs no hub layer kernels. sys.modules["kernels"] = None import torch import torchvision.transforms.v2.functional as tvF from torchvision.io import ImageReadMode, decode_jpeg, encode_jpeg from torchvision.transforms import InterpolationMode from kernel_image_resize import resize_normalize from kernel_image_resize._pack import PIL_RESAMPLE_TO_INTERP, max_taps _TV_INTERP = {"bilinear": InterpolationMode.BILINEAR, "bicubic": InterpolationMode.BICUBIC} def make_ragged_images(n, device, min_res, max_res, seed=0): g = torch.Generator(device="cpu").manual_seed(seed) images = [] for _ in range(n): h = int(torch.randint(min_res, max_res + 1, (1,), generator=g).item()) w = int(torch.randint(min_res, max_res + 1, (1,), generator=g).item()) images.append(torch.randint(0, 256, (3, h, w), generator=g, dtype=torch.uint8).to(device)) return images def torchvision_reference(images, out_h, out_w, mean, std, rescale, interp, antialias): mode = _TV_INTERP[interp] mean_t = torch.tensor(mean, device=images[0].device).view(3, 1, 1) std_t = torch.tensor(std, device=images[0].device).view(3, 1, 1) outs = [] for img in images: r = tvF.resize(img.float(), [out_h, out_w], interpolation=mode, antialias=antialias) outs.append((r * rescale - mean_t) / std_t) return torch.stack(outs) def build_compiled_reference(out_h, out_w, mean, std, rescale, interp, antialias, device): """torch.compile(dynamic=True) of a per-image float resize+normalize.""" import torch.nn.functional as F mean_t = torch.tensor(mean, device=device).view(3, 1, 1) std_t = torch.tensor(std, device=device).view(3, 1, 1) mode = "bicubic" if interp == "bicubic" else "bilinear" def _one(img): r = F.interpolate(img.unsqueeze(0).float(), size=(out_h, out_w), mode=mode, antialias=antialias, align_corners=False) return (r.squeeze(0) * rescale - mean_t) / std_t compiled = torch.compile(_one, dynamic=True) def run(images): return torch.stack([compiled(img) for img in images]) return run def pad_stack(images): """Pad ragged CHW images to the batch-max H/W and stack into (N, C, Hmax, Wmax).""" c = images[0].shape[0] max_h = max(img.shape[1] for img in images) max_w = max(img.shape[2] for img in images) out = torch.zeros(len(images), c, max_h, max_w, dtype=images[0].dtype, device=images[0].device) for i, img in enumerate(images): out[i, :, : img.shape[1], : img.shape[2]] = img return out def build_packed_compiled_reference(out_h, out_w, mean, std, rescale, interp, antialias, device): """torch.compile of a single batched resize+normalize over a stacked (N, C, H, W) tensor.""" import torch.nn.functional as F mean_t = torch.tensor(mean, device=device).view(1, 3, 1, 1) std_t = torch.tensor(std, device=device).view(1, 3, 1, 1) mode = "bicubic" if interp == "bicubic" else "bilinear" def _batch(stacked): r = F.interpolate(stacked.float(), size=(out_h, out_w), mode=mode, antialias=antialias, align_corners=False) return (r * rescale - mean_t) / std_t return torch.compile(_batch, dynamic=True) def run_inference(model_id, images, block, iters, device): """End-to-end: preprocess (processor / separable / fused / compiled) -> vision features (bf16 forward). Checks each kernel feeds the model with no feature drift and times the full pipeline.""" from transformers import AutoModel proc, (out_h, out_w, mean, std, rescale, interp, antialias) = load_processor_config(model_id) model = AutoModel.from_pretrained(model_id).to(device=device, dtype=torch.bfloat16).eval() vision = model.vision_model kk = dict(size=(out_h, out_w), image_mean=mean, image_std=std, rescale_factor=rescale, resample=interp, antialias=antialias, block=block) @torch.no_grad() def features(pixel_values): out = vision(pixel_values=pixel_values.to(model.dtype)) pooled = getattr(out, "pooler_output", None) return pooled if pooled is not None else out.last_hidden_state compiled_one = build_compiled_reference(out_h, out_w, mean, std, rescale, interp, antialias, device) methods = { "processor": lambda: proc(images, return_tensors="pt", device=device)["pixel_values"], "separable": lambda: resize_normalize(images, backend="separable", **kk), "fused": lambda: resize_normalize(images, backend="fused", **kk), "compiled": lambda: compiled_one(images), } methods["compiled"]() # warmup the compiled artifact methods["compiled"]() torch.cuda.synchronize() print(f"\n[infer] {model_id} out={out_h}x{out_w} forward dtype=bfloat16") base = features(methods["processor"]()) base_scale = base.abs().max().item() for name in ("separable", "fused", "compiled"): d = (features(methods[name]()) - base).abs().max().item() print(f"[infer parity] features {name} vs processor: max|Δ| = {d:.2e} ({d / base_scale:.1%} of feature max)") # forward is timed on a FIXED precomputed tensor, so it is method-independent by construction; # if it varies across rows, the preprocessor's output (dtype/contiguity) is hurting the model. print("[infer] ms/iter: preprocess forward(fixed input) preprocess+forward") for name, preprocess in methods.items(): pixel_values = preprocess() pre = _time(preprocess, iters, device) fwd = _time(lambda pixel_values=pixel_values: features(pixel_values), iters, device) e2e = _time(lambda preprocess=preprocess: features(preprocess()), iters, device) print(f" {name:10s} {pre:8.3f} {fwd:8.3f} {e2e:8.3f}") def run_decode(images_cpu, out_h, out_w, mean, std, rescale, interp, antialias, block, iters, device): """Data-path table from JPEG bytes: CPU decode (libjpeg) vs GPU decode (nvJPEG) + the kernel. decoders differ at the pixel level (nvJPEG vs libjpeg), so this measures wall-clock, not parity. """ jpeg = [encode_jpeg(img, quality=95) for img in images_cpu] avg_kb = sum(b.numel() for b in jpeg) / len(jpeg) / 1024 kk = dict(size=(out_h, out_w), image_mean=mean, image_std=std, rescale_factor=rescale, resample=interp, antialias=antialias, block=block) def cpu_decode_kernel(): imgs = [decode_jpeg(b, mode=ImageReadMode.RGB).to(device) for b in jpeg] return resize_normalize(imgs, backend="separable", **kk) def gpu_decode_kernel(): imgs = decode_jpeg(jpeg, mode=ImageReadMode.RGB, device=device) return resize_normalize(imgs, backend="separable", **kk) def gpu_decode_torchvision(): imgs = decode_jpeg(jpeg, mode=ImageReadMode.RGB, device=device) return torchvision_reference(imgs, out_h, out_w, mean, std, rescale, interp, antialias) def cpu_decode_torchvision(): imgs = [decode_jpeg(b, mode=ImageReadMode.RGB).to(device) for b in jpeg] return torchvision_reference(imgs, out_h, out_w, mean, std, rescale, interp, antialias) print(f"\n[decode] N={len(jpeg)} avg={avg_kb:.0f} KB/img out={out_h}x{out_w} (from JPEG bytes, ms/iter)") print(f" CPU decode + torchvision resize : {_time(cpu_decode_torchvision, iters, device):8.3f} [status quo data path]") print(f" CPU decode + separable kernel : {_time(cpu_decode_kernel, iters, device):8.3f}") print(f" GPU decode (nvJPEG) + tv resize : {_time(gpu_decode_torchvision, iters, device):8.3f} [GPU pipeline, tv resize]") print(f" GPU decode (nvJPEG) + kernel : {_time(gpu_decode_kernel, iters, device):8.3f} [GPU pipeline, kernel resize]") def load_processor_config(name): from transformers import AutoImageProcessor proc = AutoImageProcessor.from_pretrained(name, backend="torchvision") size = proc.size if "height" not in size or "width" not in size: raise ValueError(f"{name}: size={size} is not a fixed (height, width)") out_h, out_w = size["height"], size["width"] interp = PIL_RESAMPLE_TO_INTERP.get(int(proc.resample)) rescale = float(proc.rescale_factor) if getattr(proc, "do_rescale", True) else 1.0 antialias = bool(getattr(proc, "antialias", True)) return proc, (out_h, out_w, list(proc.image_mean), list(proc.image_std), rescale, interp, antialias) def _time(fn, iters, device): for _ in range(3): fn() if device.type == "cuda": torch.cuda.synchronize() start, end = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True) start.record() for _ in range(iters): fn() end.record() torch.cuda.synchronize() return start.elapsed_time(end) / iters t0 = time.perf_counter() for _ in range(iters): fn() return (time.perf_counter() - t0) / iters * 1e3 def main(): parser = argparse.ArgumentParser() parser.add_argument("--processor", default=None) parser.add_argument("--n", type=int, default=32) parser.add_argument("--out", type=int, nargs=2, default=[384, 384], metavar=("H", "W")) parser.add_argument("--interp", choices=["bilinear", "bicubic"], default="bicubic") parser.add_argument("--antialias", action="store_true") parser.add_argument("--min-res", type=int, default=384) parser.add_argument("--max-res", type=int, default=1024) parser.add_argument("--iters", type=int, default=50) parser.add_argument("--block", type=int, default=256) parser.add_argument("--tol", type=float, default=3e-3) parser.add_argument("--infer", action="store_true", help="end-to-end Siglip2 inference comparison (bf16 forward)") parser.add_argument("--model", default="google/siglip2-base-patch16-224", help="model for --infer") parser.add_argument("--decode", action="store_true", help="JPEG-decode data-path table (CPU vs GPU/nvJPEG) and stop") args = parser.parse_args() device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if device.type != "cuda": print("benchmark needs CUDA.") return proc = None if args.processor: proc, (out_h, out_w, mean, std, rescale, interp, antialias) = load_processor_config(args.processor) print(f"processor={args.processor} -> out={out_h}x{out_w} interp={interp} antialias={antialias}") else: out_h, out_w = args.out mean = [0.48145466, 0.4578275, 0.40821073] std = [0.26862954, 0.26130258, 0.27577711] rescale = 1.0 / 255.0 interp, antialias = args.interp, args.antialias images = make_ragged_images(args.n, device, args.min_res, args.max_res) taps = (max_taps(images, out_h, 1, interp, antialias), max_taps(images, out_w, 2, interp, antialias)) print(f"N={args.n} in∈[{args.min_res},{args.max_res}]² ragged out={out_h}x{out_w} " f"interp={interp} antialias={antialias} max_taps={taps} iters={args.iters}\n") if args.decode: images_cpu = make_ragged_images(args.n, torch.device("cpu"), args.min_res, args.max_res) run_decode(images_cpu, out_h, out_w, mean, std, rescale, interp, antialias, args.block, args.iters, device) return ref = torchvision_reference(images, out_h, out_w, mean, std, rescale, interp, antialias) common = dict(size=(out_h, out_w), image_mean=mean, image_std=std, rescale_factor=rescale, resample=interp, antialias=antialias, block=args.block) for backend in ("fused", "separable"): got = resize_normalize(images, backend=backend, **common) d = (got - ref).abs().max().item() print(f"[parity] {backend:9s} vs torchvision(float): max|Δ| = {d:.2e} " f"({'PASS' if d < args.tol else 'FAIL'} @ tol={args.tol})") print() compiled_run = build_compiled_reference(out_h, out_w, mean, std, rescale, interp, antialias, device) packed = pad_stack(images) packed_compiled_run = build_packed_compiled_reference(out_h, out_w, mean, std, rescale, interp, antialias, device) t0 = time.perf_counter() compiled_run(images) compiled_run(images) packed_compiled_run(packed) packed_compiled_run(packed) torch.cuda.synchronize() t_warmup = (time.perf_counter() - t0) * 1e3 t_eager = _time(lambda: torchvision_reference(images, out_h, out_w, mean, std, rescale, interp, antialias), args.iters, device) t_comp = _time(lambda: compiled_run(images), args.iters, device) t_comp_packed = _time(lambda: packed_compiled_run(packed), args.iters, device) t_fused = _time(lambda: resize_normalize(images, backend="fused", **common), args.iters, device) t_sep = _time(lambda: resize_normalize(images, backend="separable", **common), args.iters, device) print("Resize+normalize only (no decode/H2D), ms/iter:") print(f" torchvision eager loop : {t_eager:8.3f} [per-image float loop]") print(f" torchvision compiled : {t_comp:8.3f} [torch.compile dynamic per-image; warmup {t_warmup:.0f} ms excluded]") print(f" torchvision compiled pkt: {t_comp_packed:8.3f} [one graph over padded (N,C,Hmax,Wmax) stack; timing only, padding alters output]") print(f" fused triton (2D) : {t_fused:8.3f} [taps*taps]") print(f" separable triton (uint8): {t_sep:8.3f} [taps+taps]") if proc is not None: t_pr = _time(lambda: proc(images, return_tensors="pt", device=device)["pixel_values"], args.iters, device) print(f"\n {args.processor} : {t_pr:8.3f} ms/iter") print(f" -> separable is {t_sep / t_pr:.2f}x the real processor") if args.infer: run_inference(args.model, images, args.block, args.iters, device) if __name__ == "__main__": main()