| | import warnings |
| | from typing import Optional |
| |
|
| | import torch |
| | from jaxtyping import Float |
| | from lxml import etree |
| |
|
| |
|
| | def load_asc_cdl(cdl_path: str, device: torch.device = torch.device("cpu")) -> dict: |
| | """ |
| | Loads ASC CDL parameters from an XML file. |
| | |
| | Parameters: |
| | cdl_path (str): Path to the ASC CDL XML file |
| | |
| | Returns: |
| | Dict: |
| | slope, offset, power, and saturation values as torch tensors |
| | """ |
| | try: |
| | tree = etree.parse(cdl_path) |
| | root = tree.getroot() |
| | except Exception as e: |
| | raise ValueError(f"Error loading ASC CDL from {cdl_path}: {e}") |
| |
|
| | |
| | sop_node = root.find(".//SOPNode") |
| | slope = torch.tensor( |
| | [float(x) for x in sop_node.find("Slope").text.split()], device=device |
| | ) |
| | offset = torch.tensor( |
| | [float(x) for x in sop_node.find("Offset").text.split()], device=device |
| | ) |
| | power = torch.tensor( |
| | [float(x) for x in sop_node.find("Power").text.split()], device=device |
| | ) |
| |
|
| | |
| | sat_node = root.find(".//SatNode") |
| | saturation = torch.tensor(float(sat_node.find("Saturation").text), device=device) |
| |
|
| | return {"slope": slope, "offset": offset, "power": power, "saturation": saturation} |
| |
|
| |
|
| | def save_asc_cdl(cdl_dict: dict, cdl_path: Optional[str]): |
| | """ |
| | Saves ASC CDL parameters to an XML file. |
| | |
| | Parameters: |
| | cdl_dict (dict): Dictionary containing slope, offset, power, and |
| | saturation values |
| | """ |
| | root = etree.Element("ASC_CDL") |
| | sop_node = etree.SubElement(root, "SOPNode") |
| | etree.SubElement(sop_node, "Slope").text = " ".join( |
| | str(x) for x in cdl_dict["slope"].detach().cpu().numpy() |
| | ) |
| | etree.SubElement(sop_node, "Offset").text = " ".join( |
| | str(x) for x in cdl_dict["offset"].detach().cpu().numpy() |
| | ) |
| | etree.SubElement(sop_node, "Power").text = " ".join( |
| | str(x) for x in cdl_dict["power"].detach().cpu().numpy() |
| | ) |
| | sat_node = etree.SubElement(root, "SatNode") |
| | etree.SubElement(sat_node, "Saturation").text = str( |
| | cdl_dict["saturation"].detach().cpu().numpy() |
| | ) |
| |
|
| | tree = etree.ElementTree(root) |
| | if cdl_path is not None: |
| | try: |
| | tree.write( |
| | cdl_path, pretty_print=True, xml_declaration=True, encoding="utf-8" |
| | ) |
| | except Exception as e: |
| | raise ValueError(f"Error saving ASC CDL to {cdl_path}: {e}") |
| | else: |
| | return etree.tostring( |
| | root, pretty_print=True, xml_declaration=True, encoding="utf-8" |
| | ).decode("utf-8") |
| |
|
| |
|
| | def apply_sop( |
| | img: Float[torch.Tensor, "*B C H W"], |
| | slope: Float[torch.Tensor, "*B C"], |
| | offset: Float[torch.Tensor, "*B C"], |
| | power: Float[torch.Tensor, "*B C"], |
| | clamp: bool = True, |
| | ) -> Float[torch.Tensor, "*B C H W"]: |
| | """ |
| | Applies Slope, Offset, and Power adjustments. |
| | |
| | Parameters: |
| | img (torch.Tensor): Input image tensor (*B, C, H, W) |
| | slope (torch.Tensor): Slope per channel (*B, C) |
| | offset (torch.Tensor): Offset per channel (*B, C) |
| | power (torch.Tensor): Power per channel (*B, C) |
| | |
| | Returns: |
| | torch.Tensor: Image after SOP adjustments. |
| | """ |
| | so = img * slope.unsqueeze(-1).unsqueeze(-1) + offset.unsqueeze(-1).unsqueeze(-1) |
| | if clamp: |
| | so = torch.clamp(so, min=0.0, max=1.0) |
| | return torch.where( |
| | so > 1e-7, torch.pow(so.clamp(min=1e-7), power.unsqueeze(-1).unsqueeze(-1)), so |
| | ) |
| |
|
| |
|
| | def apply_saturation( |
| | img: Float[torch.Tensor, "*B C H W"], |
| | saturation: Float[torch.Tensor, "*B"], |
| | ) -> Float[torch.Tensor, "*B C H W"]: |
| | """ |
| | Applies saturation adjustment. |
| | |
| | Parameters: |
| | img (torch.Tensor): Image tensor (*B, C, H, W) |
| | saturation (torch.Tensor): Saturation factor (*B) |
| | |
| | Returns: |
| | torch.Tensor: Image after saturation adjustment. |
| | """ |
| | |
| | lum = ( |
| | 0.2126 * img[..., 0, :, :] |
| | + 0.7152 * img[..., 1, :, :] |
| | + 0.0722 * img[..., 2, :, :] |
| | ) |
| | lum = lum.unsqueeze(-3) |
| | return lum + (img - lum) * saturation.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) |
| |
|
| |
|
| | def asc_cdl_forward( |
| | img: Float[torch.Tensor, "*B C H W"], |
| | slope: Float[torch.Tensor, "*B C"], |
| | offset: Float[torch.Tensor, "*B C"], |
| | power: Float[torch.Tensor, "*B C"], |
| | saturation: Float[torch.Tensor, "*B"], |
| | clamp: bool = True, |
| | ) -> Float[torch.Tensor, "*B C H W"]: |
| | """ |
| | Applies ASC CDL transformation in Fwd or FwdNoClamp mode. |
| | |
| | Parameters: |
| | img (torch.Tensor): Input image tensor (*B, C, H, W) |
| | slope (torch.Tensor): Slope per channel (*B, C) |
| | offset (torch.Tensor): Offset per channel (*B, C) |
| | power (torch.Tensor): Power per channel (*B, C) |
| | saturation (torch.Tensor): Saturation factor (*B) |
| | clamp (bool): If True, clamps output to [0, 1] (Fwd mode). |
| | If False, no clamping (FwdNoClamp mode). |
| | |
| | Returns: |
| | torch.Tensor: Transformed image tensor. |
| | """ |
| | |
| | if (saturation < 0).any(): |
| | warnings.warn("Saturation is below 0, this will result in a color shift.") |
| | if (slope < 0).any(): |
| | warnings.warn("Slope is below 0, this will result in a color shift.") |
| | if (power < 0).any(): |
| | warnings.warn("Power is below 0, this will result in a color shift.") |
| |
|
| | img_batch_dim = img.shape[:-3] |
| | |
| | |
| | if slope.ndim == 1: |
| | slope = slope.view(*[1] * len(img_batch_dim), *slope.shape) |
| | if offset.ndim == 1: |
| | offset = offset.view(*[1] * len(img_batch_dim), *offset.shape) |
| | if power.ndim == 1: |
| | power = power.view(*[1] * len(img_batch_dim), *power.shape) |
| | if saturation.ndim == 0: |
| | saturation = saturation.view(*[1] * len(img_batch_dim), *saturation.shape) |
| |
|
| | |
| | assert slope.ndim == len(img_batch_dim) + 1 |
| | assert offset.ndim == len(img_batch_dim) + 1 |
| | assert power.ndim == len(img_batch_dim) + 1 |
| | assert saturation.ndim == len(img_batch_dim) |
| |
|
| | |
| | img = apply_sop(img, slope, offset, power, clamp=clamp) |
| | |
| | |
| | img = apply_saturation(img, saturation) |
| | |
| | |
| | if clamp: |
| | img = torch.clamp(img, 0.0, 1.0) |
| | return img |
| |
|
| |
|
| | def inverse_saturation( |
| | img: Float[torch.Tensor, "*B C H W"], |
| | saturation: Float[torch.Tensor, "*B"], |
| | ) -> Float[torch.Tensor, "*B C H W"]: |
| | """ |
| | Reverts saturation adjustment. |
| | |
| | Parameters: |
| | img (torch.Tensor): Image tensor (*B, C, H, W) |
| | saturation (torch.Tensor): Saturation factor (*B) |
| | |
| | Returns: |
| | torch.Tensor: Image after reversing saturation adjustment. |
| | """ |
| | |
| | lum = ( |
| | 0.2126 * img[..., 0, :, :] |
| | + 0.7152 * img[..., 1, :, :] |
| | + 0.0722 * img[..., 2, :, :] |
| | ) |
| | lum = lum.unsqueeze(-3) |
| | return lum + (img - lum) / saturation.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) |
| |
|
| |
|
| | def asc_cdl_reverse( |
| | img: Float[torch.Tensor, "*B C H W"], |
| | slope: Float[torch.Tensor, "*B C"], |
| | offset: Float[torch.Tensor, "*B C"], |
| | power: Float[torch.Tensor, "*B C"], |
| | saturation: Float[torch.Tensor, "*B"], |
| | clamp: bool = True, |
| | ) -> Float[torch.Tensor, "*B C H W"]: |
| | """ |
| | Applies reverse ASC CDL transformation. |
| | |
| | Parameters: |
| | img (torch.Tensor): Transformed image tensor (*B, C, H, W) |
| | slope (torch.Tensor): Slope per channel (*B, C) |
| | offset (torch.Tensor): Offset per channel (*B, C) |
| | power (torch.Tensor): Power per channel (*B, C) |
| | saturation (torch.Tensor): Saturation factor (*B) |
| | clamp (bool): If True, clamps output to [0, 1]. |
| | |
| | Returns: |
| | torch.Tensor: Recovered input image tensor. |
| | """ |
| | |
| | if (saturation < 0).any(): |
| | warnings.warn("Saturation is below 0, this will result in a color shift.") |
| | if (slope < 0).any(): |
| | warnings.warn("Slope is below 0, this will result in a color shift.") |
| | if (power < 0).any(): |
| | warnings.warn("Power is below 0, this will result in a color shift.") |
| |
|
| | img_batch_dim = img.shape[:-3] |
| | |
| | |
| | if slope.ndim == 1: |
| | slope = slope.view(*[1] * len(img_batch_dim), *slope.shape) |
| | if offset.ndim == 1: |
| | offset = offset.view(*[1] * len(img_batch_dim), *offset.shape) |
| | if power.ndim == 1: |
| | power = power.view(*[1] * len(img_batch_dim), *power.shape) |
| | if saturation.ndim == 0: |
| | saturation = saturation.view(*[1] * len(img_batch_dim), *saturation.shape) |
| |
|
| | |
| | assert slope.ndim == len(img_batch_dim) + 1 |
| | assert offset.ndim == len(img_batch_dim) + 1 |
| | assert power.ndim == len(img_batch_dim) + 1 |
| | assert saturation.ndim == len(img_batch_dim) |
| |
|
| | |
| | img = inverse_saturation(img, saturation) |
| | |
| | if clamp: |
| | img = torch.clamp(img, 0.0, 1.0) |
| | img = torch.where( |
| | img > 1e-7, torch.pow(img, 1 / power.unsqueeze(-1).unsqueeze(-1)), img |
| | ) |
| | img = (img - offset.unsqueeze(-1).unsqueeze(-1)) / slope.unsqueeze(-1).unsqueeze(-1) |
| | |
| | if clamp: |
| | img = torch.clamp(img, 0.0, 1.0) |
| | return img |
| |
|