Spaces:
Running
on
Zero
Running
on
Zero
| import gradio as gr | |
| import numpy as np | |
| import random | |
| import torch | |
| import spaces | |
| from PIL import Image | |
| from diffusers import FlowMatchEulerDiscreteScheduler | |
| from optimization import optimize_pipeline_ | |
| from diffusers import QwenImageEditPlusPipeline | |
| import math | |
| from huggingface_hub import hf_hub_download | |
| from safetensors.torch import load_file | |
| from PIL import Image | |
| import os | |
| import gradio as gr | |
| from gradio_client import Client, handle_file | |
| import tempfile | |
| from typing import Optional, Tuple, Any | |
| # --- Model Loading --- | |
| dtype = torch.bfloat16 | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| scheduler_config = { | |
| "base_image_seq_len": 256, | |
| "base_shift": math.log(3), | |
| "invert_sigmas": False, | |
| "max_image_seq_len": 8192, | |
| "max_shift": math.log(3), | |
| "num_train_timesteps": 1000, | |
| "shift": 1.0, | |
| "shift_terminal": None, | |
| "stochastic_sampling": False, | |
| "time_shift_type": "exponential", | |
| "use_beta_sigmas": False, | |
| "use_dynamic_shifting": True, | |
| "use_exponential_sigmas": False, | |
| "use_karras_sigmas": False, | |
| } | |
| scheduler = FlowMatchEulerDiscreteScheduler.from_config(scheduler_config) | |
| pipe = QwenImageEditPlusPipeline.from_pretrained( | |
| "Qwen/Qwen-Image-Edit-2509", | |
| scheduler=scheduler, | |
| torch_dtype=dtype | |
| ).to(device) | |
| pipe.load_lora_weights( | |
| "lightx2v/Qwen-Image-Lightning", | |
| weight_name="Qwen-Image-Lightning-8steps-V2.0-bf16.safetensors", | |
| adapter_name="fast" | |
| ) | |
| pipe.load_lora_weights( | |
| "dx8152/Qwen-Edit-2509-Light-Migration", | |
| weight_name="参考色调.safetensors", | |
| adapter_name="angles" | |
| ) | |
| pipe.set_adapters(["angles"], adapter_weights=[1.]) | |
| pipe.fuse_lora(adapter_names=["angles"], lora_scale=1.) | |
| pipe.set_adapters(["fast"], adapter_weights=[1.]) | |
| pipe.fuse_lora(adapter_names=["fast"], lora_scale=1.) | |
| pipe.unload_lora_weights() | |
| #spaces.aoti_blocks_load(pipe.transformer, "zerogpu-aoti/Qwen-Image", variant="fa3") | |
| pipe.transformer.set_attention_backend("_flash_3_hub") | |
| optimize_pipeline_( | |
| pipe, | |
| image=[Image.new("RGB", (1024, 1024)), Image.new("RGB", (1024, 1024))], | |
| prompt="prompt" | |
| ) | |
| MAX_SEED = np.iinfo(np.int32).max | |
| # Default prompt for light migration | |
| DEFAULT_PROMPT = "参考色调,移除图1原有的光照并参考图2的光照和色调对图1重新照明" | |
| def infer_light_migration( | |
| image: Optional[Image.Image] = None, | |
| light_source: Optional[Image.Image] = None, | |
| prompt: str = DEFAULT_PROMPT, | |
| seed: int = 0, | |
| randomize_seed: bool = True, | |
| true_guidance_scale: float = 1.0, | |
| num_inference_steps: int = 8, | |
| height: Optional[int] = None, | |
| width: Optional[int] = None, | |
| progress: Optional[gr.Progress] = gr.Progress(track_tqdm=True) | |
| ) -> Tuple[Image.Image, int]: | |
| """ | |
| Transfer lighting and color tones from a reference image to a source image | |
| using Qwen Image Edit 2509 with the Light Migration LoRA. | |
| Args: | |
| image (PIL.Image.Image | None, optional): | |
| The source image to relight. Defaults to None. | |
| light_source (PIL.Image.Image | None, optional): | |
| The reference image providing the lighting and color tones. Defaults to None. | |
| prompt (str, optional): | |
| The prompt describing the lighting transfer operation. | |
| Defaults to the Chinese prompt for light migration. | |
| seed (int, optional): | |
| Random seed for the generation. Ignored if `randomize_seed=True`. | |
| Defaults to 0. | |
| randomize_seed (bool, optional): | |
| If True, a random seed (0..MAX_SEED) is chosen per call. | |
| Defaults to True. | |
| true_guidance_scale (float, optional): | |
| CFG / guidance scale controlling prompt adherence. | |
| Defaults to 1.0 for the distilled transformer. | |
| num_inference_steps (int, optional): | |
| Number of inference steps. Defaults to 4. | |
| height (int, optional): | |
| Output image height. Must typically be a multiple of 8. | |
| If set to 0 or None, the model will infer a size. Defaults to None. | |
| width (int, optional): | |
| Output image width. Must typically be a multiple of 8. | |
| If set to 0 or None, the model will infer a size. Defaults to None. | |
| Returns: | |
| Tuple[PIL.Image.Image, int]: | |
| - The relit output image. | |
| - The actual seed used for generation. | |
| """ | |
| if image is None: | |
| raise gr.Error("Please upload a source image (Image 1).") | |
| if light_source is None: | |
| raise gr.Error("Please upload a light source reference image (Image 2).") | |
| if randomize_seed: | |
| seed = random.randint(0, MAX_SEED) | |
| generator = torch.Generator(device=device).manual_seed(seed) | |
| # Prepare images - Image 1 is source, Image 2 is light reference | |
| pil_images = [] | |
| if isinstance(image, Image.Image): | |
| pil_images.append(image.convert("RGB")) | |
| elif hasattr(image, "name"): | |
| pil_images.append(Image.open(image.name).convert("RGB")) | |
| if isinstance(light_source, Image.Image): | |
| pil_images.append(light_source.convert("RGB")) | |
| elif hasattr(light_source, "name"): | |
| pil_images.append(Image.open(light_source.name).convert("RGB")) | |
| result = pipe( | |
| image=pil_images, | |
| prompt=prompt, | |
| height=height if height and height != 0 else None, | |
| width=width if width and width != 0 else None, | |
| num_inference_steps=num_inference_steps, | |
| generator=generator, | |
| true_cfg_scale=true_guidance_scale, | |
| num_images_per_prompt=1, | |
| ).images[0] | |
| return result, seed | |
| def update_dimensions_on_upload( | |
| image: Optional[Image.Image] | |
| ) -> Tuple[int, int]: | |
| """ | |
| Compute recommended (width, height) for the output resolution when an | |
| image is uploaded while preserving the aspect ratio. | |
| Args: | |
| image (PIL.Image.Image | None): | |
| The uploaded image. If `None`, defaults to (1024, 1024). | |
| Returns: | |
| Tuple[int, int]: | |
| The new (width, height). | |
| """ | |
| if image is None: | |
| return 1024, 1024 | |
| original_width, original_height = image.size | |
| if original_width > original_height: | |
| new_width = 1024 | |
| aspect_ratio = original_height / original_width | |
| new_height = int(new_width * aspect_ratio) | |
| else: | |
| new_height = 1024 | |
| aspect_ratio = original_width / original_height | |
| new_width = int(new_height * aspect_ratio) | |
| # Ensure dimensions are multiples of 8 | |
| new_width = (new_width // 8) * 8 | |
| new_height = (new_height // 8) * 8 | |
| return new_width, new_height | |
| # --- UI --- | |
| css = ''' | |
| #col-container { max-width: 1000px; margin: 0 auto; } | |
| .dark .progress-text { color: white !important } | |
| #examples { max-width: 1000px; margin: 0 auto; } | |
| .image-container { min-height: 300px; } | |
| ''' | |
| with gr.Blocks() as demo: | |
| with gr.Column(elem_id="col-container"): | |
| gr.Markdown("## 💡 Qwen Image Edit — Light Migration") | |
| gr.Markdown(""" | |
| Transfer lighting and color tones from a reference image to your source image ✨ | |
| Using [dx8152's Qwen-Edit-2509-Light-Migration LoRA](https://huggingface.co/dx8152/Qwen-Edit-2509-Light-Migration) | |
| and [lightx2v/Qwen-Image-Lightning](https://huggingface.co/lightx2v/Qwen-Image-Lightning/tree/main) for 8-step inference 💨 | |
| """) | |
| with gr.Row(): | |
| with gr.Column(): | |
| with gr.Row(): | |
| image = gr.Image( | |
| label="Image 1 (Source - to be relit)", | |
| type="pil", | |
| elem_classes="image-container" | |
| ) | |
| light_source = gr.Image( | |
| label="Image 2 (Light Reference)", | |
| type="pil", | |
| elem_classes="image-container" | |
| ) | |
| run_btn = gr.Button("✨ Transfer Lighting", variant="primary", size="lg") | |
| with gr.Accordion("Advanced Settings", open=False): | |
| prompt = gr.Textbox( | |
| label="Prompt", | |
| value=DEFAULT_PROMPT, | |
| placeholder="Enter prompt for light migration...", | |
| lines=2 | |
| ) | |
| seed = gr.Slider( | |
| label="Seed", | |
| minimum=0, | |
| maximum=MAX_SEED, | |
| step=1, | |
| value=0 | |
| ) | |
| randomize_seed = gr.Checkbox( | |
| label="Randomize Seed", | |
| value=True | |
| ) | |
| true_guidance_scale = gr.Slider( | |
| label="True Guidance Scale", | |
| minimum=1.0, | |
| maximum=10.0, | |
| step=0.1, | |
| value=1.0 | |
| ) | |
| num_inference_steps = gr.Slider( | |
| label="Inference Steps", | |
| minimum=1, | |
| maximum=40, | |
| step=1, | |
| value=8 | |
| ) | |
| height = gr.Slider( | |
| label="Height", | |
| minimum=256, | |
| maximum=2048, | |
| step=8, | |
| value=1024 | |
| ) | |
| width = gr.Slider( | |
| label="Width", | |
| minimum=256, | |
| maximum=2048, | |
| step=8, | |
| value=1024 | |
| ) | |
| with gr.Column(): | |
| result = gr.Image(label="Output Image", interactive=False) | |
| # output_seed = gr.Number(label="Seed Used", interactive=False, visible=False) | |
| gr.Examples( | |
| examples=[ | |
| # Character 1 with 3 different lights | |
| ["character_1.png", "light_1.png"], | |
| ["character_1.png", "light_3.jpeg"], | |
| ["character_1.png", "light_5.png"], | |
| # Character 2 with 3 different lights | |
| ["character_2.png", "light_2.png"], | |
| ["character_2.png", "light_4.png"], | |
| ["character_2.png", "light_6.png"], | |
| # Place 1 with 3 different lights | |
| ["place_1.png", "light_1.png"], | |
| ["place_1.png", "light_4.png"], | |
| ["place_1.png", "light_6.png"], | |
| ], | |
| inputs=[ | |
| image, light_source | |
| ], | |
| outputs=[result, seed], | |
| fn=infer_light_migration, | |
| cache_examples=True, | |
| cache_mode="lazy", | |
| elem_id="examples" | |
| ) | |
| inputs = [ | |
| image, light_source, prompt, | |
| seed, randomize_seed, true_guidance_scale, | |
| num_inference_steps, height, width | |
| ] | |
| outputs = [result, seed] | |
| # Run button click | |
| run_btn.click( | |
| fn=infer_light_migration, | |
| inputs=inputs, | |
| outputs=outputs | |
| ) | |
| # Image upload triggers dimension update | |
| image.upload( | |
| fn=update_dimensions_on_upload, | |
| inputs=[image], | |
| outputs=[width, height] | |
| ) | |
| # API endpoint | |
| # gr.api(infer_light_migration, api_name="infer_light_migration") | |
| demo.launch(mcp_server=True, theme=gr.themes.Citrus(), css=css, footer_links=["api", "gradio", "settings"]) |