File size: 3,876 Bytes
9f71d3a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import cv2
import numpy as np
import gradio as gr
import tempfile

def color_transfer(source, target):
    # Resize target to match source
    target = cv2.resize(target, (source.shape[1], source.shape[0]))

    # Convert to LAB color space
    source_lab = cv2.cvtColor(source, cv2.COLOR_BGR2LAB).astype("float32")
    target_lab = cv2.cvtColor(target, cv2.COLOR_BGR2LAB).astype("float32")

    # Calculate mean and std dev
    (l_mean_src, a_mean_src, b_mean_src), (l_std_src, a_std_src, b_std_src) = cv2.meanStdDev(source_lab)
    (l_mean_tar, a_mean_tar, b_mean_tar), (l_std_tar, a_std_tar, b_std_tar) = cv2.meanStdDev(target_lab)

    # Apply color transfer
    l, a, b = cv2.split(source_lab)
    l -= l_mean_src[0]
    a -= a_mean_src[0]
    b -= b_mean_src[0]
    l = (l * (l_std_tar[0] / l_std_src[0])) + l_mean_tar[0]
    a = (a * (a_std_tar[0] / a_std_src[0])) + a_mean_tar[0]
    b = (b * (b_std_tar[0] / b_std_src[0])) + b_mean_tar[0]

    # Merge and convert back to BGR
    transfer_lab = cv2.merge([l, a, b])
    transfer_lab = np.clip(transfer_lab, 0, 255).astype("uint8")
    result_bgr = cv2.cvtColor(transfer_lab, cv2.COLOR_LAB2BGR)

    return result_bgr, target  # Return resized target too

def process_images(source_img, ref_img, red_scale, green_scale, blue_scale, brightness, contrast):
    if source_img is None or ref_img is None:
        raise gr.Error("Please upload both the source image and the reference image before clicking Convert.")

    # Convert RGB to BGR
    source_bgr = cv2.cvtColor(source_img, cv2.COLOR_RGB2BGR)
    ref_bgr = cv2.cvtColor(ref_img, cv2.COLOR_RGB2BGR)

    # Perform color transfer
    output_bgr, resized_ref_bgr = color_transfer(source_bgr, ref_bgr)

    # Apply color scaling
    b, g, r = cv2.split(output_bgr)
    r = np.clip(r * red_scale, 0, 255).astype(np.uint8)
    g = np.clip(g * green_scale, 0, 255).astype(np.uint8)
    b = np.clip(b * blue_scale, 0, 255).astype(np.uint8)
    output_bgr = cv2.merge([b, g, r])

    # Apply brightness and contrast adjustment
    output_bgr = cv2.convertScaleAbs(output_bgr, alpha=contrast, beta=brightness)

    # Convert all images to RGB for display
    source_rgb = cv2.cvtColor(source_bgr, cv2.COLOR_BGR2RGB)
    ref_rgb = cv2.cvtColor(resized_ref_bgr, cv2.COLOR_BGR2RGB)
    output_rgb = cv2.cvtColor(output_bgr, cv2.COLOR_BGR2RGB)

    # Save output to temp file for download
    temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".png")
    cv2.imwrite(temp_file.name, output_bgr)

    return source_rgb, ref_rgb, output_rgb, temp_file.name


# Gradio UI
with gr.Blocks() as demo:
    gr.Markdown("## 🎨 Color Transfer: Style Match Between Images")

    with gr.Row():
        source_input = gr.Image(type="numpy", label="Upload Source Image")
        ref_input = gr.Image(type="numpy", label="Upload Reference Image")

    with gr.Row():
        red_slider = gr.Slider(0.0, 2.0, value=1.0, label="Red Scale")
        green_slider = gr.Slider(0.0, 2.0, value=1.0, label="Green Scale")
        blue_slider = gr.Slider(0.0, 2.0, value=1.0, label="Blue Scale")

    with gr.Row():
        brightness_slider = gr.Slider(-100, 100, value=0, label="Brightness")
        contrast_slider = gr.Slider(0.1, 3.0, value=1.0, label="Contrast")

    convert_button = gr.Button("Convert")

    with gr.Row():
        source_display = gr.Image(label="Source Image")
        ref_display = gr.Image(label="Resized Reference Image")
        output_display = gr.Image(label="Output Image")

    download_output = gr.File(label="Download Result Image")

    convert_button.click(
        fn=process_images,
        inputs=[
            source_input, ref_input,
            red_slider, green_slider, blue_slider,
            brightness_slider, contrast_slider
        ],
        outputs=[source_display, ref_display, output_display, download_output]
    )

demo.launch()