import gradio as gr from rembg import remove from PIL import Image, ImageOps, ImageEnhance, ImageStat import torch import torchvision.transforms.functional as tf from torchvision import transforms import numpy as np from src import model # Load the harmonizer model def load_harmonization_model(pretrained_path): harmonizer = model.Harmonizer() if torch.cuda.is_available(): harmonizer = harmonizer.cuda() harmonizer.load_state_dict(torch.load(pretrained_path), strict=True) harmonizer.eval() return harmonizer # Load the enhancer model def load_enhancement_model(pretrained_path): enhancer = model.Enhancer() if torch.cuda.is_available(): enhancer = enhancer.cuda() enhancer.load_state_dict(torch.load(pretrained_path), strict=True) enhancer.eval() return enhancer # Function to unify the image using the custom AI harmonization model def unify_image(combined_img, harmonizer): original_size = combined_img.size # Create a mask for the composite image mask = Image.new("L", original_size, 255) mask = mask.point(lambda p: p > 0 and 255) preprocess = transforms.Compose([ transforms.ToTensor(), ]) # Preprocess the images comp = preprocess(combined_img.convert("RGB")).unsqueeze(0) mask = preprocess(mask).unsqueeze(0) if torch.cuda.is_available(): comp = comp.cuda() mask = mask.cuda() # Harmonization with torch.no_grad(): arguments = harmonizer.predict_arguments(comp, mask) harmonized = harmonizer.restore_image(comp, mask, arguments)[-1] # Postprocess the output harmonized = np.transpose(harmonized[0].cpu().numpy(), (1, 2, 0)) * 255 harmonized_img = Image.fromarray(harmonized.astype(np.uint8)).convert("RGBA") harmonized_img = harmonized_img.resize(original_size) return harmonized_img # Function to enhance the image using the custom AI enhancement model def enhance_unified_image(harmonized_img, enhancer): original_size = harmonized_img.size preprocess = transforms.Compose([ transforms.ToTensor(), ]) # Preprocess the image original = preprocess(harmonized_img.convert("RGB")).unsqueeze(0) # Create a mask (not used in enhancement, so all pixels are equal to 1) mask = original * 0 + 1 if torch.cuda.is_available(): original = original.cuda() mask = mask.cuda() # Enhancement with torch.no_grad(): arguments = enhancer.predict_arguments(original, mask) enhanced = enhancer.restore_image(original, mask, arguments)[-1] # Postprocess the output enhanced = np.transpose(enhanced[0].cpu().numpy(), (1, 2, 0)) * 255 enhanced_img = Image.fromarray(enhanced.astype(np.uint8)).convert("RGBA") enhanced_img = enhanced_img.resize(original_size) return enhanced_img def embed_person_on_background(person_img, background_img, position_x, position_y, scale): # Scale the person image while keeping proportions person_width, person_height = person_img.size new_width = int(person_width * scale) new_height = int(person_height * scale) person_img = person_img.resize((new_width, new_height), Image.LANCZOS) # Calculate the position based on bottom-center transformation point background_width, background_height = background_img.size # Default position: bottom-center of the background default_x = (background_width - new_width) // 2 default_y = background_height - new_height # Adjust the position based on sliders position_x = default_x + int(position_x) position_y = default_y + int(position_y) # Create a new image with the same size as the background and paste the person image onto it combined_img = Image.new("RGBA", background_img.size) combined_img.paste(background_img, (0, 0)) combined_img.paste(person_img, (position_x, position_y), person_img) return combined_img def auto_match_enhancers(person_img, background_img): # Calculate the enhancement factors based on the background image stat = ImageStat.Stat(background_img) mean = stat.mean[:3] # Mean color of the background # Simple logic to calculate enhancement factors contrast = 1.5 if mean[0] < 128 else 1.2 brightness = 1.2 if mean[1] < 128 else 1.1 color = 1.3 if mean[2] < 128 else 1.0 enhancers = [ (ImageEnhance.Contrast(person_img), contrast), (ImageEnhance.Brightness(person_img), brightness), (ImageEnhance.Color(person_img), color), ] enhanced_image = person_img for enhancer, factor in enhancers: enhanced_image = enhancer.enhance(factor) return enhanced_image def enhance_image(image, contrast, brightness, color): # Enhance the image based on the provided parameters enhancers = [ (ImageEnhance.Contrast(image), contrast), (ImageEnhance.Brightness(image), brightness), (ImageEnhance.Color(image), color), ] enhanced_image = image for enhancer, factor in enhancers: enhanced_image = enhancer.enhance(factor) return enhanced_image def process_images(person_img, background_img, enhance, auto_match, contrast, brightness, color, unify, position_x, position_y, scale): # Remove background from the person image person_no_bg = remove(person_img) if enhance and auto_match: print("Auto-matching enhancers based on the background color...") person_no_bg = auto_match_enhancers(person_no_bg, background_img) elif enhance: print(f"Applying enhancement with contrast={contrast}, brightness={brightness}, color={color}...") person_no_bg = enhance_image(person_no_bg, contrast, brightness, color) combined_img = embed_person_on_background(person_no_bg, background_img, position_x, position_y, scale) if unify: print("Unifying image with AI...") harmonizer = load_harmonization_model('pretrained/harmonizer.pth') combined_img = unify_image(combined_img, harmonizer) enhancer = load_enhancement_model('pretrained/enhancer.pth') combined_img = enhance_unified_image(combined_img, enhancer) return combined_img def gradio_interface(person_img, background_img, enhance, auto_match, contrast, brightness, color, unify, position_x, position_y, scale): try: result = process_images(person_img, background_img, enhance, auto_match, contrast, brightness, color, unify, position_x, position_y, scale) return result except Exception as e: return str(e) def update_enhancement_controls(auto_match): # Disable enhancement sliders if auto-match is checked return { contrast_slider: gr.update(interactive=not auto_match), brightness_slider: gr.update(interactive=not auto_match), color_slider: gr.update(interactive=not auto_match), } # Create Gradio interface with gr.Blocks(css='#output_image {max-width: 800px !important; width: auto !important; height: auto !important;}') as interface: with gr.Row(): person_img = gr.Image(type="pil", label="Upload Person Image") background_img = gr.Image(type="pil", label="Upload Background Image") enhance = gr.Checkbox(label="Enhance Image", value=False) auto_match = gr.Checkbox(label="Auto-Match Enhancers", value=False) contrast_slider = gr.Slider(minimum=0.5, maximum=3.0, step=0.1, value=1.0, label="Contrast") brightness_slider = gr.Slider(minimum=0.5, maximum=3.0, step=0.1, value=1.0, label="Brightness") color_slider = gr.Slider(minimum=0.5, maximum=3.0, step=0.1, value=1.0, label="Color") auto_match.change(fn=update_enhancement_controls, inputs=auto_match, outputs=[contrast_slider, brightness_slider, color_slider]) unify = gr.Checkbox(label="Unify Image with AI", value=True) position_x = gr.Slider(minimum=-500, maximum=500, step=1, value=0, label="Horizontal Position (pixels)") position_y = gr.Slider(minimum=-500, maximum=500, step=1, value=0, label="Vertical Position (pixels)") scale = gr.Slider(minimum=0.1, maximum=3.0, step=0.1, value=1.0, label="Scale") output = gr.Image(type="pil", label="Generated Image", elem_id="output_image") run_button = gr.Button("Run") run_button.click( fn=gradio_interface, inputs=[person_img, background_img, enhance, auto_match, contrast_slider, brightness_slider, color_slider, unify, position_x, position_y, scale], outputs=output ) if __name__ == "__main__": interface.launch()