| import torch |
| import torch.nn.functional as F |
| from skimage.metrics import peak_signal_noise_ratio |
|
|
|
|
| def pad_to_factor(x, factor=8): |
| if factor <= 1: |
| return x, (0, 0) |
|
|
| h, w = x.shape[-2:] |
| pad_h = (factor - h % factor) % factor |
| pad_w = (factor - w % factor) % factor |
| if pad_h == 0 and pad_w == 0: |
| return x, (h, w) |
|
|
| return F.pad(x, (0, pad_w, 0, pad_h), mode="reflect"), (h, w) |
|
|
|
|
| def crop_to_size(x, size): |
| h, w = size |
| return x[..., :h, :w] |
|
|
|
|
| def _tta_inputs(x): |
| return [ |
| (x, lambda y: y), |
| (torch.flip(x, dims=[-1]), lambda y: torch.flip(y, dims=[-1])), |
| (torch.flip(x, dims=[-2]), lambda y: torch.flip(y, dims=[-2])), |
| (torch.rot90(x, k=1, dims=(-2, -1)), lambda y: torch.rot90(y, k=3, dims=(-2, -1))), |
| ] |
|
|
|
|
| def _unwrap_output(output): |
| if isinstance(output, tuple): |
| return output[0] |
| return output |
|
|
|
|
| @torch.no_grad() |
| def run_model(model, x, pad_factor=8, tta=False): |
| if not tta: |
| padded, size = pad_to_factor(x, pad_factor) |
| out = _unwrap_output(model(padded)) |
| return crop_to_size(out, size) |
|
|
| outputs = [] |
| for aug_x, deaug in _tta_inputs(x): |
| padded, size = pad_to_factor(aug_x, pad_factor) |
| out = crop_to_size(_unwrap_output(model(padded)), size) |
| outputs.append(deaug(out)) |
| return torch.stack(outputs, dim=0).mean(dim=0) |
|
|
|
|
| def batch_rgb_psnr(pred, target): |
| pred_np = pred.detach().cpu().numpy() |
| target_np = target.detach().cpu().numpy() |
| total = 0.0 |
| for pred_img, target_img in zip(pred_np, target_np): |
| total += peak_signal_noise_ratio(target_img, pred_img, data_range=1.0) |
| return total / max(pred_np.shape[0], 1) |
|
|