|
|
"""FLIP metric functions""" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import sys |
|
|
import numpy as np |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
|
|
|
|
|
|
class HDRFLIPLoss(nn.Module): |
|
|
"""Class for computing HDR-FLIP""" |
|
|
|
|
|
def __init__(self): |
|
|
"""Init""" |
|
|
super().__init__() |
|
|
self.qc = 0.7 |
|
|
self.qf = 0.5 |
|
|
self.pc = 0.4 |
|
|
self.pt = 0.95 |
|
|
self.tmax = 0.85 |
|
|
self.tmin = 0.85 |
|
|
self.eps = 1e-15 |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
test, |
|
|
reference, |
|
|
pixels_per_degree=(0.7 * 3840 / 0.7) * np.pi / 180, |
|
|
tone_mapper="aces", |
|
|
start_exposure=None, |
|
|
stop_exposure=None, |
|
|
): |
|
|
""" |
|
|
Computes the HDR-FLIP error map between two HDR images, |
|
|
assuming the images are observed at a certain number of |
|
|
pixels per degree of visual angle |
|
|
|
|
|
:param test: test tensor (with NxCxHxW layout with nonnegative values) |
|
|
:param reference: reference tensor (with NxCxHxW layout with nonnegative values) |
|
|
:param pixels_per_degree: float describing the number of pixels per degree of visual angle of the observer, |
|
|
default corresponds to viewing the images on a 0.7 meters wide 4K monitor at 0.7 meters from the display |
|
|
:param tone_mapper: (optional) string describing what tone mapper HDR-FLIP should assume |
|
|
:param start_exposure: (optional tensor (with Nx1x1x1 layout) with start exposures corresponding to each HDR reference/test pair |
|
|
:param stop_exposure: (optional) tensor (with Nx1x1x1 layout) with stop exposures corresponding to each HDR reference/test pair |
|
|
:return: float containing the mean FLIP error (in the range [0,1]) between the HDR reference and test images in the batch |
|
|
""" |
|
|
|
|
|
reference = torch.clamp(reference, 0, 65536.0) |
|
|
test = torch.clamp(test, 0, 65536.0) |
|
|
|
|
|
|
|
|
if start_exposure is None or stop_exposure is None: |
|
|
c_start, c_stop = compute_start_stop_exposures( |
|
|
reference, tone_mapper, self.tmax, self.tmin |
|
|
) |
|
|
if start_exposure is None: |
|
|
start_exposure = c_start |
|
|
if stop_exposure is None: |
|
|
stop_exposure = c_stop |
|
|
|
|
|
|
|
|
num_exposures = torch.max( |
|
|
torch.tensor([2.0], requires_grad=False).cuda(), |
|
|
torch.ceil(stop_exposure - start_exposure), |
|
|
) |
|
|
most_exposures = int(torch.amax(num_exposures, dim=0).item()) |
|
|
|
|
|
|
|
|
step_size = (stop_exposure - start_exposure) / torch.max( |
|
|
num_exposures - 1, torch.tensor([1.0], requires_grad=False).cuda() |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
dim = reference.size() |
|
|
all_errors = torch.zeros(size=(dim[0], most_exposures, dim[2], dim[3])).cuda() |
|
|
|
|
|
|
|
|
for i in range(0, most_exposures): |
|
|
exposure = start_exposure + i * step_size |
|
|
|
|
|
reference_tone_mapped = tone_map(reference, tone_mapper, exposure) |
|
|
test_tone_mapped = tone_map(test, tone_mapper, exposure) |
|
|
|
|
|
reference_opponent = color_space_transform( |
|
|
reference_tone_mapped, "linrgb2ycxcz" |
|
|
) |
|
|
test_opponent = color_space_transform(test_tone_mapped, "linrgb2ycxcz") |
|
|
|
|
|
all_errors[:, i, :, :] = compute_ldrflip( |
|
|
test_opponent, |
|
|
reference_opponent, |
|
|
pixels_per_degree, |
|
|
self.qc, |
|
|
self.qf, |
|
|
self.pc, |
|
|
self.pt, |
|
|
self.eps, |
|
|
).squeeze(1) |
|
|
|
|
|
|
|
|
hdrflip_error = torch.amax(all_errors, dim=1, keepdim=True) |
|
|
return torch.mean(hdrflip_error) |
|
|
|
|
|
|
|
|
class LDRFLIPLoss(nn.Module): |
|
|
"""Class for computing LDR FLIP loss""" |
|
|
|
|
|
def __init__(self): |
|
|
"""Init""" |
|
|
super().__init__() |
|
|
self.qc = 0.7 |
|
|
self.qf = 0.5 |
|
|
self.pc = 0.4 |
|
|
self.pt = 0.95 |
|
|
self.eps = 1e-15 |
|
|
|
|
|
def forward( |
|
|
self, test, reference, pixels_per_degree=(0.7 * 3840 / 0.7) * np.pi / 180 |
|
|
): |
|
|
""" |
|
|
Computes the LDR-FLIP error map between two LDR images, |
|
|
assuming the images are observed at a certain number of |
|
|
pixels per degree of visual angle |
|
|
|
|
|
:param test: test tensor (with NxCxHxW layout with values in the range [0, 1] in the sRGB color space) |
|
|
:param reference: reference tensor (with NxCxHxW layout with values in the range [0, 1] in the sRGB color space) |
|
|
:param pixels_per_degree: float describing the number of pixels per degree of visual angle of the observer, |
|
|
default corresponds to viewing the images on a 0.7 meters wide 4K monitor at 0.7 meters from the display |
|
|
:return: float containing the mean FLIP error (in the range [0,1]) between the LDR reference and test images in the batch |
|
|
""" |
|
|
|
|
|
reference = torch.clamp(reference, 0, 1) |
|
|
test = torch.clamp(test, 0, 1) |
|
|
|
|
|
|
|
|
reference_opponent = color_space_transform(reference, "srgb2ycxcz") |
|
|
test_opponent = color_space_transform(test, "srgb2ycxcz") |
|
|
|
|
|
deltaE = compute_ldrflip( |
|
|
test_opponent, |
|
|
reference_opponent, |
|
|
pixels_per_degree, |
|
|
self.qc, |
|
|
self.qf, |
|
|
self.pc, |
|
|
self.pt, |
|
|
self.eps, |
|
|
) |
|
|
|
|
|
return torch.mean(deltaE) |
|
|
|
|
|
|
|
|
def compute_ldrflip(test, reference, pixels_per_degree, qc, qf, pc, pt, eps): |
|
|
""" |
|
|
Computes the LDR-FLIP error map between two LDR images, |
|
|
assuming the images are observed at a certain number of |
|
|
pixels per degree of visual angle |
|
|
|
|
|
:param reference: reference tensor (with NxCxHxW layout with values in the YCxCz color space) |
|
|
:param test: test tensor (with NxCxHxW layout with values in the YCxCz color space) |
|
|
:param pixels_per_degree: float describing the number of pixels per degree of visual angle of the observer, |
|
|
default corresponds to viewing the images on a 0.7 meters wide 4K monitor at 0.7 meters from the display |
|
|
:param qc: float describing the q_c exponent in the LDR-FLIP color pipeline (see FLIP paper for details) |
|
|
:param qf: float describing the q_f exponent in the LDR-FLIP feature pipeline (see FLIP paper for details) |
|
|
:param pc: float describing the p_c exponent in the LDR-FLIP color pipeline (see FLIP paper for details) |
|
|
:param pt: float describing the p_t exponent in the LDR-FLIP color pipeline (see FLIP paper for details) |
|
|
:param eps: float containing a small value used to improve training stability |
|
|
:return: tensor containing the per-pixel FLIP errors (with Nx1xHxW layout and values in the range [0, 1]) between LDR reference and test images |
|
|
""" |
|
|
|
|
|
|
|
|
s_a, radius_a = generate_spatial_filter(pixels_per_degree, "A") |
|
|
s_rg, radius_rg = generate_spatial_filter(pixels_per_degree, "RG") |
|
|
s_by, radius_by = generate_spatial_filter(pixels_per_degree, "BY") |
|
|
radius = max(radius_a, radius_rg, radius_by) |
|
|
filtered_reference = spatial_filter(reference, s_a, s_rg, s_by, radius) |
|
|
filtered_test = spatial_filter(test, s_a, s_rg, s_by, radius) |
|
|
|
|
|
|
|
|
preprocessed_reference = hunt_adjustment( |
|
|
color_space_transform(filtered_reference, "linrgb2lab") |
|
|
) |
|
|
preprocessed_test = hunt_adjustment( |
|
|
color_space_transform(filtered_test, "linrgb2lab") |
|
|
) |
|
|
|
|
|
|
|
|
deltaE_hyab = hyab(preprocessed_reference, preprocessed_test, eps) |
|
|
power_deltaE_hyab = torch.pow(deltaE_hyab, qc) |
|
|
hunt_adjusted_green = hunt_adjustment( |
|
|
color_space_transform( |
|
|
torch.tensor([[[0.0]], [[1.0]], [[0.0]]]).unsqueeze(0), "linrgb2lab" |
|
|
) |
|
|
) |
|
|
hunt_adjusted_blue = hunt_adjustment( |
|
|
color_space_transform( |
|
|
torch.tensor([[[0.0]], [[0.0]], [[1.0]]]).unsqueeze(0), "linrgb2lab" |
|
|
) |
|
|
) |
|
|
cmax = torch.pow(hyab(hunt_adjusted_green, hunt_adjusted_blue, eps), qc).item() |
|
|
deltaE_c = redistribute_errors(power_deltaE_hyab, cmax, pc, pt) |
|
|
|
|
|
|
|
|
|
|
|
ref_y = (reference[:, 0:1, :, :] + 16) / 116 |
|
|
test_y = (test[:, 0:1, :, :] + 16) / 116 |
|
|
|
|
|
|
|
|
edges_reference = feature_detection(ref_y, pixels_per_degree, "edge") |
|
|
points_reference = feature_detection(ref_y, pixels_per_degree, "point") |
|
|
edges_test = feature_detection(test_y, pixels_per_degree, "edge") |
|
|
points_test = feature_detection(test_y, pixels_per_degree, "point") |
|
|
|
|
|
|
|
|
deltaE_f = torch.max( |
|
|
torch.abs( |
|
|
torch.norm(edges_reference, dim=1, keepdim=True) |
|
|
- torch.norm(edges_test, dim=1, keepdim=True) |
|
|
), |
|
|
torch.abs( |
|
|
torch.norm(points_test, dim=1, keepdim=True) |
|
|
- torch.norm(points_reference, dim=1, keepdim=True) |
|
|
), |
|
|
) |
|
|
deltaE_f = torch.clamp(deltaE_f, min=eps) |
|
|
deltaE_f = torch.pow(((1 / np.sqrt(2)) * deltaE_f), qf) |
|
|
|
|
|
|
|
|
return torch.pow(deltaE_c, 1 - deltaE_f) |
|
|
|
|
|
|
|
|
def tone_map(img, tone_mapper, exposure): |
|
|
""" |
|
|
Applies exposure compensation and tone mapping. |
|
|
Refer to the Visualizing Errors in Rendered High Dynamic Range Images |
|
|
paper for details about the formulas. |
|
|
|
|
|
:param img: float tensor (with NxCxHxW layout) containing nonnegative values |
|
|
:param tone_mapper: string describing the tone mapper to apply |
|
|
:param exposure: float tensor (with Nx1x1x1 layout) describing the exposure compensation factor |
|
|
""" |
|
|
|
|
|
x = (2**exposure) * img |
|
|
|
|
|
|
|
|
if tone_mapper == "reinhard": |
|
|
lum_coeff_r = 0.2126 |
|
|
lum_coeff_g = 0.7152 |
|
|
lum_coeff_b = 0.0722 |
|
|
|
|
|
Y = ( |
|
|
x[:, 0:1, :, :] * lum_coeff_r |
|
|
+ x[:, 1:2, :, :] * lum_coeff_g |
|
|
+ x[:, 2:3, :, :] * lum_coeff_b |
|
|
) |
|
|
return torch.clamp(torch.div(x, 1 + Y), 0.0, 1.0) |
|
|
|
|
|
if tone_mapper == "hable": |
|
|
|
|
|
A = 0.15 |
|
|
B = 0.50 |
|
|
C = 0.10 |
|
|
D = 0.20 |
|
|
E = 0.02 |
|
|
F = 0.30 |
|
|
k0 = A * F - A * E |
|
|
k1 = C * B * F - B * E |
|
|
k2 = 0 |
|
|
k3 = A * F |
|
|
k4 = B * F |
|
|
k5 = D * F * F |
|
|
|
|
|
W = 11.2 |
|
|
nom = k0 * torch.pow(W, torch.tensor([2.0]).cuda()) + k1 * W + k2 |
|
|
denom = k3 * torch.pow(W, torch.tensor([2.0]).cuda()) + k4 * W + k5 |
|
|
white_scale = torch.div(denom, nom) |
|
|
|
|
|
|
|
|
k0 = 4 * k0 * white_scale |
|
|
k1 = 2 * k1 * white_scale |
|
|
k2 = k2 * white_scale |
|
|
k3 = 4 * k3 |
|
|
k4 = 2 * k4 |
|
|
|
|
|
else: |
|
|
|
|
|
|
|
|
k0 = 0.6 * 0.6 * 2.51 |
|
|
k1 = 0.6 * 0.03 |
|
|
k2 = 0 |
|
|
k3 = 0.6 * 0.6 * 2.43 |
|
|
k4 = 0.6 * 0.59 |
|
|
k5 = 0.14 |
|
|
|
|
|
x2 = torch.pow(x, 2) |
|
|
nom = k0 * x2 + k1 * x + k2 |
|
|
denom = k3 * x2 + k4 * x + k5 |
|
|
denom = torch.where( |
|
|
torch.isinf(denom), torch.Tensor([1.0]).cuda(), denom |
|
|
) |
|
|
y = torch.div(nom, denom) |
|
|
return torch.clamp(y, 0.0, 1.0) |
|
|
|
|
|
|
|
|
def compute_start_stop_exposures(reference, tone_mapper, tmax, tmin): |
|
|
""" |
|
|
Computes start and stop exposure for HDR-FLIP based on given tone mapper and reference image. |
|
|
Refer to the Visualizing Errors in Rendered High Dynamic Range Images |
|
|
paper for details about the formulas |
|
|
|
|
|
:param reference: float tensor (with NxCxHxW layout) containing reference images (nonnegative values) |
|
|
:param tone_mapper: string describing which tone mapper should be assumed |
|
|
:param tmax: float describing the t value used to find the start exposure |
|
|
:param tmin: float describing the t value used to find the stop exposure |
|
|
:return: two float tensors (with Nx1x1x1 layout) containing start and stop exposures, respectively, to use for HDR-FLIP |
|
|
""" |
|
|
if tone_mapper == "reinhard": |
|
|
k0 = 0 |
|
|
k1 = 1 |
|
|
k2 = 0 |
|
|
k3 = 0 |
|
|
k4 = 1 |
|
|
k5 = 1 |
|
|
|
|
|
x_max = tmax * k5 / (k1 - tmax * k4) |
|
|
x_min = tmin * k5 / (k1 - tmin * k4) |
|
|
elif tone_mapper == "hable": |
|
|
|
|
|
A = 0.15 |
|
|
B = 0.50 |
|
|
C = 0.10 |
|
|
D = 0.20 |
|
|
E = 0.02 |
|
|
F = 0.30 |
|
|
k0 = A * F - A * E |
|
|
k1 = C * B * F - B * E |
|
|
k2 = 0 |
|
|
k3 = A * F |
|
|
k4 = B * F |
|
|
k5 = D * F * F |
|
|
|
|
|
W = 11.2 |
|
|
nom = k0 * torch.pow(W, torch.tensor([2.0]).cuda()) + k1 * W + k2 |
|
|
denom = k3 * torch.pow(W, torch.tensor([2.0]).cuda()) + k4 * W + k5 |
|
|
white_scale = torch.div(denom, nom) |
|
|
|
|
|
|
|
|
k0 = 4 * k0 * white_scale |
|
|
k1 = 2 * k1 * white_scale |
|
|
k2 = k2 * white_scale |
|
|
k3 = 4 * k3 |
|
|
k4 = 2 * k4 |
|
|
|
|
|
|
|
|
c0 = (k1 - k4 * tmax) / (k0 - k3 * tmax) |
|
|
c1 = (k2 - k5 * tmax) / (k0 - k3 * tmax) |
|
|
x_max = -0.5 * c0 + torch.sqrt(((torch.tensor([0.5]).cuda() * c0) ** 2) - c1) |
|
|
|
|
|
c0 = (k1 - k4 * tmin) / (k0 - k3 * tmin) |
|
|
c1 = (k2 - k5 * tmin) / (k0 - k3 * tmin) |
|
|
x_min = -0.5 * c0 + torch.sqrt(((torch.tensor([0.5]).cuda() * c0) ** 2) - c1) |
|
|
else: |
|
|
|
|
|
|
|
|
k0 = 0.6 * 0.6 * 2.51 |
|
|
k1 = 0.6 * 0.03 |
|
|
k2 = 0 |
|
|
k3 = 0.6 * 0.6 * 2.43 |
|
|
k4 = 0.6 * 0.59 |
|
|
k5 = 0.14 |
|
|
|
|
|
c0 = (k1 - k4 * tmax) / (k0 - k3 * tmax) |
|
|
c1 = (k2 - k5 * tmax) / (k0 - k3 * tmax) |
|
|
x_max = -0.5 * c0 + torch.sqrt(((torch.tensor([0.5]).cuda() * c0) ** 2) - c1) |
|
|
|
|
|
c0 = (k1 - k4 * tmin) / (k0 - k3 * tmin) |
|
|
c1 = (k2 - k5 * tmin) / (k0 - k3 * tmin) |
|
|
x_min = -0.5 * c0 + torch.sqrt(((torch.tensor([0.5]).cuda() * c0) ** 2) - c1) |
|
|
|
|
|
|
|
|
lum_coeff_r = 0.2126 |
|
|
lum_coeff_g = 0.7152 |
|
|
lum_coeff_b = 0.0722 |
|
|
Y_reference = ( |
|
|
reference[:, 0:1, :, :] * lum_coeff_r |
|
|
+ reference[:, 1:2, :, :] * lum_coeff_g |
|
|
+ reference[:, 2:3, :, :] * lum_coeff_b |
|
|
) |
|
|
|
|
|
|
|
|
Y_hi = torch.amax(Y_reference, dim=(2, 3), keepdim=True) |
|
|
start_exposure = torch.log2(x_max / Y_hi) |
|
|
|
|
|
|
|
|
dim = Y_reference.size() |
|
|
Y_ref = Y_reference.view(dim[0], dim[1], dim[2] * dim[3]) |
|
|
Y_lo = torch.median(Y_ref, dim=2).values.unsqueeze(2).unsqueeze(3) |
|
|
stop_exposure = torch.log2(x_min / Y_lo) |
|
|
|
|
|
return start_exposure, stop_exposure |
|
|
|
|
|
|
|
|
def generate_spatial_filter(pixels_per_degree, channel): |
|
|
""" |
|
|
Generates spatial contrast sensitivity filters with width depending on |
|
|
the number of pixels per degree of visual angle of the observer |
|
|
|
|
|
:param pixels_per_degree: float indicating number of pixels per degree of visual angle |
|
|
:param channel: string describing what filter should be generated |
|
|
:yield: Filter kernel corresponding to the spatial contrast sensitivity function of the given channel and kernel's radius |
|
|
""" |
|
|
a1_A = 1 |
|
|
b1_A = 0.0047 |
|
|
a2_A = 0 |
|
|
b2_A = 1e-5 |
|
|
a1_rg = 1 |
|
|
b1_rg = 0.0053 |
|
|
a2_rg = 0 |
|
|
b2_rg = 1e-5 |
|
|
a1_by = 34.1 |
|
|
b1_by = 0.04 |
|
|
a2_by = 13.5 |
|
|
b2_by = 0.025 |
|
|
if channel == "A": |
|
|
a1 = a1_A |
|
|
b1 = b1_A |
|
|
a2 = a2_A |
|
|
b2 = b2_A |
|
|
elif channel == "RG": |
|
|
a1 = a1_rg |
|
|
b1 = b1_rg |
|
|
a2 = a2_rg |
|
|
b2 = b2_rg |
|
|
elif channel == "BY": |
|
|
a1 = a1_by |
|
|
b1 = b1_by |
|
|
a2 = a2_by |
|
|
b2 = b2_by |
|
|
|
|
|
|
|
|
max_scale_parameter = max([b1_A, b2_A, b1_rg, b2_rg, b1_by, b2_by]) |
|
|
r = np.ceil(3 * np.sqrt(max_scale_parameter / (2 * np.pi**2)) * pixels_per_degree) |
|
|
r = int(r) |
|
|
deltaX = 1.0 / pixels_per_degree |
|
|
x, y = np.meshgrid(range(-r, r + 1), range(-r, r + 1)) |
|
|
z = (x * deltaX) ** 2 + (y * deltaX) ** 2 |
|
|
|
|
|
|
|
|
g = a1 * np.sqrt(np.pi / b1) * np.exp(-(np.pi**2) * z / b1) + a2 * np.sqrt( |
|
|
np.pi / b2 |
|
|
) * np.exp(-(np.pi**2) * z / b2) |
|
|
g = g / np.sum(g) |
|
|
g = torch.Tensor(g).unsqueeze(0).unsqueeze(0).cuda() |
|
|
|
|
|
return g, r |
|
|
|
|
|
|
|
|
def spatial_filter(img, s_a, s_rg, s_by, radius): |
|
|
""" |
|
|
Filters an image with channel specific spatial contrast sensitivity functions |
|
|
and clips result to the unit cube in linear RGB |
|
|
|
|
|
:param img: image tensor to filter (with NxCxHxW layout in the YCxCz color space) |
|
|
:param s_a: spatial filter matrix for the achromatic channel |
|
|
:param s_rg: spatial filter matrix for the red-green channel |
|
|
:param s_by: spatial filter matrix for the blue-yellow channel |
|
|
:return: input image (with NxCxHxW layout) transformed to linear RGB after filtering with spatial contrast sensitivity functions |
|
|
""" |
|
|
dim = img.size() |
|
|
|
|
|
img_pad = torch.zeros( |
|
|
(dim[0], dim[1], dim[2] + 2 * radius, dim[3] + 2 * radius), device="cuda" |
|
|
) |
|
|
img_pad[:, 0:1, :, :] = nn.functional.pad( |
|
|
img[:, 0:1, :, :], (radius, radius, radius, radius), mode="replicate" |
|
|
) |
|
|
img_pad[:, 1:2, :, :] = nn.functional.pad( |
|
|
img[:, 1:2, :, :], (radius, radius, radius, radius), mode="replicate" |
|
|
) |
|
|
img_pad[:, 2:3, :, :] = nn.functional.pad( |
|
|
img[:, 2:3, :, :], (radius, radius, radius, radius), mode="replicate" |
|
|
) |
|
|
|
|
|
|
|
|
img_tilde_opponent = torch.zeros((dim[0], dim[1], dim[2], dim[3]), device="cuda") |
|
|
img_tilde_opponent[:, 0:1, :, :] = nn.functional.conv2d( |
|
|
img_pad[:, 0:1, :, :], s_a.cuda(), padding=0 |
|
|
) |
|
|
img_tilde_opponent[:, 1:2, :, :] = nn.functional.conv2d( |
|
|
img_pad[:, 1:2, :, :], s_rg.cuda(), padding=0 |
|
|
) |
|
|
img_tilde_opponent[:, 2:3, :, :] = nn.functional.conv2d( |
|
|
img_pad[:, 2:3, :, :], s_by.cuda(), padding=0 |
|
|
) |
|
|
|
|
|
|
|
|
img_tilde_linear_rgb = color_space_transform(img_tilde_opponent, "ycxcz2linrgb") |
|
|
|
|
|
|
|
|
return torch.clamp(img_tilde_linear_rgb, 0.0, 1.0) |
|
|
|
|
|
|
|
|
def hunt_adjustment(img): |
|
|
""" |
|
|
Applies Hunt-adjustment to an image |
|
|
|
|
|
:param img: image tensor to adjust (with NxCxHxW layout in the L*a*b* color space) |
|
|
:return: Hunt-adjusted image tensor (with NxCxHxW layout in the Hunt-adjusted L*A*B* color space) |
|
|
""" |
|
|
|
|
|
L = img[:, 0:1, :, :] |
|
|
|
|
|
|
|
|
img_h = torch.zeros(img.size(), device="cuda") |
|
|
img_h[:, 0:1, :, :] = L |
|
|
img_h[:, 1:2, :, :] = torch.mul((0.01 * L), img[:, 1:2, :, :]) |
|
|
img_h[:, 2:3, :, :] = torch.mul((0.01 * L), img[:, 2:3, :, :]) |
|
|
|
|
|
return img_h |
|
|
|
|
|
|
|
|
def hyab(reference, test, eps): |
|
|
""" |
|
|
Computes the HyAB distance between reference and test images |
|
|
|
|
|
:param reference: reference image tensor (with NxCxHxW layout in the standard or Hunt-adjusted L*A*B* color space) |
|
|
:param test: test image tensor (with NxCxHxW layout in the standard or Hunt-adjusted L*a*b* color space) |
|
|
:param eps: float containing a small value used to improve training stability |
|
|
:return: image tensor (with Nx1xHxW layout) containing the per-pixel HyAB distances between reference and test images |
|
|
""" |
|
|
delta = reference - test |
|
|
root = torch.sqrt(torch.clamp(torch.pow(delta[:, 0:1, :, :], 2), min=eps)) |
|
|
delta_norm = torch.norm(delta[:, 1:3, :, :], dim=1, keepdim=True) |
|
|
return root + delta_norm |
|
|
|
|
|
|
|
|
def redistribute_errors(power_deltaE_hyab, cmax, pc, pt): |
|
|
""" |
|
|
Redistributes exponentiated HyAB errors to the [0,1] range |
|
|
|
|
|
:param power_deltaE_hyab: float tensor (with Nx1xHxW layout) containing the exponentiated HyAb distance |
|
|
:param cmax: float containing the exponentiated, maximum HyAB difference between two colors in Hunt-adjusted L*A*B* space |
|
|
:param pc: float containing the cmax multiplier p_c (see FLIP paper) |
|
|
:param pt: float containing the target value, p_t, for p_c * cmax (see FLIP paper) |
|
|
:return: image tensor (with Nx1xHxW layout) containing redistributed per-pixel HyAB distances (in range [0,1]) |
|
|
""" |
|
|
|
|
|
|
|
|
|
|
|
deltaE_c = torch.zeros(power_deltaE_hyab.size(), device="cuda") |
|
|
pccmax = pc * cmax |
|
|
deltaE_c = torch.where( |
|
|
power_deltaE_hyab < pccmax, |
|
|
(pt / pccmax) * power_deltaE_hyab, |
|
|
pt + ((power_deltaE_hyab - pccmax) / (cmax - pccmax)) * (1.0 - pt), |
|
|
) |
|
|
|
|
|
return deltaE_c |
|
|
|
|
|
|
|
|
def feature_detection(img_y, pixels_per_degree, feature_type): |
|
|
""" |
|
|
Detects edges and points (features) in the achromatic image |
|
|
|
|
|
:param imgy: achromatic image tensor (with Nx1xHxW layout, containing normalized Y-values from YCxCz) |
|
|
:param pixels_per_degree: float describing the number of pixels per degree of visual angle of the observer |
|
|
:param feature_type: string indicating the type of feature to detect |
|
|
:return: image tensor (with Nx2xHxW layout, with values in range [0,1]) containing large values where features were detected |
|
|
""" |
|
|
|
|
|
|
|
|
w = 0.082 |
|
|
|
|
|
|
|
|
sd = 0.5 * w * pixels_per_degree |
|
|
radius = int(np.ceil(3 * sd)) |
|
|
|
|
|
|
|
|
[x, y] = np.meshgrid(range(-radius, radius + 1), range(-radius, radius + 1)) |
|
|
g = np.exp(-(x**2 + y**2) / (2 * sd * sd)) |
|
|
|
|
|
if feature_type == "edge": |
|
|
|
|
|
Gx = np.multiply(-x, g) |
|
|
else: |
|
|
|
|
|
Gx = np.multiply(x**2 / (sd * sd) - 1, g) |
|
|
|
|
|
|
|
|
negative_weights_sum = -np.sum(Gx[Gx < 0]) |
|
|
positive_weights_sum = np.sum(Gx[Gx > 0]) |
|
|
Gx = torch.Tensor(Gx) |
|
|
Gx = torch.where(Gx < 0, Gx / negative_weights_sum, Gx / positive_weights_sum) |
|
|
Gx = Gx.unsqueeze(0).unsqueeze(0).cuda() |
|
|
|
|
|
|
|
|
featuresX = nn.functional.conv2d( |
|
|
nn.functional.pad(img_y, (radius, radius, radius, radius), mode="replicate"), |
|
|
Gx, |
|
|
padding=0, |
|
|
) |
|
|
featuresY = nn.functional.conv2d( |
|
|
nn.functional.pad(img_y, (radius, radius, radius, radius), mode="replicate"), |
|
|
torch.transpose(Gx, 2, 3), |
|
|
padding=0, |
|
|
) |
|
|
return torch.cat((featuresX, featuresY), dim=1) |
|
|
|
|
|
|
|
|
def color_space_transform(input_color, fromSpace2toSpace): |
|
|
""" |
|
|
Transforms inputs between different color spaces |
|
|
|
|
|
:param input_color: tensor of colors to transform (with NxCxHxW layout) |
|
|
:param fromSpace2toSpace: string describing transform |
|
|
:return: transformed tensor (with NxCxHxW layout) |
|
|
""" |
|
|
dim = input_color.size() |
|
|
|
|
|
|
|
|
reference_illuminant = torch.tensor( |
|
|
[[[0.950428545]], [[1.000000000]], [[1.088900371]]] |
|
|
).cuda() |
|
|
inv_reference_illuminant = torch.tensor( |
|
|
[[[1.052156925]], [[1.000000000]], [[0.918357670]]] |
|
|
).cuda() |
|
|
|
|
|
if fromSpace2toSpace == "srgb2linrgb": |
|
|
limit = 0.04045 |
|
|
transformed_color = torch.where( |
|
|
input_color > limit, |
|
|
torch.pow((torch.clamp(input_color, min=limit) + 0.055) / 1.055, 2.4), |
|
|
input_color / 12.92, |
|
|
) |
|
|
|
|
|
elif fromSpace2toSpace == "linrgb2srgb": |
|
|
limit = 0.0031308 |
|
|
transformed_color = torch.where( |
|
|
input_color > limit, |
|
|
1.055 * torch.pow(torch.clamp(input_color, min=limit), (1.0 / 2.4)) - 0.055, |
|
|
12.92 * input_color, |
|
|
) |
|
|
|
|
|
elif fromSpace2toSpace in ["linrgb2xyz", "xyz2linrgb"]: |
|
|
|
|
|
|
|
|
if fromSpace2toSpace == "linrgb2xyz": |
|
|
a11 = 10135552 / 24577794 |
|
|
a12 = 8788810 / 24577794 |
|
|
a13 = 4435075 / 24577794 |
|
|
a21 = 2613072 / 12288897 |
|
|
a22 = 8788810 / 12288897 |
|
|
a23 = 887015 / 12288897 |
|
|
a31 = 1425312 / 73733382 |
|
|
a32 = 8788810 / 73733382 |
|
|
a33 = 70074185 / 73733382 |
|
|
else: |
|
|
|
|
|
|
|
|
a11 = 3.241003275 |
|
|
a12 = -1.537398934 |
|
|
a13 = -0.498615861 |
|
|
a21 = -0.969224334 |
|
|
a22 = 1.875930071 |
|
|
a23 = 0.041554224 |
|
|
a31 = 0.055639423 |
|
|
a32 = -0.204011202 |
|
|
a33 = 1.057148933 |
|
|
A = torch.Tensor([[a11, a12, a13], [a21, a22, a23], [a31, a32, a33]]) |
|
|
|
|
|
input_color = input_color.view(dim[0], dim[1], dim[2] * dim[3]).cuda() |
|
|
|
|
|
transformed_color = torch.matmul(A.cuda(), input_color) |
|
|
transformed_color = transformed_color.view(dim[0], dim[1], dim[2], dim[3]) |
|
|
|
|
|
elif fromSpace2toSpace == "xyz2ycxcz": |
|
|
input_color = torch.mul(input_color, inv_reference_illuminant) |
|
|
y = 116 * input_color[:, 1:2, :, :] - 16 |
|
|
cx = 500 * (input_color[:, 0:1, :, :] - input_color[:, 1:2, :, :]) |
|
|
cz = 200 * (input_color[:, 1:2, :, :] - input_color[:, 2:3, :, :]) |
|
|
transformed_color = torch.cat((y, cx, cz), 1) |
|
|
|
|
|
elif fromSpace2toSpace == "ycxcz2xyz": |
|
|
y = (input_color[:, 0:1, :, :] + 16) / 116 |
|
|
cx = input_color[:, 1:2, :, :] / 500 |
|
|
cz = input_color[:, 2:3, :, :] / 200 |
|
|
|
|
|
x = y + cx |
|
|
z = y - cz |
|
|
transformed_color = torch.cat((x, y, z), 1) |
|
|
|
|
|
transformed_color = torch.mul(transformed_color, reference_illuminant) |
|
|
|
|
|
elif fromSpace2toSpace == "xyz2lab": |
|
|
input_color = torch.mul(input_color, inv_reference_illuminant) |
|
|
delta = 6 / 29 |
|
|
delta_square = delta * delta |
|
|
delta_cube = delta * delta_square |
|
|
factor = 1 / (3 * delta_square) |
|
|
|
|
|
clamped_term = torch.pow( |
|
|
torch.clamp(input_color, min=delta_cube), 1.0 / 3.0 |
|
|
).to(dtype=input_color.dtype) |
|
|
div = (factor * input_color + (4 / 29)).to(dtype=input_color.dtype) |
|
|
input_color = torch.where( |
|
|
input_color > delta_cube, clamped_term, div |
|
|
) |
|
|
|
|
|
L = 116 * input_color[:, 1:2, :, :] - 16 |
|
|
a = 500 * (input_color[:, 0:1, :, :] - input_color[:, 1:2, :, :]) |
|
|
b = 200 * (input_color[:, 1:2, :, :] - input_color[:, 2:3, :, :]) |
|
|
|
|
|
transformed_color = torch.cat((L, a, b), 1) |
|
|
|
|
|
elif fromSpace2toSpace == "lab2xyz": |
|
|
y = (input_color[:, 0:1, :, :] + 16) / 116 |
|
|
a = input_color[:, 1:2, :, :] / 500 |
|
|
b = input_color[:, 2:3, :, :] / 200 |
|
|
|
|
|
x = y + a |
|
|
z = y - b |
|
|
|
|
|
xyz = torch.cat((x, y, z), 1) |
|
|
delta = 6 / 29 |
|
|
delta_square = delta * delta |
|
|
factor = 3 * delta_square |
|
|
xyz = torch.where(xyz > delta, torch.pow(xyz, 3), factor * (xyz - 4 / 29)) |
|
|
|
|
|
transformed_color = torch.mul(xyz, reference_illuminant) |
|
|
|
|
|
elif fromSpace2toSpace == "srgb2xyz": |
|
|
transformed_color = color_space_transform(input_color, "srgb2linrgb") |
|
|
transformed_color = color_space_transform(transformed_color, "linrgb2xyz") |
|
|
elif fromSpace2toSpace == "srgb2ycxcz": |
|
|
transformed_color = color_space_transform(input_color, "srgb2linrgb") |
|
|
transformed_color = color_space_transform(transformed_color, "linrgb2xyz") |
|
|
transformed_color = color_space_transform(transformed_color, "xyz2ycxcz") |
|
|
elif fromSpace2toSpace == "linrgb2ycxcz": |
|
|
transformed_color = color_space_transform(input_color, "linrgb2xyz") |
|
|
transformed_color = color_space_transform(transformed_color, "xyz2ycxcz") |
|
|
elif fromSpace2toSpace == "srgb2lab": |
|
|
transformed_color = color_space_transform(input_color, "srgb2linrgb") |
|
|
transformed_color = color_space_transform(transformed_color, "linrgb2xyz") |
|
|
transformed_color = color_space_transform(transformed_color, "xyz2lab") |
|
|
elif fromSpace2toSpace == "linrgb2lab": |
|
|
transformed_color = color_space_transform(input_color, "linrgb2xyz") |
|
|
transformed_color = color_space_transform(transformed_color, "xyz2lab") |
|
|
elif fromSpace2toSpace == "ycxcz2linrgb": |
|
|
transformed_color = color_space_transform(input_color, "ycxcz2xyz") |
|
|
transformed_color = color_space_transform(transformed_color, "xyz2linrgb") |
|
|
elif fromSpace2toSpace == "lab2srgb": |
|
|
transformed_color = color_space_transform(input_color, "lab2xyz") |
|
|
transformed_color = color_space_transform(transformed_color, "xyz2linrgb") |
|
|
transformed_color = color_space_transform(transformed_color, "linrgb2srgb") |
|
|
elif fromSpace2toSpace == "ycxcz2lab": |
|
|
transformed_color = color_space_transform(input_color, "ycxcz2xyz") |
|
|
transformed_color = color_space_transform(transformed_color, "xyz2lab") |
|
|
else: |
|
|
sys.exit("Error: The color transform %s is not defined!" % fromSpace2toSpace) |
|
|
|
|
|
return transformed_color |
|
|
|