kernel_image_resize / KERNEL_OPS.md
Molbap's picture
Molbap HF Staff
Upload folder using huggingface_hub
e199518 verified
|
Raw
History Blame Contribute Delete
14.1 kB

kernel_image_resize — how every op works (study notes)

This explains the whole package end to end: the resampling math, the data layout, and every kernel op, plus the benchmark findings. It is meant for pen & paper — there is a fully worked numeric example you can reproduce by hand in the "Worked example" section.

The package does one thing: resize + rescale + normalize, the op sequence a transformers fast image processor runs (TorchvisionBackend: resize, then (x*rescale - mean)/std), as a single GPU pipeline. Input: raw CHW uint8 images (any size, ragged). Output: (N, C, out_h, out_w) normalized float32.


1. The one idea behind resizing

Resizing does not "pick" pixels; each output pixel is a weighted average of a small window of input pixels. Two things define that average:

  1. Where in the input an output pixel lands (its center).
  2. Which input pixels are in its window, and with what weights.

The weight of an input pixel falls off with distance from the center, following a filter curve:

  • bilinear → a triangle, width 1 on each side (so 2 input pixels per axis normally).
  • bicubic → a cubic bump, width 2 on each side (so 4 input pixels per axis normally).

When you shrink an image, you must also blur first or you get aliasing. That is what antialias=True does: it widens the window so each output pixel averages more input pixels (a low-pass filter before throwing pixels away). Widening is proportional to the shrink factor, so shrinking 3× turns a 4-tap bicubic into ~13 taps.


2. The resampling-weight formula (the heart of everything)

All kernels use the same formula, which matches PyTorch's aten UpSampleKernel (align_corners=False, "half-pixel" convention). For one axis:

scale       = in_size / out_size                       # > 1 means shrinking
interp_half = 1 (bilinear)  or  2 (bicubic)            # half-width of the filter
cubic_a     = -0.75 (no antialias)  or  -0.5 (antialias)   # the cubic curve's shape constant

# antialias only widens the window, and only when shrinking:
if antialias and scale > 1:
    eff = scale                # window widens by the shrink factor
else:
    eff = 1                    # plain 2-tap / 4-tap interpolation
support  = interp_half * eff   # half-window width, in INPUT pixels
inv      = 1 / eff             # squashes the filter curve to match the widened window

# for output index i:
center     = scale * (i + 0.5)                 # input coordinate this output maps to
first_tap  = floor(center - support + 0.5)     # leftmost input pixel in the window

# for each tap t = 0, 1, 2, ... (up to MAX_TAPS):
tap_pos    = first_tap + t                     # an input pixel index
arg        = (tap_pos - center + 0.5) * inv    # distance from center, squashed
weight     = filter(arg)                       # triangle or cubic, see below

The filter (_resample_weight in _fused.py), with x = |arg|:

bilinear:   max(1 - x, 0)                                            # triangle, zero past 1

bicubic:    x <= 1 :  (a+2)x^3 - (a+3)x^2 + 1
            1<x<2 :   a x^3 - 5a x^2 + 8a x - 4a
            else  :   0

Two edge rules (both kernels do this identically):

  • non-antialias: clamp the tap index into [0, in_size-1] → replicates the border pixel. The filter weights of a standard 2/4-tap interpolation already sum to 1.
  • antialias: instead set the weight to 0 for taps that fall off the image (tap_pos < 0 or >= in_size), then renormalize by dividing by the sum of the realized weights. This keeps the average correct at the edges.

That renormalization is why every kernel computes a weight_sum and divides by it. For the non-antialias case weight_sum == 1, so the division is a harmless no-op.


3. Worked example (do this by hand)

bilinear, no antialias, one axis, in_size=4, out_size=2.

scale = 4/2 = 2,  interp_half = 1,  eff = 1,  support = 1,  inv = 1

Output pixel i = 0:

center    = 2 * (0 + 0.5) = 1.0
first_tap = floor(1.0 - 1 + 0.5) = floor(0.5) = 0
t=0: tap_pos=0, arg=(0-1.0+0.5)= -0.5 -> weight = 1-0.5 = 0.5
t=1: tap_pos=1, arg=(1-1.0+0.5)=  0.5 -> weight = 1-0.5 = 0.5
t=2: tap_pos=2, arg=(2-1.0+0.5)=  1.5 -> weight = max(1-1.5,0) = 0
weight_sum = 1.0
output[0] = (0.5*in[0] + 0.5*in[1]) / 1.0       # halfway between in[0] and in[1]

Output pixel i = 1:

center    = 2 * 1.5 = 3.0
first_tap = floor(3.0 - 0.5) = 2
t=0: tap_pos=2, arg=-0.5 -> 0.5
t=1: tap_pos=3, arg= 0.5 -> 0.5
t=2: tap_pos=4, arg= 1.5 -> 0   (index 4 would clamp to 3, but weight is 0 anyway)
output[1] = 0.5*in[2] + 0.5*in[3]

This 1-D operation is exactly one pass of the separable kernel. The 2-D result is the same formula applied on both axes (rows and columns).


4. Data layout (host side, _pack.py)

Ragged images (different H×W) cannot be stacked into one tensor, so they are flattened and concatenated into one buffer, with side tables describing each image.

pack_images(images, dtype)

input_pixels : 1-D buffer, all images flattened (C,H,W row-major) and concatenated
offsets[n]   : element index where image n starts
heights[n], widths[n] : that image's H and W
channels     : C (shared by all images)

Address of input pixel (channel, row, col) of image n:

input_pixels[ offsets[n] + channel*(H*W) + row*W + col ]

The separable path packs as uint8 (1 byte/pixel, half the memory traffic of float).

fold_mean_std(mean, std, rescale) → folds the rescale factor into the normalization constants so the kernel does a single (x - m)/s:

m = mean / rescale       s = std / rescale
(x - m)/s  ==  (x*rescale - mean)/std      # identical to the processor's fused normalize

max_taps(images, out_size, axis, interp, antialias) → the widest window in the batch = ceil(support) * 2 + 1. A Triton loop bound must be a compile-time constant, so every program loops this fixed count; taps beyond a given pixel's real window get ~0 weight.

as_image_list → accepts a stacked (N,C,H,W) tensor or a list, always returns a list.


5. Fused kernel (_fused.py, backend="fused")

One launch. One program = one image + a BLOCK of its output pixels. Each output pixel reads the full 2-D window directly: MAX_TAPS_H × MAX_TAPS_W input pixels.

grid = (num_images, ceil(out_h*out_w / BLOCK))

per lane (one output pixel):
  oy, ox            = (flat_index // out_w, flat_index % out_w)
  center_y, center_x, first_tap_y, first_tap_x       # section 2, both axes

  # weight_sum factorizes across axes (separable math, even though the LOADS are 2-D):
  sum_wy = Σ_ty filter_y      ;  sum_wx = Σ_tx filter_x      ;  denom = sum_wy * sum_wx

  for channel:
    acc = 0
    for ty in 0..MAX_TAPS_H:                 #  <-- the 2-D window: TAPS_H * TAPS_W loads
      for tx in 0..MAX_TAPS_W:
        weight = filter_y(ty) * filter_x(tx)
        pixel  = input_pixels[channel, clamp(tap_y), clamp(tap_x)]
        acc   += weight * pixel
    acc = acc / denom
    out[image, channel, oy, ox] = (acc - mean[channel]) / std[channel]

Cost per output pixel: TAPS_H * TAPS_W loads (e.g. 13×13 = 169). Correct and simple, but the 2-D load count is what makes it slow — hence the separable version.


6. Separable kernel (_separable.py, backend="separable", the default)

Same math, but the 2-D window is done as two 1-D passes, with a float intermediate buffer in between. Loads per output pixel: TAPS_W + TAPS_H (e.g. 13+13 = 26).

input_pixels (uint8, C×H×W)  --pass1-->  intermediate (float, C×H×out_w)  --pass2-->  output (float, C×out_h×out_w)
                              resize W                               resize H + normalize

The intermediate is the key object: same height as the input, but already the final width. ("Tall and narrow.") It is also ragged in height, so it gets its own offset table (built in separable_resize_normalize, same scheme as pack_images).

Pass 1 — _horizontal_resize_kernel (resize width only)

grid = (num_images, ceil(H*out_w / BLOCK))     # work = every input row × every output col
per lane:
  input_row = flat_index // out_w     # row index, UNCHANGED by this pass
  out_col   = flat_index %  out_w     # output column being computed

  center_x, first_tap_x, col_weight_sum     # section 2, COLUMN axis only

  for channel:
    acc = 0
    for tap in 0..MAX_TAPS_COL:                          #  <-- 1-D: only TAPS_W loads
      weight = filter_x(tap)
      pixel  = input_pixels[channel, input_row, clamp(tap_col)]   # uint8 -> float
      acc   += weight * pixel
    acc = acc / col_weight_sum
    intermediate[channel, input_row, out_col] = acc      # NO normalize yet

Reads original uint8 bytes; writes float32. No normalization here.

Pass 2 — _vertical_resize_normalize_kernel (resize height, then normalize)

grid = (num_images, ceil(out_h*out_w / BLOCK))     # work = every output pixel
per lane:
  out_row = flat_index // out_w
  out_col = flat_index %  out_w

  center_y, first_tap_y, row_weight_sum     # section 2, ROW axis only

  for channel:
    acc = 0
    for tap in 0..MAX_TAPS_ROW:                          #  <-- 1-D: only TAPS_H loads
      weight = filter_y(tap)
      pixel  = intermediate[channel, clamp(tap_row), out_col]      # float
      acc   += weight * pixel
    acc = acc / row_weight_sum
    out[image, channel, out_row, out_col] = (acc - mean[channel]) / std[channel]   # normalize here

Two launches (an implicit sync between them), so pass 2 always sees pass 1's finished output.

Why separable wins

TAPS_W + TAPS_H loads instead of TAPS_W * TAPS_H. For a 13×13 window that is 26 vs 169. This is exactly the algorithm PIL and torchvision use. The catch: an extra full-size float intermediate buffer (more memory traffic), but the read-count reduction dominates.

Parity note: the intermediate here is float32; torchvision keeps a fixed-point uint8 intermediate. So the separable output is parity-close to torchvision, not bit-identical — and the float version is actually the more accurate one.


7. Public API (__init__.py)

resize_normalize(images, size, image_mean, image_std,
                 rescale_factor=1/255, resample="bilinear", antialias=False,
                 backend="separable", block=256)
  • images: stacked (N,C,H,W) tensor or a list of CHW tensors.
  • size: int (square), (H,W), or {"height","width"}.
  • resample: "bilinear"/"bicubic", or a PIL resample int (0/2→bilinear, 3→bicubic).
  • backend: "separable" (default, fastest) or "fused" (2-D reference).
  • resize_normalize_ragged: same kernels, list-only.

8. Benchmark findings (A100, CUDA_VISIBLE_DEVICES=1)

Standalone resize+normalize — SigLIP-so400m config, N=32 ragged 384–1024², out 384², bicubic+AA

torchvision eager loop  :   2.91 ms   (per-image float loop)
torchvision compiled    :   5.70 ms   (torch.compile dynamic, per-image; slower than eager)
torchvision compiled pkt:   2.55 ms   (one graph over a padded stack; timing only)
fused triton (2D)       :  11.49 ms   (taps*taps; the slow reference)
separable triton (uint8):   1.29 ms   (taps+taps)   <-- fastest
real processor          :   3.92 ms

Separable is ~3× the real processor, parity ≤1e-4 vs torchvision-float. The fused 2-D loses for the algorithmic reason above (169 vs 26 loads). torch.compile does not help: per-image it is slower (dispatch overhead over 32 ragged shapes); even as one packed graph it only matches the eager loop, because inductor's interpolate is no faster than aten resize.

End-to-end inference — Siglip2-base-patch16-224, bf16 forward

                preprocess   forward(fixed input)   preprocess+forward
processor          3.99           12.86                  14.44
separable          0.93           13.02                  13.76     <-- ~5% faster e2e
fused              2.00           13.01                  14.79
compiled           6.14           12.89                  14.00
feature parity (separable/fused/compiled vs processor): 9.38e-2 = 1.2% of feature max
  • forward(fixed input) is identical (~12.9 ms) for all → no inference regression; the model does not care which preprocessor made the tensor.
  • The 1.2% feature drift is the float-vs-uint8 resize difference, identical across all float backends → not a bug. The float path is the more accurate one.
  • End-to-end win is ~5% with a bf16 forward (was ~0.5% with fp32, where the forward was ~80 ms). The win scales with how preprocessing-bound you are.

Data path from JPEG bytes — 552 KB/img

CPU decode + torchvision resize :  177.5 ms   (status quo)
CPU decode + separable kernel   :  176.4 ms   (kernel saves ~1 ms; decode dominates)
GPU decode (nvJPEG) + kernel    :   14.8 ms   (fully on-GPU)
  • ~175 ms of the 177 ms is CPU JPEG decode + host→device copy. Resize/normalize is ~1%.
  • The 12× win (177→15) is GPU decode (nvJPEG), i.e. torchvision.io.decode_jpeg(device="cuda")not the kernel. The kernel is the resize/normalize component of that GPU pipeline.

9. What is true / what to claim

  • The kernel is correct (≤1e-4 vs torchvision-float, more accurate than the processor's uint8 path) and feeds the model with no inference regression.
  • It is ~3× the real processor at the resize/normalize stage — a real, parity-clean win.
  • It does not speed up preprocessing 12×. Decode dominates the data path; the GPU-decode lever is nvJPEG, a torchvision feature, not this kernel.
  • The kernel matters end-to-end only once you are not decode-bound: in a GPU-decode pipeline it keeps resize/normalize minimal (~10% of that pipeline), and its standalone preprocess win shows up when the forward is small (bf16, small model, large batch).
  • Honest one-liner: "GPU-native resize+normalize, 3× the fast processor at that stage, drop-in for a GPU-decode pipeline."