File size: 14,130 Bytes
e199518 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 | # kernel_image_resize — how every op works (study notes)
This explains the whole package end to end: the resampling math, the data layout, and
every kernel op, plus the benchmark findings. It is meant for pen & paper — there is a
fully worked numeric example you can reproduce by hand in the "Worked example" section.
The package does one thing: **resize + rescale + normalize**, the op sequence a
`transformers` fast image processor runs (`TorchvisionBackend`: resize, then
`(x*rescale - mean)/std`), as a single GPU pipeline. Input: raw CHW `uint8` images
(any size, ragged). Output: `(N, C, out_h, out_w)` normalized `float32`.
---
## 1. The one idea behind resizing
Resizing does not "pick" pixels; each **output** pixel is a **weighted average of a small
window of input pixels**. Two things define that average:
1. **Where** in the input an output pixel lands (its center).
2. **Which** input pixels are in its window, and with **what weights**.
The weight of an input pixel falls off with distance from the center, following a filter
curve:
- **bilinear** → a triangle, width 1 on each side (so 2 input pixels per axis normally).
- **bicubic** → a cubic bump, width 2 on each side (so 4 input pixels per axis normally).
When you **shrink** an image, you must also **blur first** or you get aliasing. That is
what `antialias=True` does: it widens the window so each output pixel averages more input
pixels (a low-pass filter before throwing pixels away). Widening is proportional to the
shrink factor, so shrinking 3× turns a 4-tap bicubic into ~13 taps.
---
## 2. The resampling-weight formula (the heart of everything)
All kernels use the same formula, which matches PyTorch's aten `UpSampleKernel`
(`align_corners=False`, "half-pixel" convention). For one axis:
```
scale = in_size / out_size # > 1 means shrinking
interp_half = 1 (bilinear) or 2 (bicubic) # half-width of the filter
cubic_a = -0.75 (no antialias) or -0.5 (antialias) # the cubic curve's shape constant
# antialias only widens the window, and only when shrinking:
if antialias and scale > 1:
eff = scale # window widens by the shrink factor
else:
eff = 1 # plain 2-tap / 4-tap interpolation
support = interp_half * eff # half-window width, in INPUT pixels
inv = 1 / eff # squashes the filter curve to match the widened window
# for output index i:
center = scale * (i + 0.5) # input coordinate this output maps to
first_tap = floor(center - support + 0.5) # leftmost input pixel in the window
# for each tap t = 0, 1, 2, ... (up to MAX_TAPS):
tap_pos = first_tap + t # an input pixel index
arg = (tap_pos - center + 0.5) * inv # distance from center, squashed
weight = filter(arg) # triangle or cubic, see below
```
The filter (`_resample_weight` in `_fused.py`), with `x = |arg|`:
```
bilinear: max(1 - x, 0) # triangle, zero past 1
bicubic: x <= 1 : (a+2)x^3 - (a+3)x^2 + 1
1<x<2 : a x^3 - 5a x^2 + 8a x - 4a
else : 0
```
Two edge rules (both kernels do this identically):
- **non-antialias**: clamp the tap index into `[0, in_size-1]` → replicates the border
pixel. The filter weights of a standard 2/4-tap interpolation already sum to 1.
- **antialias**: instead set the weight to **0** for taps that fall off the image
(`tap_pos < 0` or `>= in_size`), then **renormalize** by dividing by the sum of the
realized weights. This keeps the average correct at the edges.
That renormalization is why every kernel computes a `weight_sum` and divides by it. For
the non-antialias case `weight_sum == 1`, so the division is a harmless no-op.
---
## 3. Worked example (do this by hand)
**bilinear, no antialias, one axis, in_size=4, out_size=2.**
```
scale = 4/2 = 2, interp_half = 1, eff = 1, support = 1, inv = 1
```
Output pixel `i = 0`:
```
center = 2 * (0 + 0.5) = 1.0
first_tap = floor(1.0 - 1 + 0.5) = floor(0.5) = 0
t=0: tap_pos=0, arg=(0-1.0+0.5)= -0.5 -> weight = 1-0.5 = 0.5
t=1: tap_pos=1, arg=(1-1.0+0.5)= 0.5 -> weight = 1-0.5 = 0.5
t=2: tap_pos=2, arg=(2-1.0+0.5)= 1.5 -> weight = max(1-1.5,0) = 0
weight_sum = 1.0
output[0] = (0.5*in[0] + 0.5*in[1]) / 1.0 # halfway between in[0] and in[1]
```
Output pixel `i = 1`:
```
center = 2 * 1.5 = 3.0
first_tap = floor(3.0 - 0.5) = 2
t=0: tap_pos=2, arg=-0.5 -> 0.5
t=1: tap_pos=3, arg= 0.5 -> 0.5
t=2: tap_pos=4, arg= 1.5 -> 0 (index 4 would clamp to 3, but weight is 0 anyway)
output[1] = 0.5*in[2] + 0.5*in[3]
```
This 1-D operation is exactly one pass of the separable kernel. The 2-D result is the same
formula applied on both axes (rows and columns).
---
## 4. Data layout (host side, `_pack.py`)
Ragged images (different H×W) cannot be stacked into one tensor, so they are flattened and
concatenated into one buffer, with side tables describing each image.
`pack_images(images, dtype)` →
```
input_pixels : 1-D buffer, all images flattened (C,H,W row-major) and concatenated
offsets[n] : element index where image n starts
heights[n], widths[n] : that image's H and W
channels : C (shared by all images)
```
Address of input pixel `(channel, row, col)` of image `n`:
```
input_pixels[ offsets[n] + channel*(H*W) + row*W + col ]
```
The separable path packs as `uint8` (1 byte/pixel, half the memory traffic of float).
`fold_mean_std(mean, std, rescale)` → folds the rescale factor into the normalization
constants so the kernel does a single `(x - m)/s`:
```
m = mean / rescale s = std / rescale
(x - m)/s == (x*rescale - mean)/std # identical to the processor's fused normalize
```
`max_taps(images, out_size, axis, interp, antialias)` → the **widest** window in the batch
= `ceil(support) * 2 + 1`. A Triton loop bound must be a compile-time constant, so every
program loops this fixed count; taps beyond a given pixel's real window get ~0 weight.
`as_image_list` → accepts a stacked `(N,C,H,W)` tensor or a list, always returns a list.
---
## 5. Fused kernel (`_fused.py`, `backend="fused"`)
One launch. **One program = one image + a BLOCK of its output pixels.** Each output pixel
reads the **full 2-D window** directly: `MAX_TAPS_H × MAX_TAPS_W` input pixels.
```
grid = (num_images, ceil(out_h*out_w / BLOCK))
per lane (one output pixel):
oy, ox = (flat_index // out_w, flat_index % out_w)
center_y, center_x, first_tap_y, first_tap_x # section 2, both axes
# weight_sum factorizes across axes (separable math, even though the LOADS are 2-D):
sum_wy = Σ_ty filter_y ; sum_wx = Σ_tx filter_x ; denom = sum_wy * sum_wx
for channel:
acc = 0
for ty in 0..MAX_TAPS_H: # <-- the 2-D window: TAPS_H * TAPS_W loads
for tx in 0..MAX_TAPS_W:
weight = filter_y(ty) * filter_x(tx)
pixel = input_pixels[channel, clamp(tap_y), clamp(tap_x)]
acc += weight * pixel
acc = acc / denom
out[image, channel, oy, ox] = (acc - mean[channel]) / std[channel]
```
Cost per output pixel: `TAPS_H * TAPS_W` loads (e.g. 13×13 = **169**). Correct and simple,
but the 2-D load count is what makes it slow — hence the separable version.
---
## 6. Separable kernel (`_separable.py`, `backend="separable"`, the default)
Same math, but the 2-D window is done as **two 1-D passes**, with a float intermediate
buffer in between. Loads per output pixel: `TAPS_W + TAPS_H` (e.g. 13+13 = **26**).
```
input_pixels (uint8, C×H×W) --pass1--> intermediate (float, C×H×out_w) --pass2--> output (float, C×out_h×out_w)
resize W resize H + normalize
```
The **intermediate** is the key object: same **height** as the input, but already the
**final width**. ("Tall and narrow.") It is also ragged in height, so it gets its own
offset table (built in `separable_resize_normalize`, same scheme as `pack_images`).
### Pass 1 — `_horizontal_resize_kernel` (resize width only)
```
grid = (num_images, ceil(H*out_w / BLOCK)) # work = every input row × every output col
per lane:
input_row = flat_index // out_w # row index, UNCHANGED by this pass
out_col = flat_index % out_w # output column being computed
center_x, first_tap_x, col_weight_sum # section 2, COLUMN axis only
for channel:
acc = 0
for tap in 0..MAX_TAPS_COL: # <-- 1-D: only TAPS_W loads
weight = filter_x(tap)
pixel = input_pixels[channel, input_row, clamp(tap_col)] # uint8 -> float
acc += weight * pixel
acc = acc / col_weight_sum
intermediate[channel, input_row, out_col] = acc # NO normalize yet
```
Reads original `uint8` bytes; writes `float32`. No normalization here.
### Pass 2 — `_vertical_resize_normalize_kernel` (resize height, then normalize)
```
grid = (num_images, ceil(out_h*out_w / BLOCK)) # work = every output pixel
per lane:
out_row = flat_index // out_w
out_col = flat_index % out_w
center_y, first_tap_y, row_weight_sum # section 2, ROW axis only
for channel:
acc = 0
for tap in 0..MAX_TAPS_ROW: # <-- 1-D: only TAPS_H loads
weight = filter_y(tap)
pixel = intermediate[channel, clamp(tap_row), out_col] # float
acc += weight * pixel
acc = acc / row_weight_sum
out[image, channel, out_row, out_col] = (acc - mean[channel]) / std[channel] # normalize here
```
Two launches (an implicit sync between them), so pass 2 always sees pass 1's finished
output.
### Why separable wins
`TAPS_W + TAPS_H` loads instead of `TAPS_W * TAPS_H`. For a 13×13 window that is 26 vs 169.
This is exactly the algorithm PIL and torchvision use. The catch: an extra full-size float
intermediate buffer (more memory traffic), but the read-count reduction dominates.
Parity note: the intermediate here is **float32**; torchvision keeps a **fixed-point
uint8** intermediate. So the separable output is parity-*close* to torchvision, not
bit-identical — and the float version is actually the more accurate one.
---
## 7. Public API (`__init__.py`)
```
resize_normalize(images, size, image_mean, image_std,
rescale_factor=1/255, resample="bilinear", antialias=False,
backend="separable", block=256)
```
- `images`: stacked `(N,C,H,W)` tensor or a list of CHW tensors.
- `size`: int (square), `(H,W)`, or `{"height","width"}`.
- `resample`: `"bilinear"`/`"bicubic"`, or a PIL resample int (0/2→bilinear, 3→bicubic).
- `backend`: `"separable"` (default, fastest) or `"fused"` (2-D reference).
- `resize_normalize_ragged`: same kernels, list-only.
---
## 8. Benchmark findings (A100, CUDA_VISIBLE_DEVICES=1)
### Standalone resize+normalize — SigLIP-so400m config, N=32 ragged 384–1024², out 384², bicubic+AA
```
torchvision eager loop : 2.91 ms (per-image float loop)
torchvision compiled : 5.70 ms (torch.compile dynamic, per-image; slower than eager)
torchvision compiled pkt: 2.55 ms (one graph over a padded stack; timing only)
fused triton (2D) : 11.49 ms (taps*taps; the slow reference)
separable triton (uint8): 1.29 ms (taps+taps) <-- fastest
real processor : 3.92 ms
```
**Separable is ~3× the real processor**, parity ≤1e-4 vs torchvision-float. The fused 2-D
loses for the algorithmic reason above (169 vs 26 loads). `torch.compile` does not help:
per-image it is *slower* (dispatch overhead over 32 ragged shapes); even as one packed
graph it only matches the eager loop, because inductor's interpolate is no faster than aten
resize.
### End-to-end inference — Siglip2-base-patch16-224, **bf16** forward
```
preprocess forward(fixed input) preprocess+forward
processor 3.99 12.86 14.44
separable 0.93 13.02 13.76 <-- ~5% faster e2e
fused 2.00 13.01 14.79
compiled 6.14 12.89 14.00
feature parity (separable/fused/compiled vs processor): 9.38e-2 = 1.2% of feature max
```
- `forward(fixed input)` is identical (~12.9 ms) for all → **no inference regression**; the
model does not care which preprocessor made the tensor.
- The 1.2% feature drift is the float-vs-uint8 resize difference, identical across all
float backends → not a bug. The float path is the more accurate one.
- End-to-end win is ~5% with a bf16 forward (was ~0.5% with fp32, where the forward was
~80 ms). **The win scales with how preprocessing-bound you are.**
### Data path from JPEG bytes — 552 KB/img
```
CPU decode + torchvision resize : 177.5 ms (status quo)
CPU decode + separable kernel : 176.4 ms (kernel saves ~1 ms; decode dominates)
GPU decode (nvJPEG) + kernel : 14.8 ms (fully on-GPU)
```
- ~175 ms of the 177 ms is **CPU JPEG decode + host→device copy**. Resize/normalize is ~1%.
- The 12× win (177→15) is **GPU decode (nvJPEG)**, i.e. `torchvision.io.decode_jpeg(device="cuda")`
— *not* the kernel. The kernel is the resize/normalize component of that GPU pipeline.
---
## 9. What is true / what to claim
- The kernel is **correct** (≤1e-4 vs torchvision-float, more accurate than the processor's
uint8 path) and feeds the model with **no inference regression**.
- It is **~3× the real processor at the resize/normalize stage** — a real, parity-clean win.
- It does **not** speed up preprocessing 12×. Decode dominates the data path; the GPU-decode
lever is nvJPEG, a torchvision feature, not this kernel.
- The kernel matters end-to-end only once you are **not decode-bound**: in a GPU-decode
pipeline it keeps resize/normalize minimal (~10% of that pipeline), and its standalone
preprocess win shows up when the forward is small (bf16, small model, large batch).
- Honest one-liner: *"GPU-native resize+normalize, 3× the fast processor at that stage,
drop-in for a GPU-decode pipeline."*
|