Env_mixer / image_styler.py
Inmental's picture
Upload folder using huggingface_hub
4c62147 verified
import gradio as gr
from rembg import remove
from PIL import Image, ImageOps, ImageEnhance, ImageStat
import torch
from torchvision import transforms
from torchvision.models import vgg19, VGG19_Weights
# Function to unify the image using a pre-trained VGG19 model
def unify_image(combined_img):
# Load pre-trained VGG19 model
weights = VGG19_Weights.IMAGENET1K_V1
model = vgg19(weights=weights).features.eval()
preprocess = transforms.Compose([
transforms.Resize((256, 256)),
transforms.ToTensor(),
transforms.Normalize(mean=weights.meta["mean"], std=weights.meta["std"]),
])
# Preprocess the image
input_tensor = preprocess(combined_img.convert("RGB")).unsqueeze(0)
# Forward pass
with torch.no_grad():
output_tensor = model(input_tensor).squeeze(0)
# Postprocess the output
postprocess = transforms.Compose([
transforms.Normalize(mean=[-2.118, -2.036, -1.804], std=[4.367, 4.464, 4.444]),
transforms.ToPILImage(),
])
unified_img = postprocess(output_tensor.cpu()).convert("RGBA")
return unified_img
def embed_person_on_background(person_img, background_img):
# Preserve the aspect ratio and resize the person image to fit within the background
person_img = ImageOps.contain(person_img, background_img.size, method=Image.LANCZOS)
# Create a new image with the same size as the background and paste the person image onto it
combined_img = Image.new("RGBA", background_img.size)
combined_img.paste(background_img, (0, 0))
combined_img.paste(person_img, (0, 0), person_img)
return combined_img
def auto_match_enhancers(person_img, background_img):
# Calculate the enhancement factors based on the background image
stat = ImageStat.Stat(background_img)
mean = stat.mean[:3] # Mean color of the background
# Simple logic to calculate enhancement factors
contrast = 1.5 if mean[0] < 128 else 1.2
brightness = 1.2 if mean[1] < 128 else 1.1
color = 1.3 if mean[2] < 128 else 1.0
enhancers = [
(ImageEnhance.Contrast(person_img), contrast),
(ImageEnhance.Brightness(person_img), brightness),
(ImageEnhance.Color(person_img), color),
]
enhanced_image = person_img
for enhancer, factor in enhancers:
enhanced_image = enhancer.enhance(factor)
return enhanced_image
def enhance_image(image, contrast, brightness, color):
# Enhance the image based on the provided parameters
enhancers = [
(ImageEnhance.Contrast(image), contrast),
(ImageEnhance.Brightness(image), brightness),
(ImageEnhance.Color(image), color),
]
enhanced_image = image
for enhancer, factor in enhancers:
enhanced_image = enhancer.enhance(factor)
return enhanced_image
def process_images(person_img, background_img, num_images, enhance, auto_match, contrast, brightness, color, unify):
# Validate parameters
if not (1 <= num_images <= 5):
raise ValueError("Number of Output Images must be between 1 and 5")
# Remove background from the person image
person_no_bg = remove(person_img)
if enhance and auto_match:
print("Auto-matching enhancers based on the background color...")
person_no_bg = auto_match_enhancers(person_no_bg, background_img)
elif enhance:
print(f"Applying enhancement with contrast={contrast}, brightness={brightness}, color={color}...")
person_no_bg = enhance_image(person_no_bg, contrast, brightness, color)
combined_img = embed_person_on_background(person_no_bg, background_img)
if unify:
print("Unifying image with AI...")
combined_img = unify_image(combined_img)
outputs = [combined_img] * num_images
return outputs
def gradio_interface(person_img, background_img, num_images, enhance, auto_match, contrast, brightness, color, unify):
try:
results = process_images(person_img, background_img, num_images, enhance, auto_match, contrast, brightness, color, unify)
return results + [None] * (5 - len(results)) # Ensure the number of returned images matches the expected output
except Exception as e:
return [str(e)] + [None] * 4
def update_enhancement_controls(auto_match):
# Disable enhancement sliders if auto-match is checked
return {
contrast_slider: gr.update(interactive=not auto_match),
brightness_slider: gr.update(interactive=not auto_match),
color_slider: gr.update(interactive=not auto_match),
}
# Create Gradio interface
with gr.Blocks() as interface:
with gr.Row():
person_img = gr.Image(type="pil", label="Upload Person Image")
background_img = gr.Image(type="pil", label="Upload Background Image")
num_images = gr.Slider(minimum=1, maximum=5, step=1, value=1, label="Number of Output Images")
enhance = gr.Checkbox(label="Enhance Image", value=False)
auto_match = gr.Checkbox(label="Auto-Match Enhancers", value=False)
contrast_slider = gr.Slider(minimum=0.5, maximum=3.0, step=0.1, value=1.0, label="Contrast")
brightness_slider = gr.Slider(minimum=0.5, maximum=3.0, step=0.1, value=1.0, label="Brightness")
color_slider = gr.Slider(minimum=0.5, maximum=3.0, step=0.1, value=1.0, label="Color")
auto_match.change(fn=update_enhancement_controls, inputs=auto_match, outputs=[contrast_slider, brightness_slider, color_slider])
unify = gr.Checkbox(label="Unify Image with AI", value=False)
outputs = [gr.Image(type="pil", label="Generated Image") for _ in range(5)]
run_button = gr.Button("Run")
run_button.click(
fn=gradio_interface,
inputs=[person_img, background_img, num_images, enhance, auto_match, contrast_slider, brightness_slider, color_slider, unify],
outputs=outputs
)
if __name__ == "__main__":
interface.launch(share=True)