kernel_image_resize / example_transformers.py
Molbap's picture
Molbap HF Staff
Upload folder using huggingface_hub
e199518 verified
Raw
History Blame Contribute Delete
3.92 kB
# /// script
# requires-python = ">=3.10"
# dependencies = ["torch", "triton", "kernels", "transformers", "torchvision"]
# ///
"""Drop-in: use the kernel as the resize+normalize stage of a transformers fast processor.
There is no `use_kernels=True` hook for image processors (that machinery swaps nn.Module
layer forwards inside the model, not processor code). So the usable path is to read the
processor's config and call the kernel directly. `preprocess_with_kernel` below is the whole
adapter — copy it into your code.
Run on a CUDA box:
python example_transformers.py
"""
import torch
from kernels import get_kernel
from transformers import AutoImageProcessor, AutoModel
kernel_image_resize = get_kernel("Molbap/kernel_image_resize", revision="main", trust_remote_code=True)
_PIL_RESAMPLE = {0: "bilinear", 2: "bilinear", 3: "bicubic"}
def preprocess_with_kernel(processor, images):
"""Run the kernel using `processor`'s own config; returns pixel_values like processor(images).
Handles fixed-size resize, square-resize + center-crop, and shortest-edge resize + center-crop
(CLIP / DINOv2). Does not handle padding processors.
"""
size = processor.size
if getattr(processor, "do_pad", False):
raise ValueError("kernel does not pad; this processor needs a pad step")
if not getattr(processor, "do_normalize", True):
raise ValueError("processor does not normalize (rescale only); kernel always normalizes")
if getattr(processor, "do_flip_channel_order", False):
raise ValueError("processor flips channels to BGR; kernel keeps RGB")
resample = _PIL_RESAMPLE[int(processor.resample)]
antialias = bool(getattr(processor, "antialias", True))
rescale = float(processor.rescale_factor) if getattr(processor, "do_rescale", True) else 1.0
mean, std = processor.image_mean, processor.image_std
crop = processor.crop_size if getattr(processor, "do_center_crop", False) else None
common = dict(rescale_factor=rescale, resample=resample, antialias=antialias)
if "shortest_edge" in size:
if crop is None:
raise ValueError("shortest-edge resize without a crop gives variable-size output")
return kernel_image_resize.resize_normalize(
images, size["shortest_edge"], mean, std,
crop_size=(crop["height"], crop["width"]), resize_mode="shortest_edge", **common,
)
if crop is not None and (crop["height"] != size["height"] or crop["width"] != size["width"]):
return kernel_image_resize.resize_normalize(
images, (size["height"], size["width"]), mean, std,
crop_size=(crop["height"], crop["width"]), resize_mode="square", **common,
)
return kernel_image_resize.resize_normalize(images, (size["height"], size["width"]), mean, std, **common)
def main():
device = "cuda" if torch.cuda.is_available() else "cpu"
model_id = "google/siglip2-base-patch16-224"
processor = AutoImageProcessor.from_pretrained(model_id, backend="torchvision")
model = AutoModel.from_pretrained(model_id).to(device).eval()
images = [
torch.randint(0, 256, (3, h, w), dtype=torch.uint8, device=device)
for h, w in [(640, 480), (800, 600), (384, 1024)]
]
pixel_values = preprocess_with_kernel(processor, images)
print(f"{len(images)} ragged images -> pixel_values {tuple(pixel_values.shape)} {pixel_values.dtype}")
with torch.no_grad():
features = model.vision_model(pixel_values=pixel_values.to(model.dtype)).pooler_output
print(f"vision features: {tuple(features.shape)}")
# parity vs the real processor (float-vs-uint8 resize -> small, expected gap)
reference = processor(images, return_tensors="pt", device=device)["pixel_values"].to(device)
print(f"max|Δ| pixel_values vs processor: {(pixel_values - reference).abs().max().item():.2e}")
if __name__ == "__main__":
main()