"""Parity tests vs torchvision for both backends, all interp×antialias combos, ragged inputs. Run locally from the repo root with the package on the path: PYTHONPATH=torch-ext pytest tests/ -q CUDA is required (Triton); tests skip on CPU. """ import pytest import torch import torchvision.transforms.v2.functional as tvF from torchvision.transforms import InterpolationMode from kernel_image_resize import resize_normalize _TV_INTERP = {"bilinear": InterpolationMode.BILINEAR, "bicubic": InterpolationMode.BICUBIC} MEAN = [0.48145466, 0.4578275, 0.40821073] STD = [0.26862954, 0.26130258, 0.27577711] RESCALE = 1.0 / 255.0 def _ragged_images(n, device, min_res=384, max_res=1024, seed=0): g = torch.Generator(device="cpu").manual_seed(seed) images = [] for _ in range(n): h = int(torch.randint(min_res, max_res + 1, (1,), generator=g).item()) w = int(torch.randint(min_res, max_res + 1, (1,), generator=g).item()) images.append(torch.randint(0, 256, (3, h, w), generator=g, dtype=torch.uint8).to(device)) return images def _torchvision_reference(images, out_h, out_w, interp, antialias): mode = _TV_INTERP[interp] mean = torch.tensor(MEAN, device=images[0].device).view(3, 1, 1) std = torch.tensor(STD, device=images[0].device).view(3, 1, 1) outs = [] for img in images: r = tvF.resize(img.float(), [out_h, out_w], interpolation=mode, antialias=antialias) outs.append((r * RESCALE - mean) / std) return torch.stack(outs) @pytest.mark.kernels_ci @pytest.mark.skipif(not torch.cuda.is_available(), reason="Triton kernel needs CUDA") @pytest.mark.parametrize("backend", ["fused", "separable"]) @pytest.mark.parametrize("interp,antialias", [("bilinear", False), ("bilinear", True), ("bicubic", False), ("bicubic", True)]) def test_parity_vs_torchvision(backend, interp, antialias): device = torch.device("cuda") images = _ragged_images(8, device) out_h = out_w = 384 got = resize_normalize( images, (out_h, out_w), MEAN, STD, RESCALE, resample=interp, antialias=antialias, backend=backend ) ref = _torchvision_reference(images, out_h, out_w, interp, antialias) max_abs = (got - ref).abs().max().item() assert max_abs < 3e-3, f"{backend}/{interp}/aa={antialias}: max|Δ|={max_abs:.2e}" @pytest.mark.kernels_ci @pytest.mark.skipif(not torch.cuda.is_available(), reason="Triton kernel needs CUDA") def test_stacked_tensor_input(): device = torch.device("cuda") images = torch.randint(0, 256, (4, 3, 512, 512), dtype=torch.uint8, device=device) got = resize_normalize(images, 224, MEAN, STD, RESCALE, resample="bicubic", antialias=True) assert got.shape == (4, 3, 224, 224) def _shortest_edge_crop_reference(images, shortest_edge, crop, interp, antialias): mode = _TV_INTERP[interp] mean = torch.tensor(MEAN, device=images[0].device).view(3, 1, 1) std = torch.tensor(STD, device=images[0].device).view(3, 1, 1) outs = [] for img in images: in_h, in_w = img.shape[1], img.shape[2] if in_h <= in_w: rh, rw = shortest_edge, int(in_w * shortest_edge / in_h) else: rh, rw = int(in_h * shortest_edge / in_w), shortest_edge r = tvF.resize(img.float(), [rh, rw], interpolation=mode, antialias=antialias) r = tvF.center_crop(r, [crop, crop]) outs.append((r * RESCALE - mean) / std) return torch.stack(outs) @pytest.mark.kernels_ci @pytest.mark.skipif(not torch.cuda.is_available(), reason="Triton kernel needs CUDA") @pytest.mark.parametrize("interp,antialias", [("bilinear", True), ("bicubic", True)]) def test_shortest_edge_crop_parity(interp, antialias): device = torch.device("cuda") images = _ragged_images(8, device) shortest_edge, crop = 256, 224 got = resize_normalize( images, shortest_edge, MEAN, STD, RESCALE, resample=interp, antialias=antialias, crop_size=(crop, crop), resize_mode="shortest_edge", ) ref = _shortest_edge_crop_reference(images, shortest_edge, crop, interp, antialias) assert got.shape == (8, 3, crop, crop) max_abs = (got - ref).abs().max().item() assert max_abs < 3e-3, f"shortest_edge+crop {interp}/aa={antialias}: max|Δ|={max_abs:.2e}" @pytest.mark.kernels_ci @pytest.mark.skipif(not torch.cuda.is_available(), reason="Triton kernel needs CUDA") def test_fused_matches_separable(): device = torch.device("cuda") images = _ragged_images(6, device) common = dict(size=(256, 256), image_mean=MEAN, image_std=STD, rescale_factor=RESCALE, resample="bicubic", antialias=True) fused = resize_normalize(images, backend="fused", **common) separable = resize_normalize(images, backend="separable", **common) max_abs = (fused - separable).abs().max().item() assert max_abs < 3e-3, f"fused vs separable: max|Δ|={max_abs:.2e}"