Eji-Sensei14's picture
Upload folder using huggingface_hub
c6535db verified
"""Image quality metrics for FSampler research."""
import torch
import numpy as np
from PIL import Image
def compute_metrics(image1, image2):
"""
Compute SSIM, RMSE, and MAE between two images.
Args:
image1: First image (torch.Tensor, numpy array, or PIL.Image)
image2: Second image (torch.Tensor, numpy array, or PIL.Image)
Returns:
dict: {"ssim": float, "rmse": float, "mae": float}
"""
# Convert both images to numpy arrays [H, W, C] in range [0, 1]
img1_np = _to_numpy(image1)
img2_np = _to_numpy(image2)
# Ensure same shape
if img1_np.shape != img2_np.shape:
raise ValueError(f"Images must have same shape. Got {img1_np.shape} and {img2_np.shape}")
# Compute metrics
ssim = compute_ssim(img1_np, img2_np)
rmse = compute_rmse(img1_np, img2_np)
mae = compute_mae(img1_np, img2_np)
return {
"ssim": float(ssim),
"rmse": float(rmse),
"mae": float(mae)
}
def _to_numpy(image):
"""Convert image to numpy array [H, W, C] in range [0, 1]."""
if isinstance(image, torch.Tensor):
# Handle torch tensor [B, H, W, C] or [H, W, C] or [B, C, H, W] or [C, H, W]
img = image.detach().cpu()
# Remove batch dimension if present
if img.dim() == 4:
img = img[0]
# Convert channels-first to channels-last if needed
if img.dim() == 3 and img.shape[0] in [1, 3, 4]:
img = img.permute(1, 2, 0)
img = img.numpy()
elif isinstance(image, Image.Image):
# PIL Image to numpy
img = np.array(image).astype(np.float32) / 255.0
elif isinstance(image, np.ndarray):
img = image.astype(np.float32)
# If in range [0, 255], normalize to [0, 1]
if img.max() > 1.0:
img = img / 255.0
else:
raise TypeError(f"Unsupported image type: {type(image)}")
# Ensure 3D array [H, W, C]
if img.ndim == 2:
img = np.expand_dims(img, axis=-1)
return img
def compute_mae(img1, img2):
"""Mean Absolute Error (average absolute pixel difference)."""
return np.mean(np.abs(img1 - img2))
def compute_rmse(img1, img2):
"""Root Mean Squared Error (root-mean-square pixel difference)."""
mse = np.mean((img1 - img2) ** 2)
return np.sqrt(mse)
def compute_ssim(img1, img2, window_size=11, k1=0.01, k2=0.03):
"""
Structural Similarity Index (SSIM).
Reference implementation based on the original SSIM paper:
Wang, Z., Bovik, A. C., Sheikh, H. R., & Simoncelli, E. P. (2004).
Image quality assessment: from error visibility to structural similarity.
IEEE transactions on image processing, 13(4), 600-612.
"""
# Convert to grayscale if color image
if img1.ndim == 3 and img1.shape[2] > 1:
# RGB to grayscale: 0.299*R + 0.587*G + 0.114*B
img1 = 0.299 * img1[:, :, 0] + 0.587 * img1[:, :, 1] + 0.114 * img1[:, :, 2]
img2 = 0.299 * img2[:, :, 0] + 0.587 * img2[:, :, 1] + 0.114 * img2[:, :, 2]
else:
img1 = img1.squeeze()
img2 = img2.squeeze()
# Constants
C1 = (k1 * 1.0) ** 2 # (k1*L)^2, L=1 for normalized images
C2 = (k2 * 1.0) ** 2 # (k2*L)^2
# Create Gaussian window
window = _gaussian_window(window_size, 1.5)
# Compute local means
mu1 = _convolve(img1, window)
mu2 = _convolve(img2, window)
# Compute local variances and covariance
mu1_sq = mu1 ** 2
mu2_sq = mu2 ** 2
mu1_mu2 = mu1 * mu2
sigma1_sq = _convolve(img1 ** 2, window) - mu1_sq
sigma2_sq = _convolve(img2 ** 2, window) - mu2_sq
sigma12 = _convolve(img1 * img2, window) - mu1_mu2
# SSIM formula
numerator = (2 * mu1_mu2 + C1) * (2 * sigma12 + C2)
denominator = (mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2)
ssim_map = numerator / denominator
# Return mean SSIM
return np.mean(ssim_map)
def _gaussian_window(size, sigma):
"""Create a 2D Gaussian window."""
coords = np.arange(size) - size // 2
g = np.exp(-(coords ** 2) / (2 * sigma ** 2))
g = g / g.sum()
return np.outer(g, g)
def _convolve(image, window):
"""Apply 2D convolution with the given window."""
from scipy.ndimage import convolve
return convolve(image, window, mode='constant', cval=0.0)