| | import gradio as gr |
| | import numpy as np |
| | import random |
| | import torch |
| | import spaces |
| | from PIL import Image |
| | import math |
| | from diffusers import FlowMatchEulerDiscreteScheduler, QwenImageEditPlusPipeline |
| | from huggingface_hub import hf_hub_download |
| | from safetensors.torch import load_file |
| | from briarmbg import BriaRMBG |
| | import os |
| | import tempfile |
| |
|
| | |
| | dtype = torch.bfloat16 |
| | device = "cuda" if torch.cuda.is_available() else "cpu" |
| |
|
| | scheduler_config = { |
| | "base_image_seq_len": 256, |
| | "base_shift": math.log(3), |
| | "invert_sigmas": False, |
| | "max_image_seq_len": 8192, |
| | "max_shift": math.log(3), |
| | "num_train_timesteps": 1000, |
| | "shift": 1.0, |
| | "shift_terminal": None, |
| | "stochastic_sampling": False, |
| | "time_shift_type": "exponential", |
| | "use_beta_sigmas": False, |
| | "use_dynamic_shifting": True, |
| | "use_exponential_sigmas": False, |
| | "use_karras_sigmas": False, |
| | } |
| |
|
| | scheduler = FlowMatchEulerDiscreteScheduler.from_config(scheduler_config) |
| |
|
| | pipe = QwenImageEditPlusPipeline.from_pretrained( |
| | "Qwen/Qwen-Image-Edit-2509", |
| | scheduler=scheduler, |
| | torch_dtype=dtype |
| | ).to(device) |
| |
|
| | pipe.load_lora_weights( |
| | "lightx2v/Qwen-Image-Lightning", |
| | weight_name="Qwen-Image-Lightning-4steps-V2.0.safetensors", adapter_name="fast" |
| | ) |
| | pipe.load_lora_weights( |
| | "dx8152/Qwen-Image-Edit-2509-Fusion", |
| | weight_name="溶图.safetensors", adapter_name="fusion" |
| | ) |
| | pipe.set_adapters(["fast"], adapter_weights=[1.]) |
| | pipe.fuse_lora(adapter_names=["fast"]) |
| | pipe.fuse_lora(adapter_names=["fusion"]) |
| | pipe.unload_lora_weights() |
| |
|
| | |
| | rmbg = BriaRMBG.from_pretrained("briaai/RMBG-1.4").to(device, dtype=torch.float32) |
| |
|
| | MAX_SEED = np.iinfo(np.int32).max |
| |
|
| |
|
| | |
| | def remove_alpha_channel(image: Image.Image) -> Image.Image: |
| | """ |
| | Remove alpha channel from PIL Image if it exists. |
| | |
| | Args: |
| | image (Image.Image): Input PIL image |
| | |
| | Returns: |
| | Image.Image: Image with alpha channel removed (RGB format) |
| | """ |
| | if image.mode in ('RGBA', 'LA'): |
| | |
| | background = Image.new('RGB', image.size, (255, 255, 255)) |
| | |
| | if image.mode == 'RGBA': |
| | background.paste(image, mask=image.split()[-1]) |
| | else: |
| | background.paste(image.convert('RGB'), mask=image.split()[-1]) |
| | return background |
| | elif image.mode == 'P': |
| | |
| | if 'transparency' in image.info: |
| | image = image.convert('RGBA') |
| | background = Image.new('RGB', image.size, (255, 255, 255)) |
| | background.paste(image, mask=image.split()[-1]) |
| | return background |
| | else: |
| | return image.convert('RGB') |
| | elif image.mode != 'RGB': |
| | |
| | return image.convert('RGB') |
| | else: |
| | |
| | return image |
| |
|
| | |
| | def numpy2pytorch(imgs): |
| | h = torch.from_numpy(np.stack(imgs, axis=0)).float() / 127.0 - 1.0 |
| | h = h.movedim(-1, 1) |
| | return h |
| |
|
| |
|
| | |
| | def pytorch2numpy(imgs, quant=True): |
| | results = [] |
| | for x in imgs: |
| | y = x.movedim(0, -1) |
| |
|
| | if quant: |
| | y = y * 127.5 + 127.5 |
| | y = y.detach().float().cpu().numpy().clip(0, 255).astype(np.uint8) |
| | else: |
| | y = y * 0.5 + 0.5 |
| | y = y.detach().float().cpu().numpy().clip(0, 1).astype(np.float32) |
| |
|
| | results.append(y) |
| | return results |
| |
|
| |
|
| | def resize_without_crop(image, target_width, target_height): |
| | pil_image = Image.fromarray(image) |
| | resized_image = pil_image.resize((target_width, target_height), Image.LANCZOS) |
| | return np.array(resized_image) |
| |
|
| |
|
| | @spaces.GPU() |
| | def run_rmbg(img, sigma=0.0): |
| | """ |
| | Remove background from image using BriaRMBG model. |
| | |
| | Args: |
| | img (np.ndarray): Input image as numpy array (H, W, C) |
| | sigma (float): Noise parameter for blending |
| | |
| | Returns: |
| | tuple: (result_image, alpha_mask) where result_image is the image with background removed |
| | """ |
| | H, W, C = img.shape |
| | assert C == 3 |
| | k = (256.0 / float(H * W)) ** 0.5 |
| | feed = resize_without_crop(img, int(64 * round(W * k)), int(64 * round(H * k))) |
| | feed = numpy2pytorch([feed]).to(device="cuda", dtype=torch.float32) |
| | alpha = rmbg(feed)[0][0] |
| | alpha = torch.nn.functional.interpolate(alpha, size=(H, W), mode="bilinear") |
| | alpha = alpha.movedim(1, -1)[0] |
| | alpha = alpha.detach().float().cpu().numpy().clip(0, 1) |
| | result = 127 + (img.astype(np.float32) - 127 + sigma) * alpha |
| | return result.clip(0, 255).astype(np.uint8), alpha |
| |
|
| | def remove_background_from_image(image: Image.Image) -> Image.Image: |
| | """ |
| | Remove background from PIL Image using RMBG model. |
| | |
| | Args: |
| | image (Image.Image): Input PIL image |
| | |
| | Returns: |
| | Image.Image: Image with background removed (transparent background) |
| | """ |
| | |
| | img_array = np.array(image) |
| | |
| | |
| | result_array, alpha_mask = run_rmbg(img_array) |
| | |
| | |
| | result_image = Image.fromarray(result_array) |
| | |
| | |
| | if result_image.mode != 'RGBA': |
| | result_image = result_image.convert('RGBA') |
| | |
| | |
| | |
| | alpha_mask_2d = np.squeeze(alpha_mask) |
| | if alpha_mask_2d.ndim > 2: |
| | |
| | alpha_mask_2d = alpha_mask_2d[:, :, 0] if alpha_mask_2d.shape[-1] == 1 else alpha_mask_2d[:, :, 0] |
| | |
| | |
| | alpha_array = (alpha_mask_2d * 255).astype(np.uint8) |
| | alpha_pil = Image.fromarray(alpha_array, 'L') |
| | result_image.putalpha(alpha_pil) |
| | |
| | return result_image |
| |
|
| |
|
| | def calculate_dimensions(image): |
| | """Calculate output dimensions based on background image, keeping largest side at 1024.""" |
| | if image is None: |
| | return 1024, 1024 |
| | |
| | original_width, original_height = image.size |
| | |
| | if original_width > original_height: |
| | new_width = 1024 |
| | aspect_ratio = original_height / original_width |
| | new_height = int(new_width * aspect_ratio) |
| | else: |
| | new_height = 1024 |
| | aspect_ratio = original_width / original_height |
| | new_width = int(new_height * aspect_ratio) |
| | |
| | |
| | new_width = (new_width // 8) * 8 |
| | new_height = (new_height // 8) * 8 |
| | |
| | return new_width, new_height |
| |
|
| |
|
| | |
| | @spaces.GPU |
| | def infer( |
| | product_image, |
| | image_background, |
| | prompt="", |
| | seed=42, |
| | randomize_seed=True, |
| | true_guidance_scale=1, |
| | num_inference_steps=4, |
| | progress=gr.Progress(track_tqdm=True) |
| | ): |
| | if randomize_seed: |
| | seed = random.randint(0, MAX_SEED) |
| | generator = torch.Generator(device=device).manual_seed(seed) |
| |
|
| | processed_subjects = [] |
| | if product_image: |
| | image = remove_background_from_image(product_image) |
| | |
| | |
| | image = remove_alpha_channel(image) |
| | processed_subjects.append(image) |
| |
|
| | all_inputs = processed_subjects |
| | if image_background is not None: |
| | all_inputs.append(image_background) |
| |
|
| | width, height = calculate_dimensions(image_background) |
| | |
| | if not all_inputs: |
| | raise gr.Error("Please upload at least one image or a background image.") |
| |
|
| | prompt = prompt +". Integrate the product from Image 1 onto Image 2 as the background, ensuring seamless blending with appropriate lighting and shadows" if len(all_inputs) > 1 else prompt |
| | result = pipe( |
| | image=all_inputs, |
| | prompt=prompt, |
| | width=width, |
| | height=height, |
| | num_inference_steps=num_inference_steps, |
| | generator=generator, |
| | true_cfg_scale=true_guidance_scale, |
| | num_images_per_prompt=1, |
| | ).images[0] |
| |
|
| | return [image_background, result], seed |
| |
|
| |
|
| | |
| | css = '''#col-container { max-width: 1100px; margin: 0 auto; } |
| | .dark .progress-text{color: white !important} |
| | #examples{max-width: 1100px; margin: 0 auto; }''' |
| |
|
| | with gr.Blocks(theme=gr.themes.Citrus(), css=css) as demo: |
| | with gr.Column(elem_id="col-container"): |
| | gr.Markdown("## Qwen Image Edit — Product Fusion") |
| | gr.Markdown(""" Seamlessy blend products onto backgrounds with Qwen Image Edit 2509 ✨ Using [dx8152's Qwen-Image-Edit-2509 Fusion LoRA](https://huggingface.co/dx8152/Qwen-Image-Edit-2509-Fusion) and [lightx2v Qwen-Image-Lightning LoRA]() for 4-step inference 💨 """ ) |
| | with gr.Row(): |
| | with gr.Column(): |
| | with gr.Row(): |
| | product_image = gr.Image( |
| | label="Product image (background auto removed)", type="pil" |
| | ) |
| | image_background = gr.Image(label="Background Image", type="pil", visible=True) |
| | prompt = gr.Textbox(label="Prompt", placeholder="put the [product] on the [background]") |
| | run_button = gr.Button("Fuse Images", variant="primary") |
| |
|
| | with gr.Accordion("Advanced Settings", open=False): |
| | seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0) |
| | randomize_seed = gr.Checkbox(label="Randomize Seed", value=True) |
| | true_guidance_scale = gr.Slider(label="True Guidance Scale", minimum=1.0, maximum=10.0, step=0.1, value=1.0) |
| | num_inference_steps = gr.Slider(label="Inference Steps", minimum=1, maximum=40, step=1, value=4) |
| |
|
| | with gr.Column(): |
| | result = gr.ImageSlider(label="Output Image", interactive=False) |
| |
|
| | gr.Examples( |
| | examples=[ |
| | ["product.png", "wednesday.png", "put the product in her hand"], |
| | [None, "fusion_car.png", ""], |
| | ["product_2.png", "background_2.png", "put the product on the chair"], |
| | [None, "fusion_milkshake.png", ""], |
| | [None, "fusion_shoes.png", "put the shoes on the feet"], |
| | ["product_3.png", "background_3.jpg", "put the product on the background"], |
| | ], |
| | inputs=[product_image, image_background, prompt], |
| | outputs=[result, seed], |
| | fn=infer, |
| | cache_examples="lazy", |
| | elem_id="examples" |
| | ) |
| |
|
| | inputs = [product_image, image_background, prompt, seed, randomize_seed, true_guidance_scale, num_inference_steps] |
| | outputs = [result, seed] |
| |
|
| | run_button.click(fn=infer, inputs=inputs, outputs=outputs) |
| |
|
| | demo.launch(share=True) |
| |
|