import gradio as gr from rembg import remove from PIL import Image, ImageOps, ImageEnhance, ImageStat import torch from torchvision import transforms from torchvision.models import vgg19, VGG19_Weights # Function to unify the image using a pre-trained VGG19 model def unify_image(combined_img): # Load pre-trained VGG19 model weights = VGG19_Weights.IMAGENET1K_V1 model = vgg19(weights=weights).features.eval() preprocess = transforms.Compose([ transforms.Resize((256, 256)), transforms.ToTensor(), transforms.Normalize(mean=weights.meta["mean"], std=weights.meta["std"]), ]) # Preprocess the image input_tensor = preprocess(combined_img.convert("RGB")).unsqueeze(0) # Forward pass with torch.no_grad(): output_tensor = model(input_tensor).squeeze(0) # Postprocess the output postprocess = transforms.Compose([ transforms.Normalize(mean=[-2.118, -2.036, -1.804], std=[4.367, 4.464, 4.444]), transforms.ToPILImage(), ]) unified_img = postprocess(output_tensor.cpu()).convert("RGBA") return unified_img def embed_person_on_background(person_img, background_img): # Preserve the aspect ratio and resize the person image to fit within the background person_img = ImageOps.contain(person_img, background_img.size, method=Image.LANCZOS) # Create a new image with the same size as the background and paste the person image onto it combined_img = Image.new("RGBA", background_img.size) combined_img.paste(background_img, (0, 0)) combined_img.paste(person_img, (0, 0), person_img) return combined_img def auto_match_enhancers(person_img, background_img): # Calculate the enhancement factors based on the background image stat = ImageStat.Stat(background_img) mean = stat.mean[:3] # Mean color of the background # Simple logic to calculate enhancement factors contrast = 1.5 if mean[0] < 128 else 1.2 brightness = 1.2 if mean[1] < 128 else 1.1 color = 1.3 if mean[2] < 128 else 1.0 enhancers = [ (ImageEnhance.Contrast(person_img), contrast), (ImageEnhance.Brightness(person_img), brightness), (ImageEnhance.Color(person_img), color), ] enhanced_image = person_img for enhancer, factor in enhancers: enhanced_image = enhancer.enhance(factor) return enhanced_image def enhance_image(image, contrast, brightness, color): # Enhance the image based on the provided parameters enhancers = [ (ImageEnhance.Contrast(image), contrast), (ImageEnhance.Brightness(image), brightness), (ImageEnhance.Color(image), color), ] enhanced_image = image for enhancer, factor in enhancers: enhanced_image = enhancer.enhance(factor) return enhanced_image def process_images(person_img, background_img, num_images, enhance, auto_match, contrast, brightness, color, unify): # Validate parameters if not (1 <= num_images <= 5): raise ValueError("Number of Output Images must be between 1 and 5") # Remove background from the person image person_no_bg = remove(person_img) if enhance and auto_match: print("Auto-matching enhancers based on the background color...") person_no_bg = auto_match_enhancers(person_no_bg, background_img) elif enhance: print(f"Applying enhancement with contrast={contrast}, brightness={brightness}, color={color}...") person_no_bg = enhance_image(person_no_bg, contrast, brightness, color) combined_img = embed_person_on_background(person_no_bg, background_img) if unify: print("Unifying image with AI...") combined_img = unify_image(combined_img) outputs = [combined_img] * num_images return outputs def gradio_interface(person_img, background_img, num_images, enhance, auto_match, contrast, brightness, color, unify): try: results = process_images(person_img, background_img, num_images, enhance, auto_match, contrast, brightness, color, unify) return results + [None] * (5 - len(results)) # Ensure the number of returned images matches the expected output except Exception as e: return [str(e)] + [None] * 4 def update_enhancement_controls(auto_match): # Disable enhancement sliders if auto-match is checked return { contrast_slider: gr.update(interactive=not auto_match), brightness_slider: gr.update(interactive=not auto_match), color_slider: gr.update(interactive=not auto_match), } # Create Gradio interface with gr.Blocks() as interface: with gr.Row(): person_img = gr.Image(type="pil", label="Upload Person Image") background_img = gr.Image(type="pil", label="Upload Background Image") num_images = gr.Slider(minimum=1, maximum=5, step=1, value=1, label="Number of Output Images") enhance = gr.Checkbox(label="Enhance Image", value=False) auto_match = gr.Checkbox(label="Auto-Match Enhancers", value=False) contrast_slider = gr.Slider(minimum=0.5, maximum=3.0, step=0.1, value=1.0, label="Contrast") brightness_slider = gr.Slider(minimum=0.5, maximum=3.0, step=0.1, value=1.0, label="Brightness") color_slider = gr.Slider(minimum=0.5, maximum=3.0, step=0.1, value=1.0, label="Color") auto_match.change(fn=update_enhancement_controls, inputs=auto_match, outputs=[contrast_slider, brightness_slider, color_slider]) unify = gr.Checkbox(label="Unify Image with AI", value=False) outputs = [gr.Image(type="pil", label="Generated Image") for _ in range(5)] run_button = gr.Button("Run") run_button.click( fn=gradio_interface, inputs=[person_img, background_img, num_images, enhance, auto_match, contrast_slider, brightness_slider, color_slider, unify], outputs=outputs ) if __name__ == "__main__": interface.launch(share=True)