Molbap HF Staff commited on
Commit
e199518
·
verified ·
1 Parent(s): b26bdad

Upload folder using huggingface_hub

Browse files
KERNEL_OPS.md ADDED
@@ -0,0 +1,326 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # kernel_image_resize — how every op works (study notes)
2
+
3
+ This explains the whole package end to end: the resampling math, the data layout, and
4
+ every kernel op, plus the benchmark findings. It is meant for pen & paper — there is a
5
+ fully worked numeric example you can reproduce by hand in the "Worked example" section.
6
+
7
+ The package does one thing: **resize + rescale + normalize**, the op sequence a
8
+ `transformers` fast image processor runs (`TorchvisionBackend`: resize, then
9
+ `(x*rescale - mean)/std`), as a single GPU pipeline. Input: raw CHW `uint8` images
10
+ (any size, ragged). Output: `(N, C, out_h, out_w)` normalized `float32`.
11
+
12
+ ---
13
+
14
+ ## 1. The one idea behind resizing
15
+
16
+ Resizing does not "pick" pixels; each **output** pixel is a **weighted average of a small
17
+ window of input pixels**. Two things define that average:
18
+
19
+ 1. **Where** in the input an output pixel lands (its center).
20
+ 2. **Which** input pixels are in its window, and with **what weights**.
21
+
22
+ The weight of an input pixel falls off with distance from the center, following a filter
23
+ curve:
24
+
25
+ - **bilinear** → a triangle, width 1 on each side (so 2 input pixels per axis normally).
26
+ - **bicubic** → a cubic bump, width 2 on each side (so 4 input pixels per axis normally).
27
+
28
+ When you **shrink** an image, you must also **blur first** or you get aliasing. That is
29
+ what `antialias=True` does: it widens the window so each output pixel averages more input
30
+ pixels (a low-pass filter before throwing pixels away). Widening is proportional to the
31
+ shrink factor, so shrinking 3× turns a 4-tap bicubic into ~13 taps.
32
+
33
+ ---
34
+
35
+ ## 2. The resampling-weight formula (the heart of everything)
36
+
37
+ All kernels use the same formula, which matches PyTorch's aten `UpSampleKernel`
38
+ (`align_corners=False`, "half-pixel" convention). For one axis:
39
+
40
+ ```
41
+ scale = in_size / out_size # > 1 means shrinking
42
+ interp_half = 1 (bilinear) or 2 (bicubic) # half-width of the filter
43
+ cubic_a = -0.75 (no antialias) or -0.5 (antialias) # the cubic curve's shape constant
44
+
45
+ # antialias only widens the window, and only when shrinking:
46
+ if antialias and scale > 1:
47
+ eff = scale # window widens by the shrink factor
48
+ else:
49
+ eff = 1 # plain 2-tap / 4-tap interpolation
50
+ support = interp_half * eff # half-window width, in INPUT pixels
51
+ inv = 1 / eff # squashes the filter curve to match the widened window
52
+
53
+ # for output index i:
54
+ center = scale * (i + 0.5) # input coordinate this output maps to
55
+ first_tap = floor(center - support + 0.5) # leftmost input pixel in the window
56
+
57
+ # for each tap t = 0, 1, 2, ... (up to MAX_TAPS):
58
+ tap_pos = first_tap + t # an input pixel index
59
+ arg = (tap_pos - center + 0.5) * inv # distance from center, squashed
60
+ weight = filter(arg) # triangle or cubic, see below
61
+ ```
62
+
63
+ The filter (`_resample_weight` in `_fused.py`), with `x = |arg|`:
64
+
65
+ ```
66
+ bilinear: max(1 - x, 0) # triangle, zero past 1
67
+
68
+ bicubic: x <= 1 : (a+2)x^3 - (a+3)x^2 + 1
69
+ 1<x<2 : a x^3 - 5a x^2 + 8a x - 4a
70
+ else : 0
71
+ ```
72
+
73
+ Two edge rules (both kernels do this identically):
74
+
75
+ - **non-antialias**: clamp the tap index into `[0, in_size-1]` → replicates the border
76
+ pixel. The filter weights of a standard 2/4-tap interpolation already sum to 1.
77
+ - **antialias**: instead set the weight to **0** for taps that fall off the image
78
+ (`tap_pos < 0` or `>= in_size`), then **renormalize** by dividing by the sum of the
79
+ realized weights. This keeps the average correct at the edges.
80
+
81
+ That renormalization is why every kernel computes a `weight_sum` and divides by it. For
82
+ the non-antialias case `weight_sum == 1`, so the division is a harmless no-op.
83
+
84
+ ---
85
+
86
+ ## 3. Worked example (do this by hand)
87
+
88
+ **bilinear, no antialias, one axis, in_size=4, out_size=2.**
89
+
90
+ ```
91
+ scale = 4/2 = 2, interp_half = 1, eff = 1, support = 1, inv = 1
92
+ ```
93
+
94
+ Output pixel `i = 0`:
95
+ ```
96
+ center = 2 * (0 + 0.5) = 1.0
97
+ first_tap = floor(1.0 - 1 + 0.5) = floor(0.5) = 0
98
+ t=0: tap_pos=0, arg=(0-1.0+0.5)= -0.5 -> weight = 1-0.5 = 0.5
99
+ t=1: tap_pos=1, arg=(1-1.0+0.5)= 0.5 -> weight = 1-0.5 = 0.5
100
+ t=2: tap_pos=2, arg=(2-1.0+0.5)= 1.5 -> weight = max(1-1.5,0) = 0
101
+ weight_sum = 1.0
102
+ output[0] = (0.5*in[0] + 0.5*in[1]) / 1.0 # halfway between in[0] and in[1]
103
+ ```
104
+
105
+ Output pixel `i = 1`:
106
+ ```
107
+ center = 2 * 1.5 = 3.0
108
+ first_tap = floor(3.0 - 0.5) = 2
109
+ t=0: tap_pos=2, arg=-0.5 -> 0.5
110
+ t=1: tap_pos=3, arg= 0.5 -> 0.5
111
+ t=2: tap_pos=4, arg= 1.5 -> 0 (index 4 would clamp to 3, but weight is 0 anyway)
112
+ output[1] = 0.5*in[2] + 0.5*in[3]
113
+ ```
114
+
115
+ This 1-D operation is exactly one pass of the separable kernel. The 2-D result is the same
116
+ formula applied on both axes (rows and columns).
117
+
118
+ ---
119
+
120
+ ## 4. Data layout (host side, `_pack.py`)
121
+
122
+ Ragged images (different H×W) cannot be stacked into one tensor, so they are flattened and
123
+ concatenated into one buffer, with side tables describing each image.
124
+
125
+ `pack_images(images, dtype)` →
126
+ ```
127
+ input_pixels : 1-D buffer, all images flattened (C,H,W row-major) and concatenated
128
+ offsets[n] : element index where image n starts
129
+ heights[n], widths[n] : that image's H and W
130
+ channels : C (shared by all images)
131
+ ```
132
+ Address of input pixel `(channel, row, col)` of image `n`:
133
+ ```
134
+ input_pixels[ offsets[n] + channel*(H*W) + row*W + col ]
135
+ ```
136
+ The separable path packs as `uint8` (1 byte/pixel, half the memory traffic of float).
137
+
138
+ `fold_mean_std(mean, std, rescale)` → folds the rescale factor into the normalization
139
+ constants so the kernel does a single `(x - m)/s`:
140
+ ```
141
+ m = mean / rescale s = std / rescale
142
+ (x - m)/s == (x*rescale - mean)/std # identical to the processor's fused normalize
143
+ ```
144
+
145
+ `max_taps(images, out_size, axis, interp, antialias)` → the **widest** window in the batch
146
+ = `ceil(support) * 2 + 1`. A Triton loop bound must be a compile-time constant, so every
147
+ program loops this fixed count; taps beyond a given pixel's real window get ~0 weight.
148
+
149
+ `as_image_list` → accepts a stacked `(N,C,H,W)` tensor or a list, always returns a list.
150
+
151
+ ---
152
+
153
+ ## 5. Fused kernel (`_fused.py`, `backend="fused"`)
154
+
155
+ One launch. **One program = one image + a BLOCK of its output pixels.** Each output pixel
156
+ reads the **full 2-D window** directly: `MAX_TAPS_H × MAX_TAPS_W` input pixels.
157
+
158
+ ```
159
+ grid = (num_images, ceil(out_h*out_w / BLOCK))
160
+
161
+ per lane (one output pixel):
162
+ oy, ox = (flat_index // out_w, flat_index % out_w)
163
+ center_y, center_x, first_tap_y, first_tap_x # section 2, both axes
164
+
165
+ # weight_sum factorizes across axes (separable math, even though the LOADS are 2-D):
166
+ sum_wy = Σ_ty filter_y ; sum_wx = Σ_tx filter_x ; denom = sum_wy * sum_wx
167
+
168
+ for channel:
169
+ acc = 0
170
+ for ty in 0..MAX_TAPS_H: # <-- the 2-D window: TAPS_H * TAPS_W loads
171
+ for tx in 0..MAX_TAPS_W:
172
+ weight = filter_y(ty) * filter_x(tx)
173
+ pixel = input_pixels[channel, clamp(tap_y), clamp(tap_x)]
174
+ acc += weight * pixel
175
+ acc = acc / denom
176
+ out[image, channel, oy, ox] = (acc - mean[channel]) / std[channel]
177
+ ```
178
+
179
+ Cost per output pixel: `TAPS_H * TAPS_W` loads (e.g. 13×13 = **169**). Correct and simple,
180
+ but the 2-D load count is what makes it slow — hence the separable version.
181
+
182
+ ---
183
+
184
+ ## 6. Separable kernel (`_separable.py`, `backend="separable"`, the default)
185
+
186
+ Same math, but the 2-D window is done as **two 1-D passes**, with a float intermediate
187
+ buffer in between. Loads per output pixel: `TAPS_W + TAPS_H` (e.g. 13+13 = **26**).
188
+
189
+ ```
190
+ input_pixels (uint8, C×H×W) --pass1--> intermediate (float, C×H×out_w) --pass2--> output (float, C×out_h×out_w)
191
+ resize W resize H + normalize
192
+ ```
193
+
194
+ The **intermediate** is the key object: same **height** as the input, but already the
195
+ **final width**. ("Tall and narrow.") It is also ragged in height, so it gets its own
196
+ offset table (built in `separable_resize_normalize`, same scheme as `pack_images`).
197
+
198
+ ### Pass 1 — `_horizontal_resize_kernel` (resize width only)
199
+
200
+ ```
201
+ grid = (num_images, ceil(H*out_w / BLOCK)) # work = every input row × every output col
202
+ per lane:
203
+ input_row = flat_index // out_w # row index, UNCHANGED by this pass
204
+ out_col = flat_index % out_w # output column being computed
205
+
206
+ center_x, first_tap_x, col_weight_sum # section 2, COLUMN axis only
207
+
208
+ for channel:
209
+ acc = 0
210
+ for tap in 0..MAX_TAPS_COL: # <-- 1-D: only TAPS_W loads
211
+ weight = filter_x(tap)
212
+ pixel = input_pixels[channel, input_row, clamp(tap_col)] # uint8 -> float
213
+ acc += weight * pixel
214
+ acc = acc / col_weight_sum
215
+ intermediate[channel, input_row, out_col] = acc # NO normalize yet
216
+ ```
217
+
218
+ Reads original `uint8` bytes; writes `float32`. No normalization here.
219
+
220
+ ### Pass 2 — `_vertical_resize_normalize_kernel` (resize height, then normalize)
221
+
222
+ ```
223
+ grid = (num_images, ceil(out_h*out_w / BLOCK)) # work = every output pixel
224
+ per lane:
225
+ out_row = flat_index // out_w
226
+ out_col = flat_index % out_w
227
+
228
+ center_y, first_tap_y, row_weight_sum # section 2, ROW axis only
229
+
230
+ for channel:
231
+ acc = 0
232
+ for tap in 0..MAX_TAPS_ROW: # <-- 1-D: only TAPS_H loads
233
+ weight = filter_y(tap)
234
+ pixel = intermediate[channel, clamp(tap_row), out_col] # float
235
+ acc += weight * pixel
236
+ acc = acc / row_weight_sum
237
+ out[image, channel, out_row, out_col] = (acc - mean[channel]) / std[channel] # normalize here
238
+ ```
239
+
240
+ Two launches (an implicit sync between them), so pass 2 always sees pass 1's finished
241
+ output.
242
+
243
+ ### Why separable wins
244
+
245
+ `TAPS_W + TAPS_H` loads instead of `TAPS_W * TAPS_H`. For a 13×13 window that is 26 vs 169.
246
+ This is exactly the algorithm PIL and torchvision use. The catch: an extra full-size float
247
+ intermediate buffer (more memory traffic), but the read-count reduction dominates.
248
+
249
+ Parity note: the intermediate here is **float32**; torchvision keeps a **fixed-point
250
+ uint8** intermediate. So the separable output is parity-*close* to torchvision, not
251
+ bit-identical — and the float version is actually the more accurate one.
252
+
253
+ ---
254
+
255
+ ## 7. Public API (`__init__.py`)
256
+
257
+ ```
258
+ resize_normalize(images, size, image_mean, image_std,
259
+ rescale_factor=1/255, resample="bilinear", antialias=False,
260
+ backend="separable", block=256)
261
+ ```
262
+ - `images`: stacked `(N,C,H,W)` tensor or a list of CHW tensors.
263
+ - `size`: int (square), `(H,W)`, or `{"height","width"}`.
264
+ - `resample`: `"bilinear"`/`"bicubic"`, or a PIL resample int (0/2→bilinear, 3→bicubic).
265
+ - `backend`: `"separable"` (default, fastest) or `"fused"` (2-D reference).
266
+ - `resize_normalize_ragged`: same kernels, list-only.
267
+
268
+ ---
269
+
270
+ ## 8. Benchmark findings (A100, CUDA_VISIBLE_DEVICES=1)
271
+
272
+ ### Standalone resize+normalize — SigLIP-so400m config, N=32 ragged 384–1024², out 384², bicubic+AA
273
+ ```
274
+ torchvision eager loop : 2.91 ms (per-image float loop)
275
+ torchvision compiled : 5.70 ms (torch.compile dynamic, per-image; slower than eager)
276
+ torchvision compiled pkt: 2.55 ms (one graph over a padded stack; timing only)
277
+ fused triton (2D) : 11.49 ms (taps*taps; the slow reference)
278
+ separable triton (uint8): 1.29 ms (taps+taps) <-- fastest
279
+ real processor : 3.92 ms
280
+ ```
281
+ **Separable is ~3× the real processor**, parity ≤1e-4 vs torchvision-float. The fused 2-D
282
+ loses for the algorithmic reason above (169 vs 26 loads). `torch.compile` does not help:
283
+ per-image it is *slower* (dispatch overhead over 32 ragged shapes); even as one packed
284
+ graph it only matches the eager loop, because inductor's interpolate is no faster than aten
285
+ resize.
286
+
287
+ ### End-to-end inference — Siglip2-base-patch16-224, **bf16** forward
288
+ ```
289
+ preprocess forward(fixed input) preprocess+forward
290
+ processor 3.99 12.86 14.44
291
+ separable 0.93 13.02 13.76 <-- ~5% faster e2e
292
+ fused 2.00 13.01 14.79
293
+ compiled 6.14 12.89 14.00
294
+ feature parity (separable/fused/compiled vs processor): 9.38e-2 = 1.2% of feature max
295
+ ```
296
+ - `forward(fixed input)` is identical (~12.9 ms) for all → **no inference regression**; the
297
+ model does not care which preprocessor made the tensor.
298
+ - The 1.2% feature drift is the float-vs-uint8 resize difference, identical across all
299
+ float backends → not a bug. The float path is the more accurate one.
300
+ - End-to-end win is ~5% with a bf16 forward (was ~0.5% with fp32, where the forward was
301
+ ~80 ms). **The win scales with how preprocessing-bound you are.**
302
+
303
+ ### Data path from JPEG bytes — 552 KB/img
304
+ ```
305
+ CPU decode + torchvision resize : 177.5 ms (status quo)
306
+ CPU decode + separable kernel : 176.4 ms (kernel saves ~1 ms; decode dominates)
307
+ GPU decode (nvJPEG) + kernel : 14.8 ms (fully on-GPU)
308
+ ```
309
+ - ~175 ms of the 177 ms is **CPU JPEG decode + host→device copy**. Resize/normalize is ~1%.
310
+ - The 12× win (177→15) is **GPU decode (nvJPEG)**, i.e. `torchvision.io.decode_jpeg(device="cuda")`
311
+ — *not* the kernel. The kernel is the resize/normalize component of that GPU pipeline.
312
+
313
+ ---
314
+
315
+ ## 9. What is true / what to claim
316
+
317
+ - The kernel is **correct** (≤1e-4 vs torchvision-float, more accurate than the processor's
318
+ uint8 path) and feeds the model with **no inference regression**.
319
+ - It is **~3× the real processor at the resize/normalize stage** — a real, parity-clean win.
320
+ - It does **not** speed up preprocessing 12×. Decode dominates the data path; the GPU-decode
321
+ lever is nvJPEG, a torchvision feature, not this kernel.
322
+ - The kernel matters end-to-end only once you are **not decode-bound**: in a GPU-decode
323
+ pipeline it keeps resize/normalize minimal (~10% of that pipeline), and its standalone
324
+ preprocess win shows up when the forward is small (bf16, small model, large batch).
325
+ - Honest one-liner: *"GPU-native resize+normalize, 3× the fast processor at that stage,
326
+ drop-in for a GPU-decode pipeline."*
README.md ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ tags:
3
+ - kernel
4
+ ---
5
+
6
+ # kernel_image_resize
7
+
8
+ A pure-Triton Hub kernel that fuses the **resize + rescale + normalize** preprocessing
9
+ pipeline run by ~150 `transformers` fast image processors (`TorchvisionBackend`: resize →
10
+ fold(rescale, normalize)) into a single GPU pass. It takes raw CHW `uint8` images and
11
+ returns the normalized `(N, C, out_h, out_w)` float tensor with no intermediate
12
+ full-resolution float buffer.
13
+
14
+ On a ragged SigLIP-so400m batch (A100, N=32, inputs 384–1024², out 384², bicubic+antialias)
15
+ the default backend runs in **1.29 ms/iter vs 3.90 ms for the fast processor (~3× faster)**
16
+ and 2.89 ms for torchvision's own per-image loop, at parity ≤1e-4 vs torchvision-float.
17
+
18
+ It ships as a `kernels` universal build variant (no compiled extension, just Triton), so it
19
+ loads on any CUDA PyTorch build via `get_kernel`.
20
+
21
+ ## Usage
22
+
23
+ ```python
24
+ import torch
25
+ from kernels import get_kernel
26
+
27
+ kir = get_kernel("Molbap/kernel_image_resize", revision="main", trust_remote_code=True)
28
+
29
+ # a list of different-H×W uint8 CHW images (the ragged case torchvision loops over)
30
+ images = [torch.randint(0, 256, (3, h, w), dtype=torch.uint8, device="cuda")
31
+ for h, w in [(640, 480), (800, 600), (384, 1024)]]
32
+
33
+ pixel_values = kir.resize_normalize(
34
+ images,
35
+ size=384, # int (square), (H, W), or {"height", "width"}
36
+ image_mean=[0.5, 0.5, 0.5],
37
+ image_std=[0.5, 0.5, 0.5],
38
+ rescale_factor=1 / 255,
39
+ resample="bicubic", # or "bilinear", or a PIL resample int
40
+ antialias=True, # match the ViT/CLIP/SigLIP default
41
+ )
42
+ # -> (3, 3, 384, 384) float32, ready for the model
43
+ ```
44
+
45
+ `trust_remote_code=True` is required because this is a personal namespace (not the trusted
46
+ `kernels-community` org). `revision="main"` loads the current code; tag a `v1.0.0` release if
47
+ you want `version=1` loading instead.
48
+
49
+ `resize_normalize` accepts a stacked `(N, C, H, W)` tensor or a ragged list of CHW
50
+ tensors. `resize_normalize_ragged` is the same kernel, list-only.
51
+
52
+ ## With a transformers processor
53
+
54
+ There is no `use_kernels=True` hook for image processors — that machinery swaps `nn.Module`
55
+ layer forwards inside the model, not processor code. Use the kernel directly with the
56
+ processor's config (see `example_transformers.py` for a runnable version):
57
+
58
+ ```python
59
+ from kernels import get_kernel
60
+ kir = get_kernel("Molbap/kernel_image_resize", revision="main", trust_remote_code=True)
61
+ _PIL_RESAMPLE = {0: "bilinear", 2: "bilinear", 3: "bicubic"}
62
+
63
+ def preprocess_with_kernel(processor, images):
64
+ size = processor.size # must be fixed {"height", "width"}; no crop/pad
65
+ return kir.resize_normalize(
66
+ images, (size["height"], size["width"]),
67
+ processor.image_mean, processor.image_std,
68
+ rescale_factor=float(processor.rescale_factor),
69
+ resample=_PIL_RESAMPLE[int(processor.resample)],
70
+ antialias=bool(getattr(processor, "antialias", True)),
71
+ )
72
+ ```
73
+
74
+ ## Backends
75
+
76
+ - `backend="separable"` (default): two-pass `uint8` kernel doing `taps+taps` loads —
77
+ torchvision's own separable algorithm. Fastest (~3× the fast processor on the batch
78
+ above); parity ≤1e-4 vs torchvision-float. The float intermediate makes it more accurate
79
+ than, but not bit-identical to, torchvision's fixed-point `uint8` intermediate.
80
+ - `backend="fused"`: a single 2D launch, `taps×taps` loads per output pixel. Same parity,
81
+ kept as the reference path but ~9× slower than separable (the 2D float load count is the
82
+ reason a separable pass wins — see `benchmarks/benchmark.py`).
83
+
84
+ ## Parity notes
85
+
86
+ The resampling weights match PyTorch aten `UpSampleKernel`. Antialiased bicubic uses the
87
+ PIL cubic coefficient `a=-0.5`; non-antialiased bicubic uses Keys `a=-0.75`. The
88
+ antialias renormalize-truncate window applies on every axis, including upsampling dims.
89
+
90
+ ## Center crop / shortest-edge
91
+
92
+ Pass `crop_size` to resize then center-crop in one pass (the crop is folded into the
93
+ output-coordinate mapping, no extra buffer). `resize_mode="shortest_edge"` does
94
+ aspect-preserving resize (short side = `size`) then crop — the CLIP / DINOv2 pipeline.
95
+
96
+ ```python
97
+ # CLIP/DINOv2-style: resize shortest edge to 256, center-crop 224
98
+ pv = kir.resize_normalize(images, 256, mean, std, resample="bicubic", antialias=True,
99
+ crop_size=224, resize_mode="shortest_edge")
100
+ ```
101
+
102
+ `example_transformers.py` derives all of this from a processor's config automatically.
103
+
104
+ ## Scope
105
+
106
+ Resize (+ optional center crop) + rescale + normalize. It does **not** pad — padding
107
+ processors (many detection models) run a different pipeline. The `fused` backend is
108
+ resize-only; crop is handled by the `separable` backend.
benchmarks/benchmark.py ADDED
@@ -0,0 +1,293 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Benchmark resize+normalize: separable / fused triton vs torchvision vs the real processor.
2
+
3
+ PYTHONPATH=../torch-ext python benchmark.py --processor google/siglip-so400m-patch14-384
4
+ PYTHONPATH=../torch-ext python benchmark.py --n 32 --out 384 384 --interp bicubic --antialias
5
+
6
+ Prints parity (vs torchvision-float) per backend, then ms/iter for each path. Needs CUDA.
7
+ """
8
+
9
+ import argparse
10
+ import sys
11
+ import time
12
+
13
+ # Hide `kernels` from transformers: this worktree builds kernels.LayerRepository without a version,
14
+ # which newer `kernels` rejects at import. Preprocessing needs no hub layer kernels.
15
+ sys.modules["kernels"] = None
16
+
17
+ import torch
18
+ import torchvision.transforms.v2.functional as tvF
19
+ from torchvision.io import ImageReadMode, decode_jpeg, encode_jpeg
20
+ from torchvision.transforms import InterpolationMode
21
+
22
+ from kernel_image_resize import resize_normalize
23
+ from kernel_image_resize._pack import PIL_RESAMPLE_TO_INTERP, max_taps
24
+
25
+
26
+ _TV_INTERP = {"bilinear": InterpolationMode.BILINEAR, "bicubic": InterpolationMode.BICUBIC}
27
+
28
+
29
+ def make_ragged_images(n, device, min_res, max_res, seed=0):
30
+ g = torch.Generator(device="cpu").manual_seed(seed)
31
+ images = []
32
+ for _ in range(n):
33
+ h = int(torch.randint(min_res, max_res + 1, (1,), generator=g).item())
34
+ w = int(torch.randint(min_res, max_res + 1, (1,), generator=g).item())
35
+ images.append(torch.randint(0, 256, (3, h, w), generator=g, dtype=torch.uint8).to(device))
36
+ return images
37
+
38
+
39
+ def torchvision_reference(images, out_h, out_w, mean, std, rescale, interp, antialias):
40
+ mode = _TV_INTERP[interp]
41
+ mean_t = torch.tensor(mean, device=images[0].device).view(3, 1, 1)
42
+ std_t = torch.tensor(std, device=images[0].device).view(3, 1, 1)
43
+ outs = []
44
+ for img in images:
45
+ r = tvF.resize(img.float(), [out_h, out_w], interpolation=mode, antialias=antialias)
46
+ outs.append((r * rescale - mean_t) / std_t)
47
+ return torch.stack(outs)
48
+
49
+
50
+ def build_compiled_reference(out_h, out_w, mean, std, rescale, interp, antialias, device):
51
+ """torch.compile(dynamic=True) of a per-image float resize+normalize."""
52
+ import torch.nn.functional as F
53
+
54
+ mean_t = torch.tensor(mean, device=device).view(3, 1, 1)
55
+ std_t = torch.tensor(std, device=device).view(3, 1, 1)
56
+ mode = "bicubic" if interp == "bicubic" else "bilinear"
57
+
58
+ def _one(img):
59
+ r = F.interpolate(img.unsqueeze(0).float(), size=(out_h, out_w), mode=mode, antialias=antialias, align_corners=False)
60
+ return (r.squeeze(0) * rescale - mean_t) / std_t
61
+
62
+ compiled = torch.compile(_one, dynamic=True)
63
+
64
+ def run(images):
65
+ return torch.stack([compiled(img) for img in images])
66
+
67
+ return run
68
+
69
+
70
+ def pad_stack(images):
71
+ """Pad ragged CHW images to the batch-max H/W and stack into (N, C, Hmax, Wmax)."""
72
+ c = images[0].shape[0]
73
+ max_h = max(img.shape[1] for img in images)
74
+ max_w = max(img.shape[2] for img in images)
75
+ out = torch.zeros(len(images), c, max_h, max_w, dtype=images[0].dtype, device=images[0].device)
76
+ for i, img in enumerate(images):
77
+ out[i, :, : img.shape[1], : img.shape[2]] = img
78
+ return out
79
+
80
+
81
+ def build_packed_compiled_reference(out_h, out_w, mean, std, rescale, interp, antialias, device):
82
+ """torch.compile of a single batched resize+normalize over a stacked (N, C, H, W) tensor."""
83
+ import torch.nn.functional as F
84
+
85
+ mean_t = torch.tensor(mean, device=device).view(1, 3, 1, 1)
86
+ std_t = torch.tensor(std, device=device).view(1, 3, 1, 1)
87
+ mode = "bicubic" if interp == "bicubic" else "bilinear"
88
+
89
+ def _batch(stacked):
90
+ r = F.interpolate(stacked.float(), size=(out_h, out_w), mode=mode, antialias=antialias, align_corners=False)
91
+ return (r * rescale - mean_t) / std_t
92
+
93
+ return torch.compile(_batch, dynamic=True)
94
+
95
+
96
+ def run_inference(model_id, images, block, iters, device):
97
+ """End-to-end: preprocess (processor / separable / fused / compiled) -> vision features (bf16 forward).
98
+ Checks each kernel feeds the model with no feature drift and times the full pipeline."""
99
+ from transformers import AutoModel
100
+
101
+ proc, (out_h, out_w, mean, std, rescale, interp, antialias) = load_processor_config(model_id)
102
+ model = AutoModel.from_pretrained(model_id).to(device=device, dtype=torch.bfloat16).eval()
103
+ vision = model.vision_model
104
+ kk = dict(size=(out_h, out_w), image_mean=mean, image_std=std, rescale_factor=rescale,
105
+ resample=interp, antialias=antialias, block=block)
106
+
107
+ @torch.no_grad()
108
+ def features(pixel_values):
109
+ out = vision(pixel_values=pixel_values.to(model.dtype))
110
+ pooled = getattr(out, "pooler_output", None)
111
+ return pooled if pooled is not None else out.last_hidden_state
112
+
113
+ compiled_one = build_compiled_reference(out_h, out_w, mean, std, rescale, interp, antialias, device)
114
+ methods = {
115
+ "processor": lambda: proc(images, return_tensors="pt", device=device)["pixel_values"],
116
+ "separable": lambda: resize_normalize(images, backend="separable", **kk),
117
+ "fused": lambda: resize_normalize(images, backend="fused", **kk),
118
+ "compiled": lambda: compiled_one(images),
119
+ }
120
+ methods["compiled"]() # warmup the compiled artifact
121
+ methods["compiled"]()
122
+ torch.cuda.synchronize()
123
+
124
+ print(f"\n[infer] {model_id} out={out_h}x{out_w} forward dtype=bfloat16")
125
+ base = features(methods["processor"]())
126
+ base_scale = base.abs().max().item()
127
+ for name in ("separable", "fused", "compiled"):
128
+ d = (features(methods[name]()) - base).abs().max().item()
129
+ print(f"[infer parity] features {name} vs processor: max|Δ| = {d:.2e} ({d / base_scale:.1%} of feature max)")
130
+
131
+ # forward is timed on a FIXED precomputed tensor, so it is method-independent by construction;
132
+ # if it varies across rows, the preprocessor's output (dtype/contiguity) is hurting the model.
133
+ print("[infer] ms/iter: preprocess forward(fixed input) preprocess+forward")
134
+ for name, preprocess in methods.items():
135
+ pixel_values = preprocess()
136
+ pre = _time(preprocess, iters, device)
137
+ fwd = _time(lambda pixel_values=pixel_values: features(pixel_values), iters, device)
138
+ e2e = _time(lambda preprocess=preprocess: features(preprocess()), iters, device)
139
+ print(f" {name:10s} {pre:8.3f} {fwd:8.3f} {e2e:8.3f}")
140
+
141
+
142
+ def run_decode(images_cpu, out_h, out_w, mean, std, rescale, interp, antialias, block, iters, device):
143
+ """Data-path table from JPEG bytes: CPU decode (libjpeg) vs GPU decode (nvJPEG) + the kernel.
144
+
145
+ decoders differ at the pixel level (nvJPEG vs libjpeg), so this measures wall-clock, not parity.
146
+ """
147
+ jpeg = [encode_jpeg(img, quality=95) for img in images_cpu]
148
+ avg_kb = sum(b.numel() for b in jpeg) / len(jpeg) / 1024
149
+ kk = dict(size=(out_h, out_w), image_mean=mean, image_std=std, rescale_factor=rescale,
150
+ resample=interp, antialias=antialias, block=block)
151
+
152
+ def cpu_decode_kernel():
153
+ imgs = [decode_jpeg(b, mode=ImageReadMode.RGB).to(device) for b in jpeg]
154
+ return resize_normalize(imgs, backend="separable", **kk)
155
+
156
+ def gpu_decode_kernel():
157
+ imgs = decode_jpeg(jpeg, mode=ImageReadMode.RGB, device=device)
158
+ return resize_normalize(imgs, backend="separable", **kk)
159
+
160
+ def gpu_decode_torchvision():
161
+ imgs = decode_jpeg(jpeg, mode=ImageReadMode.RGB, device=device)
162
+ return torchvision_reference(imgs, out_h, out_w, mean, std, rescale, interp, antialias)
163
+
164
+ def cpu_decode_torchvision():
165
+ imgs = [decode_jpeg(b, mode=ImageReadMode.RGB).to(device) for b in jpeg]
166
+ return torchvision_reference(imgs, out_h, out_w, mean, std, rescale, interp, antialias)
167
+
168
+ print(f"\n[decode] N={len(jpeg)} avg={avg_kb:.0f} KB/img out={out_h}x{out_w} (from JPEG bytes, ms/iter)")
169
+ print(f" CPU decode + torchvision resize : {_time(cpu_decode_torchvision, iters, device):8.3f} [status quo data path]")
170
+ print(f" CPU decode + separable kernel : {_time(cpu_decode_kernel, iters, device):8.3f}")
171
+ print(f" GPU decode (nvJPEG) + tv resize : {_time(gpu_decode_torchvision, iters, device):8.3f} [GPU pipeline, tv resize]")
172
+ print(f" GPU decode (nvJPEG) + kernel : {_time(gpu_decode_kernel, iters, device):8.3f} [GPU pipeline, kernel resize]")
173
+
174
+
175
+ def load_processor_config(name):
176
+ from transformers import AutoImageProcessor
177
+
178
+ proc = AutoImageProcessor.from_pretrained(name, backend="torchvision")
179
+ size = proc.size
180
+ if "height" not in size or "width" not in size:
181
+ raise ValueError(f"{name}: size={size} is not a fixed (height, width)")
182
+ out_h, out_w = size["height"], size["width"]
183
+ interp = PIL_RESAMPLE_TO_INTERP.get(int(proc.resample))
184
+ rescale = float(proc.rescale_factor) if getattr(proc, "do_rescale", True) else 1.0
185
+ antialias = bool(getattr(proc, "antialias", True))
186
+ return proc, (out_h, out_w, list(proc.image_mean), list(proc.image_std), rescale, interp, antialias)
187
+
188
+
189
+ def _time(fn, iters, device):
190
+ for _ in range(3):
191
+ fn()
192
+ if device.type == "cuda":
193
+ torch.cuda.synchronize()
194
+ start, end = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True)
195
+ start.record()
196
+ for _ in range(iters):
197
+ fn()
198
+ end.record()
199
+ torch.cuda.synchronize()
200
+ return start.elapsed_time(end) / iters
201
+ t0 = time.perf_counter()
202
+ for _ in range(iters):
203
+ fn()
204
+ return (time.perf_counter() - t0) / iters * 1e3
205
+
206
+
207
+ def main():
208
+ parser = argparse.ArgumentParser()
209
+ parser.add_argument("--processor", default=None)
210
+ parser.add_argument("--n", type=int, default=32)
211
+ parser.add_argument("--out", type=int, nargs=2, default=[384, 384], metavar=("H", "W"))
212
+ parser.add_argument("--interp", choices=["bilinear", "bicubic"], default="bicubic")
213
+ parser.add_argument("--antialias", action="store_true")
214
+ parser.add_argument("--min-res", type=int, default=384)
215
+ parser.add_argument("--max-res", type=int, default=1024)
216
+ parser.add_argument("--iters", type=int, default=50)
217
+ parser.add_argument("--block", type=int, default=256)
218
+ parser.add_argument("--tol", type=float, default=3e-3)
219
+ parser.add_argument("--infer", action="store_true", help="end-to-end Siglip2 inference comparison (bf16 forward)")
220
+ parser.add_argument("--model", default="google/siglip2-base-patch16-224", help="model for --infer")
221
+ parser.add_argument("--decode", action="store_true", help="JPEG-decode data-path table (CPU vs GPU/nvJPEG) and stop")
222
+ args = parser.parse_args()
223
+
224
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
225
+ if device.type != "cuda":
226
+ print("benchmark needs CUDA.")
227
+ return
228
+
229
+ proc = None
230
+ if args.processor:
231
+ proc, (out_h, out_w, mean, std, rescale, interp, antialias) = load_processor_config(args.processor)
232
+ print(f"processor={args.processor} -> out={out_h}x{out_w} interp={interp} antialias={antialias}")
233
+ else:
234
+ out_h, out_w = args.out
235
+ mean = [0.48145466, 0.4578275, 0.40821073]
236
+ std = [0.26862954, 0.26130258, 0.27577711]
237
+ rescale = 1.0 / 255.0
238
+ interp, antialias = args.interp, args.antialias
239
+
240
+ images = make_ragged_images(args.n, device, args.min_res, args.max_res)
241
+ taps = (max_taps(images, out_h, 1, interp, antialias), max_taps(images, out_w, 2, interp, antialias))
242
+ print(f"N={args.n} in∈[{args.min_res},{args.max_res}]² ragged out={out_h}x{out_w} "
243
+ f"interp={interp} antialias={antialias} max_taps={taps} iters={args.iters}\n")
244
+
245
+ if args.decode:
246
+ images_cpu = make_ragged_images(args.n, torch.device("cpu"), args.min_res, args.max_res)
247
+ run_decode(images_cpu, out_h, out_w, mean, std, rescale, interp, antialias, args.block, args.iters, device)
248
+ return
249
+
250
+ ref = torchvision_reference(images, out_h, out_w, mean, std, rescale, interp, antialias)
251
+ common = dict(size=(out_h, out_w), image_mean=mean, image_std=std, rescale_factor=rescale,
252
+ resample=interp, antialias=antialias, block=args.block)
253
+ for backend in ("fused", "separable"):
254
+ got = resize_normalize(images, backend=backend, **common)
255
+ d = (got - ref).abs().max().item()
256
+ print(f"[parity] {backend:9s} vs torchvision(float): max|Δ| = {d:.2e} "
257
+ f"({'PASS' if d < args.tol else 'FAIL'} @ tol={args.tol})")
258
+ print()
259
+
260
+ compiled_run = build_compiled_reference(out_h, out_w, mean, std, rescale, interp, antialias, device)
261
+ packed = pad_stack(images)
262
+ packed_compiled_run = build_packed_compiled_reference(out_h, out_w, mean, std, rescale, interp, antialias, device)
263
+ t0 = time.perf_counter()
264
+ compiled_run(images)
265
+ compiled_run(images)
266
+ packed_compiled_run(packed)
267
+ packed_compiled_run(packed)
268
+ torch.cuda.synchronize()
269
+ t_warmup = (time.perf_counter() - t0) * 1e3
270
+
271
+ t_eager = _time(lambda: torchvision_reference(images, out_h, out_w, mean, std, rescale, interp, antialias), args.iters, device)
272
+ t_comp = _time(lambda: compiled_run(images), args.iters, device)
273
+ t_comp_packed = _time(lambda: packed_compiled_run(packed), args.iters, device)
274
+ t_fused = _time(lambda: resize_normalize(images, backend="fused", **common), args.iters, device)
275
+ t_sep = _time(lambda: resize_normalize(images, backend="separable", **common), args.iters, device)
276
+ print("Resize+normalize only (no decode/H2D), ms/iter:")
277
+ print(f" torchvision eager loop : {t_eager:8.3f} [per-image float loop]")
278
+ print(f" torchvision compiled : {t_comp:8.3f} [torch.compile dynamic per-image; warmup {t_warmup:.0f} ms excluded]")
279
+ print(f" torchvision compiled pkt: {t_comp_packed:8.3f} [one graph over padded (N,C,Hmax,Wmax) stack; timing only, padding alters output]")
280
+ print(f" fused triton (2D) : {t_fused:8.3f} [taps*taps]")
281
+ print(f" separable triton (uint8): {t_sep:8.3f} [taps+taps]")
282
+
283
+ if proc is not None:
284
+ t_pr = _time(lambda: proc(images, return_tensors="pt", device=device)["pixel_values"], args.iters, device)
285
+ print(f"\n {args.processor} : {t_pr:8.3f} ms/iter")
286
+ print(f" -> separable is {t_sep / t_pr:.2f}x the real processor")
287
+
288
+ if args.infer:
289
+ run_inference(args.model, images, args.block, args.iters, device)
290
+
291
+
292
+ if __name__ == "__main__":
293
+ main()
benchmarks/compat_check.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Compatibility + forward-parity sweep over the most-downloaded image models (one per architecture).
2
+
3
+ For each model: load its processor + model, decide whether the kernel can stand in for the
4
+ processor's resize(+crop)+normalize, and for supported ones run processor-vs-kernel pixel_values
5
+ through the SAME vision tower and report pixel + feature parity. Unsupported models list the reason.
6
+
7
+ Run on the DGX (CUDA + working transformers):
8
+ PYTHONPATH=../torch-ext python compat_check.py
9
+ """
10
+
11
+ import sys
12
+
13
+ # This transformers worktree constructs kernels.LayerRepository without a version/revision, which
14
+ # newer `kernels` rejects at import. We do not need hub LAYER kernels for preprocessing, so hide
15
+ # `kernels` from transformers — it falls back to its no-hub-kernels stub path and imports cleanly.
16
+ sys.modules["kernels"] = None
17
+
18
+ import torch
19
+
20
+ from kernel_image_resize import resize_normalize # local package, via PYTHONPATH=../torch-ext
21
+
22
+
23
+ _PIL_RESAMPLE = {2: "bilinear", 3: "bicubic"}
24
+
25
+ # Top image models by HF downloads (June 2026), deduplicated to one repo per architecture family.
26
+ MODELS = [
27
+ ("openai/clip-vit-base-patch32", 20528683), # clip
28
+ ("google/vit-base-patch16-224", 4910416), # vit
29
+ ("apple/mobilevit-small", 3488074), # mobilevit
30
+ ("facebook/dinov2-small", 2602780), # dinov2
31
+ ("google/siglip-so400m-patch14-384", 1379598), # siglip
32
+ ("facebook/dinov3-vitb16-pretrain-lvd1689m", 467337), # dinov3
33
+ ("microsoft/swinv2-tiny-patch4-window16-256", 385713), # swinv2
34
+ ("google/siglip2-base-patch16-224", 336824), # siglip2
35
+ ("microsoft/resnet-50", 307057), # resnet (convnext processor)
36
+ ("nvidia/segformer-b0-finetuned-ade-512-512", 262459), # segformer
37
+ ("facebook/convnextv2-tiny-22k-384", 48614), # convnextv2
38
+ ("google/mobilenet_v2_1.0_224", 48342), # mobilenet
39
+ ("facebook/convnext-tiny-224", 16984), # convnext
40
+ ("google/efficientnet-b0", 8577), # efficientnet
41
+ ("microsoft/beit-base-patch16-224-pt22k-ft22k", 7529), # beit
42
+ ]
43
+
44
+
45
+ def unsupported_reason(p):
46
+ """Return None if the kernel can stand in for this processor, else a short reason."""
47
+ if not getattr(p, "do_resize", True):
48
+ return "no resize"
49
+ if not getattr(p, "do_normalize", True):
50
+ return "no normalize (rescale only)"
51
+ if getattr(p, "do_flip_channel_order", False):
52
+ return "channel flip (BGR)"
53
+ if getattr(p, "do_pad", False):
54
+ return "pad"
55
+ if int(getattr(p, "resample", 2)) not in _PIL_RESAMPLE:
56
+ return f"resample {p.resample}"
57
+ size = getattr(p, "size", {}) or {}
58
+ crop = p.crop_size if getattr(p, "do_center_crop", False) else None
59
+ if "shortest_edge" in size:
60
+ return None if crop else "shortest_edge without crop (variable output)"
61
+ if "height" in size and "width" in size:
62
+ return None
63
+ return f"size {size}"
64
+
65
+
66
+ def preprocess_with_kernel(p, images):
67
+ size = p.size
68
+ resample = _PIL_RESAMPLE[int(p.resample)]
69
+ antialias = bool(getattr(p, "antialias", True))
70
+ rescale = float(p.rescale_factor) if getattr(p, "do_rescale", True) else 1.0
71
+ mean, std = p.image_mean, p.image_std
72
+ crop = p.crop_size if getattr(p, "do_center_crop", False) else None
73
+ common = dict(rescale_factor=rescale, resample=resample, antialias=antialias)
74
+ if "shortest_edge" in size:
75
+ return resize_normalize(
76
+ images, size["shortest_edge"], mean, std,
77
+ crop_size=(crop["height"], crop["width"]), resize_mode="shortest_edge", **common)
78
+ if crop is not None and (crop["height"] != size["height"] or crop["width"] != size["width"]):
79
+ return resize_normalize(
80
+ images, (size["height"], size["width"]), mean, std,
81
+ crop_size=(crop["height"], crop["width"]), resize_mode="square", **common)
82
+ return resize_normalize(images, (size["height"], size["width"]), mean, std, **common)
83
+
84
+
85
+ def vision_features(model, pixel_values):
86
+ tower = getattr(model, "vision_model", model)
87
+ out = tower(pixel_values=pixel_values.to(model.dtype))
88
+ for attr in ("pooler_output", "last_hidden_state"):
89
+ value = getattr(out, attr, None)
90
+ if value is not None and torch.is_tensor(value):
91
+ return value
92
+ return out[0]
93
+
94
+
95
+ def main():
96
+ from transformers import AutoImageProcessor, AutoModel # lazy: avoids importing the kernels lib first
97
+
98
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
99
+ images = [
100
+ torch.randint(0, 256, (3, h, w), dtype=torch.uint8, device=device)
101
+ for h, w in [(640, 480), (800, 600), (512, 512), (384, 1024)]
102
+ ]
103
+ print(f"{'model':46s} {'verdict':10s} pixel max|Δ| feature max|Δ| (rel)")
104
+ for model_id, _ in MODELS:
105
+ try:
106
+ processor = AutoImageProcessor.from_pretrained(model_id)
107
+ reason = unsupported_reason(processor)
108
+ if reason is not None:
109
+ print(f"{model_id:46s} SKIP: {reason}")
110
+ continue
111
+ model = AutoModel.from_pretrained(model_id).to(device).eval()
112
+ reference_pv = processor(images, return_tensors="pt", device=device)["pixel_values"].to(device)
113
+ kernel_pv = preprocess_with_kernel(processor, images)
114
+ pixel_delta = (kernel_pv - reference_pv).abs().max().item()
115
+ with torch.no_grad():
116
+ base = vision_features(model, reference_pv)
117
+ feat_delta = (vision_features(model, kernel_pv) - base).abs().max().item()
118
+ rel = feat_delta / base.abs().max().item()
119
+ print(f"{model_id:46s} OK {pixel_delta:.2e} {feat_delta:.2e} ({rel:.1%})")
120
+ del model
121
+ torch.cuda.empty_cache()
122
+ except Exception as e:
123
+ print(f"{model_id:46s} ERROR: {type(e).__name__}: {str(e)[:55]}")
124
+
125
+
126
+ if __name__ == "__main__":
127
+ main()
build.toml ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ [general]
2
+ name = "kernel_image_resize"
3
+ universal = true
4
+ version = 1
5
+
6
+ [general.hub]
7
+ repo-id = "Molbap/kernel_image_resize"
build/torch-universal/kernel_image_resize/__init__.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Resize + rescale + normalize for transformers fast image processors, as a Triton kernel.
2
+
3
+ resize -> fold(rescale, normalize) in one GPU pipeline: CHW uint8 images in,
4
+ (N, C, out_h, out_w) normalized float out, no full-resolution float intermediate.
5
+
6
+ - resize_normalize — stacked (N, C, H, W) tensor or a list of CHW images.
7
+ - resize_normalize_ragged — same kernels; takes a list of different-H/W CHW tensors.
8
+
9
+ backend="separable" (default): two-pass uint8, taps+taps. backend="fused": single 2D
10
+ launch, taps*taps. Both parity <=1e-4 vs torchvision-float.
11
+
12
+ from kernels import get_kernel
13
+ kir = get_kernel("Molbap/kernel_image_resize")
14
+ pixel_values = kir.resize_normalize(
15
+ images, size=384, image_mean=[...], image_std=[...], resample="bicubic", antialias=True,
16
+ )
17
+ """
18
+
19
+ from ._fused import fused_resize_normalize
20
+ from ._pack import PIL_RESAMPLE_TO_INTERP, as_image_list
21
+ from ._separable import separable_resize_crop_normalize, separable_resize_normalize
22
+
23
+
24
+ def _normalize_size(size) -> tuple[int, int]:
25
+ if isinstance(size, int):
26
+ return size, size
27
+ if isinstance(size, dict):
28
+ if "height" in size and "width" in size:
29
+ return int(size["height"]), int(size["width"])
30
+ raise ValueError(f"size dict must hold 'height'/'width' for a fixed resize, got {size}")
31
+ out_h, out_w = size
32
+ return int(out_h), int(out_w)
33
+
34
+
35
+ def _normalize_resample(resample) -> str:
36
+ if isinstance(resample, str):
37
+ if resample not in ("bilinear", "bicubic"):
38
+ raise ValueError(f"resample must be 'bilinear' or 'bicubic', got {resample!r}")
39
+ return resample
40
+ interp = PIL_RESAMPLE_TO_INTERP.get(int(resample))
41
+ if interp is None:
42
+ raise ValueError(f"unsupported PIL resample code {resample}")
43
+ return interp
44
+
45
+
46
+ def resize_normalize(
47
+ images,
48
+ size,
49
+ image_mean,
50
+ image_std,
51
+ rescale_factor: float = 1.0 / 255.0,
52
+ resample="bilinear",
53
+ antialias: bool = False,
54
+ backend: str = "separable",
55
+ block: int = 256,
56
+ crop_size=None,
57
+ resize_mode: str = "square",
58
+ ):
59
+ """Resize, optionally center-crop, rescale and normalize — one GPU pipeline.
60
+
61
+ Args:
62
+ images: a stacked `(N, C, H, W)` uint8/float tensor, or a list of CHW tensors (ragged).
63
+ size: resize target. With no crop: int (square), `(height, width)`, or `{"height","width"}`.
64
+ With `resize_mode="shortest_edge"`: an int, the short side after aspect-preserving resize.
65
+ image_mean, image_std: per-channel normalization stats (length C).
66
+ rescale_factor: folded into mean/std so the kernel does `(x*rescale - mean)/std`.
67
+ resample: "bilinear" / "bicubic", or a PIL resample int (0/2 -> bilinear, 3 -> bicubic).
68
+ antialias: match the ViT/CLIP/SigLIP default (`True` for those processors).
69
+ backend: "separable" (default) or "fused" (2D reference, no crop support).
70
+ crop_size: `None` (no crop), int (square), or `(crop_h, crop_w)`. Center crop after resize.
71
+ resize_mode: "square" (resize to `size`) or "shortest_edge" (aspect-preserving, needs a crop).
72
+ """
73
+ interp = _normalize_resample(resample)
74
+ image_list = as_image_list(images)
75
+
76
+ if crop_size is not None or resize_mode == "shortest_edge":
77
+ crop_h, crop_w = _normalize_size(crop_size if crop_size is not None else size)
78
+ resize_arg = int(size) if resize_mode == "shortest_edge" else _normalize_size(size)
79
+ return separable_resize_crop_normalize(
80
+ image_list, resize_arg, (crop_h, crop_w), image_mean, image_std, rescale_factor,
81
+ interp, antialias, resize_mode, block,
82
+ )
83
+
84
+ out_h, out_w = _normalize_size(size)
85
+ if backend == "fused":
86
+ return fused_resize_normalize(image_list, out_h, out_w, image_mean, image_std, rescale_factor, interp, antialias, block)
87
+ if backend == "separable":
88
+ return separable_resize_normalize(image_list, out_h, out_w, image_mean, image_std, rescale_factor, interp, antialias, block)
89
+ raise ValueError(f"backend must be 'fused' or 'separable', got {backend!r}")
90
+
91
+
92
+ def resize_normalize_ragged(
93
+ images,
94
+ size,
95
+ image_mean,
96
+ image_std,
97
+ rescale_factor: float = 1.0 / 255.0,
98
+ resample="bilinear",
99
+ antialias: bool = False,
100
+ backend: str = "separable",
101
+ block: int = 256,
102
+ ):
103
+ """Variant taking a list of different-H/W CHW tensors. Same kernels as `resize_normalize`."""
104
+ if isinstance(images, list):
105
+ image_list = images
106
+ else:
107
+ raise ValueError("resize_normalize_ragged expects a list of CHW tensors; use resize_normalize for a stacked tensor")
108
+ return resize_normalize(
109
+ image_list, size, image_mean, image_std, rescale_factor, resample, antialias, backend, block
110
+ )
111
+
112
+
113
+ __all__ = ["resize_normalize", "resize_normalize_ragged"]
build/torch-universal/kernel_image_resize/_fused.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Fused 2D resize+rescale+normalize over a ragged batch, single launch.
2
+
3
+ One program owns one image and a BLOCK of its output pixels, gathers a
4
+ MAX_TAPS_H × MAX_TAPS_W window, applies the separable weights as a 2D product, then folds
5
+ rescale+normalize. taps×taps loads per output pixel.
6
+
7
+ Resampling-weight formula (PyTorch aten UpSampleKernel):
8
+ scale = in / out
9
+ support = interp_half * (scale if antialias and scale > 1 else 1) # interp_half: 1 linear, 2 cubic
10
+ center = scale * (i + 0.5)
11
+ weight = filter((tap - center + 0.5) / eff), renormalized over the realized window
12
+ """
13
+
14
+ import triton
15
+ import triton.language as tl
16
+
17
+ from ._pack import fold_mean_std, max_taps, pack_images
18
+
19
+
20
+ @triton.jit
21
+ def _resample_weight(arg, cubic_a, CUBIC: tl.constexpr):
22
+ """Interpolation filter at `arg` (coordinate distance already divided by support)."""
23
+ ax = tl.abs(arg)
24
+ if CUBIC: # Keys cubic convolution kernel, support 2
25
+ ax2 = ax * ax
26
+ ax3 = ax2 * ax
27
+ inner = (cubic_a + 2.0) * ax3 - (cubic_a + 3.0) * ax2 + 1.0 # |x| <= 1
28
+ outer = cubic_a * ax3 - 5.0 * cubic_a * ax2 + 8.0 * cubic_a * ax - 4.0 * cubic_a # 1 < |x| < 2
29
+ return tl.where(ax <= 1.0, inner, tl.where(ax < 2.0, outer, 0.0))
30
+ return tl.maximum(1.0 - ax, 0.0) # triangle (bilinear), support 1
31
+
32
+
33
+ @triton.jit
34
+ def _resize_normalize_kernel(
35
+ in_ptr, out_ptr, offsets_ptr, heights_ptr, widths_ptr, mean_ptr, std_ptr,
36
+ out_h, out_w, cubic_a,
37
+ C: tl.constexpr, BLOCK: tl.constexpr,
38
+ CUBIC: tl.constexpr, ANTIALIAS: tl.constexpr,
39
+ MAX_TAPS_H: tl.constexpr, MAX_TAPS_W: tl.constexpr,
40
+ ):
41
+ n = tl.program_id(0)
42
+ blk = tl.program_id(1)
43
+ H = tl.load(heights_ptr + n)
44
+ W = tl.load(widths_ptr + n)
45
+ off = tl.load(offsets_ptr + n)
46
+ Hf = H.to(tl.float32)
47
+ Wf = W.to(tl.float32)
48
+
49
+ npix = out_h * out_w
50
+ pos = blk * BLOCK + tl.arange(0, BLOCK)
51
+ mask = pos < npix
52
+ oy = pos // out_w
53
+ ox = pos % out_w
54
+
55
+ interp_half = 2.0 if CUBIC else 1.0
56
+ scale_h = Hf / out_h
57
+ scale_w = Wf / out_w
58
+ eff_h = tl.maximum(scale_h, 1.0) if ANTIALIAS else 1.0
59
+ eff_w = tl.maximum(scale_w, 1.0) if ANTIALIAS else 1.0
60
+ support_h = interp_half * eff_h
61
+ support_w = interp_half * eff_w
62
+ inv_h = 1.0 / eff_h
63
+ inv_w = 1.0 / eff_w
64
+
65
+ center_y = scale_h * (oy.to(tl.float32) + 0.5)
66
+ center_x = scale_w * (ox.to(tl.float32) + 0.5)
67
+ ystart = tl.floor(center_y - support_h + 0.5)
68
+ xstart = tl.floor(center_x - support_w + 0.5)
69
+
70
+ sum_wy = tl.zeros([BLOCK], dtype=tl.float32)
71
+ for ty in tl.static_range(MAX_TAPS_H):
72
+ yy = ystart + ty
73
+ wy = _resample_weight((yy - center_y + 0.5) * inv_h, cubic_a, CUBIC)
74
+ if ANTIALIAS:
75
+ wy = tl.where((yy >= 0.0) & (yy < Hf), wy, 0.0)
76
+ sum_wy += wy
77
+ sum_wx = tl.zeros([BLOCK], dtype=tl.float32)
78
+ for tx in tl.static_range(MAX_TAPS_W):
79
+ xx = xstart + tx
80
+ wx = _resample_weight((xx - center_x + 0.5) * inv_w, cubic_a, CUBIC)
81
+ if ANTIALIAS:
82
+ wx = tl.where((xx >= 0.0) & (xx < Wf), wx, 0.0)
83
+ sum_wx += wx
84
+ denom = sum_wy * sum_wx
85
+
86
+ plane = (H * W).to(tl.int64)
87
+ Wl = W.to(tl.int64)
88
+ for c in tl.static_range(C):
89
+ base = off + c * plane
90
+ acc = tl.zeros([BLOCK], dtype=tl.float32)
91
+ for ty in tl.static_range(MAX_TAPS_H):
92
+ yy = ystart + ty
93
+ wy = _resample_weight((yy - center_y + 0.5) * inv_h, cubic_a, CUBIC)
94
+ if ANTIALIAS:
95
+ wy = tl.where((yy >= 0.0) & (yy < Hf), wy, 0.0)
96
+ yidx = tl.minimum(tl.maximum(yy.to(tl.int32), 0), H - 1).to(tl.int64)
97
+ row = base + yidx * Wl
98
+ for tx in tl.static_range(MAX_TAPS_W):
99
+ xx = xstart + tx
100
+ wx = _resample_weight((xx - center_x + 0.5) * inv_w, cubic_a, CUBIC)
101
+ if ANTIALIAS:
102
+ wx = tl.where((xx >= 0.0) & (xx < Wf), wx, 0.0)
103
+ xidx = tl.minimum(tl.maximum(xx.to(tl.int32), 0), W - 1).to(tl.int64)
104
+ pix = tl.load(in_ptr + row + xidx, mask=mask, other=0.0)
105
+ acc += wy * wx * pix
106
+ acc = acc / denom
107
+ m = tl.load(mean_ptr + c)
108
+ s = tl.load(std_ptr + c)
109
+ acc = (acc - m) / s
110
+ oidx = ((n * C + c) * out_h + oy) * out_w + ox
111
+ tl.store(out_ptr + oidx, acc, mask=mask)
112
+
113
+
114
+ def fused_resize_normalize(images, out_h, out_w, mean, std, rescale, interp, antialias, block: int = 256):
115
+ """Single fused launch over a ragged packed buffer -> (N, C, out_h, out_w) normalized float."""
116
+ import torch
117
+
118
+ images = list(images)
119
+ device = images[0].device
120
+ n = len(images)
121
+ cubic_a = -0.5 if antialias else -0.75 # PIL coeff under antialias, Keys coeff otherwise
122
+ max_taps_h = max_taps(images, out_h, 1, interp, antialias)
123
+ max_taps_w = max_taps(images, out_w, 2, interp, antialias)
124
+ mean_t, std_t = fold_mean_std(mean, std, rescale, device)
125
+
126
+ in_buf, offsets_t, heights_t, widths_t, c = pack_images(images)
127
+ out = torch.empty((n, c, out_h, out_w), device=device, dtype=torch.float32)
128
+ grid = (n, triton.cdiv(out_h * out_w, block))
129
+ _resize_normalize_kernel[grid](
130
+ in_buf, out, offsets_t, heights_t, widths_t, mean_t, std_t,
131
+ out_h, out_w, cubic_a, C=c, BLOCK=block,
132
+ CUBIC=(interp == "bicubic"), ANTIALIAS=antialias, MAX_TAPS_H=max_taps_h, MAX_TAPS_W=max_taps_w,
133
+ )
134
+ return out
build/torch-universal/kernel_image_resize/_pack.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Ragged packing + resampling helpers shared by the fused and separable backends."""
2
+
3
+ import math
4
+
5
+ import torch
6
+
7
+
8
+ PIL_RESAMPLE_TO_INTERP = {0: "bilinear", 2: "bilinear", 3: "bicubic"}
9
+
10
+
11
+ def pack_images(
12
+ images: list[torch.Tensor], dtype: torch.dtype = torch.float32
13
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, int]:
14
+ """Concatenate a ragged list of CHW images into one flat buffer of `dtype`.
15
+
16
+ Returns (in_buf, offsets, heights, widths, channels); offsets[n] is the element index
17
+ where image n starts.
18
+ """
19
+ device = images[0].device
20
+ channels = images[0].shape[0]
21
+ flats, offsets, heights, widths, cur = [], [], [], [], 0
22
+ for img in images:
23
+ ic, ih, iw = img.shape
24
+ if ic != channels:
25
+ raise ValueError(f"all images must share channel count {channels}, got {ic}")
26
+ flats.append(img.reshape(-1).to(dtype))
27
+ offsets.append(cur)
28
+ heights.append(ih)
29
+ widths.append(iw)
30
+ cur += ic * ih * iw
31
+ in_buf = torch.cat(flats)
32
+ offsets_t = torch.tensor(offsets, device=device, dtype=torch.int64)
33
+ heights_t = torch.tensor(heights, device=device, dtype=torch.int32)
34
+ widths_t = torch.tensor(widths, device=device, dtype=torch.int32)
35
+ return in_buf, offsets_t, heights_t, widths_t, channels
36
+
37
+
38
+ def fold_mean_std(mean, std, rescale: float, device) -> tuple[torch.Tensor, torch.Tensor]:
39
+ """Fold rescale into mean/std so the kernel does (x - m)/s == (x*rescale - mean)/std."""
40
+ mean_t = (torch.tensor(mean, device=device, dtype=torch.float32) / rescale).contiguous()
41
+ std_t = (torch.tensor(std, device=device, dtype=torch.float32) / rescale).contiguous()
42
+ return mean_t, std_t
43
+
44
+
45
+ def max_taps(images: list[torch.Tensor], out_size: int, axis_dim: int, interp: str, antialias: bool) -> int:
46
+ """Batch-wide worst-case tap count for one axis = ceil(support) * 2 + 1."""
47
+ interp_half = 2.0 if interp == "bicubic" else 1.0
48
+ worst = 0
49
+ for img in images:
50
+ scale = img.shape[axis_dim] / out_size
51
+ eff = max(scale, 1.0) if antialias else 1.0
52
+ worst = max(worst, math.ceil(interp_half * eff) * 2 + 1)
53
+ return worst
54
+
55
+
56
+ def as_image_list(images) -> list[torch.Tensor]:
57
+ """Accept a stacked (N, C, H, W) tensor or a list of CHW tensors; always return a list."""
58
+ if isinstance(images, torch.Tensor):
59
+ if images.dim() != 4:
60
+ raise ValueError(f"stacked input must be (N, C, H, W), got shape {tuple(images.shape)}")
61
+ return list(images)
62
+ return list(images)
build/torch-universal/kernel_image_resize/_separable.py ADDED
@@ -0,0 +1,280 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Separable resize + center-crop + normalize over a ragged uint8 batch.
2
+
3
+ WHAT "RESIZE" DOES, CONCRETELY
4
+ Every output pixel is a weighted average of a small window of input pixels. When you shrink
5
+ an image a lot (with antialiasing) that window gets wide — e.g. 13 input pixels across and
6
+ 13 down, so 13x13 = 169 input pixels feed one output pixel.
7
+
8
+ FUSED vs SEPARABLE (the two backends in this package)
9
+ - FUSED (see _fused.py): for each output pixel, read the whole 2D window directly -> 169 reads.
10
+ - SEPARABLE (this file): do the resize as two 1D steps instead of one 2D step:
11
+ step 1 (horizontal): resize only the WIDTH -> an intermediate image
12
+ step 2 (vertical): resize only the HEIGHT -> the final image
13
+ Each step's window is 1D, so 13 + 13 = 26 reads per output pixel instead of 169. Same math,
14
+ far fewer reads. This is what PIL and torchvision do.
15
+
16
+ CENTER CROP (folded in, no extra pass)
17
+ Processors like CLIP / DINOv2 resize to a "resize size" and then keep only the centered
18
+ crop. We do not materialize the full resized image and slice it; instead each output pixel
19
+ of the CROP maps to a resize-image coordinate by adding the crop offset, and that maps back
20
+ to the input. So:
21
+ resize is described by (resize_height, resize_width) -- per image
22
+ crop is described by (crop_top, crop_left) -- per image, the centered offset
23
+ output size is (crop_height, crop_width) -- the same for every image
24
+ When there is no crop, resize size == crop size and the offsets are 0 (the plain resize).
25
+ The resize SCALE uses the resize size; only the output coordinate is shifted by the crop.
26
+
27
+ uint8 input + float intermediate; each 1D step renormalizes its own weights (matches
28
+ torchvision). Output is parity-close to torchvision, not bit-identical (torchvision keeps a
29
+ fixed-point uint8 intermediate; ours is more accurate float).
30
+ """
31
+
32
+ import triton
33
+ import triton.language as tl
34
+
35
+ from ._fused import _resample_weight
36
+ from ._pack import fold_mean_std, pack_images
37
+
38
+
39
+ @triton.jit
40
+ def _horizontal_resize_kernel(
41
+ input_pixels, # flat uint8 buffer, all images packed back to back
42
+ intermediate, # flat float32 output: width resized + col-cropped, height untouched
43
+ input_offsets, # input_offsets[image] = where that image starts in input_pixels
44
+ intermediate_offsets, # same idea for the intermediate buffer
45
+ heights, widths, # per-image input height / width
46
+ resize_widths, # per-image width to resize to (before cropping)
47
+ crop_lefts, # per-image left offset of the centered crop
48
+ crop_w, # output (crop) width, same for every image
49
+ cubic_coeff,
50
+ CHANNELS: tl.constexpr, BLOCK: tl.constexpr,
51
+ CUBIC: tl.constexpr, ANTIALIAS: tl.constexpr,
52
+ MAX_TAPS_COL: tl.constexpr,
53
+ ):
54
+ """Resize width to resize_width, keep only the cropped columns: uint8 (C,H,W) -> float (C,H,crop_w)."""
55
+ image_index = tl.program_id(0)
56
+ block_index = tl.program_id(1)
57
+ in_height = tl.load(heights + image_index)
58
+ in_width = tl.load(widths + image_index)
59
+ resize_width = tl.load(resize_widths + image_index)
60
+ crop_left = tl.load(crop_lefts + image_index)
61
+ input_start = tl.load(input_offsets + image_index)
62
+ intermediate_start = tl.load(intermediate_offsets + image_index)
63
+ in_width_f = in_width.to(tl.float32)
64
+
65
+ num_pixels = in_height * crop_w # every input row x every cropped output column
66
+ flat_index = block_index * BLOCK + tl.arange(0, BLOCK)
67
+ active = flat_index < num_pixels
68
+ input_row = flat_index // crop_w
69
+ out_col = flat_index % crop_w
70
+ resize_col = out_col + crop_left # column in the (uncropped) resized image
71
+
72
+ filter_half = 2.0 if CUBIC else 1.0
73
+ col_scale = in_width_f / resize_width.to(tl.float32)
74
+ col_filter_scale = tl.maximum(col_scale, 1.0) if ANTIALIAS else 1.0
75
+ col_support = filter_half * col_filter_scale
76
+ col_inv_scale = 1.0 / col_filter_scale
77
+ src_center_col = col_scale * (resize_col.to(tl.float32) + 0.5)
78
+ first_tap_col = tl.floor(src_center_col - col_support + 0.5)
79
+
80
+ col_weight_sum = tl.zeros([BLOCK], dtype=tl.float32)
81
+ for tap in tl.static_range(MAX_TAPS_COL):
82
+ tap_col = first_tap_col + tap
83
+ weight = _resample_weight((tap_col - src_center_col + 0.5) * col_inv_scale, cubic_coeff, CUBIC)
84
+ if ANTIALIAS:
85
+ weight = tl.where((tap_col >= 0.0) & (tap_col < in_width_f), weight, 0.0)
86
+ col_weight_sum += weight
87
+
88
+ input_plane = (in_height * in_width).to(tl.int64)
89
+ intermediate_plane = (in_height * crop_w).to(tl.int64)
90
+ in_width_i64 = in_width.to(tl.int64)
91
+ crop_w_i64 = crop_w.to(tl.int64)
92
+ input_row_i64 = input_row.to(tl.int64)
93
+ for channel in tl.static_range(CHANNELS):
94
+ input_row_base = input_start + channel * input_plane + input_row_i64 * in_width_i64
95
+ accumulator = tl.zeros([BLOCK], dtype=tl.float32)
96
+ for tap in tl.static_range(MAX_TAPS_COL):
97
+ tap_col = first_tap_col + tap
98
+ weight = _resample_weight((tap_col - src_center_col + 0.5) * col_inv_scale, cubic_coeff, CUBIC)
99
+ if ANTIALIAS:
100
+ weight = tl.where((tap_col >= 0.0) & (tap_col < in_width_f), weight, 0.0)
101
+ clamped_tap_col = tl.minimum(tl.maximum(tap_col.to(tl.int32), 0), in_width - 1).to(tl.int64)
102
+ pixel = tl.load(input_pixels + input_row_base + clamped_tap_col, mask=active, other=0).to(tl.float32)
103
+ accumulator += weight * pixel
104
+ accumulator = accumulator / col_weight_sum
105
+ write_index = intermediate_start + channel * intermediate_plane + input_row_i64 * crop_w_i64 + out_col
106
+ tl.store(intermediate + write_index, accumulator, mask=active)
107
+
108
+
109
+ @triton.jit
110
+ def _vertical_resize_normalize_kernel(
111
+ intermediate, # float32 from the horizontal step: (C, H, crop_w) per image
112
+ output, # final (N, C, crop_h, crop_w) float32
113
+ intermediate_offsets,
114
+ heights, # per-image input height (the intermediate still has H rows)
115
+ resize_heights, # per-image height to resize to (before cropping)
116
+ crop_tops, # per-image top offset of the centered crop
117
+ means, stds, # per-channel normalization, rescale already folded in
118
+ crop_h, crop_w,
119
+ cubic_coeff,
120
+ CHANNELS: tl.constexpr, BLOCK: tl.constexpr,
121
+ CUBIC: tl.constexpr, ANTIALIAS: tl.constexpr,
122
+ MAX_TAPS_ROW: tl.constexpr,
123
+ ):
124
+ """Resize height to resize_height, keep cropped rows, normalize: float (C,H,crop_w) -> (C,crop_h,crop_w)."""
125
+ image_index = tl.program_id(0)
126
+ block_index = tl.program_id(1)
127
+ in_height = tl.load(heights + image_index)
128
+ resize_height = tl.load(resize_heights + image_index)
129
+ crop_top = tl.load(crop_tops + image_index)
130
+ intermediate_start = tl.load(intermediate_offsets + image_index)
131
+ in_height_f = in_height.to(tl.float32)
132
+
133
+ num_pixels = crop_h * crop_w
134
+ flat_index = block_index * BLOCK + tl.arange(0, BLOCK)
135
+ active = flat_index < num_pixels
136
+ out_row = flat_index // crop_w
137
+ out_col = flat_index % crop_w
138
+ resize_row = out_row + crop_top # row in the (uncropped) resized image
139
+
140
+ filter_half = 2.0 if CUBIC else 1.0
141
+ row_scale = in_height_f / resize_height.to(tl.float32)
142
+ row_filter_scale = tl.maximum(row_scale, 1.0) if ANTIALIAS else 1.0
143
+ row_support = filter_half * row_filter_scale
144
+ row_inv_scale = 1.0 / row_filter_scale
145
+ src_center_row = row_scale * (resize_row.to(tl.float32) + 0.5)
146
+ first_tap_row = tl.floor(src_center_row - row_support + 0.5)
147
+
148
+ row_weight_sum = tl.zeros([BLOCK], dtype=tl.float32)
149
+ for tap in tl.static_range(MAX_TAPS_ROW):
150
+ tap_row = first_tap_row + tap
151
+ weight = _resample_weight((tap_row - src_center_row + 0.5) * row_inv_scale, cubic_coeff, CUBIC)
152
+ if ANTIALIAS:
153
+ weight = tl.where((tap_row >= 0.0) & (tap_row < in_height_f), weight, 0.0)
154
+ row_weight_sum += weight
155
+
156
+ intermediate_plane = (in_height * crop_w).to(tl.int64)
157
+ crop_w_i64 = crop_w.to(tl.int64)
158
+ out_col_i64 = out_col.to(tl.int64)
159
+ for channel in tl.static_range(CHANNELS):
160
+ channel_base = intermediate_start + channel * intermediate_plane
161
+ accumulator = tl.zeros([BLOCK], dtype=tl.float32)
162
+ for tap in tl.static_range(MAX_TAPS_ROW):
163
+ tap_row = first_tap_row + tap
164
+ weight = _resample_weight((tap_row - src_center_row + 0.5) * row_inv_scale, cubic_coeff, CUBIC)
165
+ if ANTIALIAS:
166
+ weight = tl.where((tap_row >= 0.0) & (tap_row < in_height_f), weight, 0.0)
167
+ clamped_tap_row = tl.minimum(tl.maximum(tap_row.to(tl.int32), 0), in_height - 1).to(tl.int64)
168
+ pixel = tl.load(intermediate + channel_base + clamped_tap_row * crop_w_i64 + out_col_i64, mask=active, other=0.0)
169
+ accumulator += weight * pixel
170
+ accumulator = accumulator / row_weight_sum
171
+ mean = tl.load(means + channel)
172
+ std = tl.load(stds + channel)
173
+ accumulator = (accumulator - mean) / std
174
+ write_index = ((image_index * CHANNELS + channel) * crop_h + out_row) * crop_w + out_col
175
+ tl.store(output + write_index, accumulator, mask=active)
176
+
177
+
178
+ def _axis_max_taps(in_sizes, resize_sizes, interp, antialias):
179
+ """Widest 1D window over the batch for one axis = ceil(support) * 2 + 1, support uses in/resize."""
180
+ import math
181
+
182
+ interp_half = 2.0 if interp == "bicubic" else 1.0
183
+ worst = 0
184
+ for in_size, resize_size in zip(in_sizes, resize_sizes):
185
+ scale = in_size / resize_size
186
+ eff = max(scale, 1.0) if antialias else 1.0
187
+ worst = max(worst, math.ceil(interp_half * eff) * 2 + 1)
188
+ return worst
189
+
190
+
191
+ def _run_separable(images, resize_heights, resize_widths, crop_tops, crop_lefts, crop_h, crop_w,
192
+ mean, std, rescale, interp, antialias, block):
193
+ """Core driver: resize each image to its (resize_h, resize_w), keep the centered crop, normalize."""
194
+ import torch
195
+
196
+ device = images[0].device
197
+ num_images = len(images)
198
+ cubic_coeff = -0.5 if antialias else -0.75
199
+ in_heights = [int(img.shape[1]) for img in images]
200
+ in_widths = [int(img.shape[2]) for img in images]
201
+ max_taps_row = _axis_max_taps(in_heights, resize_heights, interp, antialias)
202
+ max_taps_col = _axis_max_taps(in_widths, resize_widths, interp, antialias)
203
+ means, stds = fold_mean_std(mean, std, rescale, device)
204
+
205
+ input_pixels, input_offsets, heights, widths, channels = pack_images(images, dtype=torch.uint8)
206
+
207
+ intermediate_offsets_list, cursor, tallest = [], 0, 0
208
+ for height in in_heights:
209
+ intermediate_offsets_list.append(cursor)
210
+ cursor += channels * height * crop_w
211
+ tallest = max(tallest, height)
212
+ intermediate = torch.empty(cursor, device=device, dtype=torch.float32)
213
+
214
+ intermediate_offsets = torch.tensor(intermediate_offsets_list, device=device, dtype=torch.int64)
215
+ resize_heights_t = torch.tensor(resize_heights, device=device, dtype=torch.int32)
216
+ resize_widths_t = torch.tensor(resize_widths, device=device, dtype=torch.int32)
217
+ crop_tops_t = torch.tensor(crop_tops, device=device, dtype=torch.int32)
218
+ crop_lefts_t = torch.tensor(crop_lefts, device=device, dtype=torch.int32)
219
+
220
+ horizontal_grid = (num_images, triton.cdiv(tallest * crop_w, block))
221
+ _horizontal_resize_kernel[horizontal_grid](
222
+ input_pixels, intermediate, input_offsets, intermediate_offsets, heights, widths,
223
+ resize_widths_t, crop_lefts_t, crop_w, cubic_coeff,
224
+ CHANNELS=channels, BLOCK=block, CUBIC=(interp == "bicubic"), ANTIALIAS=antialias,
225
+ MAX_TAPS_COL=max_taps_col,
226
+ )
227
+
228
+ output = torch.empty((num_images, channels, crop_h, crop_w), device=device, dtype=torch.float32)
229
+ vertical_grid = (num_images, triton.cdiv(crop_h * crop_w, block))
230
+ _vertical_resize_normalize_kernel[vertical_grid](
231
+ intermediate, output, intermediate_offsets, heights, resize_heights_t, crop_tops_t, means, stds,
232
+ crop_h, crop_w, cubic_coeff,
233
+ CHANNELS=channels, BLOCK=block, CUBIC=(interp == "bicubic"), ANTIALIAS=antialias,
234
+ MAX_TAPS_ROW=max_taps_row,
235
+ )
236
+ return output
237
+
238
+
239
+ def _aspect_preserving_size(in_h, in_w, shortest_edge):
240
+ """transformers shortest-edge rule: short side -> shortest_edge, long side truncated (int(), not round)."""
241
+ if in_h <= in_w:
242
+ return shortest_edge, int(in_w * shortest_edge / in_h)
243
+ return int(in_h * shortest_edge / in_w), shortest_edge
244
+
245
+
246
+ def separable_resize_normalize(images, out_h, out_w, mean, std, rescale, interp, antialias, block: int = 256):
247
+ """Resize to (out_h, out_w) and normalize (no crop)."""
248
+ images = list(images)
249
+ n = len(images)
250
+ return _run_separable(images, [out_h] * n, [out_w] * n, [0] * n, [0] * n, out_h, out_w,
251
+ mean, std, rescale, interp, antialias, block)
252
+
253
+
254
+ def separable_resize_crop_normalize(images, resize_size, crop_size, mean, std, rescale, interp, antialias,
255
+ resize_mode="square", block: int = 256):
256
+ """Resize then center-crop then normalize.
257
+
258
+ resize_mode="square": resize_size is (resize_h, resize_w) applied to every image.
259
+ resize_mode="shortest_edge": resize_size is an int; each image is resized aspect-preserving
260
+ so its short side equals it, then center-cropped to crop_size.
261
+ """
262
+ images = list(images)
263
+ crop_h, crop_w = crop_size
264
+ resize_heights, resize_widths = [], []
265
+ for img in images:
266
+ in_h, in_w = int(img.shape[1]), int(img.shape[2])
267
+ if resize_mode == "shortest_edge":
268
+ rh, rw = _aspect_preserving_size(in_h, in_w, int(resize_size))
269
+ elif resize_mode == "square":
270
+ rh, rw = int(resize_size[0]), int(resize_size[1])
271
+ else:
272
+ raise ValueError(f"resize_mode must be 'square' or 'shortest_edge', got {resize_mode!r}")
273
+ if rh < crop_h or rw < crop_w:
274
+ raise ValueError(f"resize size ({rh},{rw}) smaller than crop ({crop_h},{crop_w})")
275
+ resize_heights.append(rh)
276
+ resize_widths.append(rw)
277
+ crop_tops = [(rh - crop_h) // 2 for rh in resize_heights]
278
+ crop_lefts = [(rw - crop_w) // 2 for rw in resize_widths]
279
+ return _run_separable(images, resize_heights, resize_widths, crop_tops, crop_lefts, crop_h, crop_w,
280
+ mean, std, rescale, interp, antialias, block)
example.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # /// script
2
+ # requires-python = ">=3.10"
3
+ # dependencies = [
4
+ # "torch",
5
+ # "triton",
6
+ # "kernels",
7
+ # ]
8
+ # ///
9
+ """Minimal smoke test of the published kernel via get_kernel (run on a CUDA box)."""
10
+
11
+ import torch
12
+ from kernels import get_kernel
13
+
14
+
15
+ kir = get_kernel("Molbap/kernel_image_resize", revision="main", trust_remote_code=True)
16
+
17
+ device = "cuda" if torch.cuda.is_available() else "cpu"
18
+ images = [
19
+ torch.randint(0, 256, (3, h, w), dtype=torch.uint8, device=device)
20
+ for h, w in [(640, 480), (800, 600), (384, 1024)]
21
+ ]
22
+
23
+ pixel_values = kir.resize_normalize(
24
+ images,
25
+ size=384,
26
+ image_mean=[0.5, 0.5, 0.5],
27
+ image_std=[0.5, 0.5, 0.5],
28
+ rescale_factor=1 / 255,
29
+ resample="bicubic",
30
+ antialias=True,
31
+ )
32
+ print(f"{len(images)} ragged images -> {tuple(pixel_values.shape)} {pixel_values.dtype}")
example_transformers.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # /// script
2
+ # requires-python = ">=3.10"
3
+ # dependencies = ["torch", "triton", "kernels", "transformers", "torchvision"]
4
+ # ///
5
+ """Drop-in: use the kernel as the resize+normalize stage of a transformers fast processor.
6
+
7
+ There is no `use_kernels=True` hook for image processors (that machinery swaps nn.Module
8
+ layer forwards inside the model, not processor code). So the usable path is to read the
9
+ processor's config and call the kernel directly. `preprocess_with_kernel` below is the whole
10
+ adapter — copy it into your code.
11
+
12
+ Run on a CUDA box:
13
+ python example_transformers.py
14
+ """
15
+
16
+ import torch
17
+ from kernels import get_kernel
18
+ from transformers import AutoImageProcessor, AutoModel
19
+
20
+
21
+ kernel_image_resize = get_kernel("Molbap/kernel_image_resize", revision="main", trust_remote_code=True)
22
+
23
+ _PIL_RESAMPLE = {0: "bilinear", 2: "bilinear", 3: "bicubic"}
24
+
25
+
26
+ def preprocess_with_kernel(processor, images):
27
+ """Run the kernel using `processor`'s own config; returns pixel_values like processor(images).
28
+
29
+ Handles fixed-size resize, square-resize + center-crop, and shortest-edge resize + center-crop
30
+ (CLIP / DINOv2). Does not handle padding processors.
31
+ """
32
+ size = processor.size
33
+ if getattr(processor, "do_pad", False):
34
+ raise ValueError("kernel does not pad; this processor needs a pad step")
35
+ if not getattr(processor, "do_normalize", True):
36
+ raise ValueError("processor does not normalize (rescale only); kernel always normalizes")
37
+ if getattr(processor, "do_flip_channel_order", False):
38
+ raise ValueError("processor flips channels to BGR; kernel keeps RGB")
39
+ resample = _PIL_RESAMPLE[int(processor.resample)]
40
+ antialias = bool(getattr(processor, "antialias", True))
41
+ rescale = float(processor.rescale_factor) if getattr(processor, "do_rescale", True) else 1.0
42
+ mean, std = processor.image_mean, processor.image_std
43
+ crop = processor.crop_size if getattr(processor, "do_center_crop", False) else None
44
+ common = dict(rescale_factor=rescale, resample=resample, antialias=antialias)
45
+
46
+ if "shortest_edge" in size:
47
+ if crop is None:
48
+ raise ValueError("shortest-edge resize without a crop gives variable-size output")
49
+ return kernel_image_resize.resize_normalize(
50
+ images, size["shortest_edge"], mean, std,
51
+ crop_size=(crop["height"], crop["width"]), resize_mode="shortest_edge", **common,
52
+ )
53
+ if crop is not None and (crop["height"] != size["height"] or crop["width"] != size["width"]):
54
+ return kernel_image_resize.resize_normalize(
55
+ images, (size["height"], size["width"]), mean, std,
56
+ crop_size=(crop["height"], crop["width"]), resize_mode="square", **common,
57
+ )
58
+ return kernel_image_resize.resize_normalize(images, (size["height"], size["width"]), mean, std, **common)
59
+
60
+
61
+ def main():
62
+ device = "cuda" if torch.cuda.is_available() else "cpu"
63
+ model_id = "google/siglip2-base-patch16-224"
64
+ processor = AutoImageProcessor.from_pretrained(model_id, backend="torchvision")
65
+ model = AutoModel.from_pretrained(model_id).to(device).eval()
66
+
67
+ images = [
68
+ torch.randint(0, 256, (3, h, w), dtype=torch.uint8, device=device)
69
+ for h, w in [(640, 480), (800, 600), (384, 1024)]
70
+ ]
71
+
72
+ pixel_values = preprocess_with_kernel(processor, images)
73
+ print(f"{len(images)} ragged images -> pixel_values {tuple(pixel_values.shape)} {pixel_values.dtype}")
74
+
75
+ with torch.no_grad():
76
+ features = model.vision_model(pixel_values=pixel_values.to(model.dtype)).pooler_output
77
+ print(f"vision features: {tuple(features.shape)}")
78
+
79
+ # parity vs the real processor (float-vs-uint8 resize -> small, expected gap)
80
+ reference = processor(images, return_tensors="pt", device=device)["pixel_values"].to(device)
81
+ print(f"max|Δ| pixel_values vs processor: {(pixel_values - reference).abs().max().item():.2e}")
82
+
83
+
84
+ if __name__ == "__main__":
85
+ main()
publish.sh ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ # Publish the universal kernel to the Hub as a `kernel` repo type.
3
+ #
4
+ # IMPORTANT: kernel repos are repo_type="kernel" (served under /api/kernels/), NOT model repos.
5
+ # get_kernel() queries repo_type="kernel", so uploading with --repo-type model gives a 404.
6
+ #
7
+ # A universal kernel needs no compilation: the build is just a copy of the source package into the
8
+ # variant directory get_kernel resolves (build/torch-universal/<name>/).
9
+ set -euo pipefail
10
+
11
+ REPO_ID="${1:-Molbap/kernel_image_resize}"
12
+ NAME="kernel_image_resize"
13
+ HERE="$(cd "$(dirname "$0")" && pwd)"
14
+
15
+ rm -rf "$HERE/build"
16
+ mkdir -p "$HERE/build/torch-universal"
17
+ cp -r "$HERE/torch-ext/$NAME" "$HERE/build/torch-universal/$NAME"
18
+ find "$HERE/build" -name __pycache__ -type d -exec rm -rf {} + 2>/dev/null || true
19
+
20
+ echo "built build/torch-universal/$NAME"
21
+
22
+ # Create the kernel repo and upload. Uses the Python API because it reliably accepts
23
+ # repo_type="kernel" (the hf CLI's repo-type choices can be stricter).
24
+ python - "$REPO_ID" "$HERE" <<'PY'
25
+ import sys
26
+
27
+ # Some huggingface_hub versions know the `kernel` repo type (constant + create_repo) but still
28
+ # validate upload_folder against the old REPO_TYPES list. Register it so the upload validator passes.
29
+ import huggingface_hub.constants as hfc
30
+ if "kernel" not in hfc.REPO_TYPES:
31
+ hfc.REPO_TYPES = list(hfc.REPO_TYPES) + ["kernel"]
32
+
33
+ from huggingface_hub import create_repo, upload_folder
34
+
35
+ repo_id, folder = sys.argv[1], sys.argv[2]
36
+ create_repo(repo_id, repo_type="kernel", exist_ok=True)
37
+ upload_folder(
38
+ repo_id=repo_id,
39
+ repo_type="kernel",
40
+ folder_path=folder,
41
+ ignore_patterns=["__pycache__/*", "*.pyc", "result", ".git/*", "build/torch-universal/*/__pycache__/*"],
42
+ )
43
+ print(f"uploaded {folder} -> {repo_id} (repo_type=kernel)")
44
+ PY
resultcompat ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model verdict pixel max|Δ| feature max|Δ| (rel)
2
+ openai/clip-vit-base-patch32 OK 7.53e-03 1.53e-02 (0.2%)
3
+ google/vit-base-patch16-224 OK 3.93e-03 6.87e-03 (0.8%)
4
+ apple/mobilevit-small SKIP: no normalize (rescale only)
5
+ facebook/dinov2-small OK 1.41e-01 1.91e-02 (0.2%)
6
+ google/siglip-so400m-patch14-384 OK 1.58e-01 9.28e-03 (0.1%)
7
+ facebook/dinov3-vitb16-pretrain-lvd1689m OK 4.99e-05 8.08e-05 (0.0%)
8
+ microsoft/swinv2-tiny-patch4-window16-256 OK 8.75e-03 1.27e-02 (0.6%)
9
+ google/siglip2-base-patch16-224 OK 3.93e-03 1.31e-02 (0.2%)
10
+ microsoft/resnet-50 SKIP: shortest_edge without crop (variable output)
11
+ nvidia/segformer-b0-finetuned-ade-512-512 OK 8.75e-03 4.33e-02 (0.4%)
12
+ facebook/convnextv2-tiny-22k-384 SKIP: shortest_edge without crop (variable output)
13
+ google/mobilenet_v2_1.0_224 OK 3.92e-03 3.90e-02 (0.7%)
14
+ facebook/convnext-tiny-224 SKIP: shortest_edge without crop (variable output)
15
+ google/efficientnet-b0 SKIP: resample 0
16
+ microsoft/beit-base-patch16-224-pt22k-ft22k OK 3.93e-03 3.19e-02 (0.9%)
tests/test_resize_normalize.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Parity tests vs torchvision for both backends, all interp×antialias combos, ragged inputs.
2
+
3
+ Run locally from the repo root with the package on the path:
4
+ PYTHONPATH=torch-ext pytest tests/ -q
5
+ CUDA is required (Triton); tests skip on CPU.
6
+ """
7
+
8
+ import pytest
9
+ import torch
10
+ import torchvision.transforms.v2.functional as tvF
11
+ from torchvision.transforms import InterpolationMode
12
+
13
+ from kernel_image_resize import resize_normalize
14
+
15
+
16
+ _TV_INTERP = {"bilinear": InterpolationMode.BILINEAR, "bicubic": InterpolationMode.BICUBIC}
17
+
18
+ MEAN = [0.48145466, 0.4578275, 0.40821073]
19
+ STD = [0.26862954, 0.26130258, 0.27577711]
20
+ RESCALE = 1.0 / 255.0
21
+
22
+
23
+ def _ragged_images(n, device, min_res=384, max_res=1024, seed=0):
24
+ g = torch.Generator(device="cpu").manual_seed(seed)
25
+ images = []
26
+ for _ in range(n):
27
+ h = int(torch.randint(min_res, max_res + 1, (1,), generator=g).item())
28
+ w = int(torch.randint(min_res, max_res + 1, (1,), generator=g).item())
29
+ images.append(torch.randint(0, 256, (3, h, w), generator=g, dtype=torch.uint8).to(device))
30
+ return images
31
+
32
+
33
+ def _torchvision_reference(images, out_h, out_w, interp, antialias):
34
+ mode = _TV_INTERP[interp]
35
+ mean = torch.tensor(MEAN, device=images[0].device).view(3, 1, 1)
36
+ std = torch.tensor(STD, device=images[0].device).view(3, 1, 1)
37
+ outs = []
38
+ for img in images:
39
+ r = tvF.resize(img.float(), [out_h, out_w], interpolation=mode, antialias=antialias)
40
+ outs.append((r * RESCALE - mean) / std)
41
+ return torch.stack(outs)
42
+
43
+
44
+ @pytest.mark.kernels_ci
45
+ @pytest.mark.skipif(not torch.cuda.is_available(), reason="Triton kernel needs CUDA")
46
+ @pytest.mark.parametrize("backend", ["fused", "separable"])
47
+ @pytest.mark.parametrize("interp,antialias", [("bilinear", False), ("bilinear", True), ("bicubic", False), ("bicubic", True)])
48
+ def test_parity_vs_torchvision(backend, interp, antialias):
49
+ device = torch.device("cuda")
50
+ images = _ragged_images(8, device)
51
+ out_h = out_w = 384
52
+ got = resize_normalize(
53
+ images, (out_h, out_w), MEAN, STD, RESCALE, resample=interp, antialias=antialias, backend=backend
54
+ )
55
+ ref = _torchvision_reference(images, out_h, out_w, interp, antialias)
56
+ max_abs = (got - ref).abs().max().item()
57
+ assert max_abs < 3e-3, f"{backend}/{interp}/aa={antialias}: max|Δ|={max_abs:.2e}"
58
+
59
+
60
+ @pytest.mark.kernels_ci
61
+ @pytest.mark.skipif(not torch.cuda.is_available(), reason="Triton kernel needs CUDA")
62
+ def test_stacked_tensor_input():
63
+ device = torch.device("cuda")
64
+ images = torch.randint(0, 256, (4, 3, 512, 512), dtype=torch.uint8, device=device)
65
+ got = resize_normalize(images, 224, MEAN, STD, RESCALE, resample="bicubic", antialias=True)
66
+ assert got.shape == (4, 3, 224, 224)
67
+
68
+
69
+ def _shortest_edge_crop_reference(images, shortest_edge, crop, interp, antialias):
70
+ mode = _TV_INTERP[interp]
71
+ mean = torch.tensor(MEAN, device=images[0].device).view(3, 1, 1)
72
+ std = torch.tensor(STD, device=images[0].device).view(3, 1, 1)
73
+ outs = []
74
+ for img in images:
75
+ in_h, in_w = img.shape[1], img.shape[2]
76
+ if in_h <= in_w:
77
+ rh, rw = shortest_edge, int(in_w * shortest_edge / in_h)
78
+ else:
79
+ rh, rw = int(in_h * shortest_edge / in_w), shortest_edge
80
+ r = tvF.resize(img.float(), [rh, rw], interpolation=mode, antialias=antialias)
81
+ r = tvF.center_crop(r, [crop, crop])
82
+ outs.append((r * RESCALE - mean) / std)
83
+ return torch.stack(outs)
84
+
85
+
86
+ @pytest.mark.kernels_ci
87
+ @pytest.mark.skipif(not torch.cuda.is_available(), reason="Triton kernel needs CUDA")
88
+ @pytest.mark.parametrize("interp,antialias", [("bilinear", True), ("bicubic", True)])
89
+ def test_shortest_edge_crop_parity(interp, antialias):
90
+ device = torch.device("cuda")
91
+ images = _ragged_images(8, device)
92
+ shortest_edge, crop = 256, 224
93
+ got = resize_normalize(
94
+ images, shortest_edge, MEAN, STD, RESCALE, resample=interp, antialias=antialias,
95
+ crop_size=(crop, crop), resize_mode="shortest_edge",
96
+ )
97
+ ref = _shortest_edge_crop_reference(images, shortest_edge, crop, interp, antialias)
98
+ assert got.shape == (8, 3, crop, crop)
99
+ max_abs = (got - ref).abs().max().item()
100
+ assert max_abs < 3e-3, f"shortest_edge+crop {interp}/aa={antialias}: max|Δ|={max_abs:.2e}"
101
+
102
+
103
+ @pytest.mark.kernels_ci
104
+ @pytest.mark.skipif(not torch.cuda.is_available(), reason="Triton kernel needs CUDA")
105
+ def test_fused_matches_separable():
106
+ device = torch.device("cuda")
107
+ images = _ragged_images(6, device)
108
+ common = dict(size=(256, 256), image_mean=MEAN, image_std=STD, rescale_factor=RESCALE, resample="bicubic", antialias=True)
109
+ fused = resize_normalize(images, backend="fused", **common)
110
+ separable = resize_normalize(images, backend="separable", **common)
111
+ max_abs = (fused - separable).abs().max().item()
112
+ assert max_abs < 3e-3, f"fused vs separable: max|Δ|={max_abs:.2e}"
torch-ext/kernel_image_resize/__init__.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Resize + rescale + normalize for transformers fast image processors, as a Triton kernel.
2
+
3
+ resize -> fold(rescale, normalize) in one GPU pipeline: CHW uint8 images in,
4
+ (N, C, out_h, out_w) normalized float out, no full-resolution float intermediate.
5
+
6
+ - resize_normalize — stacked (N, C, H, W) tensor or a list of CHW images.
7
+ - resize_normalize_ragged — same kernels; takes a list of different-H/W CHW tensors.
8
+
9
+ backend="separable" (default): two-pass uint8, taps+taps. backend="fused": single 2D
10
+ launch, taps*taps. Both parity <=1e-4 vs torchvision-float.
11
+
12
+ from kernels import get_kernel
13
+ kir = get_kernel("Molbap/kernel_image_resize")
14
+ pixel_values = kir.resize_normalize(
15
+ images, size=384, image_mean=[...], image_std=[...], resample="bicubic", antialias=True,
16
+ )
17
+ """
18
+
19
+ from ._fused import fused_resize_normalize
20
+ from ._pack import PIL_RESAMPLE_TO_INTERP, as_image_list
21
+ from ._separable import separable_resize_crop_normalize, separable_resize_normalize
22
+
23
+
24
+ def _normalize_size(size) -> tuple[int, int]:
25
+ if isinstance(size, int):
26
+ return size, size
27
+ if isinstance(size, dict):
28
+ if "height" in size and "width" in size:
29
+ return int(size["height"]), int(size["width"])
30
+ raise ValueError(f"size dict must hold 'height'/'width' for a fixed resize, got {size}")
31
+ out_h, out_w = size
32
+ return int(out_h), int(out_w)
33
+
34
+
35
+ def _normalize_resample(resample) -> str:
36
+ if isinstance(resample, str):
37
+ if resample not in ("bilinear", "bicubic"):
38
+ raise ValueError(f"resample must be 'bilinear' or 'bicubic', got {resample!r}")
39
+ return resample
40
+ interp = PIL_RESAMPLE_TO_INTERP.get(int(resample))
41
+ if interp is None:
42
+ raise ValueError(f"unsupported PIL resample code {resample}")
43
+ return interp
44
+
45
+
46
+ def resize_normalize(
47
+ images,
48
+ size,
49
+ image_mean,
50
+ image_std,
51
+ rescale_factor: float = 1.0 / 255.0,
52
+ resample="bilinear",
53
+ antialias: bool = False,
54
+ backend: str = "separable",
55
+ block: int = 256,
56
+ crop_size=None,
57
+ resize_mode: str = "square",
58
+ ):
59
+ """Resize, optionally center-crop, rescale and normalize — one GPU pipeline.
60
+
61
+ Args:
62
+ images: a stacked `(N, C, H, W)` uint8/float tensor, or a list of CHW tensors (ragged).
63
+ size: resize target. With no crop: int (square), `(height, width)`, or `{"height","width"}`.
64
+ With `resize_mode="shortest_edge"`: an int, the short side after aspect-preserving resize.
65
+ image_mean, image_std: per-channel normalization stats (length C).
66
+ rescale_factor: folded into mean/std so the kernel does `(x*rescale - mean)/std`.
67
+ resample: "bilinear" / "bicubic", or a PIL resample int (0/2 -> bilinear, 3 -> bicubic).
68
+ antialias: match the ViT/CLIP/SigLIP default (`True` for those processors).
69
+ backend: "separable" (default) or "fused" (2D reference, no crop support).
70
+ crop_size: `None` (no crop), int (square), or `(crop_h, crop_w)`. Center crop after resize.
71
+ resize_mode: "square" (resize to `size`) or "shortest_edge" (aspect-preserving, needs a crop).
72
+ """
73
+ interp = _normalize_resample(resample)
74
+ image_list = as_image_list(images)
75
+
76
+ if crop_size is not None or resize_mode == "shortest_edge":
77
+ crop_h, crop_w = _normalize_size(crop_size if crop_size is not None else size)
78
+ resize_arg = int(size) if resize_mode == "shortest_edge" else _normalize_size(size)
79
+ return separable_resize_crop_normalize(
80
+ image_list, resize_arg, (crop_h, crop_w), image_mean, image_std, rescale_factor,
81
+ interp, antialias, resize_mode, block,
82
+ )
83
+
84
+ out_h, out_w = _normalize_size(size)
85
+ if backend == "fused":
86
+ return fused_resize_normalize(image_list, out_h, out_w, image_mean, image_std, rescale_factor, interp, antialias, block)
87
+ if backend == "separable":
88
+ return separable_resize_normalize(image_list, out_h, out_w, image_mean, image_std, rescale_factor, interp, antialias, block)
89
+ raise ValueError(f"backend must be 'fused' or 'separable', got {backend!r}")
90
+
91
+
92
+ def resize_normalize_ragged(
93
+ images,
94
+ size,
95
+ image_mean,
96
+ image_std,
97
+ rescale_factor: float = 1.0 / 255.0,
98
+ resample="bilinear",
99
+ antialias: bool = False,
100
+ backend: str = "separable",
101
+ block: int = 256,
102
+ ):
103
+ """Variant taking a list of different-H/W CHW tensors. Same kernels as `resize_normalize`."""
104
+ if isinstance(images, list):
105
+ image_list = images
106
+ else:
107
+ raise ValueError("resize_normalize_ragged expects a list of CHW tensors; use resize_normalize for a stacked tensor")
108
+ return resize_normalize(
109
+ image_list, size, image_mean, image_std, rescale_factor, resample, antialias, backend, block
110
+ )
111
+
112
+
113
+ __all__ = ["resize_normalize", "resize_normalize_ragged"]
torch-ext/kernel_image_resize/_fused.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Fused 2D resize+rescale+normalize over a ragged batch, single launch.
2
+
3
+ One program owns one image and a BLOCK of its output pixels, gathers a
4
+ MAX_TAPS_H × MAX_TAPS_W window, applies the separable weights as a 2D product, then folds
5
+ rescale+normalize. taps×taps loads per output pixel.
6
+
7
+ Resampling-weight formula (PyTorch aten UpSampleKernel):
8
+ scale = in / out
9
+ support = interp_half * (scale if antialias and scale > 1 else 1) # interp_half: 1 linear, 2 cubic
10
+ center = scale * (i + 0.5)
11
+ weight = filter((tap - center + 0.5) / eff), renormalized over the realized window
12
+ """
13
+
14
+ import triton
15
+ import triton.language as tl
16
+
17
+ from ._pack import fold_mean_std, max_taps, pack_images
18
+
19
+
20
+ @triton.jit
21
+ def _resample_weight(arg, cubic_a, CUBIC: tl.constexpr):
22
+ """Interpolation filter at `arg` (coordinate distance already divided by support)."""
23
+ ax = tl.abs(arg)
24
+ if CUBIC: # Keys cubic convolution kernel, support 2
25
+ ax2 = ax * ax
26
+ ax3 = ax2 * ax
27
+ inner = (cubic_a + 2.0) * ax3 - (cubic_a + 3.0) * ax2 + 1.0 # |x| <= 1
28
+ outer = cubic_a * ax3 - 5.0 * cubic_a * ax2 + 8.0 * cubic_a * ax - 4.0 * cubic_a # 1 < |x| < 2
29
+ return tl.where(ax <= 1.0, inner, tl.where(ax < 2.0, outer, 0.0))
30
+ return tl.maximum(1.0 - ax, 0.0) # triangle (bilinear), support 1
31
+
32
+
33
+ @triton.jit
34
+ def _resize_normalize_kernel(
35
+ in_ptr, out_ptr, offsets_ptr, heights_ptr, widths_ptr, mean_ptr, std_ptr,
36
+ out_h, out_w, cubic_a,
37
+ C: tl.constexpr, BLOCK: tl.constexpr,
38
+ CUBIC: tl.constexpr, ANTIALIAS: tl.constexpr,
39
+ MAX_TAPS_H: tl.constexpr, MAX_TAPS_W: tl.constexpr,
40
+ ):
41
+ n = tl.program_id(0)
42
+ blk = tl.program_id(1)
43
+ H = tl.load(heights_ptr + n)
44
+ W = tl.load(widths_ptr + n)
45
+ off = tl.load(offsets_ptr + n)
46
+ Hf = H.to(tl.float32)
47
+ Wf = W.to(tl.float32)
48
+
49
+ npix = out_h * out_w
50
+ pos = blk * BLOCK + tl.arange(0, BLOCK)
51
+ mask = pos < npix
52
+ oy = pos // out_w
53
+ ox = pos % out_w
54
+
55
+ interp_half = 2.0 if CUBIC else 1.0
56
+ scale_h = Hf / out_h
57
+ scale_w = Wf / out_w
58
+ eff_h = tl.maximum(scale_h, 1.0) if ANTIALIAS else 1.0
59
+ eff_w = tl.maximum(scale_w, 1.0) if ANTIALIAS else 1.0
60
+ support_h = interp_half * eff_h
61
+ support_w = interp_half * eff_w
62
+ inv_h = 1.0 / eff_h
63
+ inv_w = 1.0 / eff_w
64
+
65
+ center_y = scale_h * (oy.to(tl.float32) + 0.5)
66
+ center_x = scale_w * (ox.to(tl.float32) + 0.5)
67
+ ystart = tl.floor(center_y - support_h + 0.5)
68
+ xstart = tl.floor(center_x - support_w + 0.5)
69
+
70
+ sum_wy = tl.zeros([BLOCK], dtype=tl.float32)
71
+ for ty in tl.static_range(MAX_TAPS_H):
72
+ yy = ystart + ty
73
+ wy = _resample_weight((yy - center_y + 0.5) * inv_h, cubic_a, CUBIC)
74
+ if ANTIALIAS:
75
+ wy = tl.where((yy >= 0.0) & (yy < Hf), wy, 0.0)
76
+ sum_wy += wy
77
+ sum_wx = tl.zeros([BLOCK], dtype=tl.float32)
78
+ for tx in tl.static_range(MAX_TAPS_W):
79
+ xx = xstart + tx
80
+ wx = _resample_weight((xx - center_x + 0.5) * inv_w, cubic_a, CUBIC)
81
+ if ANTIALIAS:
82
+ wx = tl.where((xx >= 0.0) & (xx < Wf), wx, 0.0)
83
+ sum_wx += wx
84
+ denom = sum_wy * sum_wx
85
+
86
+ plane = (H * W).to(tl.int64)
87
+ Wl = W.to(tl.int64)
88
+ for c in tl.static_range(C):
89
+ base = off + c * plane
90
+ acc = tl.zeros([BLOCK], dtype=tl.float32)
91
+ for ty in tl.static_range(MAX_TAPS_H):
92
+ yy = ystart + ty
93
+ wy = _resample_weight((yy - center_y + 0.5) * inv_h, cubic_a, CUBIC)
94
+ if ANTIALIAS:
95
+ wy = tl.where((yy >= 0.0) & (yy < Hf), wy, 0.0)
96
+ yidx = tl.minimum(tl.maximum(yy.to(tl.int32), 0), H - 1).to(tl.int64)
97
+ row = base + yidx * Wl
98
+ for tx in tl.static_range(MAX_TAPS_W):
99
+ xx = xstart + tx
100
+ wx = _resample_weight((xx - center_x + 0.5) * inv_w, cubic_a, CUBIC)
101
+ if ANTIALIAS:
102
+ wx = tl.where((xx >= 0.0) & (xx < Wf), wx, 0.0)
103
+ xidx = tl.minimum(tl.maximum(xx.to(tl.int32), 0), W - 1).to(tl.int64)
104
+ pix = tl.load(in_ptr + row + xidx, mask=mask, other=0.0)
105
+ acc += wy * wx * pix
106
+ acc = acc / denom
107
+ m = tl.load(mean_ptr + c)
108
+ s = tl.load(std_ptr + c)
109
+ acc = (acc - m) / s
110
+ oidx = ((n * C + c) * out_h + oy) * out_w + ox
111
+ tl.store(out_ptr + oidx, acc, mask=mask)
112
+
113
+
114
+ def fused_resize_normalize(images, out_h, out_w, mean, std, rescale, interp, antialias, block: int = 256):
115
+ """Single fused launch over a ragged packed buffer -> (N, C, out_h, out_w) normalized float."""
116
+ import torch
117
+
118
+ images = list(images)
119
+ device = images[0].device
120
+ n = len(images)
121
+ cubic_a = -0.5 if antialias else -0.75 # PIL coeff under antialias, Keys coeff otherwise
122
+ max_taps_h = max_taps(images, out_h, 1, interp, antialias)
123
+ max_taps_w = max_taps(images, out_w, 2, interp, antialias)
124
+ mean_t, std_t = fold_mean_std(mean, std, rescale, device)
125
+
126
+ in_buf, offsets_t, heights_t, widths_t, c = pack_images(images)
127
+ out = torch.empty((n, c, out_h, out_w), device=device, dtype=torch.float32)
128
+ grid = (n, triton.cdiv(out_h * out_w, block))
129
+ _resize_normalize_kernel[grid](
130
+ in_buf, out, offsets_t, heights_t, widths_t, mean_t, std_t,
131
+ out_h, out_w, cubic_a, C=c, BLOCK=block,
132
+ CUBIC=(interp == "bicubic"), ANTIALIAS=antialias, MAX_TAPS_H=max_taps_h, MAX_TAPS_W=max_taps_w,
133
+ )
134
+ return out
torch-ext/kernel_image_resize/_pack.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Ragged packing + resampling helpers shared by the fused and separable backends."""
2
+
3
+ import math
4
+
5
+ import torch
6
+
7
+
8
+ PIL_RESAMPLE_TO_INTERP = {0: "bilinear", 2: "bilinear", 3: "bicubic"}
9
+
10
+
11
+ def pack_images(
12
+ images: list[torch.Tensor], dtype: torch.dtype = torch.float32
13
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, int]:
14
+ """Concatenate a ragged list of CHW images into one flat buffer of `dtype`.
15
+
16
+ Returns (in_buf, offsets, heights, widths, channels); offsets[n] is the element index
17
+ where image n starts.
18
+ """
19
+ device = images[0].device
20
+ channels = images[0].shape[0]
21
+ flats, offsets, heights, widths, cur = [], [], [], [], 0
22
+ for img in images:
23
+ ic, ih, iw = img.shape
24
+ if ic != channels:
25
+ raise ValueError(f"all images must share channel count {channels}, got {ic}")
26
+ flats.append(img.reshape(-1).to(dtype))
27
+ offsets.append(cur)
28
+ heights.append(ih)
29
+ widths.append(iw)
30
+ cur += ic * ih * iw
31
+ in_buf = torch.cat(flats)
32
+ offsets_t = torch.tensor(offsets, device=device, dtype=torch.int64)
33
+ heights_t = torch.tensor(heights, device=device, dtype=torch.int32)
34
+ widths_t = torch.tensor(widths, device=device, dtype=torch.int32)
35
+ return in_buf, offsets_t, heights_t, widths_t, channels
36
+
37
+
38
+ def fold_mean_std(mean, std, rescale: float, device) -> tuple[torch.Tensor, torch.Tensor]:
39
+ """Fold rescale into mean/std so the kernel does (x - m)/s == (x*rescale - mean)/std."""
40
+ mean_t = (torch.tensor(mean, device=device, dtype=torch.float32) / rescale).contiguous()
41
+ std_t = (torch.tensor(std, device=device, dtype=torch.float32) / rescale).contiguous()
42
+ return mean_t, std_t
43
+
44
+
45
+ def max_taps(images: list[torch.Tensor], out_size: int, axis_dim: int, interp: str, antialias: bool) -> int:
46
+ """Batch-wide worst-case tap count for one axis = ceil(support) * 2 + 1."""
47
+ interp_half = 2.0 if interp == "bicubic" else 1.0
48
+ worst = 0
49
+ for img in images:
50
+ scale = img.shape[axis_dim] / out_size
51
+ eff = max(scale, 1.0) if antialias else 1.0
52
+ worst = max(worst, math.ceil(interp_half * eff) * 2 + 1)
53
+ return worst
54
+
55
+
56
+ def as_image_list(images) -> list[torch.Tensor]:
57
+ """Accept a stacked (N, C, H, W) tensor or a list of CHW tensors; always return a list."""
58
+ if isinstance(images, torch.Tensor):
59
+ if images.dim() != 4:
60
+ raise ValueError(f"stacked input must be (N, C, H, W), got shape {tuple(images.shape)}")
61
+ return list(images)
62
+ return list(images)
torch-ext/kernel_image_resize/_separable.py ADDED
@@ -0,0 +1,280 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Separable resize + center-crop + normalize over a ragged uint8 batch.
2
+
3
+ WHAT "RESIZE" DOES, CONCRETELY
4
+ Every output pixel is a weighted average of a small window of input pixels. When you shrink
5
+ an image a lot (with antialiasing) that window gets wide — e.g. 13 input pixels across and
6
+ 13 down, so 13x13 = 169 input pixels feed one output pixel.
7
+
8
+ FUSED vs SEPARABLE (the two backends in this package)
9
+ - FUSED (see _fused.py): for each output pixel, read the whole 2D window directly -> 169 reads.
10
+ - SEPARABLE (this file): do the resize as two 1D steps instead of one 2D step:
11
+ step 1 (horizontal): resize only the WIDTH -> an intermediate image
12
+ step 2 (vertical): resize only the HEIGHT -> the final image
13
+ Each step's window is 1D, so 13 + 13 = 26 reads per output pixel instead of 169. Same math,
14
+ far fewer reads. This is what PIL and torchvision do.
15
+
16
+ CENTER CROP (folded in, no extra pass)
17
+ Processors like CLIP / DINOv2 resize to a "resize size" and then keep only the centered
18
+ crop. We do not materialize the full resized image and slice it; instead each output pixel
19
+ of the CROP maps to a resize-image coordinate by adding the crop offset, and that maps back
20
+ to the input. So:
21
+ resize is described by (resize_height, resize_width) -- per image
22
+ crop is described by (crop_top, crop_left) -- per image, the centered offset
23
+ output size is (crop_height, crop_width) -- the same for every image
24
+ When there is no crop, resize size == crop size and the offsets are 0 (the plain resize).
25
+ The resize SCALE uses the resize size; only the output coordinate is shifted by the crop.
26
+
27
+ uint8 input + float intermediate; each 1D step renormalizes its own weights (matches
28
+ torchvision). Output is parity-close to torchvision, not bit-identical (torchvision keeps a
29
+ fixed-point uint8 intermediate; ours is more accurate float).
30
+ """
31
+
32
+ import triton
33
+ import triton.language as tl
34
+
35
+ from ._fused import _resample_weight
36
+ from ._pack import fold_mean_std, pack_images
37
+
38
+
39
+ @triton.jit
40
+ def _horizontal_resize_kernel(
41
+ input_pixels, # flat uint8 buffer, all images packed back to back
42
+ intermediate, # flat float32 output: width resized + col-cropped, height untouched
43
+ input_offsets, # input_offsets[image] = where that image starts in input_pixels
44
+ intermediate_offsets, # same idea for the intermediate buffer
45
+ heights, widths, # per-image input height / width
46
+ resize_widths, # per-image width to resize to (before cropping)
47
+ crop_lefts, # per-image left offset of the centered crop
48
+ crop_w, # output (crop) width, same for every image
49
+ cubic_coeff,
50
+ CHANNELS: tl.constexpr, BLOCK: tl.constexpr,
51
+ CUBIC: tl.constexpr, ANTIALIAS: tl.constexpr,
52
+ MAX_TAPS_COL: tl.constexpr,
53
+ ):
54
+ """Resize width to resize_width, keep only the cropped columns: uint8 (C,H,W) -> float (C,H,crop_w)."""
55
+ image_index = tl.program_id(0)
56
+ block_index = tl.program_id(1)
57
+ in_height = tl.load(heights + image_index)
58
+ in_width = tl.load(widths + image_index)
59
+ resize_width = tl.load(resize_widths + image_index)
60
+ crop_left = tl.load(crop_lefts + image_index)
61
+ input_start = tl.load(input_offsets + image_index)
62
+ intermediate_start = tl.load(intermediate_offsets + image_index)
63
+ in_width_f = in_width.to(tl.float32)
64
+
65
+ num_pixels = in_height * crop_w # every input row x every cropped output column
66
+ flat_index = block_index * BLOCK + tl.arange(0, BLOCK)
67
+ active = flat_index < num_pixels
68
+ input_row = flat_index // crop_w
69
+ out_col = flat_index % crop_w
70
+ resize_col = out_col + crop_left # column in the (uncropped) resized image
71
+
72
+ filter_half = 2.0 if CUBIC else 1.0
73
+ col_scale = in_width_f / resize_width.to(tl.float32)
74
+ col_filter_scale = tl.maximum(col_scale, 1.0) if ANTIALIAS else 1.0
75
+ col_support = filter_half * col_filter_scale
76
+ col_inv_scale = 1.0 / col_filter_scale
77
+ src_center_col = col_scale * (resize_col.to(tl.float32) + 0.5)
78
+ first_tap_col = tl.floor(src_center_col - col_support + 0.5)
79
+
80
+ col_weight_sum = tl.zeros([BLOCK], dtype=tl.float32)
81
+ for tap in tl.static_range(MAX_TAPS_COL):
82
+ tap_col = first_tap_col + tap
83
+ weight = _resample_weight((tap_col - src_center_col + 0.5) * col_inv_scale, cubic_coeff, CUBIC)
84
+ if ANTIALIAS:
85
+ weight = tl.where((tap_col >= 0.0) & (tap_col < in_width_f), weight, 0.0)
86
+ col_weight_sum += weight
87
+
88
+ input_plane = (in_height * in_width).to(tl.int64)
89
+ intermediate_plane = (in_height * crop_w).to(tl.int64)
90
+ in_width_i64 = in_width.to(tl.int64)
91
+ crop_w_i64 = crop_w.to(tl.int64)
92
+ input_row_i64 = input_row.to(tl.int64)
93
+ for channel in tl.static_range(CHANNELS):
94
+ input_row_base = input_start + channel * input_plane + input_row_i64 * in_width_i64
95
+ accumulator = tl.zeros([BLOCK], dtype=tl.float32)
96
+ for tap in tl.static_range(MAX_TAPS_COL):
97
+ tap_col = first_tap_col + tap
98
+ weight = _resample_weight((tap_col - src_center_col + 0.5) * col_inv_scale, cubic_coeff, CUBIC)
99
+ if ANTIALIAS:
100
+ weight = tl.where((tap_col >= 0.0) & (tap_col < in_width_f), weight, 0.0)
101
+ clamped_tap_col = tl.minimum(tl.maximum(tap_col.to(tl.int32), 0), in_width - 1).to(tl.int64)
102
+ pixel = tl.load(input_pixels + input_row_base + clamped_tap_col, mask=active, other=0).to(tl.float32)
103
+ accumulator += weight * pixel
104
+ accumulator = accumulator / col_weight_sum
105
+ write_index = intermediate_start + channel * intermediate_plane + input_row_i64 * crop_w_i64 + out_col
106
+ tl.store(intermediate + write_index, accumulator, mask=active)
107
+
108
+
109
+ @triton.jit
110
+ def _vertical_resize_normalize_kernel(
111
+ intermediate, # float32 from the horizontal step: (C, H, crop_w) per image
112
+ output, # final (N, C, crop_h, crop_w) float32
113
+ intermediate_offsets,
114
+ heights, # per-image input height (the intermediate still has H rows)
115
+ resize_heights, # per-image height to resize to (before cropping)
116
+ crop_tops, # per-image top offset of the centered crop
117
+ means, stds, # per-channel normalization, rescale already folded in
118
+ crop_h, crop_w,
119
+ cubic_coeff,
120
+ CHANNELS: tl.constexpr, BLOCK: tl.constexpr,
121
+ CUBIC: tl.constexpr, ANTIALIAS: tl.constexpr,
122
+ MAX_TAPS_ROW: tl.constexpr,
123
+ ):
124
+ """Resize height to resize_height, keep cropped rows, normalize: float (C,H,crop_w) -> (C,crop_h,crop_w)."""
125
+ image_index = tl.program_id(0)
126
+ block_index = tl.program_id(1)
127
+ in_height = tl.load(heights + image_index)
128
+ resize_height = tl.load(resize_heights + image_index)
129
+ crop_top = tl.load(crop_tops + image_index)
130
+ intermediate_start = tl.load(intermediate_offsets + image_index)
131
+ in_height_f = in_height.to(tl.float32)
132
+
133
+ num_pixels = crop_h * crop_w
134
+ flat_index = block_index * BLOCK + tl.arange(0, BLOCK)
135
+ active = flat_index < num_pixels
136
+ out_row = flat_index // crop_w
137
+ out_col = flat_index % crop_w
138
+ resize_row = out_row + crop_top # row in the (uncropped) resized image
139
+
140
+ filter_half = 2.0 if CUBIC else 1.0
141
+ row_scale = in_height_f / resize_height.to(tl.float32)
142
+ row_filter_scale = tl.maximum(row_scale, 1.0) if ANTIALIAS else 1.0
143
+ row_support = filter_half * row_filter_scale
144
+ row_inv_scale = 1.0 / row_filter_scale
145
+ src_center_row = row_scale * (resize_row.to(tl.float32) + 0.5)
146
+ first_tap_row = tl.floor(src_center_row - row_support + 0.5)
147
+
148
+ row_weight_sum = tl.zeros([BLOCK], dtype=tl.float32)
149
+ for tap in tl.static_range(MAX_TAPS_ROW):
150
+ tap_row = first_tap_row + tap
151
+ weight = _resample_weight((tap_row - src_center_row + 0.5) * row_inv_scale, cubic_coeff, CUBIC)
152
+ if ANTIALIAS:
153
+ weight = tl.where((tap_row >= 0.0) & (tap_row < in_height_f), weight, 0.0)
154
+ row_weight_sum += weight
155
+
156
+ intermediate_plane = (in_height * crop_w).to(tl.int64)
157
+ crop_w_i64 = crop_w.to(tl.int64)
158
+ out_col_i64 = out_col.to(tl.int64)
159
+ for channel in tl.static_range(CHANNELS):
160
+ channel_base = intermediate_start + channel * intermediate_plane
161
+ accumulator = tl.zeros([BLOCK], dtype=tl.float32)
162
+ for tap in tl.static_range(MAX_TAPS_ROW):
163
+ tap_row = first_tap_row + tap
164
+ weight = _resample_weight((tap_row - src_center_row + 0.5) * row_inv_scale, cubic_coeff, CUBIC)
165
+ if ANTIALIAS:
166
+ weight = tl.where((tap_row >= 0.0) & (tap_row < in_height_f), weight, 0.0)
167
+ clamped_tap_row = tl.minimum(tl.maximum(tap_row.to(tl.int32), 0), in_height - 1).to(tl.int64)
168
+ pixel = tl.load(intermediate + channel_base + clamped_tap_row * crop_w_i64 + out_col_i64, mask=active, other=0.0)
169
+ accumulator += weight * pixel
170
+ accumulator = accumulator / row_weight_sum
171
+ mean = tl.load(means + channel)
172
+ std = tl.load(stds + channel)
173
+ accumulator = (accumulator - mean) / std
174
+ write_index = ((image_index * CHANNELS + channel) * crop_h + out_row) * crop_w + out_col
175
+ tl.store(output + write_index, accumulator, mask=active)
176
+
177
+
178
+ def _axis_max_taps(in_sizes, resize_sizes, interp, antialias):
179
+ """Widest 1D window over the batch for one axis = ceil(support) * 2 + 1, support uses in/resize."""
180
+ import math
181
+
182
+ interp_half = 2.0 if interp == "bicubic" else 1.0
183
+ worst = 0
184
+ for in_size, resize_size in zip(in_sizes, resize_sizes):
185
+ scale = in_size / resize_size
186
+ eff = max(scale, 1.0) if antialias else 1.0
187
+ worst = max(worst, math.ceil(interp_half * eff) * 2 + 1)
188
+ return worst
189
+
190
+
191
+ def _run_separable(images, resize_heights, resize_widths, crop_tops, crop_lefts, crop_h, crop_w,
192
+ mean, std, rescale, interp, antialias, block):
193
+ """Core driver: resize each image to its (resize_h, resize_w), keep the centered crop, normalize."""
194
+ import torch
195
+
196
+ device = images[0].device
197
+ num_images = len(images)
198
+ cubic_coeff = -0.5 if antialias else -0.75
199
+ in_heights = [int(img.shape[1]) for img in images]
200
+ in_widths = [int(img.shape[2]) for img in images]
201
+ max_taps_row = _axis_max_taps(in_heights, resize_heights, interp, antialias)
202
+ max_taps_col = _axis_max_taps(in_widths, resize_widths, interp, antialias)
203
+ means, stds = fold_mean_std(mean, std, rescale, device)
204
+
205
+ input_pixels, input_offsets, heights, widths, channels = pack_images(images, dtype=torch.uint8)
206
+
207
+ intermediate_offsets_list, cursor, tallest = [], 0, 0
208
+ for height in in_heights:
209
+ intermediate_offsets_list.append(cursor)
210
+ cursor += channels * height * crop_w
211
+ tallest = max(tallest, height)
212
+ intermediate = torch.empty(cursor, device=device, dtype=torch.float32)
213
+
214
+ intermediate_offsets = torch.tensor(intermediate_offsets_list, device=device, dtype=torch.int64)
215
+ resize_heights_t = torch.tensor(resize_heights, device=device, dtype=torch.int32)
216
+ resize_widths_t = torch.tensor(resize_widths, device=device, dtype=torch.int32)
217
+ crop_tops_t = torch.tensor(crop_tops, device=device, dtype=torch.int32)
218
+ crop_lefts_t = torch.tensor(crop_lefts, device=device, dtype=torch.int32)
219
+
220
+ horizontal_grid = (num_images, triton.cdiv(tallest * crop_w, block))
221
+ _horizontal_resize_kernel[horizontal_grid](
222
+ input_pixels, intermediate, input_offsets, intermediate_offsets, heights, widths,
223
+ resize_widths_t, crop_lefts_t, crop_w, cubic_coeff,
224
+ CHANNELS=channels, BLOCK=block, CUBIC=(interp == "bicubic"), ANTIALIAS=antialias,
225
+ MAX_TAPS_COL=max_taps_col,
226
+ )
227
+
228
+ output = torch.empty((num_images, channels, crop_h, crop_w), device=device, dtype=torch.float32)
229
+ vertical_grid = (num_images, triton.cdiv(crop_h * crop_w, block))
230
+ _vertical_resize_normalize_kernel[vertical_grid](
231
+ intermediate, output, intermediate_offsets, heights, resize_heights_t, crop_tops_t, means, stds,
232
+ crop_h, crop_w, cubic_coeff,
233
+ CHANNELS=channels, BLOCK=block, CUBIC=(interp == "bicubic"), ANTIALIAS=antialias,
234
+ MAX_TAPS_ROW=max_taps_row,
235
+ )
236
+ return output
237
+
238
+
239
+ def _aspect_preserving_size(in_h, in_w, shortest_edge):
240
+ """transformers shortest-edge rule: short side -> shortest_edge, long side truncated (int(), not round)."""
241
+ if in_h <= in_w:
242
+ return shortest_edge, int(in_w * shortest_edge / in_h)
243
+ return int(in_h * shortest_edge / in_w), shortest_edge
244
+
245
+
246
+ def separable_resize_normalize(images, out_h, out_w, mean, std, rescale, interp, antialias, block: int = 256):
247
+ """Resize to (out_h, out_w) and normalize (no crop)."""
248
+ images = list(images)
249
+ n = len(images)
250
+ return _run_separable(images, [out_h] * n, [out_w] * n, [0] * n, [0] * n, out_h, out_w,
251
+ mean, std, rescale, interp, antialias, block)
252
+
253
+
254
+ def separable_resize_crop_normalize(images, resize_size, crop_size, mean, std, rescale, interp, antialias,
255
+ resize_mode="square", block: int = 256):
256
+ """Resize then center-crop then normalize.
257
+
258
+ resize_mode="square": resize_size is (resize_h, resize_w) applied to every image.
259
+ resize_mode="shortest_edge": resize_size is an int; each image is resized aspect-preserving
260
+ so its short side equals it, then center-cropped to crop_size.
261
+ """
262
+ images = list(images)
263
+ crop_h, crop_w = crop_size
264
+ resize_heights, resize_widths = [], []
265
+ for img in images:
266
+ in_h, in_w = int(img.shape[1]), int(img.shape[2])
267
+ if resize_mode == "shortest_edge":
268
+ rh, rw = _aspect_preserving_size(in_h, in_w, int(resize_size))
269
+ elif resize_mode == "square":
270
+ rh, rw = int(resize_size[0]), int(resize_size[1])
271
+ else:
272
+ raise ValueError(f"resize_mode must be 'square' or 'shortest_edge', got {resize_mode!r}")
273
+ if rh < crop_h or rw < crop_w:
274
+ raise ValueError(f"resize size ({rh},{rw}) smaller than crop ({crop_h},{crop_w})")
275
+ resize_heights.append(rh)
276
+ resize_widths.append(rw)
277
+ crop_tops = [(rh - crop_h) // 2 for rh in resize_heights]
278
+ crop_lefts = [(rw - crop_w) // 2 for rw in resize_widths]
279
+ return _run_separable(images, resize_heights, resize_widths, crop_tops, crop_lefts, crop_h, crop_w,
280
+ mean, std, rescale, interp, antialias, block)