| import torch |
| import numpy as np |
| import skimage.metrics |
| import lpips |
| from PIL import Image |
| from .sifid import SIFID |
|
|
|
|
| def resize_array(x, size=256): |
| """ |
| Resize image array to given size. |
| Args: |
| x (np.ndarray): Image array of shape (N, H, W, C) in range [0, 255]. |
| size (int): Size of output image. |
| Returns: |
| (np.ndarray): Image array of shape (N, H, W, C) in range [0, 255]. |
| """ |
| if x.shape[1] != size: |
| x = [Image.fromarray(x[i]).resize((size, size), resample=Image.BILINEAR) for i in range(x.shape[0])] |
| x = np.array([np.array(i) for i in x]) |
| return x |
|
|
|
|
| def resize_tensor(x, size=256): |
| """ |
| Resize image tensor to given size. |
| Args: |
| x (torch.Tensor): Image tensor of shape (N, C, H, W) in range [-1, 1]. |
| size (int): Size of output image. |
| Returns: |
| (torch.Tensor): Image tensor of shape (N, C, H, W) in range [-1, 1]. |
| """ |
| if x.shape[2] != size: |
| x = torch.nn.functional.interpolate(x, size=(size, size), mode='bilinear', align_corners=False) |
| return x |
|
|
|
|
| def normalise(x): |
| """ |
| Normalise image array to range [-1, 1] and tensor. |
| Args: |
| x (np.ndarray): Image array of shape (N, H, W, C) in range [0, 255]. |
| Returns: |
| (torch.Tensor): Image tensor of shape (N, C, H, W) in range [-1, 1]. |
| """ |
| x = x.astype(np.float32) |
| x = x / 255 |
| x = (x - 0.5) / 0.5 |
| x = torch.from_numpy(x) |
| x = x.permute(0, 3, 1, 2) |
| return x |
|
|
|
|
| def unormalise(x, vrange=[-1, 1]): |
| """ |
| Unormalise image tensor to range [0, 255] and RGB array. |
| Args: |
| x (torch.Tensor): Image tensor of shape (N, C, H, W) in range [-1, 1]. |
| Returns: |
| (np.ndarray): Image array of shape (N, H, W, C) in range [0, 255]. |
| """ |
| x = (x - vrange[0])/(vrange[1] - vrange[0]) |
| x = x * 255 |
| x = x.permute(0, 2, 3, 1) |
| x = x.cpu().numpy().astype(np.uint8) |
| return x |
|
|
|
|
| def compute_mse(x, y): |
| """ |
| Compute mean squared error between two image arrays. |
| Args: |
| x (np.ndarray): Image of shape (N, H, W, C) in range [0, 255]. |
| y (np.ndarray): Image of shape (N, H, W, C) in range [0, 255]. |
| Returns: |
| (1darray): Mean squared error. |
| """ |
| return np.square(x - y).reshape(x.shape[0], -1).mean(axis=1) |
|
|
|
|
| def compute_psnr(x, y): |
| """ |
| Compute peak signal-to-noise ratio between two images. |
| Args: |
| x (np.ndarray): Image of shape (N, H, W, C) in range [0, 255]. |
| y (np.ndarray): Image of shape (N, H, W, C) in range [0, 255]. |
| Returns: |
| (float): Peak signal-to-noise ratio. |
| """ |
| return 10 * np.log10(255 ** 2 / compute_mse(x, y)) |
|
|
|
|
| def compute_ssim(x, y): |
| """ |
| Compute structural similarity index between two images. |
| Args: |
| x (np.ndarray): Image of shape (N, H, W, C) in range [0, 255]. |
| y (np.ndarray): Image of shape (N, H, W, C) in range [0, 255]. |
| Returns: |
| (float): Structural similarity index. |
| """ |
| return np.array([skimage.metrics.structural_similarity(xi, yi, channel_axis=2, gaussian_weights=True, sigma=1.5, use_sample_covariance=False, data_range=255) for xi, yi in zip(x, y)]) |
|
|
|
|
| def compute_lpips(x, y, net='alex'): |
| """ |
| Compute LPIPS between two images. |
| Args: |
| x (torch.Tensor): Image tensor of shape (N, C, H, W) in range [-1, 1]. |
| y (torch.Tensor): Image tensor of shape (N, C, H, W) in range [-1, 1]. |
| Returns: |
| (float): LPIPS. |
| """ |
| lpips_fn = lpips.LPIPS(net=net, verbose=False).cuda() if isinstance(net, str) else net |
| x, y = x.cuda(), y.cuda() |
| return lpips_fn(x, y).detach().cpu().numpy().squeeze() |
|
|
|
|
| def compute_sifid(x, y, net=None): |
| """ |
| Compute SIFID between two images. |
| Args: |
| x (torch.Tensor): Image tensor of shape (N, C, H, W) in range [-1, 1]. |
| y (torch.Tensor): Image tensor of shape (N, C, H, W) in range [-1, 1]. |
| Returns: |
| (float): SIFID. |
| """ |
| fn = SIFID() if net is None else net |
| out = [fn(xi, yi) for xi, yi in zip(x, y)] |
| return np.array(out) |