balthou's picture
update interactive pipe version
3ea2f51
from interactive_pipe import interactive_pipeline, interactive
from interactive_pipe import CircularControl
from interactive_pipe.data_objects.curves import SingleCurve, Curve
from interactive_pipe.data_objects.image import Image
from global_tone_mapping import apply_s_curve_tone_mapping
from color_conversions import rgb_to_hsv, hsv_to_rgb
from synthetic_charts import generate_color_wheel
import argparse
import numpy as np
def histogram(img: np.ndarray) -> Curve:
hist_curves = []
for ch in range(img.shape[-1]):
hist, bins = np.histogram(
img[..., ch].flatten(),
bins=128,
range=(0, 1)
)
hist_curves.append(
SingleCurve(bins[:-1], hist, style='rgb'[ch], linestyle="-")
)
hist_curve = Curve(
hist_curves,
xlabel="Intensity",
ylabel="Frequency",
# title="Histogram"
)
return hist_curve
@interactive()
def set_tone_mapping_params(
# shadow_boost: float = (0., [-1., 1.]),
shadow_boost: float = CircularControl(0., [-1., 1.]),
highlight_boost: float = (0., [-1., 1.]),
exposure: float = (0., [-1., 1.]),
contrast: float = (0., [-1., 1.]),
global_params: dict = {}
) -> None:
global_params["shadow_boost"] = shadow_boost
global_params["highlight_boost"] = highlight_boost
global_params["contrast"] = contrast
global_params["exposure"] = exposure
@interactive(
apply_tone_curve_on_luma=(True,)
)
def tone_map(
hsv: np.ndarray,
apply_tone_curve_on_luma: bool = True,
global_params={}
) -> np.ndarray:
shadow_boost = global_params.get("shadow_boost", 0.)
highlight_boost = global_params.get("highlight_boost", 0.)
contrast = global_params.get("contrast", 0.)
exposure = global_params.get("exposure", 0.)
if not apply_tone_curve_on_luma:
return rgb_to_hsv(
apply_s_curve_tone_mapping(
hsv_to_rgb(hsv),
shadow_boost,
highlight_boost,
contrast,
exposure
)
)
luma_tone_mapped = apply_s_curve_tone_mapping(
hsv[..., -1],
shadow_boost,
highlight_boost,
contrast,
exposure
)
hsv[..., -1] = luma_tone_mapped
return hsv
@interactive(
vibrance=(0, [0, 100]),
# vibrance_mask_smoothness=(3., [-0.9, 4.])
)
def modify_vibrance(
hsv_in: np.ndarray,
vibrance: float = 0.,
vibrance_mask_smoothness: float = 3.
) -> np.ndarray:
gain_luma = np.interp(
hsv_in[..., 2],
[0., 0.5, 1.],
[0., 0., 1.]
)
gain_sat = np.interp(hsv_in[..., 1],
[0., 0.25, 1.],
[0., 0., 1.]
)**(1+vibrance_mask_smoothness)
red_thresh = 55.
pink_thresh = 10.
overlap = 10.
gain_hue = np.interp(
hsv_in[..., 0],
[0., red_thresh, red_thresh+overlap,
360-pink_thresh-overlap, 360-pink_thresh, 360],
[0., 0., 1., 1., 0., 0.]
)**(1+vibrance_mask_smoothness)
gain = 1. + gain_luma * (gain_sat + gain_hue).clip(0., 1.)*(vibrance / 100)
hsv_in[..., 1] = np.clip(gain*hsv_in[..., 1], 0, 1)
return hsv_in
def s_curve_visualization(global_params={}) -> Curve:
shadow_boost = global_params.get("shadow_boost", 0.)
highlight_boost = global_params.get("highlight_boost", 0.)
contrast = global_params.get("contrast", 0.)
exposure = global_params.get("exposure", 0.)
x = np.linspace(0, 1, 256)
y = apply_s_curve_tone_mapping(x, shadow_boost,
highlight_boost,
contrast,
exposure)
return Curve(
[
SingleCurve(x, y, style='m', linestyle='-'),
SingleCurve(x, x, style='k', linestyle='--'),
],
xlabel="Input", ylabel="Output",
title=f"Tone Curve {shadow_boost:.2f}, {highlight_boost:.2f}, {contrast:.2f}, {exposure:.2f}"
)
@interactive(
input_image=("sample image", ["sample image",
"alternative image", "color wheel"]),
)
def pick_image(
default_image: np.ndarray,
second_image: np.ndarray,
input_image: str = "sample image",
global_params={}
) -> np.ndarray:
if input_image == "sample image":
return default_image
elif input_image == "alternative image":
return second_image
elif input_image == "color wheel":
color_wheel = global_params.get("color wheel", None)
if color_wheel is None:
color_wheel = generate_color_wheel(resolution=496)
global_params["color wheel"] = color_wheel
return color_wheel
else:
return default_image
def image_editing_pipeline(sample_image, second_image):
input_image = pick_image(sample_image, second_image)
set_tone_mapping_params()
hsv_in = rgb_to_hsv(input_image)
hsv_tc = tone_map(hsv_in)
hsv_out = modify_vibrance(hsv_tc)
tc_image = hsv_to_rgb(hsv_out)
tone_curve = s_curve_visualization()
histogram_curve = histogram(tc_image)
return tc_image, histogram_curve, tone_curve
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("-b", "--backend", type=str,
choices=["gradio", "qt", "mpl"], default="gradio")
args = parser.parse_args()
img = Image.load_image("image_sample_2.jpg")
img2 = Image.load_image("image_sample.jpg")
markdown_description = r"# Tone mapping: S-curve" + "\n" \
r"$$\text{shadow boost} \in [-1, 1]$$ $$\text{highlight boost} \in [-1, 1]$$ $$\text{contrast} \in [-1, 1]$$ $$\text{exposure} \in [-1, 1]$$" + "\n" \
"If Tone curves are applied to RGB colors separately, you'll get serious color shifts. It is much better when tone curve is applied to the luminance channel." + "\n"
interactive_pipeline(
gui=args.backend,
markdown_description=markdown_description,
sliders_layout=["compact", "group",
"vertical", "collapsible", "smart"][2],
sliders_per_row_layout=1
)(image_editing_pipeline)(img, img2)