File size: 33,267 Bytes
d62394f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 |
"""FLIP metric functions"""
#################################################################################
# Copyright (c) 2020-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
# SPDX-FileCopyrightText: Copyright (c) 2020-2024 NVIDIA CORPORATION & AFFILIATES
# SPDX-License-Identifier: BSD-3-Clause
#################################################################################
# Visualizing and Communicating Errors in Rendered Images
# Ray Tracing Gems II, 2021,
# by Pontus Andersson, Jim Nilsson, and Tomas Akenine-Moller.
# Pointer to the chapter: https://research.nvidia.com/publication/2021-08_Visualizing-and-Communicating.
# Visualizing Errors in Rendered High Dynamic Range Images
# Eurographics 2021,
# by Pontus Andersson, Jim Nilsson, Peter Shirley, and Tomas Akenine-Moller.
# Pointer to the paper: https://research.nvidia.com/publication/2021-05_HDR-FLIP.
# FLIP: A Difference Evaluator for Alternating Images
# High Performance Graphics 2020,
# by Pontus Andersson, Jim Nilsson, Tomas Akenine-Moller,
# Magnus Oskarsson, Kalle Astrom, and Mark D. Fairchild.
# Pointer to the paper: https://research.nvidia.com/publication/2020-07_FLIP.
# Code by Pontus Ebelin (formerly Andersson), Jim Nilsson, and Tomas Akenine-Moller.
import sys
import numpy as np
import torch
import torch.nn as nn
class HDRFLIPLoss(nn.Module):
"""Class for computing HDR-FLIP"""
def __init__(self):
"""Init"""
super().__init__()
self.qc = 0.7
self.qf = 0.5
self.pc = 0.4
self.pt = 0.95
self.tmax = 0.85
self.tmin = 0.85
self.eps = 1e-15
def forward(
self,
test,
reference,
pixels_per_degree=(0.7 * 3840 / 0.7) * np.pi / 180,
tone_mapper="aces",
start_exposure=None,
stop_exposure=None,
):
"""
Computes the HDR-FLIP error map between two HDR images,
assuming the images are observed at a certain number of
pixels per degree of visual angle
:param test: test tensor (with NxCxHxW layout with nonnegative values)
:param reference: reference tensor (with NxCxHxW layout with nonnegative values)
:param pixels_per_degree: float describing the number of pixels per degree of visual angle of the observer,
default corresponds to viewing the images on a 0.7 meters wide 4K monitor at 0.7 meters from the display
:param tone_mapper: (optional) string describing what tone mapper HDR-FLIP should assume
:param start_exposure: (optional tensor (with Nx1x1x1 layout) with start exposures corresponding to each HDR reference/test pair
:param stop_exposure: (optional) tensor (with Nx1x1x1 layout) with stop exposures corresponding to each HDR reference/test pair
:return: float containing the mean FLIP error (in the range [0,1]) between the HDR reference and test images in the batch
"""
# HDR-FLIP expects nonnegative and non-NaN values in the input
reference = torch.clamp(reference, 0, 65536.0)
test = torch.clamp(test, 0, 65536.0)
# Compute start and stop exposures, if they are not given
if start_exposure is None or stop_exposure is None:
c_start, c_stop = compute_start_stop_exposures(
reference, tone_mapper, self.tmax, self.tmin
)
if start_exposure is None:
start_exposure = c_start
if stop_exposure is None:
stop_exposure = c_stop
# Compute number of exposures
num_exposures = torch.max(
torch.tensor([2.0], requires_grad=False).cuda(),
torch.ceil(stop_exposure - start_exposure),
)
most_exposures = int(torch.amax(num_exposures, dim=0).item())
# Compute exposure step size
step_size = (stop_exposure - start_exposure) / torch.max(
num_exposures - 1, torch.tensor([1.0], requires_grad=False).cuda()
)
# Set the depth of the error tensor to the number of exposures given by the largest exposure range any reference image yielded.
# This allows us to do one loop for each image in our batch, while not affecting the HDR-FLIP error, as we fill up the error tensor with 0s.
# Note that the step size still depends on num_exposures and is therefore independent of most_exposures
dim = reference.size()
all_errors = torch.zeros(size=(dim[0], most_exposures, dim[2], dim[3])).cuda()
# Loop over exposures and compute LDR-FLIP for each pair of LDR reference and test
for i in range(0, most_exposures):
exposure = start_exposure + i * step_size
reference_tone_mapped = tone_map(reference, tone_mapper, exposure)
test_tone_mapped = tone_map(test, tone_mapper, exposure)
reference_opponent = color_space_transform(
reference_tone_mapped, "linrgb2ycxcz"
)
test_opponent = color_space_transform(test_tone_mapped, "linrgb2ycxcz")
all_errors[:, i, :, :] = compute_ldrflip(
test_opponent,
reference_opponent,
pixels_per_degree,
self.qc,
self.qf,
self.pc,
self.pt,
self.eps,
).squeeze(1)
# Take per-pixel maximum over all LDR-FLIP errors to get HDR-FLIP
hdrflip_error = torch.amax(all_errors, dim=1, keepdim=True)
return torch.mean(hdrflip_error)
class LDRFLIPLoss(nn.Module):
"""Class for computing LDR FLIP loss"""
def __init__(self):
"""Init"""
super().__init__()
self.qc = 0.7
self.qf = 0.5
self.pc = 0.4
self.pt = 0.95
self.eps = 1e-15
def forward(
self, test, reference, pixels_per_degree=(0.7 * 3840 / 0.7) * np.pi / 180
):
"""
Computes the LDR-FLIP error map between two LDR images,
assuming the images are observed at a certain number of
pixels per degree of visual angle
:param test: test tensor (with NxCxHxW layout with values in the range [0, 1] in the sRGB color space)
:param reference: reference tensor (with NxCxHxW layout with values in the range [0, 1] in the sRGB color space)
:param pixels_per_degree: float describing the number of pixels per degree of visual angle of the observer,
default corresponds to viewing the images on a 0.7 meters wide 4K monitor at 0.7 meters from the display
:return: float containing the mean FLIP error (in the range [0,1]) between the LDR reference and test images in the batch
"""
# LDR-FLIP expects non-NaN values in [0,1] as input
reference = torch.clamp(reference, 0, 1)
test = torch.clamp(test, 0, 1)
# Transform reference and test to opponent color space
reference_opponent = color_space_transform(reference, "srgb2ycxcz")
test_opponent = color_space_transform(test, "srgb2ycxcz")
deltaE = compute_ldrflip(
test_opponent,
reference_opponent,
pixels_per_degree,
self.qc,
self.qf,
self.pc,
self.pt,
self.eps,
)
return torch.mean(deltaE)
def compute_ldrflip(test, reference, pixels_per_degree, qc, qf, pc, pt, eps):
"""
Computes the LDR-FLIP error map between two LDR images,
assuming the images are observed at a certain number of
pixels per degree of visual angle
:param reference: reference tensor (with NxCxHxW layout with values in the YCxCz color space)
:param test: test tensor (with NxCxHxW layout with values in the YCxCz color space)
:param pixels_per_degree: float describing the number of pixels per degree of visual angle of the observer,
default corresponds to viewing the images on a 0.7 meters wide 4K monitor at 0.7 meters from the display
:param qc: float describing the q_c exponent in the LDR-FLIP color pipeline (see FLIP paper for details)
:param qf: float describing the q_f exponent in the LDR-FLIP feature pipeline (see FLIP paper for details)
:param pc: float describing the p_c exponent in the LDR-FLIP color pipeline (see FLIP paper for details)
:param pt: float describing the p_t exponent in the LDR-FLIP color pipeline (see FLIP paper for details)
:param eps: float containing a small value used to improve training stability
:return: tensor containing the per-pixel FLIP errors (with Nx1xHxW layout and values in the range [0, 1]) between LDR reference and test images
"""
# --- Color pipeline ---
# Spatial filtering
s_a, radius_a = generate_spatial_filter(pixels_per_degree, "A")
s_rg, radius_rg = generate_spatial_filter(pixels_per_degree, "RG")
s_by, radius_by = generate_spatial_filter(pixels_per_degree, "BY")
radius = max(radius_a, radius_rg, radius_by)
filtered_reference = spatial_filter(reference, s_a, s_rg, s_by, radius)
filtered_test = spatial_filter(test, s_a, s_rg, s_by, radius)
# Perceptually Uniform Color Space
preprocessed_reference = hunt_adjustment(
color_space_transform(filtered_reference, "linrgb2lab")
)
preprocessed_test = hunt_adjustment(
color_space_transform(filtered_test, "linrgb2lab")
)
# Color metric
deltaE_hyab = hyab(preprocessed_reference, preprocessed_test, eps)
power_deltaE_hyab = torch.pow(deltaE_hyab, qc)
hunt_adjusted_green = hunt_adjustment(
color_space_transform(
torch.tensor([[[0.0]], [[1.0]], [[0.0]]]).unsqueeze(0), "linrgb2lab"
)
)
hunt_adjusted_blue = hunt_adjustment(
color_space_transform(
torch.tensor([[[0.0]], [[0.0]], [[1.0]]]).unsqueeze(0), "linrgb2lab"
)
)
cmax = torch.pow(hyab(hunt_adjusted_green, hunt_adjusted_blue, eps), qc).item()
deltaE_c = redistribute_errors(power_deltaE_hyab, cmax, pc, pt)
# --- Feature pipeline ---
# Extract and normalize Yy component
ref_y = (reference[:, 0:1, :, :] + 16) / 116
test_y = (test[:, 0:1, :, :] + 16) / 116
# Edge and point detection
edges_reference = feature_detection(ref_y, pixels_per_degree, "edge")
points_reference = feature_detection(ref_y, pixels_per_degree, "point")
edges_test = feature_detection(test_y, pixels_per_degree, "edge")
points_test = feature_detection(test_y, pixels_per_degree, "point")
# Feature metric
deltaE_f = torch.max(
torch.abs(
torch.norm(edges_reference, dim=1, keepdim=True)
- torch.norm(edges_test, dim=1, keepdim=True)
),
torch.abs(
torch.norm(points_test, dim=1, keepdim=True)
- torch.norm(points_reference, dim=1, keepdim=True)
),
)
deltaE_f = torch.clamp(deltaE_f, min=eps) # clamp to stabilize training
deltaE_f = torch.pow(((1 / np.sqrt(2)) * deltaE_f), qf)
# --- Final error ---
return torch.pow(deltaE_c, 1 - deltaE_f)
def tone_map(img, tone_mapper, exposure):
"""
Applies exposure compensation and tone mapping.
Refer to the Visualizing Errors in Rendered High Dynamic Range Images
paper for details about the formulas.
:param img: float tensor (with NxCxHxW layout) containing nonnegative values
:param tone_mapper: string describing the tone mapper to apply
:param exposure: float tensor (with Nx1x1x1 layout) describing the exposure compensation factor
"""
# Exposure compensation
x = (2**exposure) * img
# Set tone mapping coefficients depending on tone_mapper
if tone_mapper == "reinhard":
lum_coeff_r = 0.2126
lum_coeff_g = 0.7152
lum_coeff_b = 0.0722
Y = (
x[:, 0:1, :, :] * lum_coeff_r
+ x[:, 1:2, :, :] * lum_coeff_g
+ x[:, 2:3, :, :] * lum_coeff_b
)
return torch.clamp(torch.div(x, 1 + Y), 0.0, 1.0)
if tone_mapper == "hable":
# Source: https://64.github.io/tonemapping/
A = 0.15
B = 0.50
C = 0.10
D = 0.20
E = 0.02
F = 0.30
k0 = A * F - A * E
k1 = C * B * F - B * E
k2 = 0
k3 = A * F
k4 = B * F
k5 = D * F * F
W = 11.2
nom = k0 * torch.pow(W, torch.tensor([2.0]).cuda()) + k1 * W + k2
denom = k3 * torch.pow(W, torch.tensor([2.0]).cuda()) + k4 * W + k5
white_scale = torch.div(denom, nom) # = 1 / (nom / denom)
# Include white scale and exposure bias in rational polynomial coefficients
k0 = 4 * k0 * white_scale
k1 = 2 * k1 * white_scale
k2 = k2 * white_scale
k3 = 4 * k3
k4 = 2 * k4
# k5 = k5 # k5 is not changed
else:
# Source: ACES approximation: https://knarkowicz.wordpress.com/2016/01/06/aces-filmic-tone-mapping-curve/
# Include pre-exposure cancelation in constants
k0 = 0.6 * 0.6 * 2.51
k1 = 0.6 * 0.03
k2 = 0
k3 = 0.6 * 0.6 * 2.43
k4 = 0.6 * 0.59
k5 = 0.14
x2 = torch.pow(x, 2)
nom = k0 * x2 + k1 * x + k2
denom = k3 * x2 + k4 * x + k5
denom = torch.where(
torch.isinf(denom), torch.Tensor([1.0]).cuda(), denom
) # if denom is inf, then so is nom => nan. Pixel is very bright. It becomes inf here, but 1 after clamp below
y = torch.div(nom, denom)
return torch.clamp(y, 0.0, 1.0)
def compute_start_stop_exposures(reference, tone_mapper, tmax, tmin):
"""
Computes start and stop exposure for HDR-FLIP based on given tone mapper and reference image.
Refer to the Visualizing Errors in Rendered High Dynamic Range Images
paper for details about the formulas
:param reference: float tensor (with NxCxHxW layout) containing reference images (nonnegative values)
:param tone_mapper: string describing which tone mapper should be assumed
:param tmax: float describing the t value used to find the start exposure
:param tmin: float describing the t value used to find the stop exposure
:return: two float tensors (with Nx1x1x1 layout) containing start and stop exposures, respectively, to use for HDR-FLIP
"""
if tone_mapper == "reinhard":
k0 = 0
k1 = 1
k2 = 0
k3 = 0
k4 = 1
k5 = 1
x_max = tmax * k5 / (k1 - tmax * k4)
x_min = tmin * k5 / (k1 - tmin * k4)
elif tone_mapper == "hable":
# Source: https://64.github.io/tonemapping/
A = 0.15
B = 0.50
C = 0.10
D = 0.20
E = 0.02
F = 0.30
k0 = A * F - A * E
k1 = C * B * F - B * E
k2 = 0
k3 = A * F
k4 = B * F
k5 = D * F * F
W = 11.2
nom = k0 * torch.pow(W, torch.tensor([2.0]).cuda()) + k1 * W + k2
denom = k3 * torch.pow(W, torch.tensor([2.0]).cuda()) + k4 * W + k5
white_scale = torch.div(denom, nom) # = 1 / (nom / denom)
# Include white scale and exposure bias in rational polynomial coefficients
k0 = 4 * k0 * white_scale
k1 = 2 * k1 * white_scale
k2 = k2 * white_scale
k3 = 4 * k3
k4 = 2 * k4
# k5 = k5 # k5 is not changed
c0 = (k1 - k4 * tmax) / (k0 - k3 * tmax)
c1 = (k2 - k5 * tmax) / (k0 - k3 * tmax)
x_max = -0.5 * c0 + torch.sqrt(((torch.tensor([0.5]).cuda() * c0) ** 2) - c1)
c0 = (k1 - k4 * tmin) / (k0 - k3 * tmin)
c1 = (k2 - k5 * tmin) / (k0 - k3 * tmin)
x_min = -0.5 * c0 + torch.sqrt(((torch.tensor([0.5]).cuda() * c0) ** 2) - c1)
else:
# Source: ACES approximation: https://knarkowicz.wordpress.com/2016/01/06/aces-filmic-tone-mapping-curve/
# Include pre-exposure cancelation in constants
k0 = 0.6 * 0.6 * 2.51
k1 = 0.6 * 0.03
k2 = 0
k3 = 0.6 * 0.6 * 2.43
k4 = 0.6 * 0.59
k5 = 0.14
c0 = (k1 - k4 * tmax) / (k0 - k3 * tmax)
c1 = (k2 - k5 * tmax) / (k0 - k3 * tmax)
x_max = -0.5 * c0 + torch.sqrt(((torch.tensor([0.5]).cuda() * c0) ** 2) - c1)
c0 = (k1 - k4 * tmin) / (k0 - k3 * tmin)
c1 = (k2 - k5 * tmin) / (k0 - k3 * tmin)
x_min = -0.5 * c0 + torch.sqrt(((torch.tensor([0.5]).cuda() * c0) ** 2) - c1)
# Convert reference to luminance
lum_coeff_r = 0.2126
lum_coeff_g = 0.7152
lum_coeff_b = 0.0722
Y_reference = (
reference[:, 0:1, :, :] * lum_coeff_r
+ reference[:, 1:2, :, :] * lum_coeff_g
+ reference[:, 2:3, :, :] * lum_coeff_b
)
# Compute start exposure
Y_hi = torch.amax(Y_reference, dim=(2, 3), keepdim=True)
start_exposure = torch.log2(x_max / Y_hi)
# Compute stop exposure
dim = Y_reference.size()
Y_ref = Y_reference.view(dim[0], dim[1], dim[2] * dim[3])
Y_lo = torch.median(Y_ref, dim=2).values.unsqueeze(2).unsqueeze(3)
stop_exposure = torch.log2(x_min / Y_lo)
return start_exposure, stop_exposure
def generate_spatial_filter(pixels_per_degree, channel):
"""
Generates spatial contrast sensitivity filters with width depending on
the number of pixels per degree of visual angle of the observer
:param pixels_per_degree: float indicating number of pixels per degree of visual angle
:param channel: string describing what filter should be generated
:yield: Filter kernel corresponding to the spatial contrast sensitivity function of the given channel and kernel's radius
"""
a1_A = 1
b1_A = 0.0047
a2_A = 0
b2_A = 1e-5 # avoid division by 0
a1_rg = 1
b1_rg = 0.0053
a2_rg = 0
b2_rg = 1e-5 # avoid division by 0
a1_by = 34.1
b1_by = 0.04
a2_by = 13.5
b2_by = 0.025
if channel == "A": # Achromatic CSF
a1 = a1_A
b1 = b1_A
a2 = a2_A
b2 = b2_A
elif channel == "RG": # Red-Green CSF
a1 = a1_rg
b1 = b1_rg
a2 = a2_rg
b2 = b2_rg
elif channel == "BY": # Blue-Yellow CSF
a1 = a1_by
b1 = b1_by
a2 = a2_by
b2 = b2_by
# Determine evaluation domain
max_scale_parameter = max([b1_A, b2_A, b1_rg, b2_rg, b1_by, b2_by])
r = np.ceil(3 * np.sqrt(max_scale_parameter / (2 * np.pi**2)) * pixels_per_degree)
r = int(r)
deltaX = 1.0 / pixels_per_degree
x, y = np.meshgrid(range(-r, r + 1), range(-r, r + 1))
z = (x * deltaX) ** 2 + (y * deltaX) ** 2
# Generate weights
g = a1 * np.sqrt(np.pi / b1) * np.exp(-(np.pi**2) * z / b1) + a2 * np.sqrt(
np.pi / b2
) * np.exp(-(np.pi**2) * z / b2)
g = g / np.sum(g)
g = torch.Tensor(g).unsqueeze(0).unsqueeze(0).cuda()
return g, r
def spatial_filter(img, s_a, s_rg, s_by, radius):
"""
Filters an image with channel specific spatial contrast sensitivity functions
and clips result to the unit cube in linear RGB
:param img: image tensor to filter (with NxCxHxW layout in the YCxCz color space)
:param s_a: spatial filter matrix for the achromatic channel
:param s_rg: spatial filter matrix for the red-green channel
:param s_by: spatial filter matrix for the blue-yellow channel
:return: input image (with NxCxHxW layout) transformed to linear RGB after filtering with spatial contrast sensitivity functions
"""
dim = img.size()
# Prepare image for convolution
img_pad = torch.zeros(
(dim[0], dim[1], dim[2] + 2 * radius, dim[3] + 2 * radius), device="cuda"
)
img_pad[:, 0:1, :, :] = nn.functional.pad(
img[:, 0:1, :, :], (radius, radius, radius, radius), mode="replicate"
)
img_pad[:, 1:2, :, :] = nn.functional.pad(
img[:, 1:2, :, :], (radius, radius, radius, radius), mode="replicate"
)
img_pad[:, 2:3, :, :] = nn.functional.pad(
img[:, 2:3, :, :], (radius, radius, radius, radius), mode="replicate"
)
# Apply Gaussian filters
img_tilde_opponent = torch.zeros((dim[0], dim[1], dim[2], dim[3]), device="cuda")
img_tilde_opponent[:, 0:1, :, :] = nn.functional.conv2d(
img_pad[:, 0:1, :, :], s_a.cuda(), padding=0
)
img_tilde_opponent[:, 1:2, :, :] = nn.functional.conv2d(
img_pad[:, 1:2, :, :], s_rg.cuda(), padding=0
)
img_tilde_opponent[:, 2:3, :, :] = nn.functional.conv2d(
img_pad[:, 2:3, :, :], s_by.cuda(), padding=0
)
# Transform to linear RGB for clamp
img_tilde_linear_rgb = color_space_transform(img_tilde_opponent, "ycxcz2linrgb")
# Clamp to RGB box
return torch.clamp(img_tilde_linear_rgb, 0.0, 1.0)
def hunt_adjustment(img):
"""
Applies Hunt-adjustment to an image
:param img: image tensor to adjust (with NxCxHxW layout in the L*a*b* color space)
:return: Hunt-adjusted image tensor (with NxCxHxW layout in the Hunt-adjusted L*A*B* color space)
"""
# Extract luminance component
L = img[:, 0:1, :, :]
# Apply Hunt adjustment
img_h = torch.zeros(img.size(), device="cuda")
img_h[:, 0:1, :, :] = L
img_h[:, 1:2, :, :] = torch.mul((0.01 * L), img[:, 1:2, :, :])
img_h[:, 2:3, :, :] = torch.mul((0.01 * L), img[:, 2:3, :, :])
return img_h
def hyab(reference, test, eps):
"""
Computes the HyAB distance between reference and test images
:param reference: reference image tensor (with NxCxHxW layout in the standard or Hunt-adjusted L*A*B* color space)
:param test: test image tensor (with NxCxHxW layout in the standard or Hunt-adjusted L*a*b* color space)
:param eps: float containing a small value used to improve training stability
:return: image tensor (with Nx1xHxW layout) containing the per-pixel HyAB distances between reference and test images
"""
delta = reference - test
root = torch.sqrt(torch.clamp(torch.pow(delta[:, 0:1, :, :], 2), min=eps))
delta_norm = torch.norm(delta[:, 1:3, :, :], dim=1, keepdim=True)
return root + delta_norm # alternative abs to stabilize training
def redistribute_errors(power_deltaE_hyab, cmax, pc, pt):
"""
Redistributes exponentiated HyAB errors to the [0,1] range
:param power_deltaE_hyab: float tensor (with Nx1xHxW layout) containing the exponentiated HyAb distance
:param cmax: float containing the exponentiated, maximum HyAB difference between two colors in Hunt-adjusted L*A*B* space
:param pc: float containing the cmax multiplier p_c (see FLIP paper)
:param pt: float containing the target value, p_t, for p_c * cmax (see FLIP paper)
:return: image tensor (with Nx1xHxW layout) containing redistributed per-pixel HyAB distances (in range [0,1])
"""
# Re-map error to 0-1 range. Values between 0 and
# pccmax are mapped to the range [0, pt],
# while the rest are mapped to the range (pt, 1]
deltaE_c = torch.zeros(power_deltaE_hyab.size(), device="cuda")
pccmax = pc * cmax
deltaE_c = torch.where(
power_deltaE_hyab < pccmax,
(pt / pccmax) * power_deltaE_hyab,
pt + ((power_deltaE_hyab - pccmax) / (cmax - pccmax)) * (1.0 - pt),
)
return deltaE_c
def feature_detection(img_y, pixels_per_degree, feature_type):
"""
Detects edges and points (features) in the achromatic image
:param imgy: achromatic image tensor (with Nx1xHxW layout, containing normalized Y-values from YCxCz)
:param pixels_per_degree: float describing the number of pixels per degree of visual angle of the observer
:param feature_type: string indicating the type of feature to detect
:return: image tensor (with Nx2xHxW layout, with values in range [0,1]) containing large values where features were detected
"""
# Set peak to trough value (2x standard deviations) of human edge
# detection filter
w = 0.082
# Compute filter radius
sd = 0.5 * w * pixels_per_degree
radius = int(np.ceil(3 * sd))
# Compute 2D Gaussian
[x, y] = np.meshgrid(range(-radius, radius + 1), range(-radius, radius + 1))
g = np.exp(-(x**2 + y**2) / (2 * sd * sd))
if feature_type == "edge": # Edge detector
# Compute partial derivative in x-direction
Gx = np.multiply(-x, g)
else: # Point detector
# Compute second partial derivative in x-direction
Gx = np.multiply(x**2 / (sd * sd) - 1, g)
# Normalize positive weights to sum to 1 and negative weights to sum to -1
negative_weights_sum = -np.sum(Gx[Gx < 0])
positive_weights_sum = np.sum(Gx[Gx > 0])
Gx = torch.Tensor(Gx)
Gx = torch.where(Gx < 0, Gx / negative_weights_sum, Gx / positive_weights_sum)
Gx = Gx.unsqueeze(0).unsqueeze(0).cuda()
# Detect features
featuresX = nn.functional.conv2d(
nn.functional.pad(img_y, (radius, radius, radius, radius), mode="replicate"),
Gx,
padding=0,
)
featuresY = nn.functional.conv2d(
nn.functional.pad(img_y, (radius, radius, radius, radius), mode="replicate"),
torch.transpose(Gx, 2, 3),
padding=0,
)
return torch.cat((featuresX, featuresY), dim=1)
def color_space_transform(input_color, fromSpace2toSpace):
"""
Transforms inputs between different color spaces
:param input_color: tensor of colors to transform (with NxCxHxW layout)
:param fromSpace2toSpace: string describing transform
:return: transformed tensor (with NxCxHxW layout)
"""
dim = input_color.size()
# Assume D65 standard illuminant
reference_illuminant = torch.tensor(
[[[0.950428545]], [[1.000000000]], [[1.088900371]]]
).cuda()
inv_reference_illuminant = torch.tensor(
[[[1.052156925]], [[1.000000000]], [[0.918357670]]]
).cuda()
if fromSpace2toSpace == "srgb2linrgb":
limit = 0.04045
transformed_color = torch.where(
input_color > limit,
torch.pow((torch.clamp(input_color, min=limit) + 0.055) / 1.055, 2.4),
input_color / 12.92,
) # clamp to stabilize training
elif fromSpace2toSpace == "linrgb2srgb":
limit = 0.0031308
transformed_color = torch.where(
input_color > limit,
1.055 * torch.pow(torch.clamp(input_color, min=limit), (1.0 / 2.4)) - 0.055,
12.92 * input_color,
)
elif fromSpace2toSpace in ["linrgb2xyz", "xyz2linrgb"]:
# Source: https://www.image-engineering.de/library/technotes/958-how-to-convert-between-srgb-and-ciexyz
# Assumes D65 standard illuminant
if fromSpace2toSpace == "linrgb2xyz":
a11 = 10135552 / 24577794
a12 = 8788810 / 24577794
a13 = 4435075 / 24577794
a21 = 2613072 / 12288897
a22 = 8788810 / 12288897
a23 = 887015 / 12288897
a31 = 1425312 / 73733382
a32 = 8788810 / 73733382
a33 = 70074185 / 73733382
else:
# Constants found by taking the inverse of the matrix
# defined by the constants for linrgb2xyz
a11 = 3.241003275
a12 = -1.537398934
a13 = -0.498615861
a21 = -0.969224334
a22 = 1.875930071
a23 = 0.041554224
a31 = 0.055639423
a32 = -0.204011202
a33 = 1.057148933
A = torch.Tensor([[a11, a12, a13], [a21, a22, a23], [a31, a32, a33]])
input_color = input_color.view(dim[0], dim[1], dim[2] * dim[3]).cuda() # NC(HW)
transformed_color = torch.matmul(A.cuda(), input_color)
transformed_color = transformed_color.view(dim[0], dim[1], dim[2], dim[3])
elif fromSpace2toSpace == "xyz2ycxcz":
input_color = torch.mul(input_color, inv_reference_illuminant)
y = 116 * input_color[:, 1:2, :, :] - 16
cx = 500 * (input_color[:, 0:1, :, :] - input_color[:, 1:2, :, :])
cz = 200 * (input_color[:, 1:2, :, :] - input_color[:, 2:3, :, :])
transformed_color = torch.cat((y, cx, cz), 1)
elif fromSpace2toSpace == "ycxcz2xyz":
y = (input_color[:, 0:1, :, :] + 16) / 116
cx = input_color[:, 1:2, :, :] / 500
cz = input_color[:, 2:3, :, :] / 200
x = y + cx
z = y - cz
transformed_color = torch.cat((x, y, z), 1)
transformed_color = torch.mul(transformed_color, reference_illuminant)
elif fromSpace2toSpace == "xyz2lab":
input_color = torch.mul(input_color, inv_reference_illuminant)
delta = 6 / 29
delta_square = delta * delta
delta_cube = delta * delta_square
factor = 1 / (3 * delta_square)
clamped_term = torch.pow(
torch.clamp(input_color, min=delta_cube), 1.0 / 3.0
).to(dtype=input_color.dtype)
div = (factor * input_color + (4 / 29)).to(dtype=input_color.dtype)
input_color = torch.where(
input_color > delta_cube, clamped_term, div
) # clamp to stabilize training
L = 116 * input_color[:, 1:2, :, :] - 16
a = 500 * (input_color[:, 0:1, :, :] - input_color[:, 1:2, :, :])
b = 200 * (input_color[:, 1:2, :, :] - input_color[:, 2:3, :, :])
transformed_color = torch.cat((L, a, b), 1)
elif fromSpace2toSpace == "lab2xyz":
y = (input_color[:, 0:1, :, :] + 16) / 116
a = input_color[:, 1:2, :, :] / 500
b = input_color[:, 2:3, :, :] / 200
x = y + a
z = y - b
xyz = torch.cat((x, y, z), 1)
delta = 6 / 29
delta_square = delta * delta
factor = 3 * delta_square
xyz = torch.where(xyz > delta, torch.pow(xyz, 3), factor * (xyz - 4 / 29))
transformed_color = torch.mul(xyz, reference_illuminant)
elif fromSpace2toSpace == "srgb2xyz":
transformed_color = color_space_transform(input_color, "srgb2linrgb")
transformed_color = color_space_transform(transformed_color, "linrgb2xyz")
elif fromSpace2toSpace == "srgb2ycxcz":
transformed_color = color_space_transform(input_color, "srgb2linrgb")
transformed_color = color_space_transform(transformed_color, "linrgb2xyz")
transformed_color = color_space_transform(transformed_color, "xyz2ycxcz")
elif fromSpace2toSpace == "linrgb2ycxcz":
transformed_color = color_space_transform(input_color, "linrgb2xyz")
transformed_color = color_space_transform(transformed_color, "xyz2ycxcz")
elif fromSpace2toSpace == "srgb2lab":
transformed_color = color_space_transform(input_color, "srgb2linrgb")
transformed_color = color_space_transform(transformed_color, "linrgb2xyz")
transformed_color = color_space_transform(transformed_color, "xyz2lab")
elif fromSpace2toSpace == "linrgb2lab":
transformed_color = color_space_transform(input_color, "linrgb2xyz")
transformed_color = color_space_transform(transformed_color, "xyz2lab")
elif fromSpace2toSpace == "ycxcz2linrgb":
transformed_color = color_space_transform(input_color, "ycxcz2xyz")
transformed_color = color_space_transform(transformed_color, "xyz2linrgb")
elif fromSpace2toSpace == "lab2srgb":
transformed_color = color_space_transform(input_color, "lab2xyz")
transformed_color = color_space_transform(transformed_color, "xyz2linrgb")
transformed_color = color_space_transform(transformed_color, "linrgb2srgb")
elif fromSpace2toSpace == "ycxcz2lab":
transformed_color = color_space_transform(input_color, "ycxcz2xyz")
transformed_color = color_space_transform(transformed_color, "xyz2lab")
else:
sys.exit("Error: The color transform %s is not defined!" % fromSpace2toSpace)
return transformed_color
|