| """Image quality metrics for FSampler research.""" |
| import torch |
| import numpy as np |
| from PIL import Image |
|
|
|
|
| def compute_metrics(image1, image2): |
| """ |
| Compute SSIM, RMSE, and MAE between two images. |
| |
| Args: |
| image1: First image (torch.Tensor, numpy array, or PIL.Image) |
| image2: Second image (torch.Tensor, numpy array, or PIL.Image) |
| |
| Returns: |
| dict: {"ssim": float, "rmse": float, "mae": float} |
| """ |
| |
| img1_np = _to_numpy(image1) |
| img2_np = _to_numpy(image2) |
|
|
| |
| if img1_np.shape != img2_np.shape: |
| raise ValueError(f"Images must have same shape. Got {img1_np.shape} and {img2_np.shape}") |
|
|
| |
| ssim = compute_ssim(img1_np, img2_np) |
| rmse = compute_rmse(img1_np, img2_np) |
| mae = compute_mae(img1_np, img2_np) |
|
|
| return { |
| "ssim": float(ssim), |
| "rmse": float(rmse), |
| "mae": float(mae) |
| } |
|
|
|
|
| def _to_numpy(image): |
| """Convert image to numpy array [H, W, C] in range [0, 1].""" |
| if isinstance(image, torch.Tensor): |
| |
| img = image.detach().cpu() |
|
|
| |
| if img.dim() == 4: |
| img = img[0] |
|
|
| |
| if img.dim() == 3 and img.shape[0] in [1, 3, 4]: |
| img = img.permute(1, 2, 0) |
|
|
| img = img.numpy() |
| elif isinstance(image, Image.Image): |
| |
| img = np.array(image).astype(np.float32) / 255.0 |
| elif isinstance(image, np.ndarray): |
| img = image.astype(np.float32) |
| |
| if img.max() > 1.0: |
| img = img / 255.0 |
| else: |
| raise TypeError(f"Unsupported image type: {type(image)}") |
|
|
| |
| if img.ndim == 2: |
| img = np.expand_dims(img, axis=-1) |
|
|
| return img |
|
|
|
|
| def compute_mae(img1, img2): |
| """Mean Absolute Error (average absolute pixel difference).""" |
| return np.mean(np.abs(img1 - img2)) |
|
|
|
|
| def compute_rmse(img1, img2): |
| """Root Mean Squared Error (root-mean-square pixel difference).""" |
| mse = np.mean((img1 - img2) ** 2) |
| return np.sqrt(mse) |
|
|
|
|
| def compute_ssim(img1, img2, window_size=11, k1=0.01, k2=0.03): |
| """ |
| Structural Similarity Index (SSIM). |
| |
| Reference implementation based on the original SSIM paper: |
| Wang, Z., Bovik, A. C., Sheikh, H. R., & Simoncelli, E. P. (2004). |
| Image quality assessment: from error visibility to structural similarity. |
| IEEE transactions on image processing, 13(4), 600-612. |
| """ |
| |
| if img1.ndim == 3 and img1.shape[2] > 1: |
| |
| img1 = 0.299 * img1[:, :, 0] + 0.587 * img1[:, :, 1] + 0.114 * img1[:, :, 2] |
| img2 = 0.299 * img2[:, :, 0] + 0.587 * img2[:, :, 1] + 0.114 * img2[:, :, 2] |
| else: |
| img1 = img1.squeeze() |
| img2 = img2.squeeze() |
|
|
| |
| C1 = (k1 * 1.0) ** 2 |
| C2 = (k2 * 1.0) ** 2 |
|
|
| |
| window = _gaussian_window(window_size, 1.5) |
|
|
| |
| mu1 = _convolve(img1, window) |
| mu2 = _convolve(img2, window) |
|
|
| |
| mu1_sq = mu1 ** 2 |
| mu2_sq = mu2 ** 2 |
| mu1_mu2 = mu1 * mu2 |
|
|
| sigma1_sq = _convolve(img1 ** 2, window) - mu1_sq |
| sigma2_sq = _convolve(img2 ** 2, window) - mu2_sq |
| sigma12 = _convolve(img1 * img2, window) - mu1_mu2 |
|
|
| |
| numerator = (2 * mu1_mu2 + C1) * (2 * sigma12 + C2) |
| denominator = (mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2) |
|
|
| ssim_map = numerator / denominator |
|
|
| |
| return np.mean(ssim_map) |
|
|
|
|
| def _gaussian_window(size, sigma): |
| """Create a 2D Gaussian window.""" |
| coords = np.arange(size) - size // 2 |
| g = np.exp(-(coords ** 2) / (2 * sigma ** 2)) |
| g = g / g.sum() |
| return np.outer(g, g) |
|
|
|
|
| def _convolve(image, window): |
| """Apply 2D convolution with the given window.""" |
| from scipy.ndimage import convolve |
| return convolve(image, window, mode='constant', cval=0.0) |
|
|