File size: 4,869 Bytes
e199518 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 | """Parity tests vs torchvision for both backends, all interp×antialias combos, ragged inputs.
Run locally from the repo root with the package on the path:
PYTHONPATH=torch-ext pytest tests/ -q
CUDA is required (Triton); tests skip on CPU.
"""
import pytest
import torch
import torchvision.transforms.v2.functional as tvF
from torchvision.transforms import InterpolationMode
from kernel_image_resize import resize_normalize
_TV_INTERP = {"bilinear": InterpolationMode.BILINEAR, "bicubic": InterpolationMode.BICUBIC}
MEAN = [0.48145466, 0.4578275, 0.40821073]
STD = [0.26862954, 0.26130258, 0.27577711]
RESCALE = 1.0 / 255.0
def _ragged_images(n, device, min_res=384, max_res=1024, seed=0):
g = torch.Generator(device="cpu").manual_seed(seed)
images = []
for _ in range(n):
h = int(torch.randint(min_res, max_res + 1, (1,), generator=g).item())
w = int(torch.randint(min_res, max_res + 1, (1,), generator=g).item())
images.append(torch.randint(0, 256, (3, h, w), generator=g, dtype=torch.uint8).to(device))
return images
def _torchvision_reference(images, out_h, out_w, interp, antialias):
mode = _TV_INTERP[interp]
mean = torch.tensor(MEAN, device=images[0].device).view(3, 1, 1)
std = torch.tensor(STD, device=images[0].device).view(3, 1, 1)
outs = []
for img in images:
r = tvF.resize(img.float(), [out_h, out_w], interpolation=mode, antialias=antialias)
outs.append((r * RESCALE - mean) / std)
return torch.stack(outs)
@pytest.mark.kernels_ci
@pytest.mark.skipif(not torch.cuda.is_available(), reason="Triton kernel needs CUDA")
@pytest.mark.parametrize("backend", ["fused", "separable"])
@pytest.mark.parametrize("interp,antialias", [("bilinear", False), ("bilinear", True), ("bicubic", False), ("bicubic", True)])
def test_parity_vs_torchvision(backend, interp, antialias):
device = torch.device("cuda")
images = _ragged_images(8, device)
out_h = out_w = 384
got = resize_normalize(
images, (out_h, out_w), MEAN, STD, RESCALE, resample=interp, antialias=antialias, backend=backend
)
ref = _torchvision_reference(images, out_h, out_w, interp, antialias)
max_abs = (got - ref).abs().max().item()
assert max_abs < 3e-3, f"{backend}/{interp}/aa={antialias}: max|Δ|={max_abs:.2e}"
@pytest.mark.kernels_ci
@pytest.mark.skipif(not torch.cuda.is_available(), reason="Triton kernel needs CUDA")
def test_stacked_tensor_input():
device = torch.device("cuda")
images = torch.randint(0, 256, (4, 3, 512, 512), dtype=torch.uint8, device=device)
got = resize_normalize(images, 224, MEAN, STD, RESCALE, resample="bicubic", antialias=True)
assert got.shape == (4, 3, 224, 224)
def _shortest_edge_crop_reference(images, shortest_edge, crop, interp, antialias):
mode = _TV_INTERP[interp]
mean = torch.tensor(MEAN, device=images[0].device).view(3, 1, 1)
std = torch.tensor(STD, device=images[0].device).view(3, 1, 1)
outs = []
for img in images:
in_h, in_w = img.shape[1], img.shape[2]
if in_h <= in_w:
rh, rw = shortest_edge, int(in_w * shortest_edge / in_h)
else:
rh, rw = int(in_h * shortest_edge / in_w), shortest_edge
r = tvF.resize(img.float(), [rh, rw], interpolation=mode, antialias=antialias)
r = tvF.center_crop(r, [crop, crop])
outs.append((r * RESCALE - mean) / std)
return torch.stack(outs)
@pytest.mark.kernels_ci
@pytest.mark.skipif(not torch.cuda.is_available(), reason="Triton kernel needs CUDA")
@pytest.mark.parametrize("interp,antialias", [("bilinear", True), ("bicubic", True)])
def test_shortest_edge_crop_parity(interp, antialias):
device = torch.device("cuda")
images = _ragged_images(8, device)
shortest_edge, crop = 256, 224
got = resize_normalize(
images, shortest_edge, MEAN, STD, RESCALE, resample=interp, antialias=antialias,
crop_size=(crop, crop), resize_mode="shortest_edge",
)
ref = _shortest_edge_crop_reference(images, shortest_edge, crop, interp, antialias)
assert got.shape == (8, 3, crop, crop)
max_abs = (got - ref).abs().max().item()
assert max_abs < 3e-3, f"shortest_edge+crop {interp}/aa={antialias}: max|Δ|={max_abs:.2e}"
@pytest.mark.kernels_ci
@pytest.mark.skipif(not torch.cuda.is_available(), reason="Triton kernel needs CUDA")
def test_fused_matches_separable():
device = torch.device("cuda")
images = _ragged_images(6, device)
common = dict(size=(256, 256), image_mean=MEAN, image_std=STD, rescale_factor=RESCALE, resample="bicubic", antialias=True)
fused = resize_normalize(images, backend="fused", **common)
separable = resize_normalize(images, backend="separable", **common)
max_abs = (fused - separable).abs().max().item()
assert max_abs < 3e-3, f"fused vs separable: max|Δ|={max_abs:.2e}"
|