""" UniPic-3 DMD Multi-Image Composition Hugging Face Space - ZeroGPU 优化版本 V5 关键策略: 1. 全局只加载不需要 GPU 的组件(scheduler, tokenizer, processor) 2. 需要 GPU 的模型在 @spaces.GPU 内部加载,显式指定 device='cuda' 3. 不使用 device_map='auto',因为它可能在 ZeroGPU 外部被错误地分配 """ import gradio as gr import torch from PIL import Image import os import sys # Hugging Face Spaces GPU decorator try: import spaces HF_SPACES = True print("✅ Running in Hugging Face Spaces with ZeroGPU") except ImportError: HF_SPACES = False print("⚠️ Running locally (no ZeroGPU)") class spaces: @staticmethod def GPU(duration=60): def decorator(func): return func return decorator # Local pipeline import sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) # Model configuration MODEL_NAME = os.environ.get("MODEL_NAME", "Skywork/Unipic3-DMD") TRANSFORMER_PATH = os.environ.get("TRANSFORMER_PATH", "Skywork/Unipic3-DMD/ema_transformer") dtype = torch.bfloat16 # ============================================================ # 全局加载轻量级组件(不需要 GPU) # ============================================================ print("🚀 Loading lightweight components (CPU)...") from diffusers import ( FlowMatchEulerDiscreteScheduler, QwenImageTransformer2DModel, AutoencoderKLQwenImage ) from transformers import AutoModel, AutoTokenizer, Qwen2VLProcessor try: from pipeline_qwenimage_edit import QwenImageEditPipeline except ImportError: from diffusers import QwenImageEditPipeline # 这些组件不需要 GPU,可以在全局加载 scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained( MODEL_NAME, subfolder='scheduler' ) tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, subfolder='tokenizer') processor = Qwen2VLProcessor.from_pretrained(MODEL_NAME, subfolder='processor') print("✅ Lightweight components loaded!") # ============================================================ # Pipeline 状态 # ============================================================ pipe = None _models_loaded = False # ============================================================ # GPU 推理函数 - 模型在这里加载 # ============================================================ @spaces.GPU(duration=180) def generate_image( images: list[Image.Image], prompt: str, true_cfg_scale: float, seed: int, num_steps: int ) -> Image.Image: """ GPU 推理函数 关键:所有需要 GPU 的模型都在这里加载,确保在真实 GPU 环境中 """ global pipe, _models_loaded print(f"🎨 Generating with {len(images)} image(s)...") print(f" Prompt: {prompt[:50]}...") print(f" Steps: {num_steps}, CFG: {true_cfg_scale}, Seed: {seed}") # 在真实 GPU 环境中加载模型(首次调用时) if not _models_loaded: print(" [INIT] Loading models on real GPU...") device = 'cuda' # 加载 text_encoder 到 GPU print(" [INIT] Loading text_encoder...") text_encoder = AutoModel.from_pretrained( MODEL_NAME, subfolder='text_encoder', torch_dtype=dtype, ).to(device).eval() # 加载 transformer 到 GPU print(" [INIT] Loading transformer...") if os.path.exists(TRANSFORMER_PATH) and os.path.isdir(TRANSFORMER_PATH): config_path = os.path.join(TRANSFORMER_PATH, "config.json") if os.path.exists(config_path): transformer = QwenImageTransformer2DModel.from_pretrained( TRANSFORMER_PATH, torch_dtype=dtype, use_safetensors=False ).to(device).eval() else: transformer = QwenImageTransformer2DModel.from_pretrained( TRANSFORMER_PATH, subfolder='transformer', torch_dtype=dtype, use_safetensors=False ).to(device).eval() else: path_parts = TRANSFORMER_PATH.split('/') if len(path_parts) >= 3: repo_id = '/'.join(path_parts[:2]) subfolder = '/'.join(path_parts[2:]) transformer = QwenImageTransformer2DModel.from_pretrained( repo_id, subfolder=subfolder, torch_dtype=dtype, use_safetensors=False ).to(device).eval() else: transformer = QwenImageTransformer2DModel.from_pretrained( TRANSFORMER_PATH, subfolder='transformer', torch_dtype=dtype, use_safetensors=False ).to(device).eval() # 加载 VAE 到 GPU print(" [INIT] Loading VAE...") vae = AutoencoderKLQwenImage.from_pretrained( MODEL_NAME, subfolder='vae', torch_dtype=dtype, ).to(device).eval() # 创建 Pipeline print(" [INIT] Creating pipeline...") pipe = QwenImageEditPipeline( scheduler=scheduler, vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, processor=processor, transformer=transformer ) _models_loaded = True print(" [INIT] ✅ Models loaded successfully!") # 验证设备 print(f" [DEBUG] text_encoder device: {next(pipe.text_encoder.parameters()).device}") print(f" [DEBUG] transformer device: {next(pipe.transformer.parameters()).device}") print(f" [DEBUG] vae device: {next(pipe.vae.parameters()).device}") # Generate with torch.no_grad(): generator = torch.Generator(device='cuda').manual_seed(int(seed)) if len(images) == 1: result = pipe( images[0], prompt=prompt, height=1024, width=1024, negative_prompt=' ', num_inference_steps=num_steps, true_cfg_scale=true_cfg_scale, generator=generator ).images[0] else: result = pipe( images=images, prompt=prompt, height=1024, width=1024, negative_prompt=' ', num_inference_steps=num_steps, true_cfg_scale=true_cfg_scale, generator=generator ).images[0] print("✅ Generation complete!") return result # ============================================================ # UI 逻辑(CPU,始终可用) # ============================================================ def process_images( img1, img2, img3, img4, img5, img6, prompt: str, cfg_scale: float, seed: int, num_steps: int ): """处理图像 - 验证输入后调用 GPU 函数""" images = [img for img in [img1, img2, img3, img4, img5, img6] if img is not None] if len(images) == 0: return None, "❌ Please upload at least one image" if len(images) > 6: return None, f"❌ Maximum 6 images allowed (got {len(images)})" if not prompt or prompt.strip() == "": return None, "❌ Please enter an editing instruction" try: images = [img.convert("RGB") for img in images] result = generate_image( images=images, prompt=prompt, true_cfg_scale=cfg_scale, seed=seed, num_steps=num_steps ) return result, f"✅ Generated from {len(images)} image(s) in {num_steps} steps" except Exception as e: import traceback traceback.print_exc() return None, f"❌ Error: {str(e)}" def update_image_visibility(num): return [gr.update(visible=(i < num)) for i in range(6)] # ============================================================ # 自定义 CSS # ============================================================ CUSTOM_CSS = """ @import url('https://fonts.googleapis.com/css2?family=Outfit:wght@300;400;500;600;700&family=JetBrains+Mono:wght@400;500&display=swap'); :root { --primary: #6366f1; --primary-dark: #4f46e5; --accent: #f472b6; --surface: #0f0f23; --surface-light: #1a1a3e; --surface-elevated: #252552; --text: #e2e8f0; --text-muted: #94a3b8; --border: #334155; --success: #10b981; --error: #ef4444; --gradient-1: linear-gradient(135deg, #667eea 0%, #764ba2 100%); --gradient-hero: linear-gradient(135deg, #0f0f23 0%, #1a1a3e 50%, #252552 100%); } .gradio-container { font-family: 'Outfit', sans-serif !important; background: var(--gradient-hero) !important; min-height: 100vh; } .main-header { text-align: center; padding: 2rem 1rem; background: linear-gradient(180deg, rgba(99, 102, 241, 0.1) 0%, transparent 100%); border-radius: 24px; margin-bottom: 2rem; border: 1px solid rgba(99, 102, 241, 0.2); } .main-header h1 { font-size: 2.5rem; font-weight: 700; background: linear-gradient(135deg, #fff 0%, #a5b4fc 50%, #f472b6 100%); -webkit-background-clip: text; -webkit-text-fill-color: transparent; background-clip: text; margin-bottom: 0.5rem; } .main-header p { color: var(--text-muted); font-size: 1.1rem; max-width: 600px; margin: 0 auto; } .feature-badges { display: flex; gap: 1rem; justify-content: center; flex-wrap: wrap; margin-top: 1.5rem; } .badge { display: inline-flex; align-items: center; gap: 0.5rem; padding: 0.5rem 1rem; background: rgba(99, 102, 241, 0.15); border: 1px solid rgba(99, 102, 241, 0.3); border-radius: 9999px; color: #a5b4fc; font-size: 0.875rem; font-weight: 500; } .section-header { display: flex; align-items: center; gap: 0.75rem; margin-bottom: 1rem; padding-bottom: 0.75rem; border-bottom: 1px solid var(--border); } .section-header h3 { font-size: 1.125rem; font-weight: 600; color: var(--text); margin: 0; } .generate-btn { background: var(--gradient-1) !important; border: none !important; border-radius: 12px !important; padding: 1rem 2rem !important; font-size: 1.1rem !important; font-weight: 600 !important; color: white !important; cursor: pointer !important; transition: all 0.3s ease !important; box-shadow: 0 4px 15px rgba(99, 102, 241, 0.4) !important; } .generate-btn:hover { transform: translateY(-2px) !important; box-shadow: 0 6px 20px rgba(99, 102, 241, 0.5) !important; } .output-image { border-radius: 16px; overflow: hidden; border: 2px solid transparent; background: linear-gradient(var(--surface-light), var(--surface-light)) padding-box, var(--gradient-1) border-box; } @media (max-width: 768px) { .main-header h1 { font-size: 1.75rem; } .feature-badges { flex-direction: column; align-items: center; } } """ # ============================================================ # 构建 Gradio 界面 # ============================================================ def create_demo(): with gr.Blocks( title="UniPic-3 DMD", theme=gr.themes.Base( primary_hue="indigo", secondary_hue="pink", neutral_hue="slate", font=("Outfit", "sans-serif"), ), css=CUSTOM_CSS ) as demo: gr.HTML("""

🎨 UniPic-3 DMD

Multi-Image Composition with Distribution-Matching Distillation

⚡ 8-Step Fast Inference 🖼️ Up to 6 Images 🚀 12.5× Speedup
""") with gr.Row(equal_height=True): with gr.Column(scale=1): gr.HTML('
📸

Upload Images

') num_images = gr.Slider(minimum=1, maximum=6, value=2, step=1, label="Number of Images", info="Select how many images to compose") with gr.Row(): img1 = gr.Image(type="pil", label="Image 1", visible=True) img2 = gr.Image(type="pil", label="Image 2", visible=True) with gr.Row(): img3 = gr.Image(type="pil", label="Image 3", visible=False) img4 = gr.Image(type="pil", label="Image 4", visible=False) with gr.Row(): img5 = gr.Image(type="pil", label="Image 5", visible=False) img6 = gr.Image(type="pil", label="Image 6", visible=False) image_inputs = [img1, img2, img3, img4, img5, img6] num_images.change(fn=update_image_visibility, inputs=num_images, outputs=image_inputs) gr.HTML('
✍️

Editing Instruction

') prompt_input = gr.Textbox( label="Prompt", placeholder="e.g., A man from Image1 standing on a surfboard from Image2...", lines=3, value="Combine the reference images to generate the final result." ) with gr.Accordion("⚙️ Advanced Settings", open=False): cfg_scale = gr.Slider(minimum=1.0, maximum=10.0, value=4.0, step=0.5, label="CFG Scale", info="Higher = more prompt alignment") with gr.Row(): seed = gr.Number(value=42, label="Seed", info="For reproducibility", precision=0) num_steps = gr.Slider(minimum=1, maximum=8, value=8, step=1, label="Steps", info="8 recommended for DMD") generate_btn = gr.Button("🚀 Generate Image", variant="primary", size="lg", elem_classes=["generate-btn"]) with gr.Column(scale=1): gr.HTML('
🎨

Generated Result

') output_image = gr.Image(type="pil", label="Output", elem_classes=["output-image"]) status_text = gr.Textbox( label="Status", value="✨ Ready! First run takes ~60s to load models.", interactive=False, ) gr.HTML("""

💡 Tips

""") generate_btn.click( fn=process_images, inputs=[*image_inputs, prompt_input, cfg_scale, seed, num_steps], outputs=[output_image, status_text] ) gr.HTML('
📚

Example Prompts

') gr.Examples( examples=[ ["A person from Image1 wearing the outfit from Image2"], ["Combine Image1 and Image2 into a single cohesive scene"], ["The object from Image1 placed in the environment from Image2"], ], inputs=[prompt_input], label="" ) return demo demo = create_demo() if __name__ == "__main__": demo.launch()