File size: 3,924 Bytes
e199518
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
# /// script
# requires-python = ">=3.10"
# dependencies = ["torch", "triton", "kernels", "transformers", "torchvision"]
# ///
"""Drop-in: use the kernel as the resize+normalize stage of a transformers fast processor.

There is no `use_kernels=True` hook for image processors (that machinery swaps nn.Module
layer forwards inside the model, not processor code). So the usable path is to read the
processor's config and call the kernel directly. `preprocess_with_kernel` below is the whole
adapter — copy it into your code.

Run on a CUDA box:
    python example_transformers.py
"""

import torch
from kernels import get_kernel
from transformers import AutoImageProcessor, AutoModel


kernel_image_resize = get_kernel("Molbap/kernel_image_resize", revision="main", trust_remote_code=True)

_PIL_RESAMPLE = {0: "bilinear", 2: "bilinear", 3: "bicubic"}


def preprocess_with_kernel(processor, images):
    """Run the kernel using `processor`'s own config; returns pixel_values like processor(images).

    Handles fixed-size resize, square-resize + center-crop, and shortest-edge resize + center-crop
    (CLIP / DINOv2). Does not handle padding processors.
    """
    size = processor.size
    if getattr(processor, "do_pad", False):
        raise ValueError("kernel does not pad; this processor needs a pad step")
    if not getattr(processor, "do_normalize", True):
        raise ValueError("processor does not normalize (rescale only); kernel always normalizes")
    if getattr(processor, "do_flip_channel_order", False):
        raise ValueError("processor flips channels to BGR; kernel keeps RGB")
    resample = _PIL_RESAMPLE[int(processor.resample)]
    antialias = bool(getattr(processor, "antialias", True))
    rescale = float(processor.rescale_factor) if getattr(processor, "do_rescale", True) else 1.0
    mean, std = processor.image_mean, processor.image_std
    crop = processor.crop_size if getattr(processor, "do_center_crop", False) else None
    common = dict(rescale_factor=rescale, resample=resample, antialias=antialias)

    if "shortest_edge" in size:
        if crop is None:
            raise ValueError("shortest-edge resize without a crop gives variable-size output")
        return kernel_image_resize.resize_normalize(
            images, size["shortest_edge"], mean, std,
            crop_size=(crop["height"], crop["width"]), resize_mode="shortest_edge", **common,
        )
    if crop is not None and (crop["height"] != size["height"] or crop["width"] != size["width"]):
        return kernel_image_resize.resize_normalize(
            images, (size["height"], size["width"]), mean, std,
            crop_size=(crop["height"], crop["width"]), resize_mode="square", **common,
        )
    return kernel_image_resize.resize_normalize(images, (size["height"], size["width"]), mean, std, **common)


def main():
    device = "cuda" if torch.cuda.is_available() else "cpu"
    model_id = "google/siglip2-base-patch16-224"
    processor = AutoImageProcessor.from_pretrained(model_id, backend="torchvision")
    model = AutoModel.from_pretrained(model_id).to(device).eval()

    images = [
        torch.randint(0, 256, (3, h, w), dtype=torch.uint8, device=device)
        for h, w in [(640, 480), (800, 600), (384, 1024)]
    ]

    pixel_values = preprocess_with_kernel(processor, images)
    print(f"{len(images)} ragged images -> pixel_values {tuple(pixel_values.shape)} {pixel_values.dtype}")

    with torch.no_grad():
        features = model.vision_model(pixel_values=pixel_values.to(model.dtype)).pooler_output
    print(f"vision features: {tuple(features.shape)}")

    # parity vs the real processor (float-vs-uint8 resize -> small, expected gap)
    reference = processor(images, return_tensors="pt", device=device)["pixel_values"].to(device)
    print(f"max|Δ| pixel_values vs processor: {(pixel_values - reference).abs().max().item():.2e}")


if __name__ == "__main__":
    main()