import torch import torch.nn.functional as F from skimage.metrics import peak_signal_noise_ratio def pad_to_factor(x, factor=8): if factor <= 1: return x, (0, 0) h, w = x.shape[-2:] pad_h = (factor - h % factor) % factor pad_w = (factor - w % factor) % factor if pad_h == 0 and pad_w == 0: return x, (h, w) return F.pad(x, (0, pad_w, 0, pad_h), mode="reflect"), (h, w) def crop_to_size(x, size): h, w = size return x[..., :h, :w] def _tta_inputs(x): return [ (x, lambda y: y), (torch.flip(x, dims=[-1]), lambda y: torch.flip(y, dims=[-1])), (torch.flip(x, dims=[-2]), lambda y: torch.flip(y, dims=[-2])), (torch.rot90(x, k=1, dims=(-2, -1)), lambda y: torch.rot90(y, k=3, dims=(-2, -1))), ] def _unwrap_output(output): if isinstance(output, tuple): return output[0] return output @torch.no_grad() def run_model(model, x, pad_factor=8, tta=False): if not tta: padded, size = pad_to_factor(x, pad_factor) out = _unwrap_output(model(padded)) return crop_to_size(out, size) outputs = [] for aug_x, deaug in _tta_inputs(x): padded, size = pad_to_factor(aug_x, pad_factor) out = crop_to_size(_unwrap_output(model(padded)), size) outputs.append(deaug(out)) return torch.stack(outputs, dim=0).mean(dim=0) def batch_rgb_psnr(pred, target): pred_np = pred.detach().cpu().numpy() target_np = target.detach().cpu().numpy() total = 0.0 for pred_img, target_img in zip(pred_np, target_np): total += peak_signal_noise_ratio(target_img, pred_img, data_range=1.0) return total / max(pred_np.shape[0], 1)