Spaces:
Sleeping
Sleeping
| from __future__ import annotations | |
| import numpy as np | |
| from PIL import Image | |
| from typing import Dict, Tuple | |
| from .utils import pil_to_np | |
| from skimage.metrics import structural_similarity as ssim | |
| def calculate_mse(original: Image.Image, reconstructed: Image.Image) -> float: | |
| """ | |
| Calculate Mean Squared Error between original and reconstructed images. | |
| Args: | |
| original: Original PIL Image | |
| reconstructed: Reconstructed PIL Image | |
| Returns: | |
| MSE value | |
| """ | |
| orig_array = pil_to_np(original) | |
| recon_array = pil_to_np(reconstructed) | |
| # Ensure same size | |
| if orig_array.shape != recon_array.shape: | |
| # Resize reconstructed to match original | |
| recon_pil = reconstructed.resize(original.size, Image.LANCZOS) | |
| recon_array = pil_to_np(recon_pil) | |
| # Calculate MSE | |
| mse = np.mean((orig_array - recon_array) ** 2) | |
| return float(mse) | |
| def calculate_psnr(original: Image.Image, reconstructed: Image.Image) -> float: | |
| """ | |
| Calculate Peak Signal-to-Noise Ratio. | |
| Args: | |
| original: Original PIL Image | |
| reconstructed: Reconstructed PIL Image | |
| Returns: | |
| PSNR value in dB | |
| """ | |
| mse = calculate_mse(original, reconstructed) | |
| if mse == 0: | |
| return float('inf') | |
| psnr = 20 * np.log10(1.0 / np.sqrt(mse)) | |
| return float(psnr) | |
| def calculate_ssim(original: Image.Image, reconstructed: Image.Image) -> float: | |
| """ | |
| Calculate Structural Similarity Index. | |
| Args: | |
| original: Original PIL Image | |
| reconstructed: Reconstructed PIL Image | |
| Returns: | |
| SSIM value between 0 and 1 | |
| """ | |
| orig_array = pil_to_np(original) | |
| recon_array = pil_to_np(reconstructed) | |
| # Ensure same size | |
| if orig_array.shape != recon_array.shape: | |
| # Resize reconstructed to match original | |
| recon_pil = reconstructed.resize(original.size, Image.LANCZOS) | |
| recon_array = pil_to_np(recon_pil) | |
| # Convert to grayscale for SSIM calculation | |
| if len(orig_array.shape) == 3: | |
| orig_gray = np.mean(orig_array, axis=2) | |
| recon_gray = np.mean(recon_array, axis=2) | |
| else: | |
| orig_gray = orig_array | |
| recon_gray = recon_array | |
| # Calculate SSIM | |
| ssim_value = ssim(orig_gray, recon_gray, data_range=1.0) | |
| return float(ssim_value) | |
| def calculate_color_similarity(original: Image.Image, reconstructed: Image.Image) -> Dict[str, float]: | |
| """ | |
| Calculate color-based similarity metrics. | |
| Args: | |
| original: Original PIL Image | |
| reconstructed: Reconstructed PIL Image | |
| Returns: | |
| Dictionary with color similarity metrics | |
| """ | |
| orig_array = pil_to_np(original) | |
| recon_array = pil_to_np(reconstructed) | |
| # Ensure same size | |
| if orig_array.shape != recon_array.shape: | |
| recon_pil = reconstructed.resize(original.size, Image.LANCZOS) | |
| recon_array = pil_to_np(recon_pil) | |
| # Calculate per-channel differences | |
| channel_diffs = [] | |
| for channel in range(3): | |
| orig_channel = orig_array[:, :, channel] | |
| recon_channel = recon_array[:, :, channel] | |
| channel_mse = np.mean((orig_channel - recon_channel) ** 2) | |
| channel_diffs.append(channel_mse) | |
| # Calculate overall color difference | |
| color_mse = np.mean(channel_diffs) | |
| # Calculate color histogram similarity | |
| orig_hist = np.histogram(orig_array.flatten(), bins=256, range=(0, 1))[0] | |
| recon_hist = np.histogram(recon_array.flatten(), bins=256, range=(0, 1))[0] | |
| # Normalize histograms | |
| orig_hist = orig_hist / np.sum(orig_hist) | |
| recon_hist = recon_hist / np.sum(recon_hist) | |
| # Calculate histogram correlation | |
| hist_correlation = np.corrcoef(orig_hist, recon_hist)[0, 1] | |
| return { | |
| 'color_mse': float(color_mse), | |
| 'red_channel_mse': float(channel_diffs[0]), | |
| 'green_channel_mse': float(channel_diffs[1]), | |
| 'blue_channel_mse': float(channel_diffs[2]), | |
| 'histogram_correlation': float(hist_correlation) if not np.isnan(hist_correlation) else 0.0 | |
| } | |
| def calculate_comprehensive_metrics(original: Image.Image, reconstructed: Image.Image) -> Dict[str, float]: | |
| """ | |
| Calculate comprehensive similarity metrics. | |
| Args: | |
| original: Original PIL Image | |
| reconstructed: Reconstructed PIL Image | |
| Returns: | |
| Dictionary with all similarity metrics | |
| """ | |
| metrics = {} | |
| # Basic metrics | |
| metrics['mse'] = calculate_mse(original, reconstructed) | |
| metrics['psnr'] = calculate_psnr(original, reconstructed) | |
| metrics['ssim'] = calculate_ssim(original, reconstructed) | |
| # Color metrics | |
| color_metrics = calculate_color_similarity(original, reconstructed) | |
| metrics.update(color_metrics) | |
| # Additional derived metrics | |
| metrics['rmse'] = np.sqrt(metrics['mse']) | |
| metrics['mae'] = calculate_mae(original, reconstructed) | |
| return metrics | |
| def calculate_mae(original: Image.Image, reconstructed: Image.Image) -> float: | |
| """ | |
| Calculate Mean Absolute Error. | |
| Args: | |
| original: Original PIL Image | |
| reconstructed: Reconstructed PIL Image | |
| Returns: | |
| MAE value | |
| """ | |
| orig_array = pil_to_np(original) | |
| recon_array = pil_to_np(reconstructed) | |
| # Ensure same size | |
| if orig_array.shape != recon_array.shape: | |
| recon_pil = reconstructed.resize(original.size, Image.LANCZOS) | |
| recon_array = pil_to_np(recon_pil) | |
| # Calculate MAE | |
| mae = np.mean(np.abs(orig_array - recon_array)) | |
| return float(mae) | |
| def interpret_metrics(metrics: Dict[str, float]) -> Dict[str, str]: | |
| """ | |
| Provide human-readable interpretations of metrics. | |
| Args: | |
| metrics: Dictionary of metric values | |
| Returns: | |
| Dictionary with interpretations | |
| """ | |
| interpretations = {} | |
| # MSE interpretation | |
| mse = metrics.get('mse', 0) | |
| if mse < 0.01: | |
| interpretations['mse'] = "Excellent similarity" | |
| elif mse < 0.05: | |
| interpretations['mse'] = "Good similarity" | |
| elif mse < 0.1: | |
| interpretations['mse'] = "Moderate similarity" | |
| else: | |
| interpretations['mse'] = "Poor similarity" | |
| # PSNR interpretation | |
| psnr = metrics.get('psnr', 0) | |
| if psnr > 40: | |
| interpretations['psnr'] = "Excellent quality" | |
| elif psnr > 30: | |
| interpretations['psnr'] = "Good quality" | |
| elif psnr > 20: | |
| interpretations['psnr'] = "Acceptable quality" | |
| else: | |
| interpretations['psnr'] = "Poor quality" | |
| # SSIM interpretation | |
| ssim_val = metrics.get('ssim', 0) | |
| if ssim_val > 0.9: | |
| interpretations['ssim'] = "Very similar structure" | |
| elif ssim_val > 0.7: | |
| interpretations['ssim'] = "Similar structure" | |
| elif ssim_val > 0.5: | |
| interpretations['ssim'] = "Moderately similar structure" | |
| else: | |
| interpretations['ssim'] = "Different structure" | |
| return interpretations | |