Spaces:
Running
on
Zero
Running
on
Zero
| """ | |
| UniPic-3 DMD Multi-Image Composition | |
| Hugging Face Space - ZeroGPU 优化版本 V5 | |
| 关键策略: | |
| 1. 全局只加载不需要 GPU 的组件(scheduler, tokenizer, processor) | |
| 2. 需要 GPU 的模型在 @spaces.GPU 内部加载,显式指定 device='cuda' | |
| 3. 不使用 device_map='auto',因为它可能在 ZeroGPU 外部被错误地分配 | |
| """ | |
| import gradio as gr | |
| import torch | |
| from PIL import Image | |
| import os | |
| import sys | |
| # Hugging Face Spaces GPU decorator | |
| try: | |
| import spaces | |
| HF_SPACES = True | |
| print("✅ Running in Hugging Face Spaces with ZeroGPU") | |
| except ImportError: | |
| HF_SPACES = False | |
| print("⚠️ Running locally (no ZeroGPU)") | |
| class spaces: | |
| def GPU(duration=60): | |
| def decorator(func): | |
| return func | |
| return decorator | |
| # Local pipeline import | |
| sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) | |
| # Model configuration | |
| MODEL_NAME = os.environ.get("MODEL_NAME", "Skywork/Unipic3-DMD") | |
| TRANSFORMER_PATH = os.environ.get("TRANSFORMER_PATH", "Skywork/Unipic3-DMD/ema_transformer") | |
| dtype = torch.bfloat16 | |
| # ============================================================ | |
| # 全局加载轻量级组件(不需要 GPU) | |
| # ============================================================ | |
| print("🚀 Loading lightweight components (CPU)...") | |
| from diffusers import ( | |
| FlowMatchEulerDiscreteScheduler, | |
| QwenImageTransformer2DModel, | |
| AutoencoderKLQwenImage | |
| ) | |
| from transformers import AutoModel, AutoTokenizer, Qwen2VLProcessor | |
| try: | |
| from pipeline_qwenimage_edit import QwenImageEditPipeline | |
| except ImportError: | |
| from diffusers import QwenImageEditPipeline | |
| # 这些组件不需要 GPU,可以在全局加载 | |
| scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained( | |
| MODEL_NAME, subfolder='scheduler' | |
| ) | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, subfolder='tokenizer') | |
| processor = Qwen2VLProcessor.from_pretrained(MODEL_NAME, subfolder='processor') | |
| print("✅ Lightweight components loaded!") | |
| # ============================================================ | |
| # Pipeline 状态 | |
| # ============================================================ | |
| pipe = None | |
| _models_loaded = False | |
| # ============================================================ | |
| # GPU 推理函数 - 模型在这里加载 | |
| # ============================================================ | |
| def generate_image( | |
| images: list[Image.Image], | |
| prompt: str, | |
| true_cfg_scale: float, | |
| seed: int, | |
| num_steps: int | |
| ) -> Image.Image: | |
| """ | |
| GPU 推理函数 | |
| 关键:所有需要 GPU 的模型都在这里加载,确保在真实 GPU 环境中 | |
| """ | |
| global pipe, _models_loaded | |
| print(f"🎨 Generating with {len(images)} image(s)...") | |
| print(f" Prompt: {prompt[:50]}...") | |
| print(f" Steps: {num_steps}, CFG: {true_cfg_scale}, Seed: {seed}") | |
| # 在真实 GPU 环境中加载模型(首次调用时) | |
| if not _models_loaded: | |
| print(" [INIT] Loading models on real GPU...") | |
| device = 'cuda' | |
| # 加载 text_encoder 到 GPU | |
| print(" [INIT] Loading text_encoder...") | |
| text_encoder = AutoModel.from_pretrained( | |
| MODEL_NAME, | |
| subfolder='text_encoder', | |
| torch_dtype=dtype, | |
| ).to(device).eval() | |
| # 加载 transformer 到 GPU | |
| print(" [INIT] Loading transformer...") | |
| if os.path.exists(TRANSFORMER_PATH) and os.path.isdir(TRANSFORMER_PATH): | |
| config_path = os.path.join(TRANSFORMER_PATH, "config.json") | |
| if os.path.exists(config_path): | |
| transformer = QwenImageTransformer2DModel.from_pretrained( | |
| TRANSFORMER_PATH, | |
| torch_dtype=dtype, | |
| use_safetensors=False | |
| ).to(device).eval() | |
| else: | |
| transformer = QwenImageTransformer2DModel.from_pretrained( | |
| TRANSFORMER_PATH, | |
| subfolder='transformer', | |
| torch_dtype=dtype, | |
| use_safetensors=False | |
| ).to(device).eval() | |
| else: | |
| path_parts = TRANSFORMER_PATH.split('/') | |
| if len(path_parts) >= 3: | |
| repo_id = '/'.join(path_parts[:2]) | |
| subfolder = '/'.join(path_parts[2:]) | |
| transformer = QwenImageTransformer2DModel.from_pretrained( | |
| repo_id, | |
| subfolder=subfolder, | |
| torch_dtype=dtype, | |
| use_safetensors=False | |
| ).to(device).eval() | |
| else: | |
| transformer = QwenImageTransformer2DModel.from_pretrained( | |
| TRANSFORMER_PATH, | |
| subfolder='transformer', | |
| torch_dtype=dtype, | |
| use_safetensors=False | |
| ).to(device).eval() | |
| # 加载 VAE 到 GPU | |
| print(" [INIT] Loading VAE...") | |
| vae = AutoencoderKLQwenImage.from_pretrained( | |
| MODEL_NAME, | |
| subfolder='vae', | |
| torch_dtype=dtype, | |
| ).to(device).eval() | |
| # 创建 Pipeline | |
| print(" [INIT] Creating pipeline...") | |
| pipe = QwenImageEditPipeline( | |
| scheduler=scheduler, | |
| vae=vae, | |
| text_encoder=text_encoder, | |
| tokenizer=tokenizer, | |
| processor=processor, | |
| transformer=transformer | |
| ) | |
| _models_loaded = True | |
| print(" [INIT] ✅ Models loaded successfully!") | |
| # 验证设备 | |
| print(f" [DEBUG] text_encoder device: {next(pipe.text_encoder.parameters()).device}") | |
| print(f" [DEBUG] transformer device: {next(pipe.transformer.parameters()).device}") | |
| print(f" [DEBUG] vae device: {next(pipe.vae.parameters()).device}") | |
| # Generate | |
| with torch.no_grad(): | |
| generator = torch.Generator(device='cuda').manual_seed(int(seed)) | |
| if len(images) == 1: | |
| result = pipe( | |
| images[0], | |
| prompt=prompt, | |
| height=1024, | |
| width=1024, | |
| negative_prompt=' ', | |
| num_inference_steps=num_steps, | |
| true_cfg_scale=true_cfg_scale, | |
| generator=generator | |
| ).images[0] | |
| else: | |
| result = pipe( | |
| images=images, | |
| prompt=prompt, | |
| height=1024, | |
| width=1024, | |
| negative_prompt=' ', | |
| num_inference_steps=num_steps, | |
| true_cfg_scale=true_cfg_scale, | |
| generator=generator | |
| ).images[0] | |
| print("✅ Generation complete!") | |
| return result | |
| # ============================================================ | |
| # UI 逻辑(CPU,始终可用) | |
| # ============================================================ | |
| def process_images( | |
| img1, img2, img3, img4, img5, img6, | |
| prompt: str, | |
| cfg_scale: float, | |
| seed: int, | |
| num_steps: int | |
| ): | |
| """处理图像 - 验证输入后调用 GPU 函数""" | |
| images = [img for img in [img1, img2, img3, img4, img5, img6] if img is not None] | |
| if len(images) == 0: | |
| return None, "❌ Please upload at least one image" | |
| if len(images) > 6: | |
| return None, f"❌ Maximum 6 images allowed (got {len(images)})" | |
| if not prompt or prompt.strip() == "": | |
| return None, "❌ Please enter an editing instruction" | |
| try: | |
| images = [img.convert("RGB") for img in images] | |
| result = generate_image( | |
| images=images, | |
| prompt=prompt, | |
| true_cfg_scale=cfg_scale, | |
| seed=seed, | |
| num_steps=num_steps | |
| ) | |
| return result, f"✅ Generated from {len(images)} image(s) in {num_steps} steps" | |
| except Exception as e: | |
| import traceback | |
| traceback.print_exc() | |
| return None, f"❌ Error: {str(e)}" | |
| def update_image_visibility(num): | |
| return [gr.update(visible=(i < num)) for i in range(6)] | |
| # ============================================================ | |
| # 自定义 CSS | |
| # ============================================================ | |
| CUSTOM_CSS = """ | |
| @import url('https://fonts.googleapis.com/css2?family=Outfit:wght@300;400;500;600;700&family=JetBrains+Mono:wght@400;500&display=swap'); | |
| :root { | |
| --primary: #6366f1; | |
| --primary-dark: #4f46e5; | |
| --accent: #f472b6; | |
| --surface: #0f0f23; | |
| --surface-light: #1a1a3e; | |
| --surface-elevated: #252552; | |
| --text: #e2e8f0; | |
| --text-muted: #94a3b8; | |
| --border: #334155; | |
| --success: #10b981; | |
| --error: #ef4444; | |
| --gradient-1: linear-gradient(135deg, #667eea 0%, #764ba2 100%); | |
| --gradient-hero: linear-gradient(135deg, #0f0f23 0%, #1a1a3e 50%, #252552 100%); | |
| } | |
| .gradio-container { | |
| font-family: 'Outfit', sans-serif !important; | |
| background: var(--gradient-hero) !important; | |
| min-height: 100vh; | |
| } | |
| .main-header { | |
| text-align: center; | |
| padding: 2rem 1rem; | |
| background: linear-gradient(180deg, rgba(99, 102, 241, 0.1) 0%, transparent 100%); | |
| border-radius: 24px; | |
| margin-bottom: 2rem; | |
| border: 1px solid rgba(99, 102, 241, 0.2); | |
| } | |
| .main-header h1 { | |
| font-size: 2.5rem; | |
| font-weight: 700; | |
| background: linear-gradient(135deg, #fff 0%, #a5b4fc 50%, #f472b6 100%); | |
| -webkit-background-clip: text; | |
| -webkit-text-fill-color: transparent; | |
| background-clip: text; | |
| margin-bottom: 0.5rem; | |
| } | |
| .main-header p { | |
| color: var(--text-muted); | |
| font-size: 1.1rem; | |
| max-width: 600px; | |
| margin: 0 auto; | |
| } | |
| .feature-badges { | |
| display: flex; | |
| gap: 1rem; | |
| justify-content: center; | |
| flex-wrap: wrap; | |
| margin-top: 1.5rem; | |
| } | |
| .badge { | |
| display: inline-flex; | |
| align-items: center; | |
| gap: 0.5rem; | |
| padding: 0.5rem 1rem; | |
| background: rgba(99, 102, 241, 0.15); | |
| border: 1px solid rgba(99, 102, 241, 0.3); | |
| border-radius: 9999px; | |
| color: #a5b4fc; | |
| font-size: 0.875rem; | |
| font-weight: 500; | |
| } | |
| .section-header { | |
| display: flex; | |
| align-items: center; | |
| gap: 0.75rem; | |
| margin-bottom: 1rem; | |
| padding-bottom: 0.75rem; | |
| border-bottom: 1px solid var(--border); | |
| } | |
| .section-header h3 { | |
| font-size: 1.125rem; | |
| font-weight: 600; | |
| color: var(--text); | |
| margin: 0; | |
| } | |
| .generate-btn { | |
| background: var(--gradient-1) !important; | |
| border: none !important; | |
| border-radius: 12px !important; | |
| padding: 1rem 2rem !important; | |
| font-size: 1.1rem !important; | |
| font-weight: 600 !important; | |
| color: white !important; | |
| cursor: pointer !important; | |
| transition: all 0.3s ease !important; | |
| box-shadow: 0 4px 15px rgba(99, 102, 241, 0.4) !important; | |
| } | |
| .generate-btn:hover { | |
| transform: translateY(-2px) !important; | |
| box-shadow: 0 6px 20px rgba(99, 102, 241, 0.5) !important; | |
| } | |
| .output-image { | |
| border-radius: 16px; | |
| overflow: hidden; | |
| border: 2px solid transparent; | |
| background: linear-gradient(var(--surface-light), var(--surface-light)) padding-box, | |
| var(--gradient-1) border-box; | |
| } | |
| @media (max-width: 768px) { | |
| .main-header h1 { font-size: 1.75rem; } | |
| .feature-badges { flex-direction: column; align-items: center; } | |
| } | |
| """ | |
| # ============================================================ | |
| # 构建 Gradio 界面 | |
| # ============================================================ | |
| def create_demo(): | |
| with gr.Blocks( | |
| title="UniPic-3 DMD", | |
| theme=gr.themes.Base( | |
| primary_hue="indigo", | |
| secondary_hue="pink", | |
| neutral_hue="slate", | |
| font=("Outfit", "sans-serif"), | |
| ), | |
| css=CUSTOM_CSS | |
| ) as demo: | |
| gr.HTML(""" | |
| <div class="main-header"> | |
| <h1>🎨 UniPic-3 DMD</h1> | |
| <p>Multi-Image Composition with Distribution-Matching Distillation</p> | |
| <div class="feature-badges"> | |
| <span class="badge">⚡ 8-Step Fast Inference</span> | |
| <span class="badge">🖼️ Up to 6 Images</span> | |
| <span class="badge">🚀 12.5× Speedup</span> | |
| </div> | |
| </div> | |
| """) | |
| with gr.Row(equal_height=True): | |
| with gr.Column(scale=1): | |
| gr.HTML('<div class="section-header"><span>📸</span><h3>Upload Images</h3></div>') | |
| num_images = gr.Slider(minimum=1, maximum=6, value=2, step=1, | |
| label="Number of Images", info="Select how many images to compose") | |
| with gr.Row(): | |
| img1 = gr.Image(type="pil", label="Image 1", visible=True) | |
| img2 = gr.Image(type="pil", label="Image 2", visible=True) | |
| with gr.Row(): | |
| img3 = gr.Image(type="pil", label="Image 3", visible=False) | |
| img4 = gr.Image(type="pil", label="Image 4", visible=False) | |
| with gr.Row(): | |
| img5 = gr.Image(type="pil", label="Image 5", visible=False) | |
| img6 = gr.Image(type="pil", label="Image 6", visible=False) | |
| image_inputs = [img1, img2, img3, img4, img5, img6] | |
| num_images.change(fn=update_image_visibility, inputs=num_images, outputs=image_inputs) | |
| gr.HTML('<div class="section-header"><span>✍️</span><h3>Editing Instruction</h3></div>') | |
| prompt_input = gr.Textbox( | |
| label="Prompt", | |
| placeholder="e.g., A man from Image1 standing on a surfboard from Image2...", | |
| lines=3, | |
| value="Combine the reference images to generate the final result." | |
| ) | |
| with gr.Accordion("⚙️ Advanced Settings", open=False): | |
| cfg_scale = gr.Slider(minimum=1.0, maximum=10.0, value=4.0, step=0.5, | |
| label="CFG Scale", info="Higher = more prompt alignment") | |
| with gr.Row(): | |
| seed = gr.Number(value=42, label="Seed", info="For reproducibility", precision=0) | |
| num_steps = gr.Slider(minimum=1, maximum=8, value=8, step=1, | |
| label="Steps", info="8 recommended for DMD") | |
| generate_btn = gr.Button("🚀 Generate Image", variant="primary", size="lg", | |
| elem_classes=["generate-btn"]) | |
| with gr.Column(scale=1): | |
| gr.HTML('<div class="section-header"><span>🎨</span><h3>Generated Result</h3></div>') | |
| output_image = gr.Image(type="pil", label="Output", elem_classes=["output-image"]) | |
| status_text = gr.Textbox( | |
| label="Status", | |
| value="✨ Ready! First run takes ~60s to load models.", | |
| interactive=False, | |
| ) | |
| gr.HTML(""" | |
| <div style="margin-top: 1.5rem; padding: 1rem; background: rgba(99, 102, 241, 0.1); | |
| border-radius: 12px; border: 1px solid rgba(99, 102, 241, 0.2);"> | |
| <p style="color: #ffffff; font-weight: 600; margin-bottom: 0.5rem;">💡 Tips</p> | |
| <ul style="color: #ffffff; font-size: 0.9rem; margin: 0; padding-left: 1.25rem;"> | |
| <li>Reference images as "Image1", "Image2", etc.</li> | |
| <li>First run loads models (~60s)</li> | |
| </ul> | |
| </div> | |
| """) | |
| generate_btn.click( | |
| fn=process_images, | |
| inputs=[*image_inputs, prompt_input, cfg_scale, seed, num_steps], | |
| outputs=[output_image, status_text] | |
| ) | |
| gr.HTML('<div class="section-header" style="margin-top: 2rem;"><span>📚</span><h3>Example Prompts</h3></div>') | |
| gr.Examples( | |
| examples=[ | |
| ["A person from Image1 wearing the outfit from Image2"], | |
| ["Combine Image1 and Image2 into a single cohesive scene"], | |
| ["The object from Image1 placed in the environment from Image2"], | |
| ], | |
| inputs=[prompt_input], | |
| label="" | |
| ) | |
| return demo | |
| demo = create_demo() | |
| if __name__ == "__main__": | |
| demo.launch() |