|
|
import gradio as gr
|
|
|
from rembg import remove
|
|
|
from PIL import Image, ImageOps, ImageEnhance, ImageStat
|
|
|
import torch
|
|
|
import torchvision.transforms.functional as tf
|
|
|
from torchvision import transforms
|
|
|
import numpy as np
|
|
|
from src import model
|
|
|
|
|
|
|
|
|
def load_harmonization_model(pretrained_path):
|
|
|
harmonizer = model.Harmonizer()
|
|
|
if torch.cuda.is_available():
|
|
|
harmonizer = harmonizer.cuda()
|
|
|
harmonizer.load_state_dict(torch.load(pretrained_path), strict=True)
|
|
|
harmonizer.eval()
|
|
|
return harmonizer
|
|
|
|
|
|
|
|
|
def load_enhancement_model(pretrained_path):
|
|
|
enhancer = model.Enhancer()
|
|
|
if torch.cuda.is_available():
|
|
|
enhancer = enhancer.cuda()
|
|
|
enhancer.load_state_dict(torch.load(pretrained_path), strict=True)
|
|
|
enhancer.eval()
|
|
|
return enhancer
|
|
|
|
|
|
|
|
|
def unify_image(combined_img, harmonizer):
|
|
|
original_size = combined_img.size
|
|
|
|
|
|
|
|
|
mask = Image.new("L", original_size, 255)
|
|
|
mask = mask.point(lambda p: p > 0 and 255)
|
|
|
|
|
|
preprocess = transforms.Compose([
|
|
|
transforms.ToTensor(),
|
|
|
])
|
|
|
|
|
|
|
|
|
comp = preprocess(combined_img.convert("RGB")).unsqueeze(0)
|
|
|
mask = preprocess(mask).unsqueeze(0)
|
|
|
|
|
|
if torch.cuda.is_available():
|
|
|
comp = comp.cuda()
|
|
|
mask = mask.cuda()
|
|
|
|
|
|
|
|
|
with torch.no_grad():
|
|
|
arguments = harmonizer.predict_arguments(comp, mask)
|
|
|
harmonized = harmonizer.restore_image(comp, mask, arguments)[-1]
|
|
|
|
|
|
|
|
|
harmonized = np.transpose(harmonized[0].cpu().numpy(), (1, 2, 0)) * 255
|
|
|
harmonized_img = Image.fromarray(harmonized.astype(np.uint8)).convert("RGBA")
|
|
|
harmonized_img = harmonized_img.resize(original_size)
|
|
|
|
|
|
return harmonized_img
|
|
|
|
|
|
|
|
|
def enhance_unified_image(harmonized_img, enhancer):
|
|
|
original_size = harmonized_img.size
|
|
|
|
|
|
preprocess = transforms.Compose([
|
|
|
transforms.ToTensor(),
|
|
|
])
|
|
|
|
|
|
|
|
|
original = preprocess(harmonized_img.convert("RGB")).unsqueeze(0)
|
|
|
|
|
|
|
|
|
mask = original * 0 + 1
|
|
|
|
|
|
if torch.cuda.is_available():
|
|
|
original = original.cuda()
|
|
|
mask = mask.cuda()
|
|
|
|
|
|
|
|
|
with torch.no_grad():
|
|
|
arguments = enhancer.predict_arguments(original, mask)
|
|
|
enhanced = enhancer.restore_image(original, mask, arguments)[-1]
|
|
|
|
|
|
|
|
|
enhanced = np.transpose(enhanced[0].cpu().numpy(), (1, 2, 0)) * 255
|
|
|
enhanced_img = Image.fromarray(enhanced.astype(np.uint8)).convert("RGBA")
|
|
|
enhanced_img = enhanced_img.resize(original_size)
|
|
|
|
|
|
return enhanced_img
|
|
|
|
|
|
def embed_person_on_background(person_img, background_img, position_x, position_y, scale):
|
|
|
|
|
|
person_width, person_height = person_img.size
|
|
|
new_width = int(person_width * scale)
|
|
|
new_height = int(person_height * scale)
|
|
|
person_img = person_img.resize((new_width, new_height), Image.LANCZOS)
|
|
|
|
|
|
|
|
|
background_width, background_height = background_img.size
|
|
|
|
|
|
|
|
|
default_x = (background_width - new_width) // 2
|
|
|
default_y = background_height - new_height
|
|
|
|
|
|
|
|
|
position_x = default_x + int(position_x)
|
|
|
position_y = default_y + int(position_y)
|
|
|
|
|
|
|
|
|
combined_img = Image.new("RGBA", background_img.size)
|
|
|
combined_img.paste(background_img, (0, 0))
|
|
|
combined_img.paste(person_img, (position_x, position_y), person_img)
|
|
|
|
|
|
return combined_img
|
|
|
|
|
|
def auto_match_enhancers(person_img, background_img):
|
|
|
|
|
|
stat = ImageStat.Stat(background_img)
|
|
|
mean = stat.mean[:3]
|
|
|
|
|
|
|
|
|
contrast = 1.5 if mean[0] < 128 else 1.2
|
|
|
brightness = 1.2 if mean[1] < 128 else 1.1
|
|
|
color = 1.3 if mean[2] < 128 else 1.0
|
|
|
|
|
|
enhancers = [
|
|
|
(ImageEnhance.Contrast(person_img), contrast),
|
|
|
(ImageEnhance.Brightness(person_img), brightness),
|
|
|
(ImageEnhance.Color(person_img), color),
|
|
|
]
|
|
|
|
|
|
enhanced_image = person_img
|
|
|
for enhancer, factor in enhancers:
|
|
|
enhanced_image = enhancer.enhance(factor)
|
|
|
return enhanced_image
|
|
|
|
|
|
def enhance_image(image, contrast, brightness, color):
|
|
|
|
|
|
enhancers = [
|
|
|
(ImageEnhance.Contrast(image), contrast),
|
|
|
(ImageEnhance.Brightness(image), brightness),
|
|
|
(ImageEnhance.Color(image), color),
|
|
|
]
|
|
|
enhanced_image = image
|
|
|
for enhancer, factor in enhancers:
|
|
|
enhanced_image = enhancer.enhance(factor)
|
|
|
return enhanced_image
|
|
|
|
|
|
def process_images(person_img, background_img, enhance, auto_match, contrast, brightness, color, unify, position_x, position_y, scale):
|
|
|
|
|
|
person_no_bg = remove(person_img)
|
|
|
|
|
|
if enhance and auto_match:
|
|
|
print("Auto-matching enhancers based on the background color...")
|
|
|
person_no_bg = auto_match_enhancers(person_no_bg, background_img)
|
|
|
elif enhance:
|
|
|
print(f"Applying enhancement with contrast={contrast}, brightness={brightness}, color={color}...")
|
|
|
person_no_bg = enhance_image(person_no_bg, contrast, brightness, color)
|
|
|
|
|
|
combined_img = embed_person_on_background(person_no_bg, background_img, position_x, position_y, scale)
|
|
|
|
|
|
if unify:
|
|
|
print("Unifying image with AI...")
|
|
|
harmonizer = load_harmonization_model('pretrained/harmonizer.pth')
|
|
|
combined_img = unify_image(combined_img, harmonizer)
|
|
|
enhancer = load_enhancement_model('pretrained/enhancer.pth')
|
|
|
combined_img = enhance_unified_image(combined_img, enhancer)
|
|
|
|
|
|
return combined_img
|
|
|
|
|
|
def gradio_interface(person_img, background_img, enhance, auto_match, contrast, brightness, color, unify, position_x, position_y, scale):
|
|
|
try:
|
|
|
result = process_images(person_img, background_img, enhance, auto_match, contrast, brightness, color, unify, position_x, position_y, scale)
|
|
|
return result
|
|
|
except Exception as e:
|
|
|
return str(e)
|
|
|
|
|
|
def update_enhancement_controls(auto_match):
|
|
|
|
|
|
return {
|
|
|
contrast_slider: gr.update(interactive=not auto_match),
|
|
|
brightness_slider: gr.update(interactive=not auto_match),
|
|
|
color_slider: gr.update(interactive=not auto_match),
|
|
|
}
|
|
|
|
|
|
|
|
|
with gr.Blocks(css='#output_image {max-width: 800px !important; width: auto !important; height: auto !important;}') as interface:
|
|
|
with gr.Row():
|
|
|
person_img = gr.Image(type="pil", label="Upload Person Image")
|
|
|
background_img = gr.Image(type="pil", label="Upload Background Image")
|
|
|
enhance = gr.Checkbox(label="Enhance Image", value=False)
|
|
|
auto_match = gr.Checkbox(label="Auto-Match Enhancers", value=False)
|
|
|
contrast_slider = gr.Slider(minimum=0.5, maximum=3.0, step=0.1, value=1.0, label="Contrast")
|
|
|
brightness_slider = gr.Slider(minimum=0.5, maximum=3.0, step=0.1, value=1.0, label="Brightness")
|
|
|
color_slider = gr.Slider(minimum=0.5, maximum=3.0, step=0.1, value=1.0, label="Color")
|
|
|
|
|
|
auto_match.change(fn=update_enhancement_controls, inputs=auto_match, outputs=[contrast_slider, brightness_slider, color_slider])
|
|
|
|
|
|
unify = gr.Checkbox(label="Unify Image with AI", value=True)
|
|
|
position_x = gr.Slider(minimum=-500, maximum=500, step=1, value=0, label="Horizontal Position (pixels)")
|
|
|
position_y = gr.Slider(minimum=-500, maximum=500, step=1, value=0, label="Vertical Position (pixels)")
|
|
|
scale = gr.Slider(minimum=0.1, maximum=3.0, step=0.1, value=1.0, label="Scale")
|
|
|
|
|
|
output = gr.Image(type="pil", label="Generated Image", elem_id="output_image")
|
|
|
run_button = gr.Button("Run")
|
|
|
|
|
|
run_button.click(
|
|
|
fn=gradio_interface,
|
|
|
inputs=[person_img, background_img, enhance, auto_match, contrast_slider, brightness_slider, color_slider, unify, position_x, position_y, scale],
|
|
|
outputs=output
|
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
interface.launch()
|
|
|
|