File size: 4,869 Bytes
e199518
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
"""Parity tests vs torchvision for both backends, all interp×antialias combos, ragged inputs.

Run locally from the repo root with the package on the path:
    PYTHONPATH=torch-ext pytest tests/ -q
CUDA is required (Triton); tests skip on CPU.
"""

import pytest
import torch
import torchvision.transforms.v2.functional as tvF
from torchvision.transforms import InterpolationMode

from kernel_image_resize import resize_normalize


_TV_INTERP = {"bilinear": InterpolationMode.BILINEAR, "bicubic": InterpolationMode.BICUBIC}

MEAN = [0.48145466, 0.4578275, 0.40821073]
STD = [0.26862954, 0.26130258, 0.27577711]
RESCALE = 1.0 / 255.0


def _ragged_images(n, device, min_res=384, max_res=1024, seed=0):
    g = torch.Generator(device="cpu").manual_seed(seed)
    images = []
    for _ in range(n):
        h = int(torch.randint(min_res, max_res + 1, (1,), generator=g).item())
        w = int(torch.randint(min_res, max_res + 1, (1,), generator=g).item())
        images.append(torch.randint(0, 256, (3, h, w), generator=g, dtype=torch.uint8).to(device))
    return images


def _torchvision_reference(images, out_h, out_w, interp, antialias):
    mode = _TV_INTERP[interp]
    mean = torch.tensor(MEAN, device=images[0].device).view(3, 1, 1)
    std = torch.tensor(STD, device=images[0].device).view(3, 1, 1)
    outs = []
    for img in images:
        r = tvF.resize(img.float(), [out_h, out_w], interpolation=mode, antialias=antialias)
        outs.append((r * RESCALE - mean) / std)
    return torch.stack(outs)


@pytest.mark.kernels_ci
@pytest.mark.skipif(not torch.cuda.is_available(), reason="Triton kernel needs CUDA")
@pytest.mark.parametrize("backend", ["fused", "separable"])
@pytest.mark.parametrize("interp,antialias", [("bilinear", False), ("bilinear", True), ("bicubic", False), ("bicubic", True)])
def test_parity_vs_torchvision(backend, interp, antialias):
    device = torch.device("cuda")
    images = _ragged_images(8, device)
    out_h = out_w = 384
    got = resize_normalize(
        images, (out_h, out_w), MEAN, STD, RESCALE, resample=interp, antialias=antialias, backend=backend
    )
    ref = _torchvision_reference(images, out_h, out_w, interp, antialias)
    max_abs = (got - ref).abs().max().item()
    assert max_abs < 3e-3, f"{backend}/{interp}/aa={antialias}: max|Δ|={max_abs:.2e}"


@pytest.mark.kernels_ci
@pytest.mark.skipif(not torch.cuda.is_available(), reason="Triton kernel needs CUDA")
def test_stacked_tensor_input():
    device = torch.device("cuda")
    images = torch.randint(0, 256, (4, 3, 512, 512), dtype=torch.uint8, device=device)
    got = resize_normalize(images, 224, MEAN, STD, RESCALE, resample="bicubic", antialias=True)
    assert got.shape == (4, 3, 224, 224)


def _shortest_edge_crop_reference(images, shortest_edge, crop, interp, antialias):
    mode = _TV_INTERP[interp]
    mean = torch.tensor(MEAN, device=images[0].device).view(3, 1, 1)
    std = torch.tensor(STD, device=images[0].device).view(3, 1, 1)
    outs = []
    for img in images:
        in_h, in_w = img.shape[1], img.shape[2]
        if in_h <= in_w:
            rh, rw = shortest_edge, int(in_w * shortest_edge / in_h)
        else:
            rh, rw = int(in_h * shortest_edge / in_w), shortest_edge
        r = tvF.resize(img.float(), [rh, rw], interpolation=mode, antialias=antialias)
        r = tvF.center_crop(r, [crop, crop])
        outs.append((r * RESCALE - mean) / std)
    return torch.stack(outs)


@pytest.mark.kernels_ci
@pytest.mark.skipif(not torch.cuda.is_available(), reason="Triton kernel needs CUDA")
@pytest.mark.parametrize("interp,antialias", [("bilinear", True), ("bicubic", True)])
def test_shortest_edge_crop_parity(interp, antialias):
    device = torch.device("cuda")
    images = _ragged_images(8, device)
    shortest_edge, crop = 256, 224
    got = resize_normalize(
        images, shortest_edge, MEAN, STD, RESCALE, resample=interp, antialias=antialias,
        crop_size=(crop, crop), resize_mode="shortest_edge",
    )
    ref = _shortest_edge_crop_reference(images, shortest_edge, crop, interp, antialias)
    assert got.shape == (8, 3, crop, crop)
    max_abs = (got - ref).abs().max().item()
    assert max_abs < 3e-3, f"shortest_edge+crop {interp}/aa={antialias}: max|Δ|={max_abs:.2e}"


@pytest.mark.kernels_ci
@pytest.mark.skipif(not torch.cuda.is_available(), reason="Triton kernel needs CUDA")
def test_fused_matches_separable():
    device = torch.device("cuda")
    images = _ragged_images(6, device)
    common = dict(size=(256, 256), image_mean=MEAN, image_std=STD, rescale_factor=RESCALE, resample="bicubic", antialias=True)
    fused = resize_normalize(images, backend="fused", **common)
    separable = resize_normalize(images, backend="separable", **common)
    max_abs = (fused - separable).abs().max().item()
    assert max_abs < 3e-3, f"fused vs separable: max|Δ|={max_abs:.2e}"