kernel_image_resize / KERNEL_OPS.md
Molbap's picture
Molbap HF Staff
Upload folder using huggingface_hub
e199518 verified
|
Raw
History Blame Contribute Delete
14.1 kB
# kernel_image_resize — how every op works (study notes)
This explains the whole package end to end: the resampling math, the data layout, and
every kernel op, plus the benchmark findings. It is meant for pen & paper — there is a
fully worked numeric example you can reproduce by hand in the "Worked example" section.
The package does one thing: **resize + rescale + normalize**, the op sequence a
`transformers` fast image processor runs (`TorchvisionBackend`: resize, then
`(x*rescale - mean)/std`), as a single GPU pipeline. Input: raw CHW `uint8` images
(any size, ragged). Output: `(N, C, out_h, out_w)` normalized `float32`.
---
## 1. The one idea behind resizing
Resizing does not "pick" pixels; each **output** pixel is a **weighted average of a small
window of input pixels**. Two things define that average:
1. **Where** in the input an output pixel lands (its center).
2. **Which** input pixels are in its window, and with **what weights**.
The weight of an input pixel falls off with distance from the center, following a filter
curve:
- **bilinear** → a triangle, width 1 on each side (so 2 input pixels per axis normally).
- **bicubic** → a cubic bump, width 2 on each side (so 4 input pixels per axis normally).
When you **shrink** an image, you must also **blur first** or you get aliasing. That is
what `antialias=True` does: it widens the window so each output pixel averages more input
pixels (a low-pass filter before throwing pixels away). Widening is proportional to the
shrink factor, so shrinking 3× turns a 4-tap bicubic into ~13 taps.
---
## 2. The resampling-weight formula (the heart of everything)
All kernels use the same formula, which matches PyTorch's aten `UpSampleKernel`
(`align_corners=False`, "half-pixel" convention). For one axis:
```
scale = in_size / out_size # > 1 means shrinking
interp_half = 1 (bilinear) or 2 (bicubic) # half-width of the filter
cubic_a = -0.75 (no antialias) or -0.5 (antialias) # the cubic curve's shape constant
# antialias only widens the window, and only when shrinking:
if antialias and scale > 1:
eff = scale # window widens by the shrink factor
else:
eff = 1 # plain 2-tap / 4-tap interpolation
support = interp_half * eff # half-window width, in INPUT pixels
inv = 1 / eff # squashes the filter curve to match the widened window
# for output index i:
center = scale * (i + 0.5) # input coordinate this output maps to
first_tap = floor(center - support + 0.5) # leftmost input pixel in the window
# for each tap t = 0, 1, 2, ... (up to MAX_TAPS):
tap_pos = first_tap + t # an input pixel index
arg = (tap_pos - center + 0.5) * inv # distance from center, squashed
weight = filter(arg) # triangle or cubic, see below
```
The filter (`_resample_weight` in `_fused.py`), with `x = |arg|`:
```
bilinear: max(1 - x, 0) # triangle, zero past 1
bicubic: x <= 1 : (a+2)x^3 - (a+3)x^2 + 1
1<x<2 : a x^3 - 5a x^2 + 8a x - 4a
else : 0
```
Two edge rules (both kernels do this identically):
- **non-antialias**: clamp the tap index into `[0, in_size-1]` → replicates the border
pixel. The filter weights of a standard 2/4-tap interpolation already sum to 1.
- **antialias**: instead set the weight to **0** for taps that fall off the image
(`tap_pos < 0` or `>= in_size`), then **renormalize** by dividing by the sum of the
realized weights. This keeps the average correct at the edges.
That renormalization is why every kernel computes a `weight_sum` and divides by it. For
the non-antialias case `weight_sum == 1`, so the division is a harmless no-op.
---
## 3. Worked example (do this by hand)
**bilinear, no antialias, one axis, in_size=4, out_size=2.**
```
scale = 4/2 = 2, interp_half = 1, eff = 1, support = 1, inv = 1
```
Output pixel `i = 0`:
```
center = 2 * (0 + 0.5) = 1.0
first_tap = floor(1.0 - 1 + 0.5) = floor(0.5) = 0
t=0: tap_pos=0, arg=(0-1.0+0.5)= -0.5 -> weight = 1-0.5 = 0.5
t=1: tap_pos=1, arg=(1-1.0+0.5)= 0.5 -> weight = 1-0.5 = 0.5
t=2: tap_pos=2, arg=(2-1.0+0.5)= 1.5 -> weight = max(1-1.5,0) = 0
weight_sum = 1.0
output[0] = (0.5*in[0] + 0.5*in[1]) / 1.0 # halfway between in[0] and in[1]
```
Output pixel `i = 1`:
```
center = 2 * 1.5 = 3.0
first_tap = floor(3.0 - 0.5) = 2
t=0: tap_pos=2, arg=-0.5 -> 0.5
t=1: tap_pos=3, arg= 0.5 -> 0.5
t=2: tap_pos=4, arg= 1.5 -> 0 (index 4 would clamp to 3, but weight is 0 anyway)
output[1] = 0.5*in[2] + 0.5*in[3]
```
This 1-D operation is exactly one pass of the separable kernel. The 2-D result is the same
formula applied on both axes (rows and columns).
---
## 4. Data layout (host side, `_pack.py`)
Ragged images (different H×W) cannot be stacked into one tensor, so they are flattened and
concatenated into one buffer, with side tables describing each image.
`pack_images(images, dtype)` →
```
input_pixels : 1-D buffer, all images flattened (C,H,W row-major) and concatenated
offsets[n] : element index where image n starts
heights[n], widths[n] : that image's H and W
channels : C (shared by all images)
```
Address of input pixel `(channel, row, col)` of image `n`:
```
input_pixels[ offsets[n] + channel*(H*W) + row*W + col ]
```
The separable path packs as `uint8` (1 byte/pixel, half the memory traffic of float).
`fold_mean_std(mean, std, rescale)` → folds the rescale factor into the normalization
constants so the kernel does a single `(x - m)/s`:
```
m = mean / rescale s = std / rescale
(x - m)/s == (x*rescale - mean)/std # identical to the processor's fused normalize
```
`max_taps(images, out_size, axis, interp, antialias)` → the **widest** window in the batch
= `ceil(support) * 2 + 1`. A Triton loop bound must be a compile-time constant, so every
program loops this fixed count; taps beyond a given pixel's real window get ~0 weight.
`as_image_list` → accepts a stacked `(N,C,H,W)` tensor or a list, always returns a list.
---
## 5. Fused kernel (`_fused.py`, `backend="fused"`)
One launch. **One program = one image + a BLOCK of its output pixels.** Each output pixel
reads the **full 2-D window** directly: `MAX_TAPS_H × MAX_TAPS_W` input pixels.
```
grid = (num_images, ceil(out_h*out_w / BLOCK))
per lane (one output pixel):
oy, ox = (flat_index // out_w, flat_index % out_w)
center_y, center_x, first_tap_y, first_tap_x # section 2, both axes
# weight_sum factorizes across axes (separable math, even though the LOADS are 2-D):
sum_wy = Σ_ty filter_y ; sum_wx = Σ_tx filter_x ; denom = sum_wy * sum_wx
for channel:
acc = 0
for ty in 0..MAX_TAPS_H: # <-- the 2-D window: TAPS_H * TAPS_W loads
for tx in 0..MAX_TAPS_W:
weight = filter_y(ty) * filter_x(tx)
pixel = input_pixels[channel, clamp(tap_y), clamp(tap_x)]
acc += weight * pixel
acc = acc / denom
out[image, channel, oy, ox] = (acc - mean[channel]) / std[channel]
```
Cost per output pixel: `TAPS_H * TAPS_W` loads (e.g. 13×13 = **169**). Correct and simple,
but the 2-D load count is what makes it slow — hence the separable version.
---
## 6. Separable kernel (`_separable.py`, `backend="separable"`, the default)
Same math, but the 2-D window is done as **two 1-D passes**, with a float intermediate
buffer in between. Loads per output pixel: `TAPS_W + TAPS_H` (e.g. 13+13 = **26**).
```
input_pixels (uint8, C×H×W) --pass1--> intermediate (float, C×H×out_w) --pass2--> output (float, C×out_h×out_w)
resize W resize H + normalize
```
The **intermediate** is the key object: same **height** as the input, but already the
**final width**. ("Tall and narrow.") It is also ragged in height, so it gets its own
offset table (built in `separable_resize_normalize`, same scheme as `pack_images`).
### Pass 1 — `_horizontal_resize_kernel` (resize width only)
```
grid = (num_images, ceil(H*out_w / BLOCK)) # work = every input row × every output col
per lane:
input_row = flat_index // out_w # row index, UNCHANGED by this pass
out_col = flat_index % out_w # output column being computed
center_x, first_tap_x, col_weight_sum # section 2, COLUMN axis only
for channel:
acc = 0
for tap in 0..MAX_TAPS_COL: # <-- 1-D: only TAPS_W loads
weight = filter_x(tap)
pixel = input_pixels[channel, input_row, clamp(tap_col)] # uint8 -> float
acc += weight * pixel
acc = acc / col_weight_sum
intermediate[channel, input_row, out_col] = acc # NO normalize yet
```
Reads original `uint8` bytes; writes `float32`. No normalization here.
### Pass 2 — `_vertical_resize_normalize_kernel` (resize height, then normalize)
```
grid = (num_images, ceil(out_h*out_w / BLOCK)) # work = every output pixel
per lane:
out_row = flat_index // out_w
out_col = flat_index % out_w
center_y, first_tap_y, row_weight_sum # section 2, ROW axis only
for channel:
acc = 0
for tap in 0..MAX_TAPS_ROW: # <-- 1-D: only TAPS_H loads
weight = filter_y(tap)
pixel = intermediate[channel, clamp(tap_row), out_col] # float
acc += weight * pixel
acc = acc / row_weight_sum
out[image, channel, out_row, out_col] = (acc - mean[channel]) / std[channel] # normalize here
```
Two launches (an implicit sync between them), so pass 2 always sees pass 1's finished
output.
### Why separable wins
`TAPS_W + TAPS_H` loads instead of `TAPS_W * TAPS_H`. For a 13×13 window that is 26 vs 169.
This is exactly the algorithm PIL and torchvision use. The catch: an extra full-size float
intermediate buffer (more memory traffic), but the read-count reduction dominates.
Parity note: the intermediate here is **float32**; torchvision keeps a **fixed-point
uint8** intermediate. So the separable output is parity-*close* to torchvision, not
bit-identical — and the float version is actually the more accurate one.
---
## 7. Public API (`__init__.py`)
```
resize_normalize(images, size, image_mean, image_std,
rescale_factor=1/255, resample="bilinear", antialias=False,
backend="separable", block=256)
```
- `images`: stacked `(N,C,H,W)` tensor or a list of CHW tensors.
- `size`: int (square), `(H,W)`, or `{"height","width"}`.
- `resample`: `"bilinear"`/`"bicubic"`, or a PIL resample int (0/2→bilinear, 3→bicubic).
- `backend`: `"separable"` (default, fastest) or `"fused"` (2-D reference).
- `resize_normalize_ragged`: same kernels, list-only.
---
## 8. Benchmark findings (A100, CUDA_VISIBLE_DEVICES=1)
### Standalone resize+normalize — SigLIP-so400m config, N=32 ragged 384–1024², out 384², bicubic+AA
```
torchvision eager loop : 2.91 ms (per-image float loop)
torchvision compiled : 5.70 ms (torch.compile dynamic, per-image; slower than eager)
torchvision compiled pkt: 2.55 ms (one graph over a padded stack; timing only)
fused triton (2D) : 11.49 ms (taps*taps; the slow reference)
separable triton (uint8): 1.29 ms (taps+taps) <-- fastest
real processor : 3.92 ms
```
**Separable is ~3× the real processor**, parity ≤1e-4 vs torchvision-float. The fused 2-D
loses for the algorithmic reason above (169 vs 26 loads). `torch.compile` does not help:
per-image it is *slower* (dispatch overhead over 32 ragged shapes); even as one packed
graph it only matches the eager loop, because inductor's interpolate is no faster than aten
resize.
### End-to-end inference — Siglip2-base-patch16-224, **bf16** forward
```
preprocess forward(fixed input) preprocess+forward
processor 3.99 12.86 14.44
separable 0.93 13.02 13.76 <-- ~5% faster e2e
fused 2.00 13.01 14.79
compiled 6.14 12.89 14.00
feature parity (separable/fused/compiled vs processor): 9.38e-2 = 1.2% of feature max
```
- `forward(fixed input)` is identical (~12.9 ms) for all → **no inference regression**; the
model does not care which preprocessor made the tensor.
- The 1.2% feature drift is the float-vs-uint8 resize difference, identical across all
float backends → not a bug. The float path is the more accurate one.
- End-to-end win is ~5% with a bf16 forward (was ~0.5% with fp32, where the forward was
~80 ms). **The win scales with how preprocessing-bound you are.**
### Data path from JPEG bytes — 552 KB/img
```
CPU decode + torchvision resize : 177.5 ms (status quo)
CPU decode + separable kernel : 176.4 ms (kernel saves ~1 ms; decode dominates)
GPU decode (nvJPEG) + kernel : 14.8 ms (fully on-GPU)
```
- ~175 ms of the 177 ms is **CPU JPEG decode + host→device copy**. Resize/normalize is ~1%.
- The 12× win (177→15) is **GPU decode (nvJPEG)**, i.e. `torchvision.io.decode_jpeg(device="cuda")`
*not* the kernel. The kernel is the resize/normalize component of that GPU pipeline.
---
## 9. What is true / what to claim
- The kernel is **correct** (≤1e-4 vs torchvision-float, more accurate than the processor's
uint8 path) and feeds the model with **no inference regression**.
- It is **~3× the real processor at the resize/normalize stage** — a real, parity-clean win.
- It does **not** speed up preprocessing 12×. Decode dominates the data path; the GPU-decode
lever is nvJPEG, a torchvision feature, not this kernel.
- The kernel matters end-to-end only once you are **not decode-bound**: in a GPU-decode
pipeline it keeps resize/normalize minimal (~10% of that pipeline), and its standalone
preprocess win shows up when the forward is small (bf16, small model, large batch).
- Honest one-liner: *"GPU-native resize+normalize, 3× the fast processor at that stage,
drop-in for a GPU-decode pipeline."*