|
|
import gradio as gr
|
|
|
from rembg import remove
|
|
|
from PIL import Image, ImageOps, ImageEnhance, ImageStat
|
|
|
import torch
|
|
|
from torchvision import transforms
|
|
|
from torchvision.models import vgg19, VGG19_Weights
|
|
|
|
|
|
|
|
|
def unify_image(combined_img):
|
|
|
|
|
|
weights = VGG19_Weights.IMAGENET1K_V1
|
|
|
model = vgg19(weights=weights).features.eval()
|
|
|
|
|
|
preprocess = transforms.Compose([
|
|
|
transforms.Resize((256, 256)),
|
|
|
transforms.ToTensor(),
|
|
|
transforms.Normalize(mean=weights.meta["mean"], std=weights.meta["std"]),
|
|
|
])
|
|
|
|
|
|
|
|
|
input_tensor = preprocess(combined_img.convert("RGB")).unsqueeze(0)
|
|
|
|
|
|
|
|
|
with torch.no_grad():
|
|
|
output_tensor = model(input_tensor).squeeze(0)
|
|
|
|
|
|
|
|
|
postprocess = transforms.Compose([
|
|
|
transforms.Normalize(mean=[-2.118, -2.036, -1.804], std=[4.367, 4.464, 4.444]),
|
|
|
transforms.ToPILImage(),
|
|
|
])
|
|
|
|
|
|
unified_img = postprocess(output_tensor.cpu()).convert("RGBA")
|
|
|
return unified_img
|
|
|
|
|
|
def embed_person_on_background(person_img, background_img):
|
|
|
|
|
|
person_img = ImageOps.contain(person_img, background_img.size, method=Image.LANCZOS)
|
|
|
|
|
|
|
|
|
combined_img = Image.new("RGBA", background_img.size)
|
|
|
combined_img.paste(background_img, (0, 0))
|
|
|
combined_img.paste(person_img, (0, 0), person_img)
|
|
|
|
|
|
return combined_img
|
|
|
|
|
|
def auto_match_enhancers(person_img, background_img):
|
|
|
|
|
|
stat = ImageStat.Stat(background_img)
|
|
|
mean = stat.mean[:3]
|
|
|
|
|
|
|
|
|
contrast = 1.5 if mean[0] < 128 else 1.2
|
|
|
brightness = 1.2 if mean[1] < 128 else 1.1
|
|
|
color = 1.3 if mean[2] < 128 else 1.0
|
|
|
|
|
|
enhancers = [
|
|
|
(ImageEnhance.Contrast(person_img), contrast),
|
|
|
(ImageEnhance.Brightness(person_img), brightness),
|
|
|
(ImageEnhance.Color(person_img), color),
|
|
|
]
|
|
|
|
|
|
enhanced_image = person_img
|
|
|
for enhancer, factor in enhancers:
|
|
|
enhanced_image = enhancer.enhance(factor)
|
|
|
return enhanced_image
|
|
|
|
|
|
def enhance_image(image, contrast, brightness, color):
|
|
|
|
|
|
enhancers = [
|
|
|
(ImageEnhance.Contrast(image), contrast),
|
|
|
(ImageEnhance.Brightness(image), brightness),
|
|
|
(ImageEnhance.Color(image), color),
|
|
|
]
|
|
|
enhanced_image = image
|
|
|
for enhancer, factor in enhancers:
|
|
|
enhanced_image = enhancer.enhance(factor)
|
|
|
return enhanced_image
|
|
|
|
|
|
def process_images(person_img, background_img, num_images, enhance, auto_match, contrast, brightness, color, unify):
|
|
|
|
|
|
if not (1 <= num_images <= 5):
|
|
|
raise ValueError("Number of Output Images must be between 1 and 5")
|
|
|
|
|
|
|
|
|
person_no_bg = remove(person_img)
|
|
|
|
|
|
if enhance and auto_match:
|
|
|
print("Auto-matching enhancers based on the background color...")
|
|
|
person_no_bg = auto_match_enhancers(person_no_bg, background_img)
|
|
|
elif enhance:
|
|
|
print(f"Applying enhancement with contrast={contrast}, brightness={brightness}, color={color}...")
|
|
|
person_no_bg = enhance_image(person_no_bg, contrast, brightness, color)
|
|
|
|
|
|
combined_img = embed_person_on_background(person_no_bg, background_img)
|
|
|
|
|
|
if unify:
|
|
|
print("Unifying image with AI...")
|
|
|
combined_img = unify_image(combined_img)
|
|
|
|
|
|
outputs = [combined_img] * num_images
|
|
|
|
|
|
return outputs
|
|
|
|
|
|
def gradio_interface(person_img, background_img, num_images, enhance, auto_match, contrast, brightness, color, unify):
|
|
|
try:
|
|
|
results = process_images(person_img, background_img, num_images, enhance, auto_match, contrast, brightness, color, unify)
|
|
|
return results + [None] * (5 - len(results))
|
|
|
except Exception as e:
|
|
|
return [str(e)] + [None] * 4
|
|
|
|
|
|
def update_enhancement_controls(auto_match):
|
|
|
|
|
|
return {
|
|
|
contrast_slider: gr.update(interactive=not auto_match),
|
|
|
brightness_slider: gr.update(interactive=not auto_match),
|
|
|
color_slider: gr.update(interactive=not auto_match),
|
|
|
}
|
|
|
|
|
|
|
|
|
with gr.Blocks() as interface:
|
|
|
with gr.Row():
|
|
|
person_img = gr.Image(type="pil", label="Upload Person Image")
|
|
|
background_img = gr.Image(type="pil", label="Upload Background Image")
|
|
|
num_images = gr.Slider(minimum=1, maximum=5, step=1, value=1, label="Number of Output Images")
|
|
|
enhance = gr.Checkbox(label="Enhance Image", value=False)
|
|
|
auto_match = gr.Checkbox(label="Auto-Match Enhancers", value=False)
|
|
|
contrast_slider = gr.Slider(minimum=0.5, maximum=3.0, step=0.1, value=1.0, label="Contrast")
|
|
|
brightness_slider = gr.Slider(minimum=0.5, maximum=3.0, step=0.1, value=1.0, label="Brightness")
|
|
|
color_slider = gr.Slider(minimum=0.5, maximum=3.0, step=0.1, value=1.0, label="Color")
|
|
|
|
|
|
auto_match.change(fn=update_enhancement_controls, inputs=auto_match, outputs=[contrast_slider, brightness_slider, color_slider])
|
|
|
|
|
|
unify = gr.Checkbox(label="Unify Image with AI", value=False)
|
|
|
|
|
|
outputs = [gr.Image(type="pil", label="Generated Image") for _ in range(5)]
|
|
|
run_button = gr.Button("Run")
|
|
|
|
|
|
run_button.click(
|
|
|
fn=gradio_interface,
|
|
|
inputs=[person_img, background_img, num_images, enhance, auto_match, contrast_slider, brightness_slider, color_slider, unify],
|
|
|
outputs=outputs
|
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
interface.launch(share=True)
|
|
|
|