llir / utils /inference_utils.py
linxin02's picture
Upload portable Low_light_rainy_new code export
4336727 verified
import torch
import torch.nn.functional as F
from skimage.metrics import peak_signal_noise_ratio
def pad_to_factor(x, factor=8):
if factor <= 1:
return x, (0, 0)
h, w = x.shape[-2:]
pad_h = (factor - h % factor) % factor
pad_w = (factor - w % factor) % factor
if pad_h == 0 and pad_w == 0:
return x, (h, w)
return F.pad(x, (0, pad_w, 0, pad_h), mode="reflect"), (h, w)
def crop_to_size(x, size):
h, w = size
return x[..., :h, :w]
def _tta_inputs(x):
return [
(x, lambda y: y),
(torch.flip(x, dims=[-1]), lambda y: torch.flip(y, dims=[-1])),
(torch.flip(x, dims=[-2]), lambda y: torch.flip(y, dims=[-2])),
(torch.rot90(x, k=1, dims=(-2, -1)), lambda y: torch.rot90(y, k=3, dims=(-2, -1))),
]
def _unwrap_output(output):
if isinstance(output, tuple):
return output[0]
return output
@torch.no_grad()
def run_model(model, x, pad_factor=8, tta=False):
if not tta:
padded, size = pad_to_factor(x, pad_factor)
out = _unwrap_output(model(padded))
return crop_to_size(out, size)
outputs = []
for aug_x, deaug in _tta_inputs(x):
padded, size = pad_to_factor(aug_x, pad_factor)
out = crop_to_size(_unwrap_output(model(padded)), size)
outputs.append(deaug(out))
return torch.stack(outputs, dim=0).mean(dim=0)
def batch_rgb_psnr(pred, target):
pred_np = pred.detach().cpu().numpy()
target_np = target.detach().cpu().numpy()
total = 0.0
for pred_img, target_img in zip(pred_np, target_np):
total += peak_signal_noise_ratio(target_img, pred_img, data_range=1.0)
return total / max(pred_np.shape[0], 1)