File size: 6,075 Bytes
4c62147
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
import gradio as gr
from rembg import remove
from PIL import Image, ImageOps, ImageEnhance, ImageStat
import torch
from torchvision import transforms
from torchvision.models import vgg19, VGG19_Weights

# Function to unify the image using a pre-trained VGG19 model
def unify_image(combined_img):
    # Load pre-trained VGG19 model
    weights = VGG19_Weights.IMAGENET1K_V1
    model = vgg19(weights=weights).features.eval()
    
    preprocess = transforms.Compose([
        transforms.Resize((256, 256)),
        transforms.ToTensor(),
        transforms.Normalize(mean=weights.meta["mean"], std=weights.meta["std"]),
    ])

    # Preprocess the image
    input_tensor = preprocess(combined_img.convert("RGB")).unsqueeze(0)
    
    # Forward pass
    with torch.no_grad():
        output_tensor = model(input_tensor).squeeze(0)
    
    # Postprocess the output
    postprocess = transforms.Compose([
        transforms.Normalize(mean=[-2.118, -2.036, -1.804], std=[4.367, 4.464, 4.444]),
        transforms.ToPILImage(),
    ])
    
    unified_img = postprocess(output_tensor.cpu()).convert("RGBA")
    return unified_img

def embed_person_on_background(person_img, background_img):
    # Preserve the aspect ratio and resize the person image to fit within the background
    person_img = ImageOps.contain(person_img, background_img.size, method=Image.LANCZOS)

    # Create a new image with the same size as the background and paste the person image onto it
    combined_img = Image.new("RGBA", background_img.size)
    combined_img.paste(background_img, (0, 0))
    combined_img.paste(person_img, (0, 0), person_img)

    return combined_img

def auto_match_enhancers(person_img, background_img):
    # Calculate the enhancement factors based on the background image
    stat = ImageStat.Stat(background_img)
    mean = stat.mean[:3]  # Mean color of the background

    # Simple logic to calculate enhancement factors
    contrast = 1.5 if mean[0] < 128 else 1.2
    brightness = 1.2 if mean[1] < 128 else 1.1
    color = 1.3 if mean[2] < 128 else 1.0

    enhancers = [
        (ImageEnhance.Contrast(person_img), contrast),
        (ImageEnhance.Brightness(person_img), brightness),
        (ImageEnhance.Color(person_img), color),
    ]

    enhanced_image = person_img
    for enhancer, factor in enhancers:
        enhanced_image = enhancer.enhance(factor)
    return enhanced_image

def enhance_image(image, contrast, brightness, color):
    # Enhance the image based on the provided parameters
    enhancers = [
        (ImageEnhance.Contrast(image), contrast),
        (ImageEnhance.Brightness(image), brightness),
        (ImageEnhance.Color(image), color),
    ]
    enhanced_image = image
    for enhancer, factor in enhancers:
        enhanced_image = enhancer.enhance(factor)
    return enhanced_image

def process_images(person_img, background_img, num_images, enhance, auto_match, contrast, brightness, color, unify):
    # Validate parameters
    if not (1 <= num_images <= 5):
        raise ValueError("Number of Output Images must be between 1 and 5")
    
    # Remove background from the person image
    person_no_bg = remove(person_img)
    
    if enhance and auto_match:
        print("Auto-matching enhancers based on the background color...")
        person_no_bg = auto_match_enhancers(person_no_bg, background_img)
    elif enhance:
        print(f"Applying enhancement with contrast={contrast}, brightness={brightness}, color={color}...")
        person_no_bg = enhance_image(person_no_bg, contrast, brightness, color)
    
    combined_img = embed_person_on_background(person_no_bg, background_img)

    if unify:
        print("Unifying image with AI...")
        combined_img = unify_image(combined_img)
    
    outputs = [combined_img] * num_images
    
    return outputs

def gradio_interface(person_img, background_img, num_images, enhance, auto_match, contrast, brightness, color, unify):
    try:
        results = process_images(person_img, background_img, num_images, enhance, auto_match, contrast, brightness, color, unify)
        return results + [None] * (5 - len(results))  # Ensure the number of returned images matches the expected output
    except Exception as e:
        return [str(e)] + [None] * 4

def update_enhancement_controls(auto_match):
    # Disable enhancement sliders if auto-match is checked
    return {
        contrast_slider: gr.update(interactive=not auto_match),
        brightness_slider: gr.update(interactive=not auto_match),
        color_slider: gr.update(interactive=not auto_match),
    }

# Create Gradio interface
with gr.Blocks() as interface:
    with gr.Row():
        person_img = gr.Image(type="pil", label="Upload Person Image")
        background_img = gr.Image(type="pil", label="Upload Background Image")
    num_images = gr.Slider(minimum=1, maximum=5, step=1, value=1, label="Number of Output Images")
    enhance = gr.Checkbox(label="Enhance Image", value=False)
    auto_match = gr.Checkbox(label="Auto-Match Enhancers", value=False)
    contrast_slider = gr.Slider(minimum=0.5, maximum=3.0, step=0.1, value=1.0, label="Contrast")
    brightness_slider = gr.Slider(minimum=0.5, maximum=3.0, step=0.1, value=1.0, label="Brightness")
    color_slider = gr.Slider(minimum=0.5, maximum=3.0, step=0.1, value=1.0, label="Color")
    
    auto_match.change(fn=update_enhancement_controls, inputs=auto_match, outputs=[contrast_slider, brightness_slider, color_slider])
    
    unify = gr.Checkbox(label="Unify Image with AI", value=False)

    outputs = [gr.Image(type="pil", label="Generated Image") for _ in range(5)]
    run_button = gr.Button("Run")
    
    run_button.click(
        fn=gradio_interface,
        inputs=[person_img, background_img, num_images, enhance, auto_match, contrast_slider, brightness_slider, color_slider, unify],
        outputs=outputs
    )

if __name__ == "__main__":
    interface.launch(share=True)