Spaces:
Runtime error
Runtime error
| import warnings | |
| from typing import Optional | |
| import torch | |
| from jaxtyping import Float | |
| from lxml import etree | |
| def load_asc_cdl(cdl_path: str, device: torch.device = torch.device("cpu")) -> dict: | |
| """ | |
| Loads ASC CDL parameters from an XML file. | |
| Parameters: | |
| cdl_path (str): Path to the ASC CDL XML file | |
| Returns: | |
| Dict: | |
| slope, offset, power, and saturation values as torch tensors | |
| """ | |
| try: | |
| tree = etree.parse(cdl_path) | |
| root = tree.getroot() | |
| except Exception as e: | |
| raise ValueError(f"Error loading ASC CDL from {cdl_path}: {e}") | |
| # Extract SOP values | |
| sop_node = root.find(".//SOPNode") | |
| slope = torch.tensor( | |
| [float(x) for x in sop_node.find("Slope").text.split()], device=device | |
| ) | |
| offset = torch.tensor( | |
| [float(x) for x in sop_node.find("Offset").text.split()], device=device | |
| ) | |
| power = torch.tensor( | |
| [float(x) for x in sop_node.find("Power").text.split()], device=device | |
| ) | |
| # Extract Saturation value | |
| sat_node = root.find(".//SatNode") | |
| saturation = torch.tensor(float(sat_node.find("Saturation").text), device=device) | |
| return {"slope": slope, "offset": offset, "power": power, "saturation": saturation} | |
| def save_asc_cdl(cdl_dict: dict, cdl_path: Optional[str]): | |
| """ | |
| Saves ASC CDL parameters to an XML file. | |
| Parameters: | |
| cdl_dict (dict): Dictionary containing slope, offset, power, and | |
| saturation values | |
| """ | |
| root = etree.Element("ASC_CDL") | |
| sop_node = etree.SubElement(root, "SOPNode") | |
| etree.SubElement(sop_node, "Slope").text = " ".join( | |
| str(x) for x in cdl_dict["slope"].detach().cpu().numpy() | |
| ) | |
| etree.SubElement(sop_node, "Offset").text = " ".join( | |
| str(x) for x in cdl_dict["offset"].detach().cpu().numpy() | |
| ) | |
| etree.SubElement(sop_node, "Power").text = " ".join( | |
| str(x) for x in cdl_dict["power"].detach().cpu().numpy() | |
| ) | |
| sat_node = etree.SubElement(root, "SatNode") | |
| etree.SubElement(sat_node, "Saturation").text = str( | |
| cdl_dict["saturation"].detach().cpu().numpy() | |
| ) | |
| tree = etree.ElementTree(root) | |
| if cdl_path is not None: | |
| try: | |
| tree.write( | |
| cdl_path, pretty_print=True, xml_declaration=True, encoding="utf-8" | |
| ) | |
| except Exception as e: | |
| raise ValueError(f"Error saving ASC CDL to {cdl_path}: {e}") | |
| else: | |
| return etree.tostring( | |
| root, pretty_print=True, xml_declaration=True, encoding="utf-8" | |
| ).decode("utf-8") | |
| def apply_sop( | |
| img: Float[torch.Tensor, "*B C H W"], | |
| slope: Float[torch.Tensor, "*B C"], | |
| offset: Float[torch.Tensor, "*B C"], | |
| power: Float[torch.Tensor, "*B C"], | |
| clamp: bool = True, | |
| ) -> Float[torch.Tensor, "*B C H W"]: | |
| """ | |
| Applies Slope, Offset, and Power adjustments. | |
| Parameters: | |
| img (torch.Tensor): Input image tensor (*B, C, H, W) | |
| slope (torch.Tensor): Slope per channel (*B, C) | |
| offset (torch.Tensor): Offset per channel (*B, C) | |
| power (torch.Tensor): Power per channel (*B, C) | |
| Returns: | |
| torch.Tensor: Image after SOP adjustments. | |
| """ | |
| so = img * slope.unsqueeze(-1).unsqueeze(-1) + offset.unsqueeze(-1).unsqueeze(-1) | |
| if clamp: | |
| so = torch.clamp(so, min=0.0, max=1.0) | |
| return torch.where( | |
| so > 1e-7, torch.pow(so.clamp(min=1e-7), power.unsqueeze(-1).unsqueeze(-1)), so | |
| ) | |
| def apply_saturation( | |
| img: Float[torch.Tensor, "*B C H W"], | |
| saturation: Float[torch.Tensor, "*B"], | |
| ) -> Float[torch.Tensor, "*B C H W"]: | |
| """ | |
| Applies saturation adjustment. | |
| Parameters: | |
| img (torch.Tensor): Image tensor (*B, C, H, W) | |
| saturation (torch.Tensor): Saturation factor (*B) | |
| Returns: | |
| torch.Tensor: Image after saturation adjustment. | |
| """ | |
| # Calculate luminance using Rec. 709 coefficients | |
| lum = ( | |
| 0.2126 * img[..., 0, :, :] | |
| + 0.7152 * img[..., 1, :, :] | |
| + 0.0722 * img[..., 2, :, :] | |
| ) | |
| lum = lum.unsqueeze(-3) # Add channel dimension | |
| return lum + (img - lum) * saturation.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) | |
| def asc_cdl_forward( | |
| img: Float[torch.Tensor, "*B C H W"], | |
| slope: Float[torch.Tensor, "*B C"], | |
| offset: Float[torch.Tensor, "*B C"], | |
| power: Float[torch.Tensor, "*B C"], | |
| saturation: Float[torch.Tensor, "*B"], | |
| clamp: bool = True, | |
| ) -> Float[torch.Tensor, "*B C H W"]: | |
| """ | |
| Applies ASC CDL transformation in Fwd or FwdNoClamp mode. | |
| Parameters: | |
| img (torch.Tensor): Input image tensor (*B, C, H, W) | |
| slope (torch.Tensor): Slope per channel (*B, C) | |
| offset (torch.Tensor): Offset per channel (*B, C) | |
| power (torch.Tensor): Power per channel (*B, C) | |
| saturation (torch.Tensor): Saturation factor (*B) | |
| clamp (bool): If True, clamps output to [0, 1] (Fwd mode). | |
| If False, no clamping (FwdNoClamp mode). | |
| Returns: | |
| torch.Tensor: Transformed image tensor. | |
| """ | |
| # Add warning if saturation, slope, power are below 0 | |
| if (saturation < 0).any(): | |
| warnings.warn("Saturation is below 0, this will result in a color shift.") | |
| if (slope < 0).any(): | |
| warnings.warn("Slope is below 0, this will result in a color shift.") | |
| if (power < 0).any(): | |
| warnings.warn("Power is below 0, this will result in a color shift.") | |
| img_batch_dim = img.shape[:-3] | |
| # Check if slope, offset, power, saturation have the same batch dimension | |
| # If they do not have any batch dimensions, add a single batch dimensions | |
| if slope.ndim == 1: | |
| slope = slope.view(*[1] * len(img_batch_dim), *slope.shape) | |
| if offset.ndim == 1: | |
| offset = offset.view(*[1] * len(img_batch_dim), *offset.shape) | |
| if power.ndim == 1: | |
| power = power.view(*[1] * len(img_batch_dim), *power.shape) | |
| if saturation.ndim == 0: | |
| saturation = saturation.view(*[1] * len(img_batch_dim), *saturation.shape) | |
| # Now check that the lengths are matching | |
| assert slope.ndim == len(img_batch_dim) + 1 | |
| assert offset.ndim == len(img_batch_dim) + 1 | |
| assert power.ndim == len(img_batch_dim) + 1 | |
| assert saturation.ndim == len(img_batch_dim) | |
| # Apply Slope, Offset, and Power adjustments | |
| img = apply_sop(img, slope, offset, power, clamp=clamp) | |
| # print("img after sop", img.min(), img.max()) | |
| # Apply Saturation adjustment | |
| img = apply_saturation(img, saturation) | |
| # print("img after saturation", img.min(), img.max()) | |
| # Clamp if in Fwd mode | |
| if clamp: | |
| img = torch.clamp(img, 0.0, 1.0) | |
| return img | |
| def inverse_saturation( | |
| img: Float[torch.Tensor, "*B C H W"], | |
| saturation: Float[torch.Tensor, "*B"], | |
| ) -> Float[torch.Tensor, "*B C H W"]: | |
| """ | |
| Reverts saturation adjustment. | |
| Parameters: | |
| img (torch.Tensor): Image tensor (*B, C, H, W) | |
| saturation (torch.Tensor): Saturation factor (*B) | |
| Returns: | |
| torch.Tensor: Image after reversing saturation adjustment. | |
| """ | |
| # Calculate luminance using Rec. 709 coefficients | |
| lum = ( | |
| 0.2126 * img[..., 0, :, :] | |
| + 0.7152 * img[..., 1, :, :] | |
| + 0.0722 * img[..., 2, :, :] | |
| ) | |
| lum = lum.unsqueeze(-3) # Add channel dimension | |
| return lum + (img - lum) / saturation.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) | |
| def asc_cdl_reverse( | |
| img: Float[torch.Tensor, "*B C H W"], | |
| slope: Float[torch.Tensor, "*B C"], | |
| offset: Float[torch.Tensor, "*B C"], | |
| power: Float[torch.Tensor, "*B C"], | |
| saturation: Float[torch.Tensor, "*B"], | |
| clamp: bool = True, | |
| ) -> Float[torch.Tensor, "*B C H W"]: | |
| """ | |
| Applies reverse ASC CDL transformation. | |
| Parameters: | |
| img (torch.Tensor): Transformed image tensor (*B, C, H, W) | |
| slope (torch.Tensor): Slope per channel (*B, C) | |
| offset (torch.Tensor): Offset per channel (*B, C) | |
| power (torch.Tensor): Power per channel (*B, C) | |
| saturation (torch.Tensor): Saturation factor (*B) | |
| clamp (bool): If True, clamps output to [0, 1]. | |
| Returns: | |
| torch.Tensor: Recovered input image tensor. | |
| """ | |
| # Add warning if saturation, slope, power are below 0 | |
| if (saturation < 0).any(): | |
| warnings.warn("Saturation is below 0, this will result in a color shift.") | |
| if (slope < 0).any(): | |
| warnings.warn("Slope is below 0, this will result in a color shift.") | |
| if (power < 0).any(): | |
| warnings.warn("Power is below 0, this will result in a color shift.") | |
| img_batch_dim = img.shape[:-3] | |
| # Check if slope, offset, power, saturation have the same batch dimension | |
| # If they do not have any batch dimensions, add a single batch dimensions | |
| if slope.ndim == 1: | |
| slope = slope.view(*[1] * len(img_batch_dim), *slope.shape) | |
| if offset.ndim == 1: | |
| offset = offset.view(*[1] * len(img_batch_dim), *offset.shape) | |
| if power.ndim == 1: | |
| power = power.view(*[1] * len(img_batch_dim), *power.shape) | |
| if saturation.ndim == 0: | |
| saturation = saturation.view(*[1] * len(img_batch_dim), *saturation.shape) | |
| # Now check that the lengths are matching | |
| assert slope.ndim == len(img_batch_dim) + 1 | |
| assert offset.ndim == len(img_batch_dim) + 1 | |
| assert power.ndim == len(img_batch_dim) + 1 | |
| assert saturation.ndim == len(img_batch_dim) | |
| # Inverse Saturation adjustment | |
| img = inverse_saturation(img, saturation) | |
| # Inverse SOP adjustments | |
| if clamp: | |
| img = torch.clamp(img, 0.0, 1.0) | |
| img = torch.where( | |
| img > 1e-7, torch.pow(img, 1 / power.unsqueeze(-1).unsqueeze(-1)), img | |
| ) | |
| img = (img - offset.unsqueeze(-1).unsqueeze(-1)) / slope.unsqueeze(-1).unsqueeze(-1) | |
| # Clamp if specified | |
| if clamp: | |
| img = torch.clamp(img, 0.0, 1.0) | |
| return img | |