Spaces:
Running on Zero
Running on Zero
| # Taken and adapted from https://github.com/SuperBeastsAI/ComfyUI-SuperBeasts | |
| import numpy as np | |
| import logging | |
| from PIL import Image, ImageEnhance, ImageCms | |
| import torch | |
| from torchvision.transforms.functional import to_pil_image, to_tensor as tv_to_tensor | |
| logger = logging.getLogger(__name__) | |
| # Detect LCMS availability at module import and cache result. Some Pillow builds | |
| # may not have liblcms2 available which causes ImageCms.profileToProfile to fail. | |
| def _check_lcms_available(): | |
| try: | |
| img = Image.new('RGB', (1, 1)) | |
| ImageCms.profileToProfile(img, ImageCms.createProfile('sRGB'), ImageCms.createProfile('LAB'), outputMode='LAB') | |
| return True | |
| except Exception as e: | |
| logger.warning("AutoHDR: LCMS profile transform not available; AutoHDR will use RGB fallback. Error: %s", e) | |
| logger.debug("AutoHDR LCMS detection traceback", exc_info=True) | |
| return False | |
| sRGB_profile = ImageCms.createProfile("sRGB") | |
| Lab_profile = ImageCms.createProfile("LAB") | |
| _HAVE_LCMS = _check_lcms_available() | |
| def tensor2pil(image: torch.Tensor) -> Image.Image: | |
| """Convert tensor to PIL image.""" | |
| t = image.squeeze() | |
| if t.dim() == 3 and t.shape[-1] in [1, 3, 4]: | |
| t = t.permute(2, 0, 1) | |
| return to_pil_image(torch.clamp(t, 0, 1)) | |
| def pil2tensor(image: Image.Image) -> torch.Tensor: | |
| """Convert PIL image to tensor [1,H,W,C].""" | |
| return tv_to_tensor(image).unsqueeze(0).permute(0, 2, 3, 1) | |
| def adjust_shadows_non_linear(luminance, shadow_intensity, max_shadow_adjustment=1.5): | |
| lum = np.asarray(luminance, dtype=np.float32) / 255.0 | |
| shadows = lum ** (1 / (1 + shadow_intensity * max_shadow_adjustment)) | |
| return np.clip(shadows * 255, 0, 255).astype(np.uint8) | |
| def adjust_highlights_non_linear(luminance, highlight_intensity, max_highlight_adjustment=1.5): | |
| lum = np.asarray(luminance, dtype=np.float32) / 255.0 | |
| highlights = 1 - (1 - lum) ** (1 + highlight_intensity * max_highlight_adjustment) | |
| return np.clip(highlights * 255, 0, 255).astype(np.uint8) | |
| def merge_adjustments_with_blend_modes(luminance, shadows, highlights, hdr_intensity, shadow_intensity, highlight_intensity): | |
| base = np.asarray(luminance, dtype=np.float32) | |
| scaled_shadow = shadow_intensity ** 2 * hdr_intensity | |
| scaled_highlight = highlight_intensity ** 2 * hdr_intensity | |
| shadow_mask = np.clip((1 - (base / 255)) ** 2, 0, 1) | |
| highlight_mask = np.clip((base / 255) ** 2, 0, 1) | |
| adj_shadows = np.clip(base * (1 - shadow_mask * scaled_shadow), 0, 255) | |
| adj_highlights = np.clip(base + (255 - base) * highlight_mask * scaled_highlight, 0, 255) | |
| adjusted = np.clip(adj_shadows + adj_highlights - base, 0, 255) | |
| final = np.clip(base * (1 - hdr_intensity) + adjusted * hdr_intensity, 0, 255).astype(np.uint8) | |
| return Image.fromarray(final) | |
| def apply_gamma_correction(lum_array, gamma): | |
| if gamma == 0: | |
| return np.clip(lum_array, 0, 255).astype(np.uint8) | |
| gamma_corrected = 1 / (1.1 - gamma) | |
| return np.clip(255 * ((lum_array / 255) ** gamma_corrected), 0, 255).astype(np.uint8) | |
| def apply_to_batch(func): | |
| """Decorator to apply function to each image in batch. | |
| Handles the common input shapes gracefully: | |
| - 4D torch.Tensor: treated as a batch (B, H, W, C) | |
| - 3D torch.Tensor: treated as a single image (H, W, C) and wrapped to a batch | |
| - list/tuple of images: processed element-wise | |
| Returns a single-element tuple containing a batched tensor `(batch_tensor,)` | |
| to preserve the previous calling contract used elsewhere in the codebase. | |
| """ | |
| def wrapper(self, image, *args, **kwargs): | |
| # Fast-path for torch.Tensor inputs | |
| if isinstance(image, torch.Tensor): | |
| if image.ndim == 4: | |
| # Already batched: iterate over batch dimension | |
| results = [func(self, img, *args, **kwargs) for img in image] | |
| return (torch.cat(results, dim=0),) | |
| elif image.ndim == 3: | |
| # Single image: add batch dimension, process, and return batched result | |
| single = image.unsqueeze(0) | |
| results = [func(self, img, *args, **kwargs) for img in single] | |
| return (torch.cat(results, dim=0),) | |
| else: | |
| # Unexpected tensor rank: delegate to func and ensure batch wrapper | |
| res = func(self, image, *args, **kwargs) | |
| if isinstance(res, torch.Tensor): | |
| return (res.unsqueeze(0),) | |
| return res | |
| # Lists/tuples: process each element | |
| if isinstance(image, (list, tuple)): | |
| results = [func(self, img, *args, **kwargs) for img in image] | |
| return (torch.cat(results, dim=0),) | |
| # Fallback for other types (e.g., PIL Image) - convert single result to a batch | |
| res = func(self, image, *args, **kwargs) | |
| if isinstance(res, torch.Tensor): | |
| return (res.unsqueeze(0),) | |
| return res | |
| return wrapper | |
| class HDREffects: | |
| def apply_hdr2(self, image, hdr_intensity=0.75, shadow_intensity=0.25, highlight_intensity=0.5, | |
| gamma_intensity=0.25, contrast=0.1, enhance_color=0.25): | |
| global _HAVE_LCMS | |
| img = tensor2pil(image) | |
| # Handle possible alpha channel by separating it out. ICC transforms expect RGB/LAB without alpha. | |
| alpha = None | |
| if 'A' in img.getbands(): | |
| alpha = img.getchannel('A') | |
| img_rgb = img.convert('RGB') | |
| else: | |
| img_rgb = img.convert('RGB') if img.mode != 'RGB' else img | |
| # If LCMS is not available, skip ICC-based path and do the RGB fallback immediately. | |
| if not _HAVE_LCMS: | |
| img_adjusted = ImageEnhance.Contrast(img_rgb).enhance(1 + contrast) | |
| img_adjusted = ImageEnhance.Color(img_adjusted).enhance(1 + enhance_color * 0.2) | |
| img_adjusted = ImageEnhance.Brightness(img_adjusted).enhance(1 + hdr_intensity * 0.1) | |
| if alpha: | |
| img_adjusted = img_adjusted.convert('RGBA') | |
| img_adjusted.putalpha(alpha) | |
| return pil2tensor(img_adjusted) | |
| try: | |
| # Preferred path using ICC profiles (Lab transform) on RGB data | |
| img_lab = ImageCms.profileToProfile(img_rgb, sRGB_profile, Lab_profile, outputMode='LAB') | |
| luminance, a, b = img_lab.split() | |
| lum_array = np.asarray(luminance, dtype=np.float32) | |
| shadows_adj = adjust_shadows_non_linear(luminance, shadow_intensity) | |
| highlights_adj = adjust_highlights_non_linear(luminance, highlight_intensity) | |
| merged = merge_adjustments_with_blend_modes(lum_array, shadows_adj, highlights_adj, | |
| hdr_intensity, shadow_intensity, highlight_intensity) | |
| gamma_corr = Image.fromarray(apply_gamma_correction(np.asarray(merged), gamma_intensity)).resize(a.size) | |
| adjusted_lab = Image.merge('LAB', (gamma_corr, a, b)) | |
| img_adjusted = ImageCms.profileToProfile(adjusted_lab, Lab_profile, sRGB_profile, outputMode='RGB') | |
| # Re-attach alpha channel if present | |
| if alpha: | |
| img_adjusted = img_adjusted.convert('RGBA') | |
| img_adjusted.putalpha(alpha) | |
| img_adjusted = ImageEnhance.Contrast(img_adjusted).enhance(1 + contrast) | |
| img_adjusted = ImageEnhance.Color(img_adjusted).enhance(1 + enhance_color * 0.2) | |
| return pil2tensor(img_adjusted) | |
| except Exception as e: | |
| logger.exception("AutoHDR: profile transform failed; using RGB fallback") | |
| # Disable LCMS after a runtime failure to avoid repeated exceptions | |
| _HAVE_LCMS = False | |
| img_adjusted = ImageEnhance.Contrast(img_rgb).enhance(1 + contrast) | |
| img_adjusted = ImageEnhance.Color(img_adjusted).enhance(1 + enhance_color * 0.2) | |
| img_adjusted = ImageEnhance.Brightness(img_adjusted).enhance(1 + hdr_intensity * 0.1) | |
| if alpha: | |
| img_adjusted = img_adjusted.convert('RGBA') | |
| img_adjusted.putalpha(alpha) | |
| return pil2tensor(img_adjusted) | |