Upload folder using huggingface_hub
Browse files- KERNEL_OPS.md +326 -0
- README.md +108 -0
- benchmarks/benchmark.py +293 -0
- benchmarks/compat_check.py +127 -0
- build.toml +7 -0
- build/torch-universal/kernel_image_resize/__init__.py +113 -0
- build/torch-universal/kernel_image_resize/_fused.py +134 -0
- build/torch-universal/kernel_image_resize/_pack.py +62 -0
- build/torch-universal/kernel_image_resize/_separable.py +280 -0
- example.py +32 -0
- example_transformers.py +85 -0
- publish.sh +44 -0
- resultcompat +16 -0
- tests/test_resize_normalize.py +112 -0
- torch-ext/kernel_image_resize/__init__.py +113 -0
- torch-ext/kernel_image_resize/_fused.py +134 -0
- torch-ext/kernel_image_resize/_pack.py +62 -0
- torch-ext/kernel_image_resize/_separable.py +280 -0
KERNEL_OPS.md
ADDED
|
@@ -0,0 +1,326 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# kernel_image_resize — how every op works (study notes)
|
| 2 |
+
|
| 3 |
+
This explains the whole package end to end: the resampling math, the data layout, and
|
| 4 |
+
every kernel op, plus the benchmark findings. It is meant for pen & paper — there is a
|
| 5 |
+
fully worked numeric example you can reproduce by hand in the "Worked example" section.
|
| 6 |
+
|
| 7 |
+
The package does one thing: **resize + rescale + normalize**, the op sequence a
|
| 8 |
+
`transformers` fast image processor runs (`TorchvisionBackend`: resize, then
|
| 9 |
+
`(x*rescale - mean)/std`), as a single GPU pipeline. Input: raw CHW `uint8` images
|
| 10 |
+
(any size, ragged). Output: `(N, C, out_h, out_w)` normalized `float32`.
|
| 11 |
+
|
| 12 |
+
---
|
| 13 |
+
|
| 14 |
+
## 1. The one idea behind resizing
|
| 15 |
+
|
| 16 |
+
Resizing does not "pick" pixels; each **output** pixel is a **weighted average of a small
|
| 17 |
+
window of input pixels**. Two things define that average:
|
| 18 |
+
|
| 19 |
+
1. **Where** in the input an output pixel lands (its center).
|
| 20 |
+
2. **Which** input pixels are in its window, and with **what weights**.
|
| 21 |
+
|
| 22 |
+
The weight of an input pixel falls off with distance from the center, following a filter
|
| 23 |
+
curve:
|
| 24 |
+
|
| 25 |
+
- **bilinear** → a triangle, width 1 on each side (so 2 input pixels per axis normally).
|
| 26 |
+
- **bicubic** → a cubic bump, width 2 on each side (so 4 input pixels per axis normally).
|
| 27 |
+
|
| 28 |
+
When you **shrink** an image, you must also **blur first** or you get aliasing. That is
|
| 29 |
+
what `antialias=True` does: it widens the window so each output pixel averages more input
|
| 30 |
+
pixels (a low-pass filter before throwing pixels away). Widening is proportional to the
|
| 31 |
+
shrink factor, so shrinking 3× turns a 4-tap bicubic into ~13 taps.
|
| 32 |
+
|
| 33 |
+
---
|
| 34 |
+
|
| 35 |
+
## 2. The resampling-weight formula (the heart of everything)
|
| 36 |
+
|
| 37 |
+
All kernels use the same formula, which matches PyTorch's aten `UpSampleKernel`
|
| 38 |
+
(`align_corners=False`, "half-pixel" convention). For one axis:
|
| 39 |
+
|
| 40 |
+
```
|
| 41 |
+
scale = in_size / out_size # > 1 means shrinking
|
| 42 |
+
interp_half = 1 (bilinear) or 2 (bicubic) # half-width of the filter
|
| 43 |
+
cubic_a = -0.75 (no antialias) or -0.5 (antialias) # the cubic curve's shape constant
|
| 44 |
+
|
| 45 |
+
# antialias only widens the window, and only when shrinking:
|
| 46 |
+
if antialias and scale > 1:
|
| 47 |
+
eff = scale # window widens by the shrink factor
|
| 48 |
+
else:
|
| 49 |
+
eff = 1 # plain 2-tap / 4-tap interpolation
|
| 50 |
+
support = interp_half * eff # half-window width, in INPUT pixels
|
| 51 |
+
inv = 1 / eff # squashes the filter curve to match the widened window
|
| 52 |
+
|
| 53 |
+
# for output index i:
|
| 54 |
+
center = scale * (i + 0.5) # input coordinate this output maps to
|
| 55 |
+
first_tap = floor(center - support + 0.5) # leftmost input pixel in the window
|
| 56 |
+
|
| 57 |
+
# for each tap t = 0, 1, 2, ... (up to MAX_TAPS):
|
| 58 |
+
tap_pos = first_tap + t # an input pixel index
|
| 59 |
+
arg = (tap_pos - center + 0.5) * inv # distance from center, squashed
|
| 60 |
+
weight = filter(arg) # triangle or cubic, see below
|
| 61 |
+
```
|
| 62 |
+
|
| 63 |
+
The filter (`_resample_weight` in `_fused.py`), with `x = |arg|`:
|
| 64 |
+
|
| 65 |
+
```
|
| 66 |
+
bilinear: max(1 - x, 0) # triangle, zero past 1
|
| 67 |
+
|
| 68 |
+
bicubic: x <= 1 : (a+2)x^3 - (a+3)x^2 + 1
|
| 69 |
+
1<x<2 : a x^3 - 5a x^2 + 8a x - 4a
|
| 70 |
+
else : 0
|
| 71 |
+
```
|
| 72 |
+
|
| 73 |
+
Two edge rules (both kernels do this identically):
|
| 74 |
+
|
| 75 |
+
- **non-antialias**: clamp the tap index into `[0, in_size-1]` → replicates the border
|
| 76 |
+
pixel. The filter weights of a standard 2/4-tap interpolation already sum to 1.
|
| 77 |
+
- **antialias**: instead set the weight to **0** for taps that fall off the image
|
| 78 |
+
(`tap_pos < 0` or `>= in_size`), then **renormalize** by dividing by the sum of the
|
| 79 |
+
realized weights. This keeps the average correct at the edges.
|
| 80 |
+
|
| 81 |
+
That renormalization is why every kernel computes a `weight_sum` and divides by it. For
|
| 82 |
+
the non-antialias case `weight_sum == 1`, so the division is a harmless no-op.
|
| 83 |
+
|
| 84 |
+
---
|
| 85 |
+
|
| 86 |
+
## 3. Worked example (do this by hand)
|
| 87 |
+
|
| 88 |
+
**bilinear, no antialias, one axis, in_size=4, out_size=2.**
|
| 89 |
+
|
| 90 |
+
```
|
| 91 |
+
scale = 4/2 = 2, interp_half = 1, eff = 1, support = 1, inv = 1
|
| 92 |
+
```
|
| 93 |
+
|
| 94 |
+
Output pixel `i = 0`:
|
| 95 |
+
```
|
| 96 |
+
center = 2 * (0 + 0.5) = 1.0
|
| 97 |
+
first_tap = floor(1.0 - 1 + 0.5) = floor(0.5) = 0
|
| 98 |
+
t=0: tap_pos=0, arg=(0-1.0+0.5)= -0.5 -> weight = 1-0.5 = 0.5
|
| 99 |
+
t=1: tap_pos=1, arg=(1-1.0+0.5)= 0.5 -> weight = 1-0.5 = 0.5
|
| 100 |
+
t=2: tap_pos=2, arg=(2-1.0+0.5)= 1.5 -> weight = max(1-1.5,0) = 0
|
| 101 |
+
weight_sum = 1.0
|
| 102 |
+
output[0] = (0.5*in[0] + 0.5*in[1]) / 1.0 # halfway between in[0] and in[1]
|
| 103 |
+
```
|
| 104 |
+
|
| 105 |
+
Output pixel `i = 1`:
|
| 106 |
+
```
|
| 107 |
+
center = 2 * 1.5 = 3.0
|
| 108 |
+
first_tap = floor(3.0 - 0.5) = 2
|
| 109 |
+
t=0: tap_pos=2, arg=-0.5 -> 0.5
|
| 110 |
+
t=1: tap_pos=3, arg= 0.5 -> 0.5
|
| 111 |
+
t=2: tap_pos=4, arg= 1.5 -> 0 (index 4 would clamp to 3, but weight is 0 anyway)
|
| 112 |
+
output[1] = 0.5*in[2] + 0.5*in[3]
|
| 113 |
+
```
|
| 114 |
+
|
| 115 |
+
This 1-D operation is exactly one pass of the separable kernel. The 2-D result is the same
|
| 116 |
+
formula applied on both axes (rows and columns).
|
| 117 |
+
|
| 118 |
+
---
|
| 119 |
+
|
| 120 |
+
## 4. Data layout (host side, `_pack.py`)
|
| 121 |
+
|
| 122 |
+
Ragged images (different H×W) cannot be stacked into one tensor, so they are flattened and
|
| 123 |
+
concatenated into one buffer, with side tables describing each image.
|
| 124 |
+
|
| 125 |
+
`pack_images(images, dtype)` →
|
| 126 |
+
```
|
| 127 |
+
input_pixels : 1-D buffer, all images flattened (C,H,W row-major) and concatenated
|
| 128 |
+
offsets[n] : element index where image n starts
|
| 129 |
+
heights[n], widths[n] : that image's H and W
|
| 130 |
+
channels : C (shared by all images)
|
| 131 |
+
```
|
| 132 |
+
Address of input pixel `(channel, row, col)` of image `n`:
|
| 133 |
+
```
|
| 134 |
+
input_pixels[ offsets[n] + channel*(H*W) + row*W + col ]
|
| 135 |
+
```
|
| 136 |
+
The separable path packs as `uint8` (1 byte/pixel, half the memory traffic of float).
|
| 137 |
+
|
| 138 |
+
`fold_mean_std(mean, std, rescale)` → folds the rescale factor into the normalization
|
| 139 |
+
constants so the kernel does a single `(x - m)/s`:
|
| 140 |
+
```
|
| 141 |
+
m = mean / rescale s = std / rescale
|
| 142 |
+
(x - m)/s == (x*rescale - mean)/std # identical to the processor's fused normalize
|
| 143 |
+
```
|
| 144 |
+
|
| 145 |
+
`max_taps(images, out_size, axis, interp, antialias)` → the **widest** window in the batch
|
| 146 |
+
= `ceil(support) * 2 + 1`. A Triton loop bound must be a compile-time constant, so every
|
| 147 |
+
program loops this fixed count; taps beyond a given pixel's real window get ~0 weight.
|
| 148 |
+
|
| 149 |
+
`as_image_list` → accepts a stacked `(N,C,H,W)` tensor or a list, always returns a list.
|
| 150 |
+
|
| 151 |
+
---
|
| 152 |
+
|
| 153 |
+
## 5. Fused kernel (`_fused.py`, `backend="fused"`)
|
| 154 |
+
|
| 155 |
+
One launch. **One program = one image + a BLOCK of its output pixels.** Each output pixel
|
| 156 |
+
reads the **full 2-D window** directly: `MAX_TAPS_H × MAX_TAPS_W` input pixels.
|
| 157 |
+
|
| 158 |
+
```
|
| 159 |
+
grid = (num_images, ceil(out_h*out_w / BLOCK))
|
| 160 |
+
|
| 161 |
+
per lane (one output pixel):
|
| 162 |
+
oy, ox = (flat_index // out_w, flat_index % out_w)
|
| 163 |
+
center_y, center_x, first_tap_y, first_tap_x # section 2, both axes
|
| 164 |
+
|
| 165 |
+
# weight_sum factorizes across axes (separable math, even though the LOADS are 2-D):
|
| 166 |
+
sum_wy = Σ_ty filter_y ; sum_wx = Σ_tx filter_x ; denom = sum_wy * sum_wx
|
| 167 |
+
|
| 168 |
+
for channel:
|
| 169 |
+
acc = 0
|
| 170 |
+
for ty in 0..MAX_TAPS_H: # <-- the 2-D window: TAPS_H * TAPS_W loads
|
| 171 |
+
for tx in 0..MAX_TAPS_W:
|
| 172 |
+
weight = filter_y(ty) * filter_x(tx)
|
| 173 |
+
pixel = input_pixels[channel, clamp(tap_y), clamp(tap_x)]
|
| 174 |
+
acc += weight * pixel
|
| 175 |
+
acc = acc / denom
|
| 176 |
+
out[image, channel, oy, ox] = (acc - mean[channel]) / std[channel]
|
| 177 |
+
```
|
| 178 |
+
|
| 179 |
+
Cost per output pixel: `TAPS_H * TAPS_W` loads (e.g. 13×13 = **169**). Correct and simple,
|
| 180 |
+
but the 2-D load count is what makes it slow — hence the separable version.
|
| 181 |
+
|
| 182 |
+
---
|
| 183 |
+
|
| 184 |
+
## 6. Separable kernel (`_separable.py`, `backend="separable"`, the default)
|
| 185 |
+
|
| 186 |
+
Same math, but the 2-D window is done as **two 1-D passes**, with a float intermediate
|
| 187 |
+
buffer in between. Loads per output pixel: `TAPS_W + TAPS_H` (e.g. 13+13 = **26**).
|
| 188 |
+
|
| 189 |
+
```
|
| 190 |
+
input_pixels (uint8, C×H×W) --pass1--> intermediate (float, C×H×out_w) --pass2--> output (float, C×out_h×out_w)
|
| 191 |
+
resize W resize H + normalize
|
| 192 |
+
```
|
| 193 |
+
|
| 194 |
+
The **intermediate** is the key object: same **height** as the input, but already the
|
| 195 |
+
**final width**. ("Tall and narrow.") It is also ragged in height, so it gets its own
|
| 196 |
+
offset table (built in `separable_resize_normalize`, same scheme as `pack_images`).
|
| 197 |
+
|
| 198 |
+
### Pass 1 — `_horizontal_resize_kernel` (resize width only)
|
| 199 |
+
|
| 200 |
+
```
|
| 201 |
+
grid = (num_images, ceil(H*out_w / BLOCK)) # work = every input row × every output col
|
| 202 |
+
per lane:
|
| 203 |
+
input_row = flat_index // out_w # row index, UNCHANGED by this pass
|
| 204 |
+
out_col = flat_index % out_w # output column being computed
|
| 205 |
+
|
| 206 |
+
center_x, first_tap_x, col_weight_sum # section 2, COLUMN axis only
|
| 207 |
+
|
| 208 |
+
for channel:
|
| 209 |
+
acc = 0
|
| 210 |
+
for tap in 0..MAX_TAPS_COL: # <-- 1-D: only TAPS_W loads
|
| 211 |
+
weight = filter_x(tap)
|
| 212 |
+
pixel = input_pixels[channel, input_row, clamp(tap_col)] # uint8 -> float
|
| 213 |
+
acc += weight * pixel
|
| 214 |
+
acc = acc / col_weight_sum
|
| 215 |
+
intermediate[channel, input_row, out_col] = acc # NO normalize yet
|
| 216 |
+
```
|
| 217 |
+
|
| 218 |
+
Reads original `uint8` bytes; writes `float32`. No normalization here.
|
| 219 |
+
|
| 220 |
+
### Pass 2 — `_vertical_resize_normalize_kernel` (resize height, then normalize)
|
| 221 |
+
|
| 222 |
+
```
|
| 223 |
+
grid = (num_images, ceil(out_h*out_w / BLOCK)) # work = every output pixel
|
| 224 |
+
per lane:
|
| 225 |
+
out_row = flat_index // out_w
|
| 226 |
+
out_col = flat_index % out_w
|
| 227 |
+
|
| 228 |
+
center_y, first_tap_y, row_weight_sum # section 2, ROW axis only
|
| 229 |
+
|
| 230 |
+
for channel:
|
| 231 |
+
acc = 0
|
| 232 |
+
for tap in 0..MAX_TAPS_ROW: # <-- 1-D: only TAPS_H loads
|
| 233 |
+
weight = filter_y(tap)
|
| 234 |
+
pixel = intermediate[channel, clamp(tap_row), out_col] # float
|
| 235 |
+
acc += weight * pixel
|
| 236 |
+
acc = acc / row_weight_sum
|
| 237 |
+
out[image, channel, out_row, out_col] = (acc - mean[channel]) / std[channel] # normalize here
|
| 238 |
+
```
|
| 239 |
+
|
| 240 |
+
Two launches (an implicit sync between them), so pass 2 always sees pass 1's finished
|
| 241 |
+
output.
|
| 242 |
+
|
| 243 |
+
### Why separable wins
|
| 244 |
+
|
| 245 |
+
`TAPS_W + TAPS_H` loads instead of `TAPS_W * TAPS_H`. For a 13×13 window that is 26 vs 169.
|
| 246 |
+
This is exactly the algorithm PIL and torchvision use. The catch: an extra full-size float
|
| 247 |
+
intermediate buffer (more memory traffic), but the read-count reduction dominates.
|
| 248 |
+
|
| 249 |
+
Parity note: the intermediate here is **float32**; torchvision keeps a **fixed-point
|
| 250 |
+
uint8** intermediate. So the separable output is parity-*close* to torchvision, not
|
| 251 |
+
bit-identical — and the float version is actually the more accurate one.
|
| 252 |
+
|
| 253 |
+
---
|
| 254 |
+
|
| 255 |
+
## 7. Public API (`__init__.py`)
|
| 256 |
+
|
| 257 |
+
```
|
| 258 |
+
resize_normalize(images, size, image_mean, image_std,
|
| 259 |
+
rescale_factor=1/255, resample="bilinear", antialias=False,
|
| 260 |
+
backend="separable", block=256)
|
| 261 |
+
```
|
| 262 |
+
- `images`: stacked `(N,C,H,W)` tensor or a list of CHW tensors.
|
| 263 |
+
- `size`: int (square), `(H,W)`, or `{"height","width"}`.
|
| 264 |
+
- `resample`: `"bilinear"`/`"bicubic"`, or a PIL resample int (0/2→bilinear, 3→bicubic).
|
| 265 |
+
- `backend`: `"separable"` (default, fastest) or `"fused"` (2-D reference).
|
| 266 |
+
- `resize_normalize_ragged`: same kernels, list-only.
|
| 267 |
+
|
| 268 |
+
---
|
| 269 |
+
|
| 270 |
+
## 8. Benchmark findings (A100, CUDA_VISIBLE_DEVICES=1)
|
| 271 |
+
|
| 272 |
+
### Standalone resize+normalize — SigLIP-so400m config, N=32 ragged 384–1024², out 384², bicubic+AA
|
| 273 |
+
```
|
| 274 |
+
torchvision eager loop : 2.91 ms (per-image float loop)
|
| 275 |
+
torchvision compiled : 5.70 ms (torch.compile dynamic, per-image; slower than eager)
|
| 276 |
+
torchvision compiled pkt: 2.55 ms (one graph over a padded stack; timing only)
|
| 277 |
+
fused triton (2D) : 11.49 ms (taps*taps; the slow reference)
|
| 278 |
+
separable triton (uint8): 1.29 ms (taps+taps) <-- fastest
|
| 279 |
+
real processor : 3.92 ms
|
| 280 |
+
```
|
| 281 |
+
**Separable is ~3× the real processor**, parity ≤1e-4 vs torchvision-float. The fused 2-D
|
| 282 |
+
loses for the algorithmic reason above (169 vs 26 loads). `torch.compile` does not help:
|
| 283 |
+
per-image it is *slower* (dispatch overhead over 32 ragged shapes); even as one packed
|
| 284 |
+
graph it only matches the eager loop, because inductor's interpolate is no faster than aten
|
| 285 |
+
resize.
|
| 286 |
+
|
| 287 |
+
### End-to-end inference — Siglip2-base-patch16-224, **bf16** forward
|
| 288 |
+
```
|
| 289 |
+
preprocess forward(fixed input) preprocess+forward
|
| 290 |
+
processor 3.99 12.86 14.44
|
| 291 |
+
separable 0.93 13.02 13.76 <-- ~5% faster e2e
|
| 292 |
+
fused 2.00 13.01 14.79
|
| 293 |
+
compiled 6.14 12.89 14.00
|
| 294 |
+
feature parity (separable/fused/compiled vs processor): 9.38e-2 = 1.2% of feature max
|
| 295 |
+
```
|
| 296 |
+
- `forward(fixed input)` is identical (~12.9 ms) for all → **no inference regression**; the
|
| 297 |
+
model does not care which preprocessor made the tensor.
|
| 298 |
+
- The 1.2% feature drift is the float-vs-uint8 resize difference, identical across all
|
| 299 |
+
float backends → not a bug. The float path is the more accurate one.
|
| 300 |
+
- End-to-end win is ~5% with a bf16 forward (was ~0.5% with fp32, where the forward was
|
| 301 |
+
~80 ms). **The win scales with how preprocessing-bound you are.**
|
| 302 |
+
|
| 303 |
+
### Data path from JPEG bytes — 552 KB/img
|
| 304 |
+
```
|
| 305 |
+
CPU decode + torchvision resize : 177.5 ms (status quo)
|
| 306 |
+
CPU decode + separable kernel : 176.4 ms (kernel saves ~1 ms; decode dominates)
|
| 307 |
+
GPU decode (nvJPEG) + kernel : 14.8 ms (fully on-GPU)
|
| 308 |
+
```
|
| 309 |
+
- ~175 ms of the 177 ms is **CPU JPEG decode + host→device copy**. Resize/normalize is ~1%.
|
| 310 |
+
- The 12× win (177→15) is **GPU decode (nvJPEG)**, i.e. `torchvision.io.decode_jpeg(device="cuda")`
|
| 311 |
+
— *not* the kernel. The kernel is the resize/normalize component of that GPU pipeline.
|
| 312 |
+
|
| 313 |
+
---
|
| 314 |
+
|
| 315 |
+
## 9. What is true / what to claim
|
| 316 |
+
|
| 317 |
+
- The kernel is **correct** (≤1e-4 vs torchvision-float, more accurate than the processor's
|
| 318 |
+
uint8 path) and feeds the model with **no inference regression**.
|
| 319 |
+
- It is **~3× the real processor at the resize/normalize stage** — a real, parity-clean win.
|
| 320 |
+
- It does **not** speed up preprocessing 12×. Decode dominates the data path; the GPU-decode
|
| 321 |
+
lever is nvJPEG, a torchvision feature, not this kernel.
|
| 322 |
+
- The kernel matters end-to-end only once you are **not decode-bound**: in a GPU-decode
|
| 323 |
+
pipeline it keeps resize/normalize minimal (~10% of that pipeline), and its standalone
|
| 324 |
+
preprocess win shows up when the forward is small (bf16, small model, large batch).
|
| 325 |
+
- Honest one-liner: *"GPU-native resize+normalize, 3× the fast processor at that stage,
|
| 326 |
+
drop-in for a GPU-decode pipeline."*
|
README.md
ADDED
|
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
tags:
|
| 3 |
+
- kernel
|
| 4 |
+
---
|
| 5 |
+
|
| 6 |
+
# kernel_image_resize
|
| 7 |
+
|
| 8 |
+
A pure-Triton Hub kernel that fuses the **resize + rescale + normalize** preprocessing
|
| 9 |
+
pipeline run by ~150 `transformers` fast image processors (`TorchvisionBackend`: resize →
|
| 10 |
+
fold(rescale, normalize)) into a single GPU pass. It takes raw CHW `uint8` images and
|
| 11 |
+
returns the normalized `(N, C, out_h, out_w)` float tensor with no intermediate
|
| 12 |
+
full-resolution float buffer.
|
| 13 |
+
|
| 14 |
+
On a ragged SigLIP-so400m batch (A100, N=32, inputs 384–1024², out 384², bicubic+antialias)
|
| 15 |
+
the default backend runs in **1.29 ms/iter vs 3.90 ms for the fast processor (~3× faster)**
|
| 16 |
+
and 2.89 ms for torchvision's own per-image loop, at parity ≤1e-4 vs torchvision-float.
|
| 17 |
+
|
| 18 |
+
It ships as a `kernels` universal build variant (no compiled extension, just Triton), so it
|
| 19 |
+
loads on any CUDA PyTorch build via `get_kernel`.
|
| 20 |
+
|
| 21 |
+
## Usage
|
| 22 |
+
|
| 23 |
+
```python
|
| 24 |
+
import torch
|
| 25 |
+
from kernels import get_kernel
|
| 26 |
+
|
| 27 |
+
kir = get_kernel("Molbap/kernel_image_resize", revision="main", trust_remote_code=True)
|
| 28 |
+
|
| 29 |
+
# a list of different-H×W uint8 CHW images (the ragged case torchvision loops over)
|
| 30 |
+
images = [torch.randint(0, 256, (3, h, w), dtype=torch.uint8, device="cuda")
|
| 31 |
+
for h, w in [(640, 480), (800, 600), (384, 1024)]]
|
| 32 |
+
|
| 33 |
+
pixel_values = kir.resize_normalize(
|
| 34 |
+
images,
|
| 35 |
+
size=384, # int (square), (H, W), or {"height", "width"}
|
| 36 |
+
image_mean=[0.5, 0.5, 0.5],
|
| 37 |
+
image_std=[0.5, 0.5, 0.5],
|
| 38 |
+
rescale_factor=1 / 255,
|
| 39 |
+
resample="bicubic", # or "bilinear", or a PIL resample int
|
| 40 |
+
antialias=True, # match the ViT/CLIP/SigLIP default
|
| 41 |
+
)
|
| 42 |
+
# -> (3, 3, 384, 384) float32, ready for the model
|
| 43 |
+
```
|
| 44 |
+
|
| 45 |
+
`trust_remote_code=True` is required because this is a personal namespace (not the trusted
|
| 46 |
+
`kernels-community` org). `revision="main"` loads the current code; tag a `v1.0.0` release if
|
| 47 |
+
you want `version=1` loading instead.
|
| 48 |
+
|
| 49 |
+
`resize_normalize` accepts a stacked `(N, C, H, W)` tensor or a ragged list of CHW
|
| 50 |
+
tensors. `resize_normalize_ragged` is the same kernel, list-only.
|
| 51 |
+
|
| 52 |
+
## With a transformers processor
|
| 53 |
+
|
| 54 |
+
There is no `use_kernels=True` hook for image processors — that machinery swaps `nn.Module`
|
| 55 |
+
layer forwards inside the model, not processor code. Use the kernel directly with the
|
| 56 |
+
processor's config (see `example_transformers.py` for a runnable version):
|
| 57 |
+
|
| 58 |
+
```python
|
| 59 |
+
from kernels import get_kernel
|
| 60 |
+
kir = get_kernel("Molbap/kernel_image_resize", revision="main", trust_remote_code=True)
|
| 61 |
+
_PIL_RESAMPLE = {0: "bilinear", 2: "bilinear", 3: "bicubic"}
|
| 62 |
+
|
| 63 |
+
def preprocess_with_kernel(processor, images):
|
| 64 |
+
size = processor.size # must be fixed {"height", "width"}; no crop/pad
|
| 65 |
+
return kir.resize_normalize(
|
| 66 |
+
images, (size["height"], size["width"]),
|
| 67 |
+
processor.image_mean, processor.image_std,
|
| 68 |
+
rescale_factor=float(processor.rescale_factor),
|
| 69 |
+
resample=_PIL_RESAMPLE[int(processor.resample)],
|
| 70 |
+
antialias=bool(getattr(processor, "antialias", True)),
|
| 71 |
+
)
|
| 72 |
+
```
|
| 73 |
+
|
| 74 |
+
## Backends
|
| 75 |
+
|
| 76 |
+
- `backend="separable"` (default): two-pass `uint8` kernel doing `taps+taps` loads —
|
| 77 |
+
torchvision's own separable algorithm. Fastest (~3× the fast processor on the batch
|
| 78 |
+
above); parity ≤1e-4 vs torchvision-float. The float intermediate makes it more accurate
|
| 79 |
+
than, but not bit-identical to, torchvision's fixed-point `uint8` intermediate.
|
| 80 |
+
- `backend="fused"`: a single 2D launch, `taps×taps` loads per output pixel. Same parity,
|
| 81 |
+
kept as the reference path but ~9× slower than separable (the 2D float load count is the
|
| 82 |
+
reason a separable pass wins — see `benchmarks/benchmark.py`).
|
| 83 |
+
|
| 84 |
+
## Parity notes
|
| 85 |
+
|
| 86 |
+
The resampling weights match PyTorch aten `UpSampleKernel`. Antialiased bicubic uses the
|
| 87 |
+
PIL cubic coefficient `a=-0.5`; non-antialiased bicubic uses Keys `a=-0.75`. The
|
| 88 |
+
antialias renormalize-truncate window applies on every axis, including upsampling dims.
|
| 89 |
+
|
| 90 |
+
## Center crop / shortest-edge
|
| 91 |
+
|
| 92 |
+
Pass `crop_size` to resize then center-crop in one pass (the crop is folded into the
|
| 93 |
+
output-coordinate mapping, no extra buffer). `resize_mode="shortest_edge"` does
|
| 94 |
+
aspect-preserving resize (short side = `size`) then crop — the CLIP / DINOv2 pipeline.
|
| 95 |
+
|
| 96 |
+
```python
|
| 97 |
+
# CLIP/DINOv2-style: resize shortest edge to 256, center-crop 224
|
| 98 |
+
pv = kir.resize_normalize(images, 256, mean, std, resample="bicubic", antialias=True,
|
| 99 |
+
crop_size=224, resize_mode="shortest_edge")
|
| 100 |
+
```
|
| 101 |
+
|
| 102 |
+
`example_transformers.py` derives all of this from a processor's config automatically.
|
| 103 |
+
|
| 104 |
+
## Scope
|
| 105 |
+
|
| 106 |
+
Resize (+ optional center crop) + rescale + normalize. It does **not** pad — padding
|
| 107 |
+
processors (many detection models) run a different pipeline. The `fused` backend is
|
| 108 |
+
resize-only; crop is handled by the `separable` backend.
|
benchmarks/benchmark.py
ADDED
|
@@ -0,0 +1,293 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Benchmark resize+normalize: separable / fused triton vs torchvision vs the real processor.
|
| 2 |
+
|
| 3 |
+
PYTHONPATH=../torch-ext python benchmark.py --processor google/siglip-so400m-patch14-384
|
| 4 |
+
PYTHONPATH=../torch-ext python benchmark.py --n 32 --out 384 384 --interp bicubic --antialias
|
| 5 |
+
|
| 6 |
+
Prints parity (vs torchvision-float) per backend, then ms/iter for each path. Needs CUDA.
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
import argparse
|
| 10 |
+
import sys
|
| 11 |
+
import time
|
| 12 |
+
|
| 13 |
+
# Hide `kernels` from transformers: this worktree builds kernels.LayerRepository without a version,
|
| 14 |
+
# which newer `kernels` rejects at import. Preprocessing needs no hub layer kernels.
|
| 15 |
+
sys.modules["kernels"] = None
|
| 16 |
+
|
| 17 |
+
import torch
|
| 18 |
+
import torchvision.transforms.v2.functional as tvF
|
| 19 |
+
from torchvision.io import ImageReadMode, decode_jpeg, encode_jpeg
|
| 20 |
+
from torchvision.transforms import InterpolationMode
|
| 21 |
+
|
| 22 |
+
from kernel_image_resize import resize_normalize
|
| 23 |
+
from kernel_image_resize._pack import PIL_RESAMPLE_TO_INTERP, max_taps
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
_TV_INTERP = {"bilinear": InterpolationMode.BILINEAR, "bicubic": InterpolationMode.BICUBIC}
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def make_ragged_images(n, device, min_res, max_res, seed=0):
|
| 30 |
+
g = torch.Generator(device="cpu").manual_seed(seed)
|
| 31 |
+
images = []
|
| 32 |
+
for _ in range(n):
|
| 33 |
+
h = int(torch.randint(min_res, max_res + 1, (1,), generator=g).item())
|
| 34 |
+
w = int(torch.randint(min_res, max_res + 1, (1,), generator=g).item())
|
| 35 |
+
images.append(torch.randint(0, 256, (3, h, w), generator=g, dtype=torch.uint8).to(device))
|
| 36 |
+
return images
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def torchvision_reference(images, out_h, out_w, mean, std, rescale, interp, antialias):
|
| 40 |
+
mode = _TV_INTERP[interp]
|
| 41 |
+
mean_t = torch.tensor(mean, device=images[0].device).view(3, 1, 1)
|
| 42 |
+
std_t = torch.tensor(std, device=images[0].device).view(3, 1, 1)
|
| 43 |
+
outs = []
|
| 44 |
+
for img in images:
|
| 45 |
+
r = tvF.resize(img.float(), [out_h, out_w], interpolation=mode, antialias=antialias)
|
| 46 |
+
outs.append((r * rescale - mean_t) / std_t)
|
| 47 |
+
return torch.stack(outs)
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def build_compiled_reference(out_h, out_w, mean, std, rescale, interp, antialias, device):
|
| 51 |
+
"""torch.compile(dynamic=True) of a per-image float resize+normalize."""
|
| 52 |
+
import torch.nn.functional as F
|
| 53 |
+
|
| 54 |
+
mean_t = torch.tensor(mean, device=device).view(3, 1, 1)
|
| 55 |
+
std_t = torch.tensor(std, device=device).view(3, 1, 1)
|
| 56 |
+
mode = "bicubic" if interp == "bicubic" else "bilinear"
|
| 57 |
+
|
| 58 |
+
def _one(img):
|
| 59 |
+
r = F.interpolate(img.unsqueeze(0).float(), size=(out_h, out_w), mode=mode, antialias=antialias, align_corners=False)
|
| 60 |
+
return (r.squeeze(0) * rescale - mean_t) / std_t
|
| 61 |
+
|
| 62 |
+
compiled = torch.compile(_one, dynamic=True)
|
| 63 |
+
|
| 64 |
+
def run(images):
|
| 65 |
+
return torch.stack([compiled(img) for img in images])
|
| 66 |
+
|
| 67 |
+
return run
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def pad_stack(images):
|
| 71 |
+
"""Pad ragged CHW images to the batch-max H/W and stack into (N, C, Hmax, Wmax)."""
|
| 72 |
+
c = images[0].shape[0]
|
| 73 |
+
max_h = max(img.shape[1] for img in images)
|
| 74 |
+
max_w = max(img.shape[2] for img in images)
|
| 75 |
+
out = torch.zeros(len(images), c, max_h, max_w, dtype=images[0].dtype, device=images[0].device)
|
| 76 |
+
for i, img in enumerate(images):
|
| 77 |
+
out[i, :, : img.shape[1], : img.shape[2]] = img
|
| 78 |
+
return out
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
def build_packed_compiled_reference(out_h, out_w, mean, std, rescale, interp, antialias, device):
|
| 82 |
+
"""torch.compile of a single batched resize+normalize over a stacked (N, C, H, W) tensor."""
|
| 83 |
+
import torch.nn.functional as F
|
| 84 |
+
|
| 85 |
+
mean_t = torch.tensor(mean, device=device).view(1, 3, 1, 1)
|
| 86 |
+
std_t = torch.tensor(std, device=device).view(1, 3, 1, 1)
|
| 87 |
+
mode = "bicubic" if interp == "bicubic" else "bilinear"
|
| 88 |
+
|
| 89 |
+
def _batch(stacked):
|
| 90 |
+
r = F.interpolate(stacked.float(), size=(out_h, out_w), mode=mode, antialias=antialias, align_corners=False)
|
| 91 |
+
return (r * rescale - mean_t) / std_t
|
| 92 |
+
|
| 93 |
+
return torch.compile(_batch, dynamic=True)
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
def run_inference(model_id, images, block, iters, device):
|
| 97 |
+
"""End-to-end: preprocess (processor / separable / fused / compiled) -> vision features (bf16 forward).
|
| 98 |
+
Checks each kernel feeds the model with no feature drift and times the full pipeline."""
|
| 99 |
+
from transformers import AutoModel
|
| 100 |
+
|
| 101 |
+
proc, (out_h, out_w, mean, std, rescale, interp, antialias) = load_processor_config(model_id)
|
| 102 |
+
model = AutoModel.from_pretrained(model_id).to(device=device, dtype=torch.bfloat16).eval()
|
| 103 |
+
vision = model.vision_model
|
| 104 |
+
kk = dict(size=(out_h, out_w), image_mean=mean, image_std=std, rescale_factor=rescale,
|
| 105 |
+
resample=interp, antialias=antialias, block=block)
|
| 106 |
+
|
| 107 |
+
@torch.no_grad()
|
| 108 |
+
def features(pixel_values):
|
| 109 |
+
out = vision(pixel_values=pixel_values.to(model.dtype))
|
| 110 |
+
pooled = getattr(out, "pooler_output", None)
|
| 111 |
+
return pooled if pooled is not None else out.last_hidden_state
|
| 112 |
+
|
| 113 |
+
compiled_one = build_compiled_reference(out_h, out_w, mean, std, rescale, interp, antialias, device)
|
| 114 |
+
methods = {
|
| 115 |
+
"processor": lambda: proc(images, return_tensors="pt", device=device)["pixel_values"],
|
| 116 |
+
"separable": lambda: resize_normalize(images, backend="separable", **kk),
|
| 117 |
+
"fused": lambda: resize_normalize(images, backend="fused", **kk),
|
| 118 |
+
"compiled": lambda: compiled_one(images),
|
| 119 |
+
}
|
| 120 |
+
methods["compiled"]() # warmup the compiled artifact
|
| 121 |
+
methods["compiled"]()
|
| 122 |
+
torch.cuda.synchronize()
|
| 123 |
+
|
| 124 |
+
print(f"\n[infer] {model_id} out={out_h}x{out_w} forward dtype=bfloat16")
|
| 125 |
+
base = features(methods["processor"]())
|
| 126 |
+
base_scale = base.abs().max().item()
|
| 127 |
+
for name in ("separable", "fused", "compiled"):
|
| 128 |
+
d = (features(methods[name]()) - base).abs().max().item()
|
| 129 |
+
print(f"[infer parity] features {name} vs processor: max|Δ| = {d:.2e} ({d / base_scale:.1%} of feature max)")
|
| 130 |
+
|
| 131 |
+
# forward is timed on a FIXED precomputed tensor, so it is method-independent by construction;
|
| 132 |
+
# if it varies across rows, the preprocessor's output (dtype/contiguity) is hurting the model.
|
| 133 |
+
print("[infer] ms/iter: preprocess forward(fixed input) preprocess+forward")
|
| 134 |
+
for name, preprocess in methods.items():
|
| 135 |
+
pixel_values = preprocess()
|
| 136 |
+
pre = _time(preprocess, iters, device)
|
| 137 |
+
fwd = _time(lambda pixel_values=pixel_values: features(pixel_values), iters, device)
|
| 138 |
+
e2e = _time(lambda preprocess=preprocess: features(preprocess()), iters, device)
|
| 139 |
+
print(f" {name:10s} {pre:8.3f} {fwd:8.3f} {e2e:8.3f}")
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
def run_decode(images_cpu, out_h, out_w, mean, std, rescale, interp, antialias, block, iters, device):
|
| 143 |
+
"""Data-path table from JPEG bytes: CPU decode (libjpeg) vs GPU decode (nvJPEG) + the kernel.
|
| 144 |
+
|
| 145 |
+
decoders differ at the pixel level (nvJPEG vs libjpeg), so this measures wall-clock, not parity.
|
| 146 |
+
"""
|
| 147 |
+
jpeg = [encode_jpeg(img, quality=95) for img in images_cpu]
|
| 148 |
+
avg_kb = sum(b.numel() for b in jpeg) / len(jpeg) / 1024
|
| 149 |
+
kk = dict(size=(out_h, out_w), image_mean=mean, image_std=std, rescale_factor=rescale,
|
| 150 |
+
resample=interp, antialias=antialias, block=block)
|
| 151 |
+
|
| 152 |
+
def cpu_decode_kernel():
|
| 153 |
+
imgs = [decode_jpeg(b, mode=ImageReadMode.RGB).to(device) for b in jpeg]
|
| 154 |
+
return resize_normalize(imgs, backend="separable", **kk)
|
| 155 |
+
|
| 156 |
+
def gpu_decode_kernel():
|
| 157 |
+
imgs = decode_jpeg(jpeg, mode=ImageReadMode.RGB, device=device)
|
| 158 |
+
return resize_normalize(imgs, backend="separable", **kk)
|
| 159 |
+
|
| 160 |
+
def gpu_decode_torchvision():
|
| 161 |
+
imgs = decode_jpeg(jpeg, mode=ImageReadMode.RGB, device=device)
|
| 162 |
+
return torchvision_reference(imgs, out_h, out_w, mean, std, rescale, interp, antialias)
|
| 163 |
+
|
| 164 |
+
def cpu_decode_torchvision():
|
| 165 |
+
imgs = [decode_jpeg(b, mode=ImageReadMode.RGB).to(device) for b in jpeg]
|
| 166 |
+
return torchvision_reference(imgs, out_h, out_w, mean, std, rescale, interp, antialias)
|
| 167 |
+
|
| 168 |
+
print(f"\n[decode] N={len(jpeg)} avg={avg_kb:.0f} KB/img out={out_h}x{out_w} (from JPEG bytes, ms/iter)")
|
| 169 |
+
print(f" CPU decode + torchvision resize : {_time(cpu_decode_torchvision, iters, device):8.3f} [status quo data path]")
|
| 170 |
+
print(f" CPU decode + separable kernel : {_time(cpu_decode_kernel, iters, device):8.3f}")
|
| 171 |
+
print(f" GPU decode (nvJPEG) + tv resize : {_time(gpu_decode_torchvision, iters, device):8.3f} [GPU pipeline, tv resize]")
|
| 172 |
+
print(f" GPU decode (nvJPEG) + kernel : {_time(gpu_decode_kernel, iters, device):8.3f} [GPU pipeline, kernel resize]")
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
def load_processor_config(name):
|
| 176 |
+
from transformers import AutoImageProcessor
|
| 177 |
+
|
| 178 |
+
proc = AutoImageProcessor.from_pretrained(name, backend="torchvision")
|
| 179 |
+
size = proc.size
|
| 180 |
+
if "height" not in size or "width" not in size:
|
| 181 |
+
raise ValueError(f"{name}: size={size} is not a fixed (height, width)")
|
| 182 |
+
out_h, out_w = size["height"], size["width"]
|
| 183 |
+
interp = PIL_RESAMPLE_TO_INTERP.get(int(proc.resample))
|
| 184 |
+
rescale = float(proc.rescale_factor) if getattr(proc, "do_rescale", True) else 1.0
|
| 185 |
+
antialias = bool(getattr(proc, "antialias", True))
|
| 186 |
+
return proc, (out_h, out_w, list(proc.image_mean), list(proc.image_std), rescale, interp, antialias)
|
| 187 |
+
|
| 188 |
+
|
| 189 |
+
def _time(fn, iters, device):
|
| 190 |
+
for _ in range(3):
|
| 191 |
+
fn()
|
| 192 |
+
if device.type == "cuda":
|
| 193 |
+
torch.cuda.synchronize()
|
| 194 |
+
start, end = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True)
|
| 195 |
+
start.record()
|
| 196 |
+
for _ in range(iters):
|
| 197 |
+
fn()
|
| 198 |
+
end.record()
|
| 199 |
+
torch.cuda.synchronize()
|
| 200 |
+
return start.elapsed_time(end) / iters
|
| 201 |
+
t0 = time.perf_counter()
|
| 202 |
+
for _ in range(iters):
|
| 203 |
+
fn()
|
| 204 |
+
return (time.perf_counter() - t0) / iters * 1e3
|
| 205 |
+
|
| 206 |
+
|
| 207 |
+
def main():
|
| 208 |
+
parser = argparse.ArgumentParser()
|
| 209 |
+
parser.add_argument("--processor", default=None)
|
| 210 |
+
parser.add_argument("--n", type=int, default=32)
|
| 211 |
+
parser.add_argument("--out", type=int, nargs=2, default=[384, 384], metavar=("H", "W"))
|
| 212 |
+
parser.add_argument("--interp", choices=["bilinear", "bicubic"], default="bicubic")
|
| 213 |
+
parser.add_argument("--antialias", action="store_true")
|
| 214 |
+
parser.add_argument("--min-res", type=int, default=384)
|
| 215 |
+
parser.add_argument("--max-res", type=int, default=1024)
|
| 216 |
+
parser.add_argument("--iters", type=int, default=50)
|
| 217 |
+
parser.add_argument("--block", type=int, default=256)
|
| 218 |
+
parser.add_argument("--tol", type=float, default=3e-3)
|
| 219 |
+
parser.add_argument("--infer", action="store_true", help="end-to-end Siglip2 inference comparison (bf16 forward)")
|
| 220 |
+
parser.add_argument("--model", default="google/siglip2-base-patch16-224", help="model for --infer")
|
| 221 |
+
parser.add_argument("--decode", action="store_true", help="JPEG-decode data-path table (CPU vs GPU/nvJPEG) and stop")
|
| 222 |
+
args = parser.parse_args()
|
| 223 |
+
|
| 224 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 225 |
+
if device.type != "cuda":
|
| 226 |
+
print("benchmark needs CUDA.")
|
| 227 |
+
return
|
| 228 |
+
|
| 229 |
+
proc = None
|
| 230 |
+
if args.processor:
|
| 231 |
+
proc, (out_h, out_w, mean, std, rescale, interp, antialias) = load_processor_config(args.processor)
|
| 232 |
+
print(f"processor={args.processor} -> out={out_h}x{out_w} interp={interp} antialias={antialias}")
|
| 233 |
+
else:
|
| 234 |
+
out_h, out_w = args.out
|
| 235 |
+
mean = [0.48145466, 0.4578275, 0.40821073]
|
| 236 |
+
std = [0.26862954, 0.26130258, 0.27577711]
|
| 237 |
+
rescale = 1.0 / 255.0
|
| 238 |
+
interp, antialias = args.interp, args.antialias
|
| 239 |
+
|
| 240 |
+
images = make_ragged_images(args.n, device, args.min_res, args.max_res)
|
| 241 |
+
taps = (max_taps(images, out_h, 1, interp, antialias), max_taps(images, out_w, 2, interp, antialias))
|
| 242 |
+
print(f"N={args.n} in∈[{args.min_res},{args.max_res}]² ragged out={out_h}x{out_w} "
|
| 243 |
+
f"interp={interp} antialias={antialias} max_taps={taps} iters={args.iters}\n")
|
| 244 |
+
|
| 245 |
+
if args.decode:
|
| 246 |
+
images_cpu = make_ragged_images(args.n, torch.device("cpu"), args.min_res, args.max_res)
|
| 247 |
+
run_decode(images_cpu, out_h, out_w, mean, std, rescale, interp, antialias, args.block, args.iters, device)
|
| 248 |
+
return
|
| 249 |
+
|
| 250 |
+
ref = torchvision_reference(images, out_h, out_w, mean, std, rescale, interp, antialias)
|
| 251 |
+
common = dict(size=(out_h, out_w), image_mean=mean, image_std=std, rescale_factor=rescale,
|
| 252 |
+
resample=interp, antialias=antialias, block=args.block)
|
| 253 |
+
for backend in ("fused", "separable"):
|
| 254 |
+
got = resize_normalize(images, backend=backend, **common)
|
| 255 |
+
d = (got - ref).abs().max().item()
|
| 256 |
+
print(f"[parity] {backend:9s} vs torchvision(float): max|Δ| = {d:.2e} "
|
| 257 |
+
f"({'PASS' if d < args.tol else 'FAIL'} @ tol={args.tol})")
|
| 258 |
+
print()
|
| 259 |
+
|
| 260 |
+
compiled_run = build_compiled_reference(out_h, out_w, mean, std, rescale, interp, antialias, device)
|
| 261 |
+
packed = pad_stack(images)
|
| 262 |
+
packed_compiled_run = build_packed_compiled_reference(out_h, out_w, mean, std, rescale, interp, antialias, device)
|
| 263 |
+
t0 = time.perf_counter()
|
| 264 |
+
compiled_run(images)
|
| 265 |
+
compiled_run(images)
|
| 266 |
+
packed_compiled_run(packed)
|
| 267 |
+
packed_compiled_run(packed)
|
| 268 |
+
torch.cuda.synchronize()
|
| 269 |
+
t_warmup = (time.perf_counter() - t0) * 1e3
|
| 270 |
+
|
| 271 |
+
t_eager = _time(lambda: torchvision_reference(images, out_h, out_w, mean, std, rescale, interp, antialias), args.iters, device)
|
| 272 |
+
t_comp = _time(lambda: compiled_run(images), args.iters, device)
|
| 273 |
+
t_comp_packed = _time(lambda: packed_compiled_run(packed), args.iters, device)
|
| 274 |
+
t_fused = _time(lambda: resize_normalize(images, backend="fused", **common), args.iters, device)
|
| 275 |
+
t_sep = _time(lambda: resize_normalize(images, backend="separable", **common), args.iters, device)
|
| 276 |
+
print("Resize+normalize only (no decode/H2D), ms/iter:")
|
| 277 |
+
print(f" torchvision eager loop : {t_eager:8.3f} [per-image float loop]")
|
| 278 |
+
print(f" torchvision compiled : {t_comp:8.3f} [torch.compile dynamic per-image; warmup {t_warmup:.0f} ms excluded]")
|
| 279 |
+
print(f" torchvision compiled pkt: {t_comp_packed:8.3f} [one graph over padded (N,C,Hmax,Wmax) stack; timing only, padding alters output]")
|
| 280 |
+
print(f" fused triton (2D) : {t_fused:8.3f} [taps*taps]")
|
| 281 |
+
print(f" separable triton (uint8): {t_sep:8.3f} [taps+taps]")
|
| 282 |
+
|
| 283 |
+
if proc is not None:
|
| 284 |
+
t_pr = _time(lambda: proc(images, return_tensors="pt", device=device)["pixel_values"], args.iters, device)
|
| 285 |
+
print(f"\n {args.processor} : {t_pr:8.3f} ms/iter")
|
| 286 |
+
print(f" -> separable is {t_sep / t_pr:.2f}x the real processor")
|
| 287 |
+
|
| 288 |
+
if args.infer:
|
| 289 |
+
run_inference(args.model, images, args.block, args.iters, device)
|
| 290 |
+
|
| 291 |
+
|
| 292 |
+
if __name__ == "__main__":
|
| 293 |
+
main()
|
benchmarks/compat_check.py
ADDED
|
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Compatibility + forward-parity sweep over the most-downloaded image models (one per architecture).
|
| 2 |
+
|
| 3 |
+
For each model: load its processor + model, decide whether the kernel can stand in for the
|
| 4 |
+
processor's resize(+crop)+normalize, and for supported ones run processor-vs-kernel pixel_values
|
| 5 |
+
through the SAME vision tower and report pixel + feature parity. Unsupported models list the reason.
|
| 6 |
+
|
| 7 |
+
Run on the DGX (CUDA + working transformers):
|
| 8 |
+
PYTHONPATH=../torch-ext python compat_check.py
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
import sys
|
| 12 |
+
|
| 13 |
+
# This transformers worktree constructs kernels.LayerRepository without a version/revision, which
|
| 14 |
+
# newer `kernels` rejects at import. We do not need hub LAYER kernels for preprocessing, so hide
|
| 15 |
+
# `kernels` from transformers — it falls back to its no-hub-kernels stub path and imports cleanly.
|
| 16 |
+
sys.modules["kernels"] = None
|
| 17 |
+
|
| 18 |
+
import torch
|
| 19 |
+
|
| 20 |
+
from kernel_image_resize import resize_normalize # local package, via PYTHONPATH=../torch-ext
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
_PIL_RESAMPLE = {2: "bilinear", 3: "bicubic"}
|
| 24 |
+
|
| 25 |
+
# Top image models by HF downloads (June 2026), deduplicated to one repo per architecture family.
|
| 26 |
+
MODELS = [
|
| 27 |
+
("openai/clip-vit-base-patch32", 20528683), # clip
|
| 28 |
+
("google/vit-base-patch16-224", 4910416), # vit
|
| 29 |
+
("apple/mobilevit-small", 3488074), # mobilevit
|
| 30 |
+
("facebook/dinov2-small", 2602780), # dinov2
|
| 31 |
+
("google/siglip-so400m-patch14-384", 1379598), # siglip
|
| 32 |
+
("facebook/dinov3-vitb16-pretrain-lvd1689m", 467337), # dinov3
|
| 33 |
+
("microsoft/swinv2-tiny-patch4-window16-256", 385713), # swinv2
|
| 34 |
+
("google/siglip2-base-patch16-224", 336824), # siglip2
|
| 35 |
+
("microsoft/resnet-50", 307057), # resnet (convnext processor)
|
| 36 |
+
("nvidia/segformer-b0-finetuned-ade-512-512", 262459), # segformer
|
| 37 |
+
("facebook/convnextv2-tiny-22k-384", 48614), # convnextv2
|
| 38 |
+
("google/mobilenet_v2_1.0_224", 48342), # mobilenet
|
| 39 |
+
("facebook/convnext-tiny-224", 16984), # convnext
|
| 40 |
+
("google/efficientnet-b0", 8577), # efficientnet
|
| 41 |
+
("microsoft/beit-base-patch16-224-pt22k-ft22k", 7529), # beit
|
| 42 |
+
]
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def unsupported_reason(p):
|
| 46 |
+
"""Return None if the kernel can stand in for this processor, else a short reason."""
|
| 47 |
+
if not getattr(p, "do_resize", True):
|
| 48 |
+
return "no resize"
|
| 49 |
+
if not getattr(p, "do_normalize", True):
|
| 50 |
+
return "no normalize (rescale only)"
|
| 51 |
+
if getattr(p, "do_flip_channel_order", False):
|
| 52 |
+
return "channel flip (BGR)"
|
| 53 |
+
if getattr(p, "do_pad", False):
|
| 54 |
+
return "pad"
|
| 55 |
+
if int(getattr(p, "resample", 2)) not in _PIL_RESAMPLE:
|
| 56 |
+
return f"resample {p.resample}"
|
| 57 |
+
size = getattr(p, "size", {}) or {}
|
| 58 |
+
crop = p.crop_size if getattr(p, "do_center_crop", False) else None
|
| 59 |
+
if "shortest_edge" in size:
|
| 60 |
+
return None if crop else "shortest_edge without crop (variable output)"
|
| 61 |
+
if "height" in size and "width" in size:
|
| 62 |
+
return None
|
| 63 |
+
return f"size {size}"
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def preprocess_with_kernel(p, images):
|
| 67 |
+
size = p.size
|
| 68 |
+
resample = _PIL_RESAMPLE[int(p.resample)]
|
| 69 |
+
antialias = bool(getattr(p, "antialias", True))
|
| 70 |
+
rescale = float(p.rescale_factor) if getattr(p, "do_rescale", True) else 1.0
|
| 71 |
+
mean, std = p.image_mean, p.image_std
|
| 72 |
+
crop = p.crop_size if getattr(p, "do_center_crop", False) else None
|
| 73 |
+
common = dict(rescale_factor=rescale, resample=resample, antialias=antialias)
|
| 74 |
+
if "shortest_edge" in size:
|
| 75 |
+
return resize_normalize(
|
| 76 |
+
images, size["shortest_edge"], mean, std,
|
| 77 |
+
crop_size=(crop["height"], crop["width"]), resize_mode="shortest_edge", **common)
|
| 78 |
+
if crop is not None and (crop["height"] != size["height"] or crop["width"] != size["width"]):
|
| 79 |
+
return resize_normalize(
|
| 80 |
+
images, (size["height"], size["width"]), mean, std,
|
| 81 |
+
crop_size=(crop["height"], crop["width"]), resize_mode="square", **common)
|
| 82 |
+
return resize_normalize(images, (size["height"], size["width"]), mean, std, **common)
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
def vision_features(model, pixel_values):
|
| 86 |
+
tower = getattr(model, "vision_model", model)
|
| 87 |
+
out = tower(pixel_values=pixel_values.to(model.dtype))
|
| 88 |
+
for attr in ("pooler_output", "last_hidden_state"):
|
| 89 |
+
value = getattr(out, attr, None)
|
| 90 |
+
if value is not None and torch.is_tensor(value):
|
| 91 |
+
return value
|
| 92 |
+
return out[0]
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
def main():
|
| 96 |
+
from transformers import AutoImageProcessor, AutoModel # lazy: avoids importing the kernels lib first
|
| 97 |
+
|
| 98 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 99 |
+
images = [
|
| 100 |
+
torch.randint(0, 256, (3, h, w), dtype=torch.uint8, device=device)
|
| 101 |
+
for h, w in [(640, 480), (800, 600), (512, 512), (384, 1024)]
|
| 102 |
+
]
|
| 103 |
+
print(f"{'model':46s} {'verdict':10s} pixel max|Δ| feature max|Δ| (rel)")
|
| 104 |
+
for model_id, _ in MODELS:
|
| 105 |
+
try:
|
| 106 |
+
processor = AutoImageProcessor.from_pretrained(model_id)
|
| 107 |
+
reason = unsupported_reason(processor)
|
| 108 |
+
if reason is not None:
|
| 109 |
+
print(f"{model_id:46s} SKIP: {reason}")
|
| 110 |
+
continue
|
| 111 |
+
model = AutoModel.from_pretrained(model_id).to(device).eval()
|
| 112 |
+
reference_pv = processor(images, return_tensors="pt", device=device)["pixel_values"].to(device)
|
| 113 |
+
kernel_pv = preprocess_with_kernel(processor, images)
|
| 114 |
+
pixel_delta = (kernel_pv - reference_pv).abs().max().item()
|
| 115 |
+
with torch.no_grad():
|
| 116 |
+
base = vision_features(model, reference_pv)
|
| 117 |
+
feat_delta = (vision_features(model, kernel_pv) - base).abs().max().item()
|
| 118 |
+
rel = feat_delta / base.abs().max().item()
|
| 119 |
+
print(f"{model_id:46s} OK {pixel_delta:.2e} {feat_delta:.2e} ({rel:.1%})")
|
| 120 |
+
del model
|
| 121 |
+
torch.cuda.empty_cache()
|
| 122 |
+
except Exception as e:
|
| 123 |
+
print(f"{model_id:46s} ERROR: {type(e).__name__}: {str(e)[:55]}")
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
if __name__ == "__main__":
|
| 127 |
+
main()
|
build.toml
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[general]
|
| 2 |
+
name = "kernel_image_resize"
|
| 3 |
+
universal = true
|
| 4 |
+
version = 1
|
| 5 |
+
|
| 6 |
+
[general.hub]
|
| 7 |
+
repo-id = "Molbap/kernel_image_resize"
|
build/torch-universal/kernel_image_resize/__init__.py
ADDED
|
@@ -0,0 +1,113 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Resize + rescale + normalize for transformers fast image processors, as a Triton kernel.
|
| 2 |
+
|
| 3 |
+
resize -> fold(rescale, normalize) in one GPU pipeline: CHW uint8 images in,
|
| 4 |
+
(N, C, out_h, out_w) normalized float out, no full-resolution float intermediate.
|
| 5 |
+
|
| 6 |
+
- resize_normalize — stacked (N, C, H, W) tensor or a list of CHW images.
|
| 7 |
+
- resize_normalize_ragged — same kernels; takes a list of different-H/W CHW tensors.
|
| 8 |
+
|
| 9 |
+
backend="separable" (default): two-pass uint8, taps+taps. backend="fused": single 2D
|
| 10 |
+
launch, taps*taps. Both parity <=1e-4 vs torchvision-float.
|
| 11 |
+
|
| 12 |
+
from kernels import get_kernel
|
| 13 |
+
kir = get_kernel("Molbap/kernel_image_resize")
|
| 14 |
+
pixel_values = kir.resize_normalize(
|
| 15 |
+
images, size=384, image_mean=[...], image_std=[...], resample="bicubic", antialias=True,
|
| 16 |
+
)
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
from ._fused import fused_resize_normalize
|
| 20 |
+
from ._pack import PIL_RESAMPLE_TO_INTERP, as_image_list
|
| 21 |
+
from ._separable import separable_resize_crop_normalize, separable_resize_normalize
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def _normalize_size(size) -> tuple[int, int]:
|
| 25 |
+
if isinstance(size, int):
|
| 26 |
+
return size, size
|
| 27 |
+
if isinstance(size, dict):
|
| 28 |
+
if "height" in size and "width" in size:
|
| 29 |
+
return int(size["height"]), int(size["width"])
|
| 30 |
+
raise ValueError(f"size dict must hold 'height'/'width' for a fixed resize, got {size}")
|
| 31 |
+
out_h, out_w = size
|
| 32 |
+
return int(out_h), int(out_w)
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def _normalize_resample(resample) -> str:
|
| 36 |
+
if isinstance(resample, str):
|
| 37 |
+
if resample not in ("bilinear", "bicubic"):
|
| 38 |
+
raise ValueError(f"resample must be 'bilinear' or 'bicubic', got {resample!r}")
|
| 39 |
+
return resample
|
| 40 |
+
interp = PIL_RESAMPLE_TO_INTERP.get(int(resample))
|
| 41 |
+
if interp is None:
|
| 42 |
+
raise ValueError(f"unsupported PIL resample code {resample}")
|
| 43 |
+
return interp
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def resize_normalize(
|
| 47 |
+
images,
|
| 48 |
+
size,
|
| 49 |
+
image_mean,
|
| 50 |
+
image_std,
|
| 51 |
+
rescale_factor: float = 1.0 / 255.0,
|
| 52 |
+
resample="bilinear",
|
| 53 |
+
antialias: bool = False,
|
| 54 |
+
backend: str = "separable",
|
| 55 |
+
block: int = 256,
|
| 56 |
+
crop_size=None,
|
| 57 |
+
resize_mode: str = "square",
|
| 58 |
+
):
|
| 59 |
+
"""Resize, optionally center-crop, rescale and normalize — one GPU pipeline.
|
| 60 |
+
|
| 61 |
+
Args:
|
| 62 |
+
images: a stacked `(N, C, H, W)` uint8/float tensor, or a list of CHW tensors (ragged).
|
| 63 |
+
size: resize target. With no crop: int (square), `(height, width)`, or `{"height","width"}`.
|
| 64 |
+
With `resize_mode="shortest_edge"`: an int, the short side after aspect-preserving resize.
|
| 65 |
+
image_mean, image_std: per-channel normalization stats (length C).
|
| 66 |
+
rescale_factor: folded into mean/std so the kernel does `(x*rescale - mean)/std`.
|
| 67 |
+
resample: "bilinear" / "bicubic", or a PIL resample int (0/2 -> bilinear, 3 -> bicubic).
|
| 68 |
+
antialias: match the ViT/CLIP/SigLIP default (`True` for those processors).
|
| 69 |
+
backend: "separable" (default) or "fused" (2D reference, no crop support).
|
| 70 |
+
crop_size: `None` (no crop), int (square), or `(crop_h, crop_w)`. Center crop after resize.
|
| 71 |
+
resize_mode: "square" (resize to `size`) or "shortest_edge" (aspect-preserving, needs a crop).
|
| 72 |
+
"""
|
| 73 |
+
interp = _normalize_resample(resample)
|
| 74 |
+
image_list = as_image_list(images)
|
| 75 |
+
|
| 76 |
+
if crop_size is not None or resize_mode == "shortest_edge":
|
| 77 |
+
crop_h, crop_w = _normalize_size(crop_size if crop_size is not None else size)
|
| 78 |
+
resize_arg = int(size) if resize_mode == "shortest_edge" else _normalize_size(size)
|
| 79 |
+
return separable_resize_crop_normalize(
|
| 80 |
+
image_list, resize_arg, (crop_h, crop_w), image_mean, image_std, rescale_factor,
|
| 81 |
+
interp, antialias, resize_mode, block,
|
| 82 |
+
)
|
| 83 |
+
|
| 84 |
+
out_h, out_w = _normalize_size(size)
|
| 85 |
+
if backend == "fused":
|
| 86 |
+
return fused_resize_normalize(image_list, out_h, out_w, image_mean, image_std, rescale_factor, interp, antialias, block)
|
| 87 |
+
if backend == "separable":
|
| 88 |
+
return separable_resize_normalize(image_list, out_h, out_w, image_mean, image_std, rescale_factor, interp, antialias, block)
|
| 89 |
+
raise ValueError(f"backend must be 'fused' or 'separable', got {backend!r}")
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
def resize_normalize_ragged(
|
| 93 |
+
images,
|
| 94 |
+
size,
|
| 95 |
+
image_mean,
|
| 96 |
+
image_std,
|
| 97 |
+
rescale_factor: float = 1.0 / 255.0,
|
| 98 |
+
resample="bilinear",
|
| 99 |
+
antialias: bool = False,
|
| 100 |
+
backend: str = "separable",
|
| 101 |
+
block: int = 256,
|
| 102 |
+
):
|
| 103 |
+
"""Variant taking a list of different-H/W CHW tensors. Same kernels as `resize_normalize`."""
|
| 104 |
+
if isinstance(images, list):
|
| 105 |
+
image_list = images
|
| 106 |
+
else:
|
| 107 |
+
raise ValueError("resize_normalize_ragged expects a list of CHW tensors; use resize_normalize for a stacked tensor")
|
| 108 |
+
return resize_normalize(
|
| 109 |
+
image_list, size, image_mean, image_std, rescale_factor, resample, antialias, backend, block
|
| 110 |
+
)
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
__all__ = ["resize_normalize", "resize_normalize_ragged"]
|
build/torch-universal/kernel_image_resize/_fused.py
ADDED
|
@@ -0,0 +1,134 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Fused 2D resize+rescale+normalize over a ragged batch, single launch.
|
| 2 |
+
|
| 3 |
+
One program owns one image and a BLOCK of its output pixels, gathers a
|
| 4 |
+
MAX_TAPS_H × MAX_TAPS_W window, applies the separable weights as a 2D product, then folds
|
| 5 |
+
rescale+normalize. taps×taps loads per output pixel.
|
| 6 |
+
|
| 7 |
+
Resampling-weight formula (PyTorch aten UpSampleKernel):
|
| 8 |
+
scale = in / out
|
| 9 |
+
support = interp_half * (scale if antialias and scale > 1 else 1) # interp_half: 1 linear, 2 cubic
|
| 10 |
+
center = scale * (i + 0.5)
|
| 11 |
+
weight = filter((tap - center + 0.5) / eff), renormalized over the realized window
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
import triton
|
| 15 |
+
import triton.language as tl
|
| 16 |
+
|
| 17 |
+
from ._pack import fold_mean_std, max_taps, pack_images
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
@triton.jit
|
| 21 |
+
def _resample_weight(arg, cubic_a, CUBIC: tl.constexpr):
|
| 22 |
+
"""Interpolation filter at `arg` (coordinate distance already divided by support)."""
|
| 23 |
+
ax = tl.abs(arg)
|
| 24 |
+
if CUBIC: # Keys cubic convolution kernel, support 2
|
| 25 |
+
ax2 = ax * ax
|
| 26 |
+
ax3 = ax2 * ax
|
| 27 |
+
inner = (cubic_a + 2.0) * ax3 - (cubic_a + 3.0) * ax2 + 1.0 # |x| <= 1
|
| 28 |
+
outer = cubic_a * ax3 - 5.0 * cubic_a * ax2 + 8.0 * cubic_a * ax - 4.0 * cubic_a # 1 < |x| < 2
|
| 29 |
+
return tl.where(ax <= 1.0, inner, tl.where(ax < 2.0, outer, 0.0))
|
| 30 |
+
return tl.maximum(1.0 - ax, 0.0) # triangle (bilinear), support 1
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
@triton.jit
|
| 34 |
+
def _resize_normalize_kernel(
|
| 35 |
+
in_ptr, out_ptr, offsets_ptr, heights_ptr, widths_ptr, mean_ptr, std_ptr,
|
| 36 |
+
out_h, out_w, cubic_a,
|
| 37 |
+
C: tl.constexpr, BLOCK: tl.constexpr,
|
| 38 |
+
CUBIC: tl.constexpr, ANTIALIAS: tl.constexpr,
|
| 39 |
+
MAX_TAPS_H: tl.constexpr, MAX_TAPS_W: tl.constexpr,
|
| 40 |
+
):
|
| 41 |
+
n = tl.program_id(0)
|
| 42 |
+
blk = tl.program_id(1)
|
| 43 |
+
H = tl.load(heights_ptr + n)
|
| 44 |
+
W = tl.load(widths_ptr + n)
|
| 45 |
+
off = tl.load(offsets_ptr + n)
|
| 46 |
+
Hf = H.to(tl.float32)
|
| 47 |
+
Wf = W.to(tl.float32)
|
| 48 |
+
|
| 49 |
+
npix = out_h * out_w
|
| 50 |
+
pos = blk * BLOCK + tl.arange(0, BLOCK)
|
| 51 |
+
mask = pos < npix
|
| 52 |
+
oy = pos // out_w
|
| 53 |
+
ox = pos % out_w
|
| 54 |
+
|
| 55 |
+
interp_half = 2.0 if CUBIC else 1.0
|
| 56 |
+
scale_h = Hf / out_h
|
| 57 |
+
scale_w = Wf / out_w
|
| 58 |
+
eff_h = tl.maximum(scale_h, 1.0) if ANTIALIAS else 1.0
|
| 59 |
+
eff_w = tl.maximum(scale_w, 1.0) if ANTIALIAS else 1.0
|
| 60 |
+
support_h = interp_half * eff_h
|
| 61 |
+
support_w = interp_half * eff_w
|
| 62 |
+
inv_h = 1.0 / eff_h
|
| 63 |
+
inv_w = 1.0 / eff_w
|
| 64 |
+
|
| 65 |
+
center_y = scale_h * (oy.to(tl.float32) + 0.5)
|
| 66 |
+
center_x = scale_w * (ox.to(tl.float32) + 0.5)
|
| 67 |
+
ystart = tl.floor(center_y - support_h + 0.5)
|
| 68 |
+
xstart = tl.floor(center_x - support_w + 0.5)
|
| 69 |
+
|
| 70 |
+
sum_wy = tl.zeros([BLOCK], dtype=tl.float32)
|
| 71 |
+
for ty in tl.static_range(MAX_TAPS_H):
|
| 72 |
+
yy = ystart + ty
|
| 73 |
+
wy = _resample_weight((yy - center_y + 0.5) * inv_h, cubic_a, CUBIC)
|
| 74 |
+
if ANTIALIAS:
|
| 75 |
+
wy = tl.where((yy >= 0.0) & (yy < Hf), wy, 0.0)
|
| 76 |
+
sum_wy += wy
|
| 77 |
+
sum_wx = tl.zeros([BLOCK], dtype=tl.float32)
|
| 78 |
+
for tx in tl.static_range(MAX_TAPS_W):
|
| 79 |
+
xx = xstart + tx
|
| 80 |
+
wx = _resample_weight((xx - center_x + 0.5) * inv_w, cubic_a, CUBIC)
|
| 81 |
+
if ANTIALIAS:
|
| 82 |
+
wx = tl.where((xx >= 0.0) & (xx < Wf), wx, 0.0)
|
| 83 |
+
sum_wx += wx
|
| 84 |
+
denom = sum_wy * sum_wx
|
| 85 |
+
|
| 86 |
+
plane = (H * W).to(tl.int64)
|
| 87 |
+
Wl = W.to(tl.int64)
|
| 88 |
+
for c in tl.static_range(C):
|
| 89 |
+
base = off + c * plane
|
| 90 |
+
acc = tl.zeros([BLOCK], dtype=tl.float32)
|
| 91 |
+
for ty in tl.static_range(MAX_TAPS_H):
|
| 92 |
+
yy = ystart + ty
|
| 93 |
+
wy = _resample_weight((yy - center_y + 0.5) * inv_h, cubic_a, CUBIC)
|
| 94 |
+
if ANTIALIAS:
|
| 95 |
+
wy = tl.where((yy >= 0.0) & (yy < Hf), wy, 0.0)
|
| 96 |
+
yidx = tl.minimum(tl.maximum(yy.to(tl.int32), 0), H - 1).to(tl.int64)
|
| 97 |
+
row = base + yidx * Wl
|
| 98 |
+
for tx in tl.static_range(MAX_TAPS_W):
|
| 99 |
+
xx = xstart + tx
|
| 100 |
+
wx = _resample_weight((xx - center_x + 0.5) * inv_w, cubic_a, CUBIC)
|
| 101 |
+
if ANTIALIAS:
|
| 102 |
+
wx = tl.where((xx >= 0.0) & (xx < Wf), wx, 0.0)
|
| 103 |
+
xidx = tl.minimum(tl.maximum(xx.to(tl.int32), 0), W - 1).to(tl.int64)
|
| 104 |
+
pix = tl.load(in_ptr + row + xidx, mask=mask, other=0.0)
|
| 105 |
+
acc += wy * wx * pix
|
| 106 |
+
acc = acc / denom
|
| 107 |
+
m = tl.load(mean_ptr + c)
|
| 108 |
+
s = tl.load(std_ptr + c)
|
| 109 |
+
acc = (acc - m) / s
|
| 110 |
+
oidx = ((n * C + c) * out_h + oy) * out_w + ox
|
| 111 |
+
tl.store(out_ptr + oidx, acc, mask=mask)
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
def fused_resize_normalize(images, out_h, out_w, mean, std, rescale, interp, antialias, block: int = 256):
|
| 115 |
+
"""Single fused launch over a ragged packed buffer -> (N, C, out_h, out_w) normalized float."""
|
| 116 |
+
import torch
|
| 117 |
+
|
| 118 |
+
images = list(images)
|
| 119 |
+
device = images[0].device
|
| 120 |
+
n = len(images)
|
| 121 |
+
cubic_a = -0.5 if antialias else -0.75 # PIL coeff under antialias, Keys coeff otherwise
|
| 122 |
+
max_taps_h = max_taps(images, out_h, 1, interp, antialias)
|
| 123 |
+
max_taps_w = max_taps(images, out_w, 2, interp, antialias)
|
| 124 |
+
mean_t, std_t = fold_mean_std(mean, std, rescale, device)
|
| 125 |
+
|
| 126 |
+
in_buf, offsets_t, heights_t, widths_t, c = pack_images(images)
|
| 127 |
+
out = torch.empty((n, c, out_h, out_w), device=device, dtype=torch.float32)
|
| 128 |
+
grid = (n, triton.cdiv(out_h * out_w, block))
|
| 129 |
+
_resize_normalize_kernel[grid](
|
| 130 |
+
in_buf, out, offsets_t, heights_t, widths_t, mean_t, std_t,
|
| 131 |
+
out_h, out_w, cubic_a, C=c, BLOCK=block,
|
| 132 |
+
CUBIC=(interp == "bicubic"), ANTIALIAS=antialias, MAX_TAPS_H=max_taps_h, MAX_TAPS_W=max_taps_w,
|
| 133 |
+
)
|
| 134 |
+
return out
|
build/torch-universal/kernel_image_resize/_pack.py
ADDED
|
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Ragged packing + resampling helpers shared by the fused and separable backends."""
|
| 2 |
+
|
| 3 |
+
import math
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
PIL_RESAMPLE_TO_INTERP = {0: "bilinear", 2: "bilinear", 3: "bicubic"}
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def pack_images(
|
| 12 |
+
images: list[torch.Tensor], dtype: torch.dtype = torch.float32
|
| 13 |
+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, int]:
|
| 14 |
+
"""Concatenate a ragged list of CHW images into one flat buffer of `dtype`.
|
| 15 |
+
|
| 16 |
+
Returns (in_buf, offsets, heights, widths, channels); offsets[n] is the element index
|
| 17 |
+
where image n starts.
|
| 18 |
+
"""
|
| 19 |
+
device = images[0].device
|
| 20 |
+
channels = images[0].shape[0]
|
| 21 |
+
flats, offsets, heights, widths, cur = [], [], [], [], 0
|
| 22 |
+
for img in images:
|
| 23 |
+
ic, ih, iw = img.shape
|
| 24 |
+
if ic != channels:
|
| 25 |
+
raise ValueError(f"all images must share channel count {channels}, got {ic}")
|
| 26 |
+
flats.append(img.reshape(-1).to(dtype))
|
| 27 |
+
offsets.append(cur)
|
| 28 |
+
heights.append(ih)
|
| 29 |
+
widths.append(iw)
|
| 30 |
+
cur += ic * ih * iw
|
| 31 |
+
in_buf = torch.cat(flats)
|
| 32 |
+
offsets_t = torch.tensor(offsets, device=device, dtype=torch.int64)
|
| 33 |
+
heights_t = torch.tensor(heights, device=device, dtype=torch.int32)
|
| 34 |
+
widths_t = torch.tensor(widths, device=device, dtype=torch.int32)
|
| 35 |
+
return in_buf, offsets_t, heights_t, widths_t, channels
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def fold_mean_std(mean, std, rescale: float, device) -> tuple[torch.Tensor, torch.Tensor]:
|
| 39 |
+
"""Fold rescale into mean/std so the kernel does (x - m)/s == (x*rescale - mean)/std."""
|
| 40 |
+
mean_t = (torch.tensor(mean, device=device, dtype=torch.float32) / rescale).contiguous()
|
| 41 |
+
std_t = (torch.tensor(std, device=device, dtype=torch.float32) / rescale).contiguous()
|
| 42 |
+
return mean_t, std_t
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def max_taps(images: list[torch.Tensor], out_size: int, axis_dim: int, interp: str, antialias: bool) -> int:
|
| 46 |
+
"""Batch-wide worst-case tap count for one axis = ceil(support) * 2 + 1."""
|
| 47 |
+
interp_half = 2.0 if interp == "bicubic" else 1.0
|
| 48 |
+
worst = 0
|
| 49 |
+
for img in images:
|
| 50 |
+
scale = img.shape[axis_dim] / out_size
|
| 51 |
+
eff = max(scale, 1.0) if antialias else 1.0
|
| 52 |
+
worst = max(worst, math.ceil(interp_half * eff) * 2 + 1)
|
| 53 |
+
return worst
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def as_image_list(images) -> list[torch.Tensor]:
|
| 57 |
+
"""Accept a stacked (N, C, H, W) tensor or a list of CHW tensors; always return a list."""
|
| 58 |
+
if isinstance(images, torch.Tensor):
|
| 59 |
+
if images.dim() != 4:
|
| 60 |
+
raise ValueError(f"stacked input must be (N, C, H, W), got shape {tuple(images.shape)}")
|
| 61 |
+
return list(images)
|
| 62 |
+
return list(images)
|
build/torch-universal/kernel_image_resize/_separable.py
ADDED
|
@@ -0,0 +1,280 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Separable resize + center-crop + normalize over a ragged uint8 batch.
|
| 2 |
+
|
| 3 |
+
WHAT "RESIZE" DOES, CONCRETELY
|
| 4 |
+
Every output pixel is a weighted average of a small window of input pixels. When you shrink
|
| 5 |
+
an image a lot (with antialiasing) that window gets wide — e.g. 13 input pixels across and
|
| 6 |
+
13 down, so 13x13 = 169 input pixels feed one output pixel.
|
| 7 |
+
|
| 8 |
+
FUSED vs SEPARABLE (the two backends in this package)
|
| 9 |
+
- FUSED (see _fused.py): for each output pixel, read the whole 2D window directly -> 169 reads.
|
| 10 |
+
- SEPARABLE (this file): do the resize as two 1D steps instead of one 2D step:
|
| 11 |
+
step 1 (horizontal): resize only the WIDTH -> an intermediate image
|
| 12 |
+
step 2 (vertical): resize only the HEIGHT -> the final image
|
| 13 |
+
Each step's window is 1D, so 13 + 13 = 26 reads per output pixel instead of 169. Same math,
|
| 14 |
+
far fewer reads. This is what PIL and torchvision do.
|
| 15 |
+
|
| 16 |
+
CENTER CROP (folded in, no extra pass)
|
| 17 |
+
Processors like CLIP / DINOv2 resize to a "resize size" and then keep only the centered
|
| 18 |
+
crop. We do not materialize the full resized image and slice it; instead each output pixel
|
| 19 |
+
of the CROP maps to a resize-image coordinate by adding the crop offset, and that maps back
|
| 20 |
+
to the input. So:
|
| 21 |
+
resize is described by (resize_height, resize_width) -- per image
|
| 22 |
+
crop is described by (crop_top, crop_left) -- per image, the centered offset
|
| 23 |
+
output size is (crop_height, crop_width) -- the same for every image
|
| 24 |
+
When there is no crop, resize size == crop size and the offsets are 0 (the plain resize).
|
| 25 |
+
The resize SCALE uses the resize size; only the output coordinate is shifted by the crop.
|
| 26 |
+
|
| 27 |
+
uint8 input + float intermediate; each 1D step renormalizes its own weights (matches
|
| 28 |
+
torchvision). Output is parity-close to torchvision, not bit-identical (torchvision keeps a
|
| 29 |
+
fixed-point uint8 intermediate; ours is more accurate float).
|
| 30 |
+
"""
|
| 31 |
+
|
| 32 |
+
import triton
|
| 33 |
+
import triton.language as tl
|
| 34 |
+
|
| 35 |
+
from ._fused import _resample_weight
|
| 36 |
+
from ._pack import fold_mean_std, pack_images
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
@triton.jit
|
| 40 |
+
def _horizontal_resize_kernel(
|
| 41 |
+
input_pixels, # flat uint8 buffer, all images packed back to back
|
| 42 |
+
intermediate, # flat float32 output: width resized + col-cropped, height untouched
|
| 43 |
+
input_offsets, # input_offsets[image] = where that image starts in input_pixels
|
| 44 |
+
intermediate_offsets, # same idea for the intermediate buffer
|
| 45 |
+
heights, widths, # per-image input height / width
|
| 46 |
+
resize_widths, # per-image width to resize to (before cropping)
|
| 47 |
+
crop_lefts, # per-image left offset of the centered crop
|
| 48 |
+
crop_w, # output (crop) width, same for every image
|
| 49 |
+
cubic_coeff,
|
| 50 |
+
CHANNELS: tl.constexpr, BLOCK: tl.constexpr,
|
| 51 |
+
CUBIC: tl.constexpr, ANTIALIAS: tl.constexpr,
|
| 52 |
+
MAX_TAPS_COL: tl.constexpr,
|
| 53 |
+
):
|
| 54 |
+
"""Resize width to resize_width, keep only the cropped columns: uint8 (C,H,W) -> float (C,H,crop_w)."""
|
| 55 |
+
image_index = tl.program_id(0)
|
| 56 |
+
block_index = tl.program_id(1)
|
| 57 |
+
in_height = tl.load(heights + image_index)
|
| 58 |
+
in_width = tl.load(widths + image_index)
|
| 59 |
+
resize_width = tl.load(resize_widths + image_index)
|
| 60 |
+
crop_left = tl.load(crop_lefts + image_index)
|
| 61 |
+
input_start = tl.load(input_offsets + image_index)
|
| 62 |
+
intermediate_start = tl.load(intermediate_offsets + image_index)
|
| 63 |
+
in_width_f = in_width.to(tl.float32)
|
| 64 |
+
|
| 65 |
+
num_pixels = in_height * crop_w # every input row x every cropped output column
|
| 66 |
+
flat_index = block_index * BLOCK + tl.arange(0, BLOCK)
|
| 67 |
+
active = flat_index < num_pixels
|
| 68 |
+
input_row = flat_index // crop_w
|
| 69 |
+
out_col = flat_index % crop_w
|
| 70 |
+
resize_col = out_col + crop_left # column in the (uncropped) resized image
|
| 71 |
+
|
| 72 |
+
filter_half = 2.0 if CUBIC else 1.0
|
| 73 |
+
col_scale = in_width_f / resize_width.to(tl.float32)
|
| 74 |
+
col_filter_scale = tl.maximum(col_scale, 1.0) if ANTIALIAS else 1.0
|
| 75 |
+
col_support = filter_half * col_filter_scale
|
| 76 |
+
col_inv_scale = 1.0 / col_filter_scale
|
| 77 |
+
src_center_col = col_scale * (resize_col.to(tl.float32) + 0.5)
|
| 78 |
+
first_tap_col = tl.floor(src_center_col - col_support + 0.5)
|
| 79 |
+
|
| 80 |
+
col_weight_sum = tl.zeros([BLOCK], dtype=tl.float32)
|
| 81 |
+
for tap in tl.static_range(MAX_TAPS_COL):
|
| 82 |
+
tap_col = first_tap_col + tap
|
| 83 |
+
weight = _resample_weight((tap_col - src_center_col + 0.5) * col_inv_scale, cubic_coeff, CUBIC)
|
| 84 |
+
if ANTIALIAS:
|
| 85 |
+
weight = tl.where((tap_col >= 0.0) & (tap_col < in_width_f), weight, 0.0)
|
| 86 |
+
col_weight_sum += weight
|
| 87 |
+
|
| 88 |
+
input_plane = (in_height * in_width).to(tl.int64)
|
| 89 |
+
intermediate_plane = (in_height * crop_w).to(tl.int64)
|
| 90 |
+
in_width_i64 = in_width.to(tl.int64)
|
| 91 |
+
crop_w_i64 = crop_w.to(tl.int64)
|
| 92 |
+
input_row_i64 = input_row.to(tl.int64)
|
| 93 |
+
for channel in tl.static_range(CHANNELS):
|
| 94 |
+
input_row_base = input_start + channel * input_plane + input_row_i64 * in_width_i64
|
| 95 |
+
accumulator = tl.zeros([BLOCK], dtype=tl.float32)
|
| 96 |
+
for tap in tl.static_range(MAX_TAPS_COL):
|
| 97 |
+
tap_col = first_tap_col + tap
|
| 98 |
+
weight = _resample_weight((tap_col - src_center_col + 0.5) * col_inv_scale, cubic_coeff, CUBIC)
|
| 99 |
+
if ANTIALIAS:
|
| 100 |
+
weight = tl.where((tap_col >= 0.0) & (tap_col < in_width_f), weight, 0.0)
|
| 101 |
+
clamped_tap_col = tl.minimum(tl.maximum(tap_col.to(tl.int32), 0), in_width - 1).to(tl.int64)
|
| 102 |
+
pixel = tl.load(input_pixels + input_row_base + clamped_tap_col, mask=active, other=0).to(tl.float32)
|
| 103 |
+
accumulator += weight * pixel
|
| 104 |
+
accumulator = accumulator / col_weight_sum
|
| 105 |
+
write_index = intermediate_start + channel * intermediate_plane + input_row_i64 * crop_w_i64 + out_col
|
| 106 |
+
tl.store(intermediate + write_index, accumulator, mask=active)
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
@triton.jit
|
| 110 |
+
def _vertical_resize_normalize_kernel(
|
| 111 |
+
intermediate, # float32 from the horizontal step: (C, H, crop_w) per image
|
| 112 |
+
output, # final (N, C, crop_h, crop_w) float32
|
| 113 |
+
intermediate_offsets,
|
| 114 |
+
heights, # per-image input height (the intermediate still has H rows)
|
| 115 |
+
resize_heights, # per-image height to resize to (before cropping)
|
| 116 |
+
crop_tops, # per-image top offset of the centered crop
|
| 117 |
+
means, stds, # per-channel normalization, rescale already folded in
|
| 118 |
+
crop_h, crop_w,
|
| 119 |
+
cubic_coeff,
|
| 120 |
+
CHANNELS: tl.constexpr, BLOCK: tl.constexpr,
|
| 121 |
+
CUBIC: tl.constexpr, ANTIALIAS: tl.constexpr,
|
| 122 |
+
MAX_TAPS_ROW: tl.constexpr,
|
| 123 |
+
):
|
| 124 |
+
"""Resize height to resize_height, keep cropped rows, normalize: float (C,H,crop_w) -> (C,crop_h,crop_w)."""
|
| 125 |
+
image_index = tl.program_id(0)
|
| 126 |
+
block_index = tl.program_id(1)
|
| 127 |
+
in_height = tl.load(heights + image_index)
|
| 128 |
+
resize_height = tl.load(resize_heights + image_index)
|
| 129 |
+
crop_top = tl.load(crop_tops + image_index)
|
| 130 |
+
intermediate_start = tl.load(intermediate_offsets + image_index)
|
| 131 |
+
in_height_f = in_height.to(tl.float32)
|
| 132 |
+
|
| 133 |
+
num_pixels = crop_h * crop_w
|
| 134 |
+
flat_index = block_index * BLOCK + tl.arange(0, BLOCK)
|
| 135 |
+
active = flat_index < num_pixels
|
| 136 |
+
out_row = flat_index // crop_w
|
| 137 |
+
out_col = flat_index % crop_w
|
| 138 |
+
resize_row = out_row + crop_top # row in the (uncropped) resized image
|
| 139 |
+
|
| 140 |
+
filter_half = 2.0 if CUBIC else 1.0
|
| 141 |
+
row_scale = in_height_f / resize_height.to(tl.float32)
|
| 142 |
+
row_filter_scale = tl.maximum(row_scale, 1.0) if ANTIALIAS else 1.0
|
| 143 |
+
row_support = filter_half * row_filter_scale
|
| 144 |
+
row_inv_scale = 1.0 / row_filter_scale
|
| 145 |
+
src_center_row = row_scale * (resize_row.to(tl.float32) + 0.5)
|
| 146 |
+
first_tap_row = tl.floor(src_center_row - row_support + 0.5)
|
| 147 |
+
|
| 148 |
+
row_weight_sum = tl.zeros([BLOCK], dtype=tl.float32)
|
| 149 |
+
for tap in tl.static_range(MAX_TAPS_ROW):
|
| 150 |
+
tap_row = first_tap_row + tap
|
| 151 |
+
weight = _resample_weight((tap_row - src_center_row + 0.5) * row_inv_scale, cubic_coeff, CUBIC)
|
| 152 |
+
if ANTIALIAS:
|
| 153 |
+
weight = tl.where((tap_row >= 0.0) & (tap_row < in_height_f), weight, 0.0)
|
| 154 |
+
row_weight_sum += weight
|
| 155 |
+
|
| 156 |
+
intermediate_plane = (in_height * crop_w).to(tl.int64)
|
| 157 |
+
crop_w_i64 = crop_w.to(tl.int64)
|
| 158 |
+
out_col_i64 = out_col.to(tl.int64)
|
| 159 |
+
for channel in tl.static_range(CHANNELS):
|
| 160 |
+
channel_base = intermediate_start + channel * intermediate_plane
|
| 161 |
+
accumulator = tl.zeros([BLOCK], dtype=tl.float32)
|
| 162 |
+
for tap in tl.static_range(MAX_TAPS_ROW):
|
| 163 |
+
tap_row = first_tap_row + tap
|
| 164 |
+
weight = _resample_weight((tap_row - src_center_row + 0.5) * row_inv_scale, cubic_coeff, CUBIC)
|
| 165 |
+
if ANTIALIAS:
|
| 166 |
+
weight = tl.where((tap_row >= 0.0) & (tap_row < in_height_f), weight, 0.0)
|
| 167 |
+
clamped_tap_row = tl.minimum(tl.maximum(tap_row.to(tl.int32), 0), in_height - 1).to(tl.int64)
|
| 168 |
+
pixel = tl.load(intermediate + channel_base + clamped_tap_row * crop_w_i64 + out_col_i64, mask=active, other=0.0)
|
| 169 |
+
accumulator += weight * pixel
|
| 170 |
+
accumulator = accumulator / row_weight_sum
|
| 171 |
+
mean = tl.load(means + channel)
|
| 172 |
+
std = tl.load(stds + channel)
|
| 173 |
+
accumulator = (accumulator - mean) / std
|
| 174 |
+
write_index = ((image_index * CHANNELS + channel) * crop_h + out_row) * crop_w + out_col
|
| 175 |
+
tl.store(output + write_index, accumulator, mask=active)
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
def _axis_max_taps(in_sizes, resize_sizes, interp, antialias):
|
| 179 |
+
"""Widest 1D window over the batch for one axis = ceil(support) * 2 + 1, support uses in/resize."""
|
| 180 |
+
import math
|
| 181 |
+
|
| 182 |
+
interp_half = 2.0 if interp == "bicubic" else 1.0
|
| 183 |
+
worst = 0
|
| 184 |
+
for in_size, resize_size in zip(in_sizes, resize_sizes):
|
| 185 |
+
scale = in_size / resize_size
|
| 186 |
+
eff = max(scale, 1.0) if antialias else 1.0
|
| 187 |
+
worst = max(worst, math.ceil(interp_half * eff) * 2 + 1)
|
| 188 |
+
return worst
|
| 189 |
+
|
| 190 |
+
|
| 191 |
+
def _run_separable(images, resize_heights, resize_widths, crop_tops, crop_lefts, crop_h, crop_w,
|
| 192 |
+
mean, std, rescale, interp, antialias, block):
|
| 193 |
+
"""Core driver: resize each image to its (resize_h, resize_w), keep the centered crop, normalize."""
|
| 194 |
+
import torch
|
| 195 |
+
|
| 196 |
+
device = images[0].device
|
| 197 |
+
num_images = len(images)
|
| 198 |
+
cubic_coeff = -0.5 if antialias else -0.75
|
| 199 |
+
in_heights = [int(img.shape[1]) for img in images]
|
| 200 |
+
in_widths = [int(img.shape[2]) for img in images]
|
| 201 |
+
max_taps_row = _axis_max_taps(in_heights, resize_heights, interp, antialias)
|
| 202 |
+
max_taps_col = _axis_max_taps(in_widths, resize_widths, interp, antialias)
|
| 203 |
+
means, stds = fold_mean_std(mean, std, rescale, device)
|
| 204 |
+
|
| 205 |
+
input_pixels, input_offsets, heights, widths, channels = pack_images(images, dtype=torch.uint8)
|
| 206 |
+
|
| 207 |
+
intermediate_offsets_list, cursor, tallest = [], 0, 0
|
| 208 |
+
for height in in_heights:
|
| 209 |
+
intermediate_offsets_list.append(cursor)
|
| 210 |
+
cursor += channels * height * crop_w
|
| 211 |
+
tallest = max(tallest, height)
|
| 212 |
+
intermediate = torch.empty(cursor, device=device, dtype=torch.float32)
|
| 213 |
+
|
| 214 |
+
intermediate_offsets = torch.tensor(intermediate_offsets_list, device=device, dtype=torch.int64)
|
| 215 |
+
resize_heights_t = torch.tensor(resize_heights, device=device, dtype=torch.int32)
|
| 216 |
+
resize_widths_t = torch.tensor(resize_widths, device=device, dtype=torch.int32)
|
| 217 |
+
crop_tops_t = torch.tensor(crop_tops, device=device, dtype=torch.int32)
|
| 218 |
+
crop_lefts_t = torch.tensor(crop_lefts, device=device, dtype=torch.int32)
|
| 219 |
+
|
| 220 |
+
horizontal_grid = (num_images, triton.cdiv(tallest * crop_w, block))
|
| 221 |
+
_horizontal_resize_kernel[horizontal_grid](
|
| 222 |
+
input_pixels, intermediate, input_offsets, intermediate_offsets, heights, widths,
|
| 223 |
+
resize_widths_t, crop_lefts_t, crop_w, cubic_coeff,
|
| 224 |
+
CHANNELS=channels, BLOCK=block, CUBIC=(interp == "bicubic"), ANTIALIAS=antialias,
|
| 225 |
+
MAX_TAPS_COL=max_taps_col,
|
| 226 |
+
)
|
| 227 |
+
|
| 228 |
+
output = torch.empty((num_images, channels, crop_h, crop_w), device=device, dtype=torch.float32)
|
| 229 |
+
vertical_grid = (num_images, triton.cdiv(crop_h * crop_w, block))
|
| 230 |
+
_vertical_resize_normalize_kernel[vertical_grid](
|
| 231 |
+
intermediate, output, intermediate_offsets, heights, resize_heights_t, crop_tops_t, means, stds,
|
| 232 |
+
crop_h, crop_w, cubic_coeff,
|
| 233 |
+
CHANNELS=channels, BLOCK=block, CUBIC=(interp == "bicubic"), ANTIALIAS=antialias,
|
| 234 |
+
MAX_TAPS_ROW=max_taps_row,
|
| 235 |
+
)
|
| 236 |
+
return output
|
| 237 |
+
|
| 238 |
+
|
| 239 |
+
def _aspect_preserving_size(in_h, in_w, shortest_edge):
|
| 240 |
+
"""transformers shortest-edge rule: short side -> shortest_edge, long side truncated (int(), not round)."""
|
| 241 |
+
if in_h <= in_w:
|
| 242 |
+
return shortest_edge, int(in_w * shortest_edge / in_h)
|
| 243 |
+
return int(in_h * shortest_edge / in_w), shortest_edge
|
| 244 |
+
|
| 245 |
+
|
| 246 |
+
def separable_resize_normalize(images, out_h, out_w, mean, std, rescale, interp, antialias, block: int = 256):
|
| 247 |
+
"""Resize to (out_h, out_w) and normalize (no crop)."""
|
| 248 |
+
images = list(images)
|
| 249 |
+
n = len(images)
|
| 250 |
+
return _run_separable(images, [out_h] * n, [out_w] * n, [0] * n, [0] * n, out_h, out_w,
|
| 251 |
+
mean, std, rescale, interp, antialias, block)
|
| 252 |
+
|
| 253 |
+
|
| 254 |
+
def separable_resize_crop_normalize(images, resize_size, crop_size, mean, std, rescale, interp, antialias,
|
| 255 |
+
resize_mode="square", block: int = 256):
|
| 256 |
+
"""Resize then center-crop then normalize.
|
| 257 |
+
|
| 258 |
+
resize_mode="square": resize_size is (resize_h, resize_w) applied to every image.
|
| 259 |
+
resize_mode="shortest_edge": resize_size is an int; each image is resized aspect-preserving
|
| 260 |
+
so its short side equals it, then center-cropped to crop_size.
|
| 261 |
+
"""
|
| 262 |
+
images = list(images)
|
| 263 |
+
crop_h, crop_w = crop_size
|
| 264 |
+
resize_heights, resize_widths = [], []
|
| 265 |
+
for img in images:
|
| 266 |
+
in_h, in_w = int(img.shape[1]), int(img.shape[2])
|
| 267 |
+
if resize_mode == "shortest_edge":
|
| 268 |
+
rh, rw = _aspect_preserving_size(in_h, in_w, int(resize_size))
|
| 269 |
+
elif resize_mode == "square":
|
| 270 |
+
rh, rw = int(resize_size[0]), int(resize_size[1])
|
| 271 |
+
else:
|
| 272 |
+
raise ValueError(f"resize_mode must be 'square' or 'shortest_edge', got {resize_mode!r}")
|
| 273 |
+
if rh < crop_h or rw < crop_w:
|
| 274 |
+
raise ValueError(f"resize size ({rh},{rw}) smaller than crop ({crop_h},{crop_w})")
|
| 275 |
+
resize_heights.append(rh)
|
| 276 |
+
resize_widths.append(rw)
|
| 277 |
+
crop_tops = [(rh - crop_h) // 2 for rh in resize_heights]
|
| 278 |
+
crop_lefts = [(rw - crop_w) // 2 for rw in resize_widths]
|
| 279 |
+
return _run_separable(images, resize_heights, resize_widths, crop_tops, crop_lefts, crop_h, crop_w,
|
| 280 |
+
mean, std, rescale, interp, antialias, block)
|
example.py
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# /// script
|
| 2 |
+
# requires-python = ">=3.10"
|
| 3 |
+
# dependencies = [
|
| 4 |
+
# "torch",
|
| 5 |
+
# "triton",
|
| 6 |
+
# "kernels",
|
| 7 |
+
# ]
|
| 8 |
+
# ///
|
| 9 |
+
"""Minimal smoke test of the published kernel via get_kernel (run on a CUDA box)."""
|
| 10 |
+
|
| 11 |
+
import torch
|
| 12 |
+
from kernels import get_kernel
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
kir = get_kernel("Molbap/kernel_image_resize", revision="main", trust_remote_code=True)
|
| 16 |
+
|
| 17 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 18 |
+
images = [
|
| 19 |
+
torch.randint(0, 256, (3, h, w), dtype=torch.uint8, device=device)
|
| 20 |
+
for h, w in [(640, 480), (800, 600), (384, 1024)]
|
| 21 |
+
]
|
| 22 |
+
|
| 23 |
+
pixel_values = kir.resize_normalize(
|
| 24 |
+
images,
|
| 25 |
+
size=384,
|
| 26 |
+
image_mean=[0.5, 0.5, 0.5],
|
| 27 |
+
image_std=[0.5, 0.5, 0.5],
|
| 28 |
+
rescale_factor=1 / 255,
|
| 29 |
+
resample="bicubic",
|
| 30 |
+
antialias=True,
|
| 31 |
+
)
|
| 32 |
+
print(f"{len(images)} ragged images -> {tuple(pixel_values.shape)} {pixel_values.dtype}")
|
example_transformers.py
ADDED
|
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# /// script
|
| 2 |
+
# requires-python = ">=3.10"
|
| 3 |
+
# dependencies = ["torch", "triton", "kernels", "transformers", "torchvision"]
|
| 4 |
+
# ///
|
| 5 |
+
"""Drop-in: use the kernel as the resize+normalize stage of a transformers fast processor.
|
| 6 |
+
|
| 7 |
+
There is no `use_kernels=True` hook for image processors (that machinery swaps nn.Module
|
| 8 |
+
layer forwards inside the model, not processor code). So the usable path is to read the
|
| 9 |
+
processor's config and call the kernel directly. `preprocess_with_kernel` below is the whole
|
| 10 |
+
adapter — copy it into your code.
|
| 11 |
+
|
| 12 |
+
Run on a CUDA box:
|
| 13 |
+
python example_transformers.py
|
| 14 |
+
"""
|
| 15 |
+
|
| 16 |
+
import torch
|
| 17 |
+
from kernels import get_kernel
|
| 18 |
+
from transformers import AutoImageProcessor, AutoModel
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
kernel_image_resize = get_kernel("Molbap/kernel_image_resize", revision="main", trust_remote_code=True)
|
| 22 |
+
|
| 23 |
+
_PIL_RESAMPLE = {0: "bilinear", 2: "bilinear", 3: "bicubic"}
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def preprocess_with_kernel(processor, images):
|
| 27 |
+
"""Run the kernel using `processor`'s own config; returns pixel_values like processor(images).
|
| 28 |
+
|
| 29 |
+
Handles fixed-size resize, square-resize + center-crop, and shortest-edge resize + center-crop
|
| 30 |
+
(CLIP / DINOv2). Does not handle padding processors.
|
| 31 |
+
"""
|
| 32 |
+
size = processor.size
|
| 33 |
+
if getattr(processor, "do_pad", False):
|
| 34 |
+
raise ValueError("kernel does not pad; this processor needs a pad step")
|
| 35 |
+
if not getattr(processor, "do_normalize", True):
|
| 36 |
+
raise ValueError("processor does not normalize (rescale only); kernel always normalizes")
|
| 37 |
+
if getattr(processor, "do_flip_channel_order", False):
|
| 38 |
+
raise ValueError("processor flips channels to BGR; kernel keeps RGB")
|
| 39 |
+
resample = _PIL_RESAMPLE[int(processor.resample)]
|
| 40 |
+
antialias = bool(getattr(processor, "antialias", True))
|
| 41 |
+
rescale = float(processor.rescale_factor) if getattr(processor, "do_rescale", True) else 1.0
|
| 42 |
+
mean, std = processor.image_mean, processor.image_std
|
| 43 |
+
crop = processor.crop_size if getattr(processor, "do_center_crop", False) else None
|
| 44 |
+
common = dict(rescale_factor=rescale, resample=resample, antialias=antialias)
|
| 45 |
+
|
| 46 |
+
if "shortest_edge" in size:
|
| 47 |
+
if crop is None:
|
| 48 |
+
raise ValueError("shortest-edge resize without a crop gives variable-size output")
|
| 49 |
+
return kernel_image_resize.resize_normalize(
|
| 50 |
+
images, size["shortest_edge"], mean, std,
|
| 51 |
+
crop_size=(crop["height"], crop["width"]), resize_mode="shortest_edge", **common,
|
| 52 |
+
)
|
| 53 |
+
if crop is not None and (crop["height"] != size["height"] or crop["width"] != size["width"]):
|
| 54 |
+
return kernel_image_resize.resize_normalize(
|
| 55 |
+
images, (size["height"], size["width"]), mean, std,
|
| 56 |
+
crop_size=(crop["height"], crop["width"]), resize_mode="square", **common,
|
| 57 |
+
)
|
| 58 |
+
return kernel_image_resize.resize_normalize(images, (size["height"], size["width"]), mean, std, **common)
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def main():
|
| 62 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 63 |
+
model_id = "google/siglip2-base-patch16-224"
|
| 64 |
+
processor = AutoImageProcessor.from_pretrained(model_id, backend="torchvision")
|
| 65 |
+
model = AutoModel.from_pretrained(model_id).to(device).eval()
|
| 66 |
+
|
| 67 |
+
images = [
|
| 68 |
+
torch.randint(0, 256, (3, h, w), dtype=torch.uint8, device=device)
|
| 69 |
+
for h, w in [(640, 480), (800, 600), (384, 1024)]
|
| 70 |
+
]
|
| 71 |
+
|
| 72 |
+
pixel_values = preprocess_with_kernel(processor, images)
|
| 73 |
+
print(f"{len(images)} ragged images -> pixel_values {tuple(pixel_values.shape)} {pixel_values.dtype}")
|
| 74 |
+
|
| 75 |
+
with torch.no_grad():
|
| 76 |
+
features = model.vision_model(pixel_values=pixel_values.to(model.dtype)).pooler_output
|
| 77 |
+
print(f"vision features: {tuple(features.shape)}")
|
| 78 |
+
|
| 79 |
+
# parity vs the real processor (float-vs-uint8 resize -> small, expected gap)
|
| 80 |
+
reference = processor(images, return_tensors="pt", device=device)["pixel_values"].to(device)
|
| 81 |
+
print(f"max|Δ| pixel_values vs processor: {(pixel_values - reference).abs().max().item():.2e}")
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
if __name__ == "__main__":
|
| 85 |
+
main()
|
publish.sh
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
# Publish the universal kernel to the Hub as a `kernel` repo type.
|
| 3 |
+
#
|
| 4 |
+
# IMPORTANT: kernel repos are repo_type="kernel" (served under /api/kernels/), NOT model repos.
|
| 5 |
+
# get_kernel() queries repo_type="kernel", so uploading with --repo-type model gives a 404.
|
| 6 |
+
#
|
| 7 |
+
# A universal kernel needs no compilation: the build is just a copy of the source package into the
|
| 8 |
+
# variant directory get_kernel resolves (build/torch-universal/<name>/).
|
| 9 |
+
set -euo pipefail
|
| 10 |
+
|
| 11 |
+
REPO_ID="${1:-Molbap/kernel_image_resize}"
|
| 12 |
+
NAME="kernel_image_resize"
|
| 13 |
+
HERE="$(cd "$(dirname "$0")" && pwd)"
|
| 14 |
+
|
| 15 |
+
rm -rf "$HERE/build"
|
| 16 |
+
mkdir -p "$HERE/build/torch-universal"
|
| 17 |
+
cp -r "$HERE/torch-ext/$NAME" "$HERE/build/torch-universal/$NAME"
|
| 18 |
+
find "$HERE/build" -name __pycache__ -type d -exec rm -rf {} + 2>/dev/null || true
|
| 19 |
+
|
| 20 |
+
echo "built build/torch-universal/$NAME"
|
| 21 |
+
|
| 22 |
+
# Create the kernel repo and upload. Uses the Python API because it reliably accepts
|
| 23 |
+
# repo_type="kernel" (the hf CLI's repo-type choices can be stricter).
|
| 24 |
+
python - "$REPO_ID" "$HERE" <<'PY'
|
| 25 |
+
import sys
|
| 26 |
+
|
| 27 |
+
# Some huggingface_hub versions know the `kernel` repo type (constant + create_repo) but still
|
| 28 |
+
# validate upload_folder against the old REPO_TYPES list. Register it so the upload validator passes.
|
| 29 |
+
import huggingface_hub.constants as hfc
|
| 30 |
+
if "kernel" not in hfc.REPO_TYPES:
|
| 31 |
+
hfc.REPO_TYPES = list(hfc.REPO_TYPES) + ["kernel"]
|
| 32 |
+
|
| 33 |
+
from huggingface_hub import create_repo, upload_folder
|
| 34 |
+
|
| 35 |
+
repo_id, folder = sys.argv[1], sys.argv[2]
|
| 36 |
+
create_repo(repo_id, repo_type="kernel", exist_ok=True)
|
| 37 |
+
upload_folder(
|
| 38 |
+
repo_id=repo_id,
|
| 39 |
+
repo_type="kernel",
|
| 40 |
+
folder_path=folder,
|
| 41 |
+
ignore_patterns=["__pycache__/*", "*.pyc", "result", ".git/*", "build/torch-universal/*/__pycache__/*"],
|
| 42 |
+
)
|
| 43 |
+
print(f"uploaded {folder} -> {repo_id} (repo_type=kernel)")
|
| 44 |
+
PY
|
resultcompat
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
model verdict pixel max|Δ| feature max|Δ| (rel)
|
| 2 |
+
openai/clip-vit-base-patch32 OK 7.53e-03 1.53e-02 (0.2%)
|
| 3 |
+
google/vit-base-patch16-224 OK 3.93e-03 6.87e-03 (0.8%)
|
| 4 |
+
apple/mobilevit-small SKIP: no normalize (rescale only)
|
| 5 |
+
facebook/dinov2-small OK 1.41e-01 1.91e-02 (0.2%)
|
| 6 |
+
google/siglip-so400m-patch14-384 OK 1.58e-01 9.28e-03 (0.1%)
|
| 7 |
+
facebook/dinov3-vitb16-pretrain-lvd1689m OK 4.99e-05 8.08e-05 (0.0%)
|
| 8 |
+
microsoft/swinv2-tiny-patch4-window16-256 OK 8.75e-03 1.27e-02 (0.6%)
|
| 9 |
+
google/siglip2-base-patch16-224 OK 3.93e-03 1.31e-02 (0.2%)
|
| 10 |
+
microsoft/resnet-50 SKIP: shortest_edge without crop (variable output)
|
| 11 |
+
nvidia/segformer-b0-finetuned-ade-512-512 OK 8.75e-03 4.33e-02 (0.4%)
|
| 12 |
+
facebook/convnextv2-tiny-22k-384 SKIP: shortest_edge without crop (variable output)
|
| 13 |
+
google/mobilenet_v2_1.0_224 OK 3.92e-03 3.90e-02 (0.7%)
|
| 14 |
+
facebook/convnext-tiny-224 SKIP: shortest_edge without crop (variable output)
|
| 15 |
+
google/efficientnet-b0 SKIP: resample 0
|
| 16 |
+
microsoft/beit-base-patch16-224-pt22k-ft22k OK 3.93e-03 3.19e-02 (0.9%)
|
tests/test_resize_normalize.py
ADDED
|
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Parity tests vs torchvision for both backends, all interp×antialias combos, ragged inputs.
|
| 2 |
+
|
| 3 |
+
Run locally from the repo root with the package on the path:
|
| 4 |
+
PYTHONPATH=torch-ext pytest tests/ -q
|
| 5 |
+
CUDA is required (Triton); tests skip on CPU.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import pytest
|
| 9 |
+
import torch
|
| 10 |
+
import torchvision.transforms.v2.functional as tvF
|
| 11 |
+
from torchvision.transforms import InterpolationMode
|
| 12 |
+
|
| 13 |
+
from kernel_image_resize import resize_normalize
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
_TV_INTERP = {"bilinear": InterpolationMode.BILINEAR, "bicubic": InterpolationMode.BICUBIC}
|
| 17 |
+
|
| 18 |
+
MEAN = [0.48145466, 0.4578275, 0.40821073]
|
| 19 |
+
STD = [0.26862954, 0.26130258, 0.27577711]
|
| 20 |
+
RESCALE = 1.0 / 255.0
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def _ragged_images(n, device, min_res=384, max_res=1024, seed=0):
|
| 24 |
+
g = torch.Generator(device="cpu").manual_seed(seed)
|
| 25 |
+
images = []
|
| 26 |
+
for _ in range(n):
|
| 27 |
+
h = int(torch.randint(min_res, max_res + 1, (1,), generator=g).item())
|
| 28 |
+
w = int(torch.randint(min_res, max_res + 1, (1,), generator=g).item())
|
| 29 |
+
images.append(torch.randint(0, 256, (3, h, w), generator=g, dtype=torch.uint8).to(device))
|
| 30 |
+
return images
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def _torchvision_reference(images, out_h, out_w, interp, antialias):
|
| 34 |
+
mode = _TV_INTERP[interp]
|
| 35 |
+
mean = torch.tensor(MEAN, device=images[0].device).view(3, 1, 1)
|
| 36 |
+
std = torch.tensor(STD, device=images[0].device).view(3, 1, 1)
|
| 37 |
+
outs = []
|
| 38 |
+
for img in images:
|
| 39 |
+
r = tvF.resize(img.float(), [out_h, out_w], interpolation=mode, antialias=antialias)
|
| 40 |
+
outs.append((r * RESCALE - mean) / std)
|
| 41 |
+
return torch.stack(outs)
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
@pytest.mark.kernels_ci
|
| 45 |
+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="Triton kernel needs CUDA")
|
| 46 |
+
@pytest.mark.parametrize("backend", ["fused", "separable"])
|
| 47 |
+
@pytest.mark.parametrize("interp,antialias", [("bilinear", False), ("bilinear", True), ("bicubic", False), ("bicubic", True)])
|
| 48 |
+
def test_parity_vs_torchvision(backend, interp, antialias):
|
| 49 |
+
device = torch.device("cuda")
|
| 50 |
+
images = _ragged_images(8, device)
|
| 51 |
+
out_h = out_w = 384
|
| 52 |
+
got = resize_normalize(
|
| 53 |
+
images, (out_h, out_w), MEAN, STD, RESCALE, resample=interp, antialias=antialias, backend=backend
|
| 54 |
+
)
|
| 55 |
+
ref = _torchvision_reference(images, out_h, out_w, interp, antialias)
|
| 56 |
+
max_abs = (got - ref).abs().max().item()
|
| 57 |
+
assert max_abs < 3e-3, f"{backend}/{interp}/aa={antialias}: max|Δ|={max_abs:.2e}"
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
@pytest.mark.kernels_ci
|
| 61 |
+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="Triton kernel needs CUDA")
|
| 62 |
+
def test_stacked_tensor_input():
|
| 63 |
+
device = torch.device("cuda")
|
| 64 |
+
images = torch.randint(0, 256, (4, 3, 512, 512), dtype=torch.uint8, device=device)
|
| 65 |
+
got = resize_normalize(images, 224, MEAN, STD, RESCALE, resample="bicubic", antialias=True)
|
| 66 |
+
assert got.shape == (4, 3, 224, 224)
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def _shortest_edge_crop_reference(images, shortest_edge, crop, interp, antialias):
|
| 70 |
+
mode = _TV_INTERP[interp]
|
| 71 |
+
mean = torch.tensor(MEAN, device=images[0].device).view(3, 1, 1)
|
| 72 |
+
std = torch.tensor(STD, device=images[0].device).view(3, 1, 1)
|
| 73 |
+
outs = []
|
| 74 |
+
for img in images:
|
| 75 |
+
in_h, in_w = img.shape[1], img.shape[2]
|
| 76 |
+
if in_h <= in_w:
|
| 77 |
+
rh, rw = shortest_edge, int(in_w * shortest_edge / in_h)
|
| 78 |
+
else:
|
| 79 |
+
rh, rw = int(in_h * shortest_edge / in_w), shortest_edge
|
| 80 |
+
r = tvF.resize(img.float(), [rh, rw], interpolation=mode, antialias=antialias)
|
| 81 |
+
r = tvF.center_crop(r, [crop, crop])
|
| 82 |
+
outs.append((r * RESCALE - mean) / std)
|
| 83 |
+
return torch.stack(outs)
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
@pytest.mark.kernels_ci
|
| 87 |
+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="Triton kernel needs CUDA")
|
| 88 |
+
@pytest.mark.parametrize("interp,antialias", [("bilinear", True), ("bicubic", True)])
|
| 89 |
+
def test_shortest_edge_crop_parity(interp, antialias):
|
| 90 |
+
device = torch.device("cuda")
|
| 91 |
+
images = _ragged_images(8, device)
|
| 92 |
+
shortest_edge, crop = 256, 224
|
| 93 |
+
got = resize_normalize(
|
| 94 |
+
images, shortest_edge, MEAN, STD, RESCALE, resample=interp, antialias=antialias,
|
| 95 |
+
crop_size=(crop, crop), resize_mode="shortest_edge",
|
| 96 |
+
)
|
| 97 |
+
ref = _shortest_edge_crop_reference(images, shortest_edge, crop, interp, antialias)
|
| 98 |
+
assert got.shape == (8, 3, crop, crop)
|
| 99 |
+
max_abs = (got - ref).abs().max().item()
|
| 100 |
+
assert max_abs < 3e-3, f"shortest_edge+crop {interp}/aa={antialias}: max|Δ|={max_abs:.2e}"
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
@pytest.mark.kernels_ci
|
| 104 |
+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="Triton kernel needs CUDA")
|
| 105 |
+
def test_fused_matches_separable():
|
| 106 |
+
device = torch.device("cuda")
|
| 107 |
+
images = _ragged_images(6, device)
|
| 108 |
+
common = dict(size=(256, 256), image_mean=MEAN, image_std=STD, rescale_factor=RESCALE, resample="bicubic", antialias=True)
|
| 109 |
+
fused = resize_normalize(images, backend="fused", **common)
|
| 110 |
+
separable = resize_normalize(images, backend="separable", **common)
|
| 111 |
+
max_abs = (fused - separable).abs().max().item()
|
| 112 |
+
assert max_abs < 3e-3, f"fused vs separable: max|Δ|={max_abs:.2e}"
|
torch-ext/kernel_image_resize/__init__.py
ADDED
|
@@ -0,0 +1,113 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Resize + rescale + normalize for transformers fast image processors, as a Triton kernel.
|
| 2 |
+
|
| 3 |
+
resize -> fold(rescale, normalize) in one GPU pipeline: CHW uint8 images in,
|
| 4 |
+
(N, C, out_h, out_w) normalized float out, no full-resolution float intermediate.
|
| 5 |
+
|
| 6 |
+
- resize_normalize — stacked (N, C, H, W) tensor or a list of CHW images.
|
| 7 |
+
- resize_normalize_ragged — same kernels; takes a list of different-H/W CHW tensors.
|
| 8 |
+
|
| 9 |
+
backend="separable" (default): two-pass uint8, taps+taps. backend="fused": single 2D
|
| 10 |
+
launch, taps*taps. Both parity <=1e-4 vs torchvision-float.
|
| 11 |
+
|
| 12 |
+
from kernels import get_kernel
|
| 13 |
+
kir = get_kernel("Molbap/kernel_image_resize")
|
| 14 |
+
pixel_values = kir.resize_normalize(
|
| 15 |
+
images, size=384, image_mean=[...], image_std=[...], resample="bicubic", antialias=True,
|
| 16 |
+
)
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
from ._fused import fused_resize_normalize
|
| 20 |
+
from ._pack import PIL_RESAMPLE_TO_INTERP, as_image_list
|
| 21 |
+
from ._separable import separable_resize_crop_normalize, separable_resize_normalize
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def _normalize_size(size) -> tuple[int, int]:
|
| 25 |
+
if isinstance(size, int):
|
| 26 |
+
return size, size
|
| 27 |
+
if isinstance(size, dict):
|
| 28 |
+
if "height" in size and "width" in size:
|
| 29 |
+
return int(size["height"]), int(size["width"])
|
| 30 |
+
raise ValueError(f"size dict must hold 'height'/'width' for a fixed resize, got {size}")
|
| 31 |
+
out_h, out_w = size
|
| 32 |
+
return int(out_h), int(out_w)
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def _normalize_resample(resample) -> str:
|
| 36 |
+
if isinstance(resample, str):
|
| 37 |
+
if resample not in ("bilinear", "bicubic"):
|
| 38 |
+
raise ValueError(f"resample must be 'bilinear' or 'bicubic', got {resample!r}")
|
| 39 |
+
return resample
|
| 40 |
+
interp = PIL_RESAMPLE_TO_INTERP.get(int(resample))
|
| 41 |
+
if interp is None:
|
| 42 |
+
raise ValueError(f"unsupported PIL resample code {resample}")
|
| 43 |
+
return interp
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def resize_normalize(
|
| 47 |
+
images,
|
| 48 |
+
size,
|
| 49 |
+
image_mean,
|
| 50 |
+
image_std,
|
| 51 |
+
rescale_factor: float = 1.0 / 255.0,
|
| 52 |
+
resample="bilinear",
|
| 53 |
+
antialias: bool = False,
|
| 54 |
+
backend: str = "separable",
|
| 55 |
+
block: int = 256,
|
| 56 |
+
crop_size=None,
|
| 57 |
+
resize_mode: str = "square",
|
| 58 |
+
):
|
| 59 |
+
"""Resize, optionally center-crop, rescale and normalize — one GPU pipeline.
|
| 60 |
+
|
| 61 |
+
Args:
|
| 62 |
+
images: a stacked `(N, C, H, W)` uint8/float tensor, or a list of CHW tensors (ragged).
|
| 63 |
+
size: resize target. With no crop: int (square), `(height, width)`, or `{"height","width"}`.
|
| 64 |
+
With `resize_mode="shortest_edge"`: an int, the short side after aspect-preserving resize.
|
| 65 |
+
image_mean, image_std: per-channel normalization stats (length C).
|
| 66 |
+
rescale_factor: folded into mean/std so the kernel does `(x*rescale - mean)/std`.
|
| 67 |
+
resample: "bilinear" / "bicubic", or a PIL resample int (0/2 -> bilinear, 3 -> bicubic).
|
| 68 |
+
antialias: match the ViT/CLIP/SigLIP default (`True` for those processors).
|
| 69 |
+
backend: "separable" (default) or "fused" (2D reference, no crop support).
|
| 70 |
+
crop_size: `None` (no crop), int (square), or `(crop_h, crop_w)`. Center crop after resize.
|
| 71 |
+
resize_mode: "square" (resize to `size`) or "shortest_edge" (aspect-preserving, needs a crop).
|
| 72 |
+
"""
|
| 73 |
+
interp = _normalize_resample(resample)
|
| 74 |
+
image_list = as_image_list(images)
|
| 75 |
+
|
| 76 |
+
if crop_size is not None or resize_mode == "shortest_edge":
|
| 77 |
+
crop_h, crop_w = _normalize_size(crop_size if crop_size is not None else size)
|
| 78 |
+
resize_arg = int(size) if resize_mode == "shortest_edge" else _normalize_size(size)
|
| 79 |
+
return separable_resize_crop_normalize(
|
| 80 |
+
image_list, resize_arg, (crop_h, crop_w), image_mean, image_std, rescale_factor,
|
| 81 |
+
interp, antialias, resize_mode, block,
|
| 82 |
+
)
|
| 83 |
+
|
| 84 |
+
out_h, out_w = _normalize_size(size)
|
| 85 |
+
if backend == "fused":
|
| 86 |
+
return fused_resize_normalize(image_list, out_h, out_w, image_mean, image_std, rescale_factor, interp, antialias, block)
|
| 87 |
+
if backend == "separable":
|
| 88 |
+
return separable_resize_normalize(image_list, out_h, out_w, image_mean, image_std, rescale_factor, interp, antialias, block)
|
| 89 |
+
raise ValueError(f"backend must be 'fused' or 'separable', got {backend!r}")
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
def resize_normalize_ragged(
|
| 93 |
+
images,
|
| 94 |
+
size,
|
| 95 |
+
image_mean,
|
| 96 |
+
image_std,
|
| 97 |
+
rescale_factor: float = 1.0 / 255.0,
|
| 98 |
+
resample="bilinear",
|
| 99 |
+
antialias: bool = False,
|
| 100 |
+
backend: str = "separable",
|
| 101 |
+
block: int = 256,
|
| 102 |
+
):
|
| 103 |
+
"""Variant taking a list of different-H/W CHW tensors. Same kernels as `resize_normalize`."""
|
| 104 |
+
if isinstance(images, list):
|
| 105 |
+
image_list = images
|
| 106 |
+
else:
|
| 107 |
+
raise ValueError("resize_normalize_ragged expects a list of CHW tensors; use resize_normalize for a stacked tensor")
|
| 108 |
+
return resize_normalize(
|
| 109 |
+
image_list, size, image_mean, image_std, rescale_factor, resample, antialias, backend, block
|
| 110 |
+
)
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
__all__ = ["resize_normalize", "resize_normalize_ragged"]
|
torch-ext/kernel_image_resize/_fused.py
ADDED
|
@@ -0,0 +1,134 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Fused 2D resize+rescale+normalize over a ragged batch, single launch.
|
| 2 |
+
|
| 3 |
+
One program owns one image and a BLOCK of its output pixels, gathers a
|
| 4 |
+
MAX_TAPS_H × MAX_TAPS_W window, applies the separable weights as a 2D product, then folds
|
| 5 |
+
rescale+normalize. taps×taps loads per output pixel.
|
| 6 |
+
|
| 7 |
+
Resampling-weight formula (PyTorch aten UpSampleKernel):
|
| 8 |
+
scale = in / out
|
| 9 |
+
support = interp_half * (scale if antialias and scale > 1 else 1) # interp_half: 1 linear, 2 cubic
|
| 10 |
+
center = scale * (i + 0.5)
|
| 11 |
+
weight = filter((tap - center + 0.5) / eff), renormalized over the realized window
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
import triton
|
| 15 |
+
import triton.language as tl
|
| 16 |
+
|
| 17 |
+
from ._pack import fold_mean_std, max_taps, pack_images
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
@triton.jit
|
| 21 |
+
def _resample_weight(arg, cubic_a, CUBIC: tl.constexpr):
|
| 22 |
+
"""Interpolation filter at `arg` (coordinate distance already divided by support)."""
|
| 23 |
+
ax = tl.abs(arg)
|
| 24 |
+
if CUBIC: # Keys cubic convolution kernel, support 2
|
| 25 |
+
ax2 = ax * ax
|
| 26 |
+
ax3 = ax2 * ax
|
| 27 |
+
inner = (cubic_a + 2.0) * ax3 - (cubic_a + 3.0) * ax2 + 1.0 # |x| <= 1
|
| 28 |
+
outer = cubic_a * ax3 - 5.0 * cubic_a * ax2 + 8.0 * cubic_a * ax - 4.0 * cubic_a # 1 < |x| < 2
|
| 29 |
+
return tl.where(ax <= 1.0, inner, tl.where(ax < 2.0, outer, 0.0))
|
| 30 |
+
return tl.maximum(1.0 - ax, 0.0) # triangle (bilinear), support 1
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
@triton.jit
|
| 34 |
+
def _resize_normalize_kernel(
|
| 35 |
+
in_ptr, out_ptr, offsets_ptr, heights_ptr, widths_ptr, mean_ptr, std_ptr,
|
| 36 |
+
out_h, out_w, cubic_a,
|
| 37 |
+
C: tl.constexpr, BLOCK: tl.constexpr,
|
| 38 |
+
CUBIC: tl.constexpr, ANTIALIAS: tl.constexpr,
|
| 39 |
+
MAX_TAPS_H: tl.constexpr, MAX_TAPS_W: tl.constexpr,
|
| 40 |
+
):
|
| 41 |
+
n = tl.program_id(0)
|
| 42 |
+
blk = tl.program_id(1)
|
| 43 |
+
H = tl.load(heights_ptr + n)
|
| 44 |
+
W = tl.load(widths_ptr + n)
|
| 45 |
+
off = tl.load(offsets_ptr + n)
|
| 46 |
+
Hf = H.to(tl.float32)
|
| 47 |
+
Wf = W.to(tl.float32)
|
| 48 |
+
|
| 49 |
+
npix = out_h * out_w
|
| 50 |
+
pos = blk * BLOCK + tl.arange(0, BLOCK)
|
| 51 |
+
mask = pos < npix
|
| 52 |
+
oy = pos // out_w
|
| 53 |
+
ox = pos % out_w
|
| 54 |
+
|
| 55 |
+
interp_half = 2.0 if CUBIC else 1.0
|
| 56 |
+
scale_h = Hf / out_h
|
| 57 |
+
scale_w = Wf / out_w
|
| 58 |
+
eff_h = tl.maximum(scale_h, 1.0) if ANTIALIAS else 1.0
|
| 59 |
+
eff_w = tl.maximum(scale_w, 1.0) if ANTIALIAS else 1.0
|
| 60 |
+
support_h = interp_half * eff_h
|
| 61 |
+
support_w = interp_half * eff_w
|
| 62 |
+
inv_h = 1.0 / eff_h
|
| 63 |
+
inv_w = 1.0 / eff_w
|
| 64 |
+
|
| 65 |
+
center_y = scale_h * (oy.to(tl.float32) + 0.5)
|
| 66 |
+
center_x = scale_w * (ox.to(tl.float32) + 0.5)
|
| 67 |
+
ystart = tl.floor(center_y - support_h + 0.5)
|
| 68 |
+
xstart = tl.floor(center_x - support_w + 0.5)
|
| 69 |
+
|
| 70 |
+
sum_wy = tl.zeros([BLOCK], dtype=tl.float32)
|
| 71 |
+
for ty in tl.static_range(MAX_TAPS_H):
|
| 72 |
+
yy = ystart + ty
|
| 73 |
+
wy = _resample_weight((yy - center_y + 0.5) * inv_h, cubic_a, CUBIC)
|
| 74 |
+
if ANTIALIAS:
|
| 75 |
+
wy = tl.where((yy >= 0.0) & (yy < Hf), wy, 0.0)
|
| 76 |
+
sum_wy += wy
|
| 77 |
+
sum_wx = tl.zeros([BLOCK], dtype=tl.float32)
|
| 78 |
+
for tx in tl.static_range(MAX_TAPS_W):
|
| 79 |
+
xx = xstart + tx
|
| 80 |
+
wx = _resample_weight((xx - center_x + 0.5) * inv_w, cubic_a, CUBIC)
|
| 81 |
+
if ANTIALIAS:
|
| 82 |
+
wx = tl.where((xx >= 0.0) & (xx < Wf), wx, 0.0)
|
| 83 |
+
sum_wx += wx
|
| 84 |
+
denom = sum_wy * sum_wx
|
| 85 |
+
|
| 86 |
+
plane = (H * W).to(tl.int64)
|
| 87 |
+
Wl = W.to(tl.int64)
|
| 88 |
+
for c in tl.static_range(C):
|
| 89 |
+
base = off + c * plane
|
| 90 |
+
acc = tl.zeros([BLOCK], dtype=tl.float32)
|
| 91 |
+
for ty in tl.static_range(MAX_TAPS_H):
|
| 92 |
+
yy = ystart + ty
|
| 93 |
+
wy = _resample_weight((yy - center_y + 0.5) * inv_h, cubic_a, CUBIC)
|
| 94 |
+
if ANTIALIAS:
|
| 95 |
+
wy = tl.where((yy >= 0.0) & (yy < Hf), wy, 0.0)
|
| 96 |
+
yidx = tl.minimum(tl.maximum(yy.to(tl.int32), 0), H - 1).to(tl.int64)
|
| 97 |
+
row = base + yidx * Wl
|
| 98 |
+
for tx in tl.static_range(MAX_TAPS_W):
|
| 99 |
+
xx = xstart + tx
|
| 100 |
+
wx = _resample_weight((xx - center_x + 0.5) * inv_w, cubic_a, CUBIC)
|
| 101 |
+
if ANTIALIAS:
|
| 102 |
+
wx = tl.where((xx >= 0.0) & (xx < Wf), wx, 0.0)
|
| 103 |
+
xidx = tl.minimum(tl.maximum(xx.to(tl.int32), 0), W - 1).to(tl.int64)
|
| 104 |
+
pix = tl.load(in_ptr + row + xidx, mask=mask, other=0.0)
|
| 105 |
+
acc += wy * wx * pix
|
| 106 |
+
acc = acc / denom
|
| 107 |
+
m = tl.load(mean_ptr + c)
|
| 108 |
+
s = tl.load(std_ptr + c)
|
| 109 |
+
acc = (acc - m) / s
|
| 110 |
+
oidx = ((n * C + c) * out_h + oy) * out_w + ox
|
| 111 |
+
tl.store(out_ptr + oidx, acc, mask=mask)
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
def fused_resize_normalize(images, out_h, out_w, mean, std, rescale, interp, antialias, block: int = 256):
|
| 115 |
+
"""Single fused launch over a ragged packed buffer -> (N, C, out_h, out_w) normalized float."""
|
| 116 |
+
import torch
|
| 117 |
+
|
| 118 |
+
images = list(images)
|
| 119 |
+
device = images[0].device
|
| 120 |
+
n = len(images)
|
| 121 |
+
cubic_a = -0.5 if antialias else -0.75 # PIL coeff under antialias, Keys coeff otherwise
|
| 122 |
+
max_taps_h = max_taps(images, out_h, 1, interp, antialias)
|
| 123 |
+
max_taps_w = max_taps(images, out_w, 2, interp, antialias)
|
| 124 |
+
mean_t, std_t = fold_mean_std(mean, std, rescale, device)
|
| 125 |
+
|
| 126 |
+
in_buf, offsets_t, heights_t, widths_t, c = pack_images(images)
|
| 127 |
+
out = torch.empty((n, c, out_h, out_w), device=device, dtype=torch.float32)
|
| 128 |
+
grid = (n, triton.cdiv(out_h * out_w, block))
|
| 129 |
+
_resize_normalize_kernel[grid](
|
| 130 |
+
in_buf, out, offsets_t, heights_t, widths_t, mean_t, std_t,
|
| 131 |
+
out_h, out_w, cubic_a, C=c, BLOCK=block,
|
| 132 |
+
CUBIC=(interp == "bicubic"), ANTIALIAS=antialias, MAX_TAPS_H=max_taps_h, MAX_TAPS_W=max_taps_w,
|
| 133 |
+
)
|
| 134 |
+
return out
|
torch-ext/kernel_image_resize/_pack.py
ADDED
|
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Ragged packing + resampling helpers shared by the fused and separable backends."""
|
| 2 |
+
|
| 3 |
+
import math
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
PIL_RESAMPLE_TO_INTERP = {0: "bilinear", 2: "bilinear", 3: "bicubic"}
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def pack_images(
|
| 12 |
+
images: list[torch.Tensor], dtype: torch.dtype = torch.float32
|
| 13 |
+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, int]:
|
| 14 |
+
"""Concatenate a ragged list of CHW images into one flat buffer of `dtype`.
|
| 15 |
+
|
| 16 |
+
Returns (in_buf, offsets, heights, widths, channels); offsets[n] is the element index
|
| 17 |
+
where image n starts.
|
| 18 |
+
"""
|
| 19 |
+
device = images[0].device
|
| 20 |
+
channels = images[0].shape[0]
|
| 21 |
+
flats, offsets, heights, widths, cur = [], [], [], [], 0
|
| 22 |
+
for img in images:
|
| 23 |
+
ic, ih, iw = img.shape
|
| 24 |
+
if ic != channels:
|
| 25 |
+
raise ValueError(f"all images must share channel count {channels}, got {ic}")
|
| 26 |
+
flats.append(img.reshape(-1).to(dtype))
|
| 27 |
+
offsets.append(cur)
|
| 28 |
+
heights.append(ih)
|
| 29 |
+
widths.append(iw)
|
| 30 |
+
cur += ic * ih * iw
|
| 31 |
+
in_buf = torch.cat(flats)
|
| 32 |
+
offsets_t = torch.tensor(offsets, device=device, dtype=torch.int64)
|
| 33 |
+
heights_t = torch.tensor(heights, device=device, dtype=torch.int32)
|
| 34 |
+
widths_t = torch.tensor(widths, device=device, dtype=torch.int32)
|
| 35 |
+
return in_buf, offsets_t, heights_t, widths_t, channels
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def fold_mean_std(mean, std, rescale: float, device) -> tuple[torch.Tensor, torch.Tensor]:
|
| 39 |
+
"""Fold rescale into mean/std so the kernel does (x - m)/s == (x*rescale - mean)/std."""
|
| 40 |
+
mean_t = (torch.tensor(mean, device=device, dtype=torch.float32) / rescale).contiguous()
|
| 41 |
+
std_t = (torch.tensor(std, device=device, dtype=torch.float32) / rescale).contiguous()
|
| 42 |
+
return mean_t, std_t
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def max_taps(images: list[torch.Tensor], out_size: int, axis_dim: int, interp: str, antialias: bool) -> int:
|
| 46 |
+
"""Batch-wide worst-case tap count for one axis = ceil(support) * 2 + 1."""
|
| 47 |
+
interp_half = 2.0 if interp == "bicubic" else 1.0
|
| 48 |
+
worst = 0
|
| 49 |
+
for img in images:
|
| 50 |
+
scale = img.shape[axis_dim] / out_size
|
| 51 |
+
eff = max(scale, 1.0) if antialias else 1.0
|
| 52 |
+
worst = max(worst, math.ceil(interp_half * eff) * 2 + 1)
|
| 53 |
+
return worst
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def as_image_list(images) -> list[torch.Tensor]:
|
| 57 |
+
"""Accept a stacked (N, C, H, W) tensor or a list of CHW tensors; always return a list."""
|
| 58 |
+
if isinstance(images, torch.Tensor):
|
| 59 |
+
if images.dim() != 4:
|
| 60 |
+
raise ValueError(f"stacked input must be (N, C, H, W), got shape {tuple(images.shape)}")
|
| 61 |
+
return list(images)
|
| 62 |
+
return list(images)
|
torch-ext/kernel_image_resize/_separable.py
ADDED
|
@@ -0,0 +1,280 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Separable resize + center-crop + normalize over a ragged uint8 batch.
|
| 2 |
+
|
| 3 |
+
WHAT "RESIZE" DOES, CONCRETELY
|
| 4 |
+
Every output pixel is a weighted average of a small window of input pixels. When you shrink
|
| 5 |
+
an image a lot (with antialiasing) that window gets wide — e.g. 13 input pixels across and
|
| 6 |
+
13 down, so 13x13 = 169 input pixels feed one output pixel.
|
| 7 |
+
|
| 8 |
+
FUSED vs SEPARABLE (the two backends in this package)
|
| 9 |
+
- FUSED (see _fused.py): for each output pixel, read the whole 2D window directly -> 169 reads.
|
| 10 |
+
- SEPARABLE (this file): do the resize as two 1D steps instead of one 2D step:
|
| 11 |
+
step 1 (horizontal): resize only the WIDTH -> an intermediate image
|
| 12 |
+
step 2 (vertical): resize only the HEIGHT -> the final image
|
| 13 |
+
Each step's window is 1D, so 13 + 13 = 26 reads per output pixel instead of 169. Same math,
|
| 14 |
+
far fewer reads. This is what PIL and torchvision do.
|
| 15 |
+
|
| 16 |
+
CENTER CROP (folded in, no extra pass)
|
| 17 |
+
Processors like CLIP / DINOv2 resize to a "resize size" and then keep only the centered
|
| 18 |
+
crop. We do not materialize the full resized image and slice it; instead each output pixel
|
| 19 |
+
of the CROP maps to a resize-image coordinate by adding the crop offset, and that maps back
|
| 20 |
+
to the input. So:
|
| 21 |
+
resize is described by (resize_height, resize_width) -- per image
|
| 22 |
+
crop is described by (crop_top, crop_left) -- per image, the centered offset
|
| 23 |
+
output size is (crop_height, crop_width) -- the same for every image
|
| 24 |
+
When there is no crop, resize size == crop size and the offsets are 0 (the plain resize).
|
| 25 |
+
The resize SCALE uses the resize size; only the output coordinate is shifted by the crop.
|
| 26 |
+
|
| 27 |
+
uint8 input + float intermediate; each 1D step renormalizes its own weights (matches
|
| 28 |
+
torchvision). Output is parity-close to torchvision, not bit-identical (torchvision keeps a
|
| 29 |
+
fixed-point uint8 intermediate; ours is more accurate float).
|
| 30 |
+
"""
|
| 31 |
+
|
| 32 |
+
import triton
|
| 33 |
+
import triton.language as tl
|
| 34 |
+
|
| 35 |
+
from ._fused import _resample_weight
|
| 36 |
+
from ._pack import fold_mean_std, pack_images
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
@triton.jit
|
| 40 |
+
def _horizontal_resize_kernel(
|
| 41 |
+
input_pixels, # flat uint8 buffer, all images packed back to back
|
| 42 |
+
intermediate, # flat float32 output: width resized + col-cropped, height untouched
|
| 43 |
+
input_offsets, # input_offsets[image] = where that image starts in input_pixels
|
| 44 |
+
intermediate_offsets, # same idea for the intermediate buffer
|
| 45 |
+
heights, widths, # per-image input height / width
|
| 46 |
+
resize_widths, # per-image width to resize to (before cropping)
|
| 47 |
+
crop_lefts, # per-image left offset of the centered crop
|
| 48 |
+
crop_w, # output (crop) width, same for every image
|
| 49 |
+
cubic_coeff,
|
| 50 |
+
CHANNELS: tl.constexpr, BLOCK: tl.constexpr,
|
| 51 |
+
CUBIC: tl.constexpr, ANTIALIAS: tl.constexpr,
|
| 52 |
+
MAX_TAPS_COL: tl.constexpr,
|
| 53 |
+
):
|
| 54 |
+
"""Resize width to resize_width, keep only the cropped columns: uint8 (C,H,W) -> float (C,H,crop_w)."""
|
| 55 |
+
image_index = tl.program_id(0)
|
| 56 |
+
block_index = tl.program_id(1)
|
| 57 |
+
in_height = tl.load(heights + image_index)
|
| 58 |
+
in_width = tl.load(widths + image_index)
|
| 59 |
+
resize_width = tl.load(resize_widths + image_index)
|
| 60 |
+
crop_left = tl.load(crop_lefts + image_index)
|
| 61 |
+
input_start = tl.load(input_offsets + image_index)
|
| 62 |
+
intermediate_start = tl.load(intermediate_offsets + image_index)
|
| 63 |
+
in_width_f = in_width.to(tl.float32)
|
| 64 |
+
|
| 65 |
+
num_pixels = in_height * crop_w # every input row x every cropped output column
|
| 66 |
+
flat_index = block_index * BLOCK + tl.arange(0, BLOCK)
|
| 67 |
+
active = flat_index < num_pixels
|
| 68 |
+
input_row = flat_index // crop_w
|
| 69 |
+
out_col = flat_index % crop_w
|
| 70 |
+
resize_col = out_col + crop_left # column in the (uncropped) resized image
|
| 71 |
+
|
| 72 |
+
filter_half = 2.0 if CUBIC else 1.0
|
| 73 |
+
col_scale = in_width_f / resize_width.to(tl.float32)
|
| 74 |
+
col_filter_scale = tl.maximum(col_scale, 1.0) if ANTIALIAS else 1.0
|
| 75 |
+
col_support = filter_half * col_filter_scale
|
| 76 |
+
col_inv_scale = 1.0 / col_filter_scale
|
| 77 |
+
src_center_col = col_scale * (resize_col.to(tl.float32) + 0.5)
|
| 78 |
+
first_tap_col = tl.floor(src_center_col - col_support + 0.5)
|
| 79 |
+
|
| 80 |
+
col_weight_sum = tl.zeros([BLOCK], dtype=tl.float32)
|
| 81 |
+
for tap in tl.static_range(MAX_TAPS_COL):
|
| 82 |
+
tap_col = first_tap_col + tap
|
| 83 |
+
weight = _resample_weight((tap_col - src_center_col + 0.5) * col_inv_scale, cubic_coeff, CUBIC)
|
| 84 |
+
if ANTIALIAS:
|
| 85 |
+
weight = tl.where((tap_col >= 0.0) & (tap_col < in_width_f), weight, 0.0)
|
| 86 |
+
col_weight_sum += weight
|
| 87 |
+
|
| 88 |
+
input_plane = (in_height * in_width).to(tl.int64)
|
| 89 |
+
intermediate_plane = (in_height * crop_w).to(tl.int64)
|
| 90 |
+
in_width_i64 = in_width.to(tl.int64)
|
| 91 |
+
crop_w_i64 = crop_w.to(tl.int64)
|
| 92 |
+
input_row_i64 = input_row.to(tl.int64)
|
| 93 |
+
for channel in tl.static_range(CHANNELS):
|
| 94 |
+
input_row_base = input_start + channel * input_plane + input_row_i64 * in_width_i64
|
| 95 |
+
accumulator = tl.zeros([BLOCK], dtype=tl.float32)
|
| 96 |
+
for tap in tl.static_range(MAX_TAPS_COL):
|
| 97 |
+
tap_col = first_tap_col + tap
|
| 98 |
+
weight = _resample_weight((tap_col - src_center_col + 0.5) * col_inv_scale, cubic_coeff, CUBIC)
|
| 99 |
+
if ANTIALIAS:
|
| 100 |
+
weight = tl.where((tap_col >= 0.0) & (tap_col < in_width_f), weight, 0.0)
|
| 101 |
+
clamped_tap_col = tl.minimum(tl.maximum(tap_col.to(tl.int32), 0), in_width - 1).to(tl.int64)
|
| 102 |
+
pixel = tl.load(input_pixels + input_row_base + clamped_tap_col, mask=active, other=0).to(tl.float32)
|
| 103 |
+
accumulator += weight * pixel
|
| 104 |
+
accumulator = accumulator / col_weight_sum
|
| 105 |
+
write_index = intermediate_start + channel * intermediate_plane + input_row_i64 * crop_w_i64 + out_col
|
| 106 |
+
tl.store(intermediate + write_index, accumulator, mask=active)
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
@triton.jit
|
| 110 |
+
def _vertical_resize_normalize_kernel(
|
| 111 |
+
intermediate, # float32 from the horizontal step: (C, H, crop_w) per image
|
| 112 |
+
output, # final (N, C, crop_h, crop_w) float32
|
| 113 |
+
intermediate_offsets,
|
| 114 |
+
heights, # per-image input height (the intermediate still has H rows)
|
| 115 |
+
resize_heights, # per-image height to resize to (before cropping)
|
| 116 |
+
crop_tops, # per-image top offset of the centered crop
|
| 117 |
+
means, stds, # per-channel normalization, rescale already folded in
|
| 118 |
+
crop_h, crop_w,
|
| 119 |
+
cubic_coeff,
|
| 120 |
+
CHANNELS: tl.constexpr, BLOCK: tl.constexpr,
|
| 121 |
+
CUBIC: tl.constexpr, ANTIALIAS: tl.constexpr,
|
| 122 |
+
MAX_TAPS_ROW: tl.constexpr,
|
| 123 |
+
):
|
| 124 |
+
"""Resize height to resize_height, keep cropped rows, normalize: float (C,H,crop_w) -> (C,crop_h,crop_w)."""
|
| 125 |
+
image_index = tl.program_id(0)
|
| 126 |
+
block_index = tl.program_id(1)
|
| 127 |
+
in_height = tl.load(heights + image_index)
|
| 128 |
+
resize_height = tl.load(resize_heights + image_index)
|
| 129 |
+
crop_top = tl.load(crop_tops + image_index)
|
| 130 |
+
intermediate_start = tl.load(intermediate_offsets + image_index)
|
| 131 |
+
in_height_f = in_height.to(tl.float32)
|
| 132 |
+
|
| 133 |
+
num_pixels = crop_h * crop_w
|
| 134 |
+
flat_index = block_index * BLOCK + tl.arange(0, BLOCK)
|
| 135 |
+
active = flat_index < num_pixels
|
| 136 |
+
out_row = flat_index // crop_w
|
| 137 |
+
out_col = flat_index % crop_w
|
| 138 |
+
resize_row = out_row + crop_top # row in the (uncropped) resized image
|
| 139 |
+
|
| 140 |
+
filter_half = 2.0 if CUBIC else 1.0
|
| 141 |
+
row_scale = in_height_f / resize_height.to(tl.float32)
|
| 142 |
+
row_filter_scale = tl.maximum(row_scale, 1.0) if ANTIALIAS else 1.0
|
| 143 |
+
row_support = filter_half * row_filter_scale
|
| 144 |
+
row_inv_scale = 1.0 / row_filter_scale
|
| 145 |
+
src_center_row = row_scale * (resize_row.to(tl.float32) + 0.5)
|
| 146 |
+
first_tap_row = tl.floor(src_center_row - row_support + 0.5)
|
| 147 |
+
|
| 148 |
+
row_weight_sum = tl.zeros([BLOCK], dtype=tl.float32)
|
| 149 |
+
for tap in tl.static_range(MAX_TAPS_ROW):
|
| 150 |
+
tap_row = first_tap_row + tap
|
| 151 |
+
weight = _resample_weight((tap_row - src_center_row + 0.5) * row_inv_scale, cubic_coeff, CUBIC)
|
| 152 |
+
if ANTIALIAS:
|
| 153 |
+
weight = tl.where((tap_row >= 0.0) & (tap_row < in_height_f), weight, 0.0)
|
| 154 |
+
row_weight_sum += weight
|
| 155 |
+
|
| 156 |
+
intermediate_plane = (in_height * crop_w).to(tl.int64)
|
| 157 |
+
crop_w_i64 = crop_w.to(tl.int64)
|
| 158 |
+
out_col_i64 = out_col.to(tl.int64)
|
| 159 |
+
for channel in tl.static_range(CHANNELS):
|
| 160 |
+
channel_base = intermediate_start + channel * intermediate_plane
|
| 161 |
+
accumulator = tl.zeros([BLOCK], dtype=tl.float32)
|
| 162 |
+
for tap in tl.static_range(MAX_TAPS_ROW):
|
| 163 |
+
tap_row = first_tap_row + tap
|
| 164 |
+
weight = _resample_weight((tap_row - src_center_row + 0.5) * row_inv_scale, cubic_coeff, CUBIC)
|
| 165 |
+
if ANTIALIAS:
|
| 166 |
+
weight = tl.where((tap_row >= 0.0) & (tap_row < in_height_f), weight, 0.0)
|
| 167 |
+
clamped_tap_row = tl.minimum(tl.maximum(tap_row.to(tl.int32), 0), in_height - 1).to(tl.int64)
|
| 168 |
+
pixel = tl.load(intermediate + channel_base + clamped_tap_row * crop_w_i64 + out_col_i64, mask=active, other=0.0)
|
| 169 |
+
accumulator += weight * pixel
|
| 170 |
+
accumulator = accumulator / row_weight_sum
|
| 171 |
+
mean = tl.load(means + channel)
|
| 172 |
+
std = tl.load(stds + channel)
|
| 173 |
+
accumulator = (accumulator - mean) / std
|
| 174 |
+
write_index = ((image_index * CHANNELS + channel) * crop_h + out_row) * crop_w + out_col
|
| 175 |
+
tl.store(output + write_index, accumulator, mask=active)
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
def _axis_max_taps(in_sizes, resize_sizes, interp, antialias):
|
| 179 |
+
"""Widest 1D window over the batch for one axis = ceil(support) * 2 + 1, support uses in/resize."""
|
| 180 |
+
import math
|
| 181 |
+
|
| 182 |
+
interp_half = 2.0 if interp == "bicubic" else 1.0
|
| 183 |
+
worst = 0
|
| 184 |
+
for in_size, resize_size in zip(in_sizes, resize_sizes):
|
| 185 |
+
scale = in_size / resize_size
|
| 186 |
+
eff = max(scale, 1.0) if antialias else 1.0
|
| 187 |
+
worst = max(worst, math.ceil(interp_half * eff) * 2 + 1)
|
| 188 |
+
return worst
|
| 189 |
+
|
| 190 |
+
|
| 191 |
+
def _run_separable(images, resize_heights, resize_widths, crop_tops, crop_lefts, crop_h, crop_w,
|
| 192 |
+
mean, std, rescale, interp, antialias, block):
|
| 193 |
+
"""Core driver: resize each image to its (resize_h, resize_w), keep the centered crop, normalize."""
|
| 194 |
+
import torch
|
| 195 |
+
|
| 196 |
+
device = images[0].device
|
| 197 |
+
num_images = len(images)
|
| 198 |
+
cubic_coeff = -0.5 if antialias else -0.75
|
| 199 |
+
in_heights = [int(img.shape[1]) for img in images]
|
| 200 |
+
in_widths = [int(img.shape[2]) for img in images]
|
| 201 |
+
max_taps_row = _axis_max_taps(in_heights, resize_heights, interp, antialias)
|
| 202 |
+
max_taps_col = _axis_max_taps(in_widths, resize_widths, interp, antialias)
|
| 203 |
+
means, stds = fold_mean_std(mean, std, rescale, device)
|
| 204 |
+
|
| 205 |
+
input_pixels, input_offsets, heights, widths, channels = pack_images(images, dtype=torch.uint8)
|
| 206 |
+
|
| 207 |
+
intermediate_offsets_list, cursor, tallest = [], 0, 0
|
| 208 |
+
for height in in_heights:
|
| 209 |
+
intermediate_offsets_list.append(cursor)
|
| 210 |
+
cursor += channels * height * crop_w
|
| 211 |
+
tallest = max(tallest, height)
|
| 212 |
+
intermediate = torch.empty(cursor, device=device, dtype=torch.float32)
|
| 213 |
+
|
| 214 |
+
intermediate_offsets = torch.tensor(intermediate_offsets_list, device=device, dtype=torch.int64)
|
| 215 |
+
resize_heights_t = torch.tensor(resize_heights, device=device, dtype=torch.int32)
|
| 216 |
+
resize_widths_t = torch.tensor(resize_widths, device=device, dtype=torch.int32)
|
| 217 |
+
crop_tops_t = torch.tensor(crop_tops, device=device, dtype=torch.int32)
|
| 218 |
+
crop_lefts_t = torch.tensor(crop_lefts, device=device, dtype=torch.int32)
|
| 219 |
+
|
| 220 |
+
horizontal_grid = (num_images, triton.cdiv(tallest * crop_w, block))
|
| 221 |
+
_horizontal_resize_kernel[horizontal_grid](
|
| 222 |
+
input_pixels, intermediate, input_offsets, intermediate_offsets, heights, widths,
|
| 223 |
+
resize_widths_t, crop_lefts_t, crop_w, cubic_coeff,
|
| 224 |
+
CHANNELS=channels, BLOCK=block, CUBIC=(interp == "bicubic"), ANTIALIAS=antialias,
|
| 225 |
+
MAX_TAPS_COL=max_taps_col,
|
| 226 |
+
)
|
| 227 |
+
|
| 228 |
+
output = torch.empty((num_images, channels, crop_h, crop_w), device=device, dtype=torch.float32)
|
| 229 |
+
vertical_grid = (num_images, triton.cdiv(crop_h * crop_w, block))
|
| 230 |
+
_vertical_resize_normalize_kernel[vertical_grid](
|
| 231 |
+
intermediate, output, intermediate_offsets, heights, resize_heights_t, crop_tops_t, means, stds,
|
| 232 |
+
crop_h, crop_w, cubic_coeff,
|
| 233 |
+
CHANNELS=channels, BLOCK=block, CUBIC=(interp == "bicubic"), ANTIALIAS=antialias,
|
| 234 |
+
MAX_TAPS_ROW=max_taps_row,
|
| 235 |
+
)
|
| 236 |
+
return output
|
| 237 |
+
|
| 238 |
+
|
| 239 |
+
def _aspect_preserving_size(in_h, in_w, shortest_edge):
|
| 240 |
+
"""transformers shortest-edge rule: short side -> shortest_edge, long side truncated (int(), not round)."""
|
| 241 |
+
if in_h <= in_w:
|
| 242 |
+
return shortest_edge, int(in_w * shortest_edge / in_h)
|
| 243 |
+
return int(in_h * shortest_edge / in_w), shortest_edge
|
| 244 |
+
|
| 245 |
+
|
| 246 |
+
def separable_resize_normalize(images, out_h, out_w, mean, std, rescale, interp, antialias, block: int = 256):
|
| 247 |
+
"""Resize to (out_h, out_w) and normalize (no crop)."""
|
| 248 |
+
images = list(images)
|
| 249 |
+
n = len(images)
|
| 250 |
+
return _run_separable(images, [out_h] * n, [out_w] * n, [0] * n, [0] * n, out_h, out_w,
|
| 251 |
+
mean, std, rescale, interp, antialias, block)
|
| 252 |
+
|
| 253 |
+
|
| 254 |
+
def separable_resize_crop_normalize(images, resize_size, crop_size, mean, std, rescale, interp, antialias,
|
| 255 |
+
resize_mode="square", block: int = 256):
|
| 256 |
+
"""Resize then center-crop then normalize.
|
| 257 |
+
|
| 258 |
+
resize_mode="square": resize_size is (resize_h, resize_w) applied to every image.
|
| 259 |
+
resize_mode="shortest_edge": resize_size is an int; each image is resized aspect-preserving
|
| 260 |
+
so its short side equals it, then center-cropped to crop_size.
|
| 261 |
+
"""
|
| 262 |
+
images = list(images)
|
| 263 |
+
crop_h, crop_w = crop_size
|
| 264 |
+
resize_heights, resize_widths = [], []
|
| 265 |
+
for img in images:
|
| 266 |
+
in_h, in_w = int(img.shape[1]), int(img.shape[2])
|
| 267 |
+
if resize_mode == "shortest_edge":
|
| 268 |
+
rh, rw = _aspect_preserving_size(in_h, in_w, int(resize_size))
|
| 269 |
+
elif resize_mode == "square":
|
| 270 |
+
rh, rw = int(resize_size[0]), int(resize_size[1])
|
| 271 |
+
else:
|
| 272 |
+
raise ValueError(f"resize_mode must be 'square' or 'shortest_edge', got {resize_mode!r}")
|
| 273 |
+
if rh < crop_h or rw < crop_w:
|
| 274 |
+
raise ValueError(f"resize size ({rh},{rw}) smaller than crop ({crop_h},{crop_w})")
|
| 275 |
+
resize_heights.append(rh)
|
| 276 |
+
resize_widths.append(rw)
|
| 277 |
+
crop_tops = [(rh - crop_h) // 2 for rh in resize_heights]
|
| 278 |
+
crop_lefts = [(rw - crop_w) // 2 for rw in resize_widths]
|
| 279 |
+
return _run_separable(images, resize_heights, resize_widths, crop_tops, crop_lefts, crop_h, crop_w,
|
| 280 |
+
mean, std, rescale, interp, antialias, block)
|