File size: 6,196 Bytes
f67bbda
3ea2f51
f67bbda
 
 
41dda1b
 
bb8764b
f67bbda
 
 
 
 
 
 
41dda1b
 
 
 
 
dde689b
41dda1b
 
 
 
387b47b
dde689b
41dda1b
f67bbda
 
 
 
 
3ea2f51
 
f67bbda
 
bb8764b
f67bbda
 
 
 
bb8764b
f67bbda
 
 
3ea2f51
 
 
fd96015
 
41dda1b
 
 
f67bbda
 
bb8764b
 
41dda1b
fd96015
 
 
 
 
 
 
 
bb8764b
41dda1b
 
f67bbda
 
bb8764b
f67bbda
 
41dda1b
fd96015
 
 
0ddc74f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fd96015
 
 
 
 
 
 
0ddc74f
 
fd96015
 
f67bbda
 
 
 
 
bb8764b
 
f67bbda
 
 
bb8764b
f67bbda
 
 
dde689b
 
f67bbda
dde689b
 
f67bbda
 
 
41dda1b
3ea2f51
 
41dda1b
 
 
3ea2f51
 
41dda1b
 
3ea2f51
41dda1b
3ea2f51
 
 
 
41dda1b
 
3ea2f51
41dda1b
 
 
 
 
3ea2f51
 
f67bbda
fd96015
 
 
 
f67bbda
 
 
 
 
 
bb8764b
 
 
 
3ea2f51
 
4a55925
 
 
3ea2f51
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
from interactive_pipe import interactive_pipeline, interactive
from interactive_pipe import CircularControl
from interactive_pipe.data_objects.curves import SingleCurve, Curve
from interactive_pipe.data_objects.image import Image
from global_tone_mapping import apply_s_curve_tone_mapping
from color_conversions import rgb_to_hsv, hsv_to_rgb
from synthetic_charts import generate_color_wheel
import argparse
import numpy as np


def histogram(img: np.ndarray) -> Curve:
    hist_curves = []
    for ch in range(img.shape[-1]):
        hist, bins = np.histogram(
            img[..., ch].flatten(),
            bins=128,
            range=(0, 1)
        )
        hist_curves.append(
            SingleCurve(bins[:-1], hist, style='rgb'[ch], linestyle="-")
        )
    hist_curve = Curve(
        hist_curves,
        xlabel="Intensity",
        ylabel="Frequency",
        # title="Histogram"
    )
    return hist_curve


@interactive()
def set_tone_mapping_params(
    # shadow_boost: float = (0., [-1., 1.]),
    shadow_boost: float = CircularControl(0., [-1., 1.]),
    highlight_boost: float = (0., [-1., 1.]),
    exposure: float = (0., [-1., 1.]),
    contrast: float = (0., [-1., 1.]),
    global_params: dict = {}
) -> None:
    global_params["shadow_boost"] = shadow_boost
    global_params["highlight_boost"] = highlight_boost
    global_params["contrast"] = contrast
    global_params["exposure"] = exposure


@interactive(
    apply_tone_curve_on_luma=(True,)
)
def tone_map(
    hsv: np.ndarray,
    apply_tone_curve_on_luma: bool = True,
    global_params={}
) -> np.ndarray:
    shadow_boost = global_params.get("shadow_boost", 0.)
    highlight_boost = global_params.get("highlight_boost", 0.)
    contrast = global_params.get("contrast", 0.)
    exposure = global_params.get("exposure", 0.)
    if not apply_tone_curve_on_luma:
        return rgb_to_hsv(
            apply_s_curve_tone_mapping(
                hsv_to_rgb(hsv),
                shadow_boost,
                highlight_boost,
                contrast,
                exposure
            )
        )
    luma_tone_mapped = apply_s_curve_tone_mapping(
        hsv[..., -1],
        shadow_boost,
        highlight_boost,
        contrast,
        exposure
    )
    hsv[..., -1] = luma_tone_mapped
    return hsv


@interactive(
    vibrance=(0, [0, 100]),
    # vibrance_mask_smoothness=(3., [-0.9, 4.])
)
def modify_vibrance(
    hsv_in: np.ndarray,
    vibrance: float = 0.,
    vibrance_mask_smoothness: float = 3.
) -> np.ndarray:
    gain_luma = np.interp(
        hsv_in[..., 2],
        [0., 0.5, 1.],
        [0., 0., 1.]
    )
    gain_sat = np.interp(hsv_in[..., 1],
                         [0., 0.25, 1.],
                         [0., 0., 1.]
                         )**(1+vibrance_mask_smoothness)
    red_thresh = 55.
    pink_thresh = 10.
    overlap = 10.

    gain_hue = np.interp(
        hsv_in[..., 0],
        [0., red_thresh, red_thresh+overlap,
         360-pink_thresh-overlap, 360-pink_thresh, 360],
        [0., 0., 1., 1., 0., 0.]
    )**(1+vibrance_mask_smoothness)
    gain = 1. + gain_luma * (gain_sat + gain_hue).clip(0., 1.)*(vibrance / 100)
    hsv_in[..., 1] = np.clip(gain*hsv_in[..., 1], 0, 1)
    return hsv_in


def s_curve_visualization(global_params={}) -> Curve:
    shadow_boost = global_params.get("shadow_boost", 0.)
    highlight_boost = global_params.get("highlight_boost", 0.)
    contrast = global_params.get("contrast", 0.)
    exposure = global_params.get("exposure", 0.)
    x = np.linspace(0, 1, 256)
    y = apply_s_curve_tone_mapping(x, shadow_boost,
                                   highlight_boost,
                                   contrast,
                                   exposure)
    return Curve(
        [
            SingleCurve(x, y, style='m', linestyle='-'),
            SingleCurve(x, x, style='k', linestyle='--'),
        ],
        xlabel="Input", ylabel="Output",
        title=f"Tone Curve {shadow_boost:.2f}, {highlight_boost:.2f}, {contrast:.2f}, {exposure:.2f}"
    )


@interactive(
    input_image=("sample image", ["sample image",
                 "alternative image", "color wheel"]),
)
def pick_image(
    default_image: np.ndarray,
    second_image: np.ndarray,
    input_image: str = "sample image",
    global_params={}
) -> np.ndarray:
    if input_image == "sample image":
        return default_image
    elif input_image == "alternative image":
        return second_image
    elif input_image == "color wheel":
        color_wheel = global_params.get("color wheel", None)
        if color_wheel is None:
            color_wheel = generate_color_wheel(resolution=496)
            global_params["color wheel"] = color_wheel
        return color_wheel
    else:
        return default_image


def image_editing_pipeline(sample_image, second_image):
    input_image = pick_image(sample_image, second_image)
    set_tone_mapping_params()
    hsv_in = rgb_to_hsv(input_image)
    hsv_tc = tone_map(hsv_in)
    hsv_out = modify_vibrance(hsv_tc)
    tc_image = hsv_to_rgb(hsv_out)
    tone_curve = s_curve_visualization()
    histogram_curve = histogram(tc_image)
    return tc_image, histogram_curve, tone_curve


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("-b", "--backend", type=str,
                        choices=["gradio", "qt", "mpl"], default="gradio")
    args = parser.parse_args()
    img = Image.load_image("image_sample_2.jpg")
    img2 = Image.load_image("image_sample.jpg")
    markdown_description = r"# Tone mapping: S-curve" + "\n" \
        r"$$\text{shadow boost} \in [-1, 1]$$  $$\text{highlight boost} \in [-1, 1]$$  $$\text{contrast} \in [-1, 1]$$  $$\text{exposure} \in [-1, 1]$$" + "\n" \
        "If Tone curves are applied to RGB colors separately, you'll get serious color shifts. It is much better when tone curve is applied to the luminance channel." + "\n"
    interactive_pipeline(
        gui=args.backend,
        markdown_description=markdown_description,
        sliders_layout=["compact", "group",
                        "vertical", "collapsible", "smart"][2],
        sliders_per_row_layout=1
    )(image_editing_pipeline)(img, img2)