Saumith devarsetty
Updated Lab5 modular code
3fffbdc
"""
metrics.py
Provides image similarity/quality metrics for the mosaic generator:
- Mean Squared Error (MSE)
- Structural Similarity Index (SSIM) averaged over RGB channels
"""
import numpy as np
from skimage.metrics import structural_similarity as ssim
def mse(a, b):
"""
Compute Mean Squared Error between two RGB images.
Parameters
----------
a : np.ndarray
First RGB image array (H, W, 3).
b : np.ndarray
Second RGB image array (H, W, 3).
Returns
-------
float
Scalar MSE value.
Raises
------
ValueError
If the input images are not the same shape or not valid RGB arrays.
"""
if a is None or b is None:
raise ValueError("mse(): both input images must be provided.")
if not isinstance(a, np.ndarray) or not isinstance(b, np.ndarray):
raise ValueError("mse(): inputs must be NumPy arrays.")
if a.shape != b.shape:
raise ValueError(
f"mse(): image size mismatch. Got {a.shape} vs {b.shape}."
)
if a.ndim != 3 or a.shape[2] != 3:
raise ValueError(f"mse(): expected RGB images, got shape {a.shape}.")
diff = a.astype(np.float32) - b.astype(np.float32)
return float(np.mean(diff ** 2))
def ssim_rgb(a, b):
"""
Compute SSIM (Structural Similarity Index) for RGB images.
SSIM is computed per-channel and then averaged to produce a single score.
Parameters
----------
a : np.ndarray
First RGB image array (H, W, 3).
b : np.ndarray
Second RGB image array (H, W, 3).
Returns
-------
float
Mean SSIM across the 3 RGB channels.
Raises
------
ValueError
If input images are mismatched or invalid.
"""
if a is None or b is None:
raise ValueError("ssim_rgb(): both input images must be provided.")
if not isinstance(a, np.ndarray) or not isinstance(b, np.ndarray):
raise ValueError("ssim_rgb(): inputs must be NumPy arrays.")
if a.shape != b.shape:
raise ValueError(
f"ssim_rgb(): image size mismatch. Got {a.shape} vs {b.shape}."
)
if a.ndim != 3 or a.shape[2] != 3:
raise ValueError(f"ssim_rgb(): expected RGB images, got shape {a.shape}.")
# Compute SSIM per channel
vals = [
ssim(
a[..., c],
b[..., c],
data_range=255,
win_size=7, # helps stability for small tiles
gaussian_weights=True
)
for c in range(3)
]
return float(sum(vals) / 3)