File size: 1,698 Bytes
4336727
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
import torch
import torch.nn.functional as F
from skimage.metrics import peak_signal_noise_ratio


def pad_to_factor(x, factor=8):
    if factor <= 1:
        return x, (0, 0)

    h, w = x.shape[-2:]
    pad_h = (factor - h % factor) % factor
    pad_w = (factor - w % factor) % factor
    if pad_h == 0 and pad_w == 0:
        return x, (h, w)

    return F.pad(x, (0, pad_w, 0, pad_h), mode="reflect"), (h, w)


def crop_to_size(x, size):
    h, w = size
    return x[..., :h, :w]


def _tta_inputs(x):
    return [
        (x, lambda y: y),
        (torch.flip(x, dims=[-1]), lambda y: torch.flip(y, dims=[-1])),
        (torch.flip(x, dims=[-2]), lambda y: torch.flip(y, dims=[-2])),
        (torch.rot90(x, k=1, dims=(-2, -1)), lambda y: torch.rot90(y, k=3, dims=(-2, -1))),
    ]


def _unwrap_output(output):
    if isinstance(output, tuple):
        return output[0]
    return output


@torch.no_grad()
def run_model(model, x, pad_factor=8, tta=False):
    if not tta:
        padded, size = pad_to_factor(x, pad_factor)
        out = _unwrap_output(model(padded))
        return crop_to_size(out, size)

    outputs = []
    for aug_x, deaug in _tta_inputs(x):
        padded, size = pad_to_factor(aug_x, pad_factor)
        out = crop_to_size(_unwrap_output(model(padded)), size)
        outputs.append(deaug(out))
    return torch.stack(outputs, dim=0).mean(dim=0)


def batch_rgb_psnr(pred, target):
    pred_np = pred.detach().cpu().numpy()
    target_np = target.detach().cpu().numpy()
    total = 0.0
    for pred_img, target_img in zip(pred_np, target_np):
        total += peak_signal_noise_ratio(target_img, pred_img, data_range=1.0)
    return total / max(pred_np.shape[0], 1)