| """Parity tests vs torchvision for both backends, all interp×antialias combos, ragged inputs. |
| |
| Run locally from the repo root with the package on the path: |
| PYTHONPATH=torch-ext pytest tests/ -q |
| CUDA is required (Triton); tests skip on CPU. |
| """ |
|
|
| import pytest |
| import torch |
| import torchvision.transforms.v2.functional as tvF |
| from torchvision.transforms import InterpolationMode |
|
|
| from kernel_image_resize import resize_normalize |
|
|
|
|
| _TV_INTERP = {"bilinear": InterpolationMode.BILINEAR, "bicubic": InterpolationMode.BICUBIC} |
|
|
| MEAN = [0.48145466, 0.4578275, 0.40821073] |
| STD = [0.26862954, 0.26130258, 0.27577711] |
| RESCALE = 1.0 / 255.0 |
|
|
|
|
| def _ragged_images(n, device, min_res=384, max_res=1024, seed=0): |
| g = torch.Generator(device="cpu").manual_seed(seed) |
| images = [] |
| for _ in range(n): |
| h = int(torch.randint(min_res, max_res + 1, (1,), generator=g).item()) |
| w = int(torch.randint(min_res, max_res + 1, (1,), generator=g).item()) |
| images.append(torch.randint(0, 256, (3, h, w), generator=g, dtype=torch.uint8).to(device)) |
| return images |
|
|
|
|
| def _torchvision_reference(images, out_h, out_w, interp, antialias): |
| mode = _TV_INTERP[interp] |
| mean = torch.tensor(MEAN, device=images[0].device).view(3, 1, 1) |
| std = torch.tensor(STD, device=images[0].device).view(3, 1, 1) |
| outs = [] |
| for img in images: |
| r = tvF.resize(img.float(), [out_h, out_w], interpolation=mode, antialias=antialias) |
| outs.append((r * RESCALE - mean) / std) |
| return torch.stack(outs) |
|
|
|
|
| @pytest.mark.kernels_ci |
| @pytest.mark.skipif(not torch.cuda.is_available(), reason="Triton kernel needs CUDA") |
| @pytest.mark.parametrize("backend", ["fused", "separable"]) |
| @pytest.mark.parametrize("interp,antialias", [("bilinear", False), ("bilinear", True), ("bicubic", False), ("bicubic", True)]) |
| def test_parity_vs_torchvision(backend, interp, antialias): |
| device = torch.device("cuda") |
| images = _ragged_images(8, device) |
| out_h = out_w = 384 |
| got = resize_normalize( |
| images, (out_h, out_w), MEAN, STD, RESCALE, resample=interp, antialias=antialias, backend=backend |
| ) |
| ref = _torchvision_reference(images, out_h, out_w, interp, antialias) |
| max_abs = (got - ref).abs().max().item() |
| assert max_abs < 3e-3, f"{backend}/{interp}/aa={antialias}: max|Δ|={max_abs:.2e}" |
|
|
|
|
| @pytest.mark.kernels_ci |
| @pytest.mark.skipif(not torch.cuda.is_available(), reason="Triton kernel needs CUDA") |
| def test_stacked_tensor_input(): |
| device = torch.device("cuda") |
| images = torch.randint(0, 256, (4, 3, 512, 512), dtype=torch.uint8, device=device) |
| got = resize_normalize(images, 224, MEAN, STD, RESCALE, resample="bicubic", antialias=True) |
| assert got.shape == (4, 3, 224, 224) |
|
|
|
|
| def _shortest_edge_crop_reference(images, shortest_edge, crop, interp, antialias): |
| mode = _TV_INTERP[interp] |
| mean = torch.tensor(MEAN, device=images[0].device).view(3, 1, 1) |
| std = torch.tensor(STD, device=images[0].device).view(3, 1, 1) |
| outs = [] |
| for img in images: |
| in_h, in_w = img.shape[1], img.shape[2] |
| if in_h <= in_w: |
| rh, rw = shortest_edge, int(in_w * shortest_edge / in_h) |
| else: |
| rh, rw = int(in_h * shortest_edge / in_w), shortest_edge |
| r = tvF.resize(img.float(), [rh, rw], interpolation=mode, antialias=antialias) |
| r = tvF.center_crop(r, [crop, crop]) |
| outs.append((r * RESCALE - mean) / std) |
| return torch.stack(outs) |
|
|
|
|
| @pytest.mark.kernels_ci |
| @pytest.mark.skipif(not torch.cuda.is_available(), reason="Triton kernel needs CUDA") |
| @pytest.mark.parametrize("interp,antialias", [("bilinear", True), ("bicubic", True)]) |
| def test_shortest_edge_crop_parity(interp, antialias): |
| device = torch.device("cuda") |
| images = _ragged_images(8, device) |
| shortest_edge, crop = 256, 224 |
| got = resize_normalize( |
| images, shortest_edge, MEAN, STD, RESCALE, resample=interp, antialias=antialias, |
| crop_size=(crop, crop), resize_mode="shortest_edge", |
| ) |
| ref = _shortest_edge_crop_reference(images, shortest_edge, crop, interp, antialias) |
| assert got.shape == (8, 3, crop, crop) |
| max_abs = (got - ref).abs().max().item() |
| assert max_abs < 3e-3, f"shortest_edge+crop {interp}/aa={antialias}: max|Δ|={max_abs:.2e}" |
|
|
|
|
| @pytest.mark.kernels_ci |
| @pytest.mark.skipif(not torch.cuda.is_available(), reason="Triton kernel needs CUDA") |
| def test_fused_matches_separable(): |
| device = torch.device("cuda") |
| images = _ragged_images(6, device) |
| common = dict(size=(256, 256), image_mean=MEAN, image_std=STD, rescale_factor=RESCALE, resample="bicubic", antialias=True) |
| fused = resize_normalize(images, backend="fused", **common) |
| separable = resize_normalize(images, backend="separable", **common) |
| max_abs = (fused - separable).abs().max().item() |
| assert max_abs < 3e-3, f"fused vs separable: max|Δ|={max_abs:.2e}" |
|
|