File size: 8,191 Bytes
b701455
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
# Taken and adapted from https://github.com/SuperBeastsAI/ComfyUI-SuperBeasts

import numpy as np
import logging
from PIL import Image, ImageEnhance, ImageCms
import torch
from torchvision.transforms.functional import to_pil_image, to_tensor as tv_to_tensor

logger = logging.getLogger(__name__)

# Detect LCMS availability at module import and cache result. Some Pillow builds
# may not have liblcms2 available which causes ImageCms.profileToProfile to fail.
def _check_lcms_available():
    try:
        img = Image.new('RGB', (1, 1))
        ImageCms.profileToProfile(img, ImageCms.createProfile('sRGB'), ImageCms.createProfile('LAB'), outputMode='LAB')
        return True
    except Exception as e:
        logger.warning("AutoHDR: LCMS profile transform not available; AutoHDR will use RGB fallback. Error: %s", e)
        logger.debug("AutoHDR LCMS detection traceback", exc_info=True)
        return False


sRGB_profile = ImageCms.createProfile("sRGB")
Lab_profile = ImageCms.createProfile("LAB")
_HAVE_LCMS = _check_lcms_available()


def tensor2pil(image: torch.Tensor) -> Image.Image:
    """Convert tensor to PIL image."""
    t = image.squeeze()
    if t.dim() == 3 and t.shape[-1] in [1, 3, 4]:
        t = t.permute(2, 0, 1)
    return to_pil_image(torch.clamp(t, 0, 1))


def pil2tensor(image: Image.Image) -> torch.Tensor:
    """Convert PIL image to tensor [1,H,W,C]."""
    return tv_to_tensor(image).unsqueeze(0).permute(0, 2, 3, 1)

def adjust_shadows_non_linear(luminance, shadow_intensity, max_shadow_adjustment=1.5):
    lum = np.asarray(luminance, dtype=np.float32) / 255.0
    shadows = lum ** (1 / (1 + shadow_intensity * max_shadow_adjustment))
    return np.clip(shadows * 255, 0, 255).astype(np.uint8)


def adjust_highlights_non_linear(luminance, highlight_intensity, max_highlight_adjustment=1.5):
    lum = np.asarray(luminance, dtype=np.float32) / 255.0
    highlights = 1 - (1 - lum) ** (1 + highlight_intensity * max_highlight_adjustment)
    return np.clip(highlights * 255, 0, 255).astype(np.uint8)


def merge_adjustments_with_blend_modes(luminance, shadows, highlights, hdr_intensity, shadow_intensity, highlight_intensity):
    base = np.asarray(luminance, dtype=np.float32)
    scaled_shadow = shadow_intensity ** 2 * hdr_intensity
    scaled_highlight = highlight_intensity ** 2 * hdr_intensity
    shadow_mask = np.clip((1 - (base / 255)) ** 2, 0, 1)
    highlight_mask = np.clip((base / 255) ** 2, 0, 1)
    adj_shadows = np.clip(base * (1 - shadow_mask * scaled_shadow), 0, 255)
    adj_highlights = np.clip(base + (255 - base) * highlight_mask * scaled_highlight, 0, 255)
    adjusted = np.clip(adj_shadows + adj_highlights - base, 0, 255)
    final = np.clip(base * (1 - hdr_intensity) + adjusted * hdr_intensity, 0, 255).astype(np.uint8)
    return Image.fromarray(final)


def apply_gamma_correction(lum_array, gamma):
    if gamma == 0:
        return np.clip(lum_array, 0, 255).astype(np.uint8)
    gamma_corrected = 1 / (1.1 - gamma)
    return np.clip(255 * ((lum_array / 255) ** gamma_corrected), 0, 255).astype(np.uint8)
def apply_to_batch(func):
    """Decorator to apply function to each image in batch.

    Handles the common input shapes gracefully:
    - 4D torch.Tensor: treated as a batch (B, H, W, C)
    - 3D torch.Tensor: treated as a single image (H, W, C) and wrapped to a batch
    - list/tuple of images: processed element-wise

    Returns a single-element tuple containing a batched tensor `(batch_tensor,)`
    to preserve the previous calling contract used elsewhere in the codebase.
    """
    def wrapper(self, image, *args, **kwargs):
        # Fast-path for torch.Tensor inputs
        if isinstance(image, torch.Tensor):
            if image.ndim == 4:
                # Already batched: iterate over batch dimension
                results = [func(self, img, *args, **kwargs) for img in image]
                return (torch.cat(results, dim=0),)
            elif image.ndim == 3:
                # Single image: add batch dimension, process, and return batched result
                single = image.unsqueeze(0)
                results = [func(self, img, *args, **kwargs) for img in single]
                return (torch.cat(results, dim=0),)
            else:
                # Unexpected tensor rank: delegate to func and ensure batch wrapper
                res = func(self, image, *args, **kwargs)
                if isinstance(res, torch.Tensor):
                    return (res.unsqueeze(0),)
                return res

        # Lists/tuples: process each element
        if isinstance(image, (list, tuple)):
            results = [func(self, img, *args, **kwargs) for img in image]
            return (torch.cat(results, dim=0),)

        # Fallback for other types (e.g., PIL Image) - convert single result to a batch
        res = func(self, image, *args, **kwargs)
        if isinstance(res, torch.Tensor):
            return (res.unsqueeze(0),)
        return res

    return wrapper


class HDREffects:
    @apply_to_batch
    def apply_hdr2(self, image, hdr_intensity=0.75, shadow_intensity=0.25, highlight_intensity=0.5, 
                   gamma_intensity=0.25, contrast=0.1, enhance_color=0.25):
        global _HAVE_LCMS
        img = tensor2pil(image)
        # Handle possible alpha channel by separating it out. ICC transforms expect RGB/LAB without alpha.
        alpha = None
        if 'A' in img.getbands():
            alpha = img.getchannel('A')
            img_rgb = img.convert('RGB')
        else:
            img_rgb = img.convert('RGB') if img.mode != 'RGB' else img

        # If LCMS is not available, skip ICC-based path and do the RGB fallback immediately.
        if not _HAVE_LCMS:
            img_adjusted = ImageEnhance.Contrast(img_rgb).enhance(1 + contrast)
            img_adjusted = ImageEnhance.Color(img_adjusted).enhance(1 + enhance_color * 0.2)
            img_adjusted = ImageEnhance.Brightness(img_adjusted).enhance(1 + hdr_intensity * 0.1)
            if alpha:
                img_adjusted = img_adjusted.convert('RGBA')
                img_adjusted.putalpha(alpha)
            return pil2tensor(img_adjusted)
        try:
            # Preferred path using ICC profiles (Lab transform) on RGB data
            img_lab = ImageCms.profileToProfile(img_rgb, sRGB_profile, Lab_profile, outputMode='LAB')
            luminance, a, b = img_lab.split()
            lum_array = np.asarray(luminance, dtype=np.float32)

            shadows_adj = adjust_shadows_non_linear(luminance, shadow_intensity)
            highlights_adj = adjust_highlights_non_linear(luminance, highlight_intensity)
            merged = merge_adjustments_with_blend_modes(lum_array, shadows_adj, highlights_adj,
                                                         hdr_intensity, shadow_intensity, highlight_intensity)
            gamma_corr = Image.fromarray(apply_gamma_correction(np.asarray(merged), gamma_intensity)).resize(a.size)

            adjusted_lab = Image.merge('LAB', (gamma_corr, a, b))
            img_adjusted = ImageCms.profileToProfile(adjusted_lab, Lab_profile, sRGB_profile, outputMode='RGB')
            # Re-attach alpha channel if present
            if alpha:
                img_adjusted = img_adjusted.convert('RGBA')
                img_adjusted.putalpha(alpha)
            img_adjusted = ImageEnhance.Contrast(img_adjusted).enhance(1 + contrast)
            img_adjusted = ImageEnhance.Color(img_adjusted).enhance(1 + enhance_color * 0.2)
            return pil2tensor(img_adjusted)
        except Exception as e:
            logger.exception("AutoHDR: profile transform failed; using RGB fallback")
            # Disable LCMS after a runtime failure to avoid repeated exceptions
            _HAVE_LCMS = False
            img_adjusted = ImageEnhance.Contrast(img_rgb).enhance(1 + contrast)
            img_adjusted = ImageEnhance.Color(img_adjusted).enhance(1 + enhance_color * 0.2)
            img_adjusted = ImageEnhance.Brightness(img_adjusted).enhance(1 + hdr_intensity * 0.1)
            if alpha:
                img_adjusted = img_adjusted.convert('RGBA')
                img_adjusted.putalpha(alpha)
            return pil2tensor(img_adjusted)