Unipic3 / app.py
OrlandoHugBot's picture
Update app.py
dd441ec verified
"""
UniPic-3 DMD Multi-Image Composition
Hugging Face Space - ZeroGPU 优化版本 V5
关键策略:
1. 全局只加载不需要 GPU 的组件(scheduler, tokenizer, processor)
2. 需要 GPU 的模型在 @spaces.GPU 内部加载,显式指定 device='cuda'
3. 不使用 device_map='auto',因为它可能在 ZeroGPU 外部被错误地分配
"""
import gradio as gr
import torch
from PIL import Image
import os
import sys
# Hugging Face Spaces GPU decorator
try:
import spaces
HF_SPACES = True
print("✅ Running in Hugging Face Spaces with ZeroGPU")
except ImportError:
HF_SPACES = False
print("⚠️ Running locally (no ZeroGPU)")
class spaces:
@staticmethod
def GPU(duration=60):
def decorator(func):
return func
return decorator
# Local pipeline import
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
# Model configuration
MODEL_NAME = os.environ.get("MODEL_NAME", "Skywork/Unipic3-DMD")
TRANSFORMER_PATH = os.environ.get("TRANSFORMER_PATH", "Skywork/Unipic3-DMD/ema_transformer")
dtype = torch.bfloat16
# ============================================================
# 全局加载轻量级组件(不需要 GPU)
# ============================================================
print("🚀 Loading lightweight components (CPU)...")
from diffusers import (
FlowMatchEulerDiscreteScheduler,
QwenImageTransformer2DModel,
AutoencoderKLQwenImage
)
from transformers import AutoModel, AutoTokenizer, Qwen2VLProcessor
try:
from pipeline_qwenimage_edit import QwenImageEditPipeline
except ImportError:
from diffusers import QwenImageEditPipeline
# 这些组件不需要 GPU,可以在全局加载
scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(
MODEL_NAME, subfolder='scheduler'
)
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, subfolder='tokenizer')
processor = Qwen2VLProcessor.from_pretrained(MODEL_NAME, subfolder='processor')
print("✅ Lightweight components loaded!")
# ============================================================
# Pipeline 状态
# ============================================================
pipe = None
_models_loaded = False
# ============================================================
# GPU 推理函数 - 模型在这里加载
# ============================================================
@spaces.GPU(duration=180)
def generate_image(
images: list[Image.Image],
prompt: str,
true_cfg_scale: float,
seed: int,
num_steps: int
) -> Image.Image:
"""
GPU 推理函数
关键:所有需要 GPU 的模型都在这里加载,确保在真实 GPU 环境中
"""
global pipe, _models_loaded
print(f"🎨 Generating with {len(images)} image(s)...")
print(f" Prompt: {prompt[:50]}...")
print(f" Steps: {num_steps}, CFG: {true_cfg_scale}, Seed: {seed}")
# 在真实 GPU 环境中加载模型(首次调用时)
if not _models_loaded:
print(" [INIT] Loading models on real GPU...")
device = 'cuda'
# 加载 text_encoder 到 GPU
print(" [INIT] Loading text_encoder...")
text_encoder = AutoModel.from_pretrained(
MODEL_NAME,
subfolder='text_encoder',
torch_dtype=dtype,
).to(device).eval()
# 加载 transformer 到 GPU
print(" [INIT] Loading transformer...")
if os.path.exists(TRANSFORMER_PATH) and os.path.isdir(TRANSFORMER_PATH):
config_path = os.path.join(TRANSFORMER_PATH, "config.json")
if os.path.exists(config_path):
transformer = QwenImageTransformer2DModel.from_pretrained(
TRANSFORMER_PATH,
torch_dtype=dtype,
use_safetensors=False
).to(device).eval()
else:
transformer = QwenImageTransformer2DModel.from_pretrained(
TRANSFORMER_PATH,
subfolder='transformer',
torch_dtype=dtype,
use_safetensors=False
).to(device).eval()
else:
path_parts = TRANSFORMER_PATH.split('/')
if len(path_parts) >= 3:
repo_id = '/'.join(path_parts[:2])
subfolder = '/'.join(path_parts[2:])
transformer = QwenImageTransformer2DModel.from_pretrained(
repo_id,
subfolder=subfolder,
torch_dtype=dtype,
use_safetensors=False
).to(device).eval()
else:
transformer = QwenImageTransformer2DModel.from_pretrained(
TRANSFORMER_PATH,
subfolder='transformer',
torch_dtype=dtype,
use_safetensors=False
).to(device).eval()
# 加载 VAE 到 GPU
print(" [INIT] Loading VAE...")
vae = AutoencoderKLQwenImage.from_pretrained(
MODEL_NAME,
subfolder='vae',
torch_dtype=dtype,
).to(device).eval()
# 创建 Pipeline
print(" [INIT] Creating pipeline...")
pipe = QwenImageEditPipeline(
scheduler=scheduler,
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
processor=processor,
transformer=transformer
)
_models_loaded = True
print(" [INIT] ✅ Models loaded successfully!")
# 验证设备
print(f" [DEBUG] text_encoder device: {next(pipe.text_encoder.parameters()).device}")
print(f" [DEBUG] transformer device: {next(pipe.transformer.parameters()).device}")
print(f" [DEBUG] vae device: {next(pipe.vae.parameters()).device}")
# Generate
with torch.no_grad():
generator = torch.Generator(device='cuda').manual_seed(int(seed))
if len(images) == 1:
result = pipe(
images[0],
prompt=prompt,
height=1024,
width=1024,
negative_prompt=' ',
num_inference_steps=num_steps,
true_cfg_scale=true_cfg_scale,
generator=generator
).images[0]
else:
result = pipe(
images=images,
prompt=prompt,
height=1024,
width=1024,
negative_prompt=' ',
num_inference_steps=num_steps,
true_cfg_scale=true_cfg_scale,
generator=generator
).images[0]
print("✅ Generation complete!")
return result
# ============================================================
# UI 逻辑(CPU,始终可用)
# ============================================================
def process_images(
img1, img2, img3, img4, img5, img6,
prompt: str,
cfg_scale: float,
seed: int,
num_steps: int
):
"""处理图像 - 验证输入后调用 GPU 函数"""
images = [img for img in [img1, img2, img3, img4, img5, img6] if img is not None]
if len(images) == 0:
return None, "❌ Please upload at least one image"
if len(images) > 6:
return None, f"❌ Maximum 6 images allowed (got {len(images)})"
if not prompt or prompt.strip() == "":
return None, "❌ Please enter an editing instruction"
try:
images = [img.convert("RGB") for img in images]
result = generate_image(
images=images,
prompt=prompt,
true_cfg_scale=cfg_scale,
seed=seed,
num_steps=num_steps
)
return result, f"✅ Generated from {len(images)} image(s) in {num_steps} steps"
except Exception as e:
import traceback
traceback.print_exc()
return None, f"❌ Error: {str(e)}"
def update_image_visibility(num):
return [gr.update(visible=(i < num)) for i in range(6)]
# ============================================================
# 自定义 CSS
# ============================================================
CUSTOM_CSS = """
@import url('https://fonts.googleapis.com/css2?family=Outfit:wght@300;400;500;600;700&family=JetBrains+Mono:wght@400;500&display=swap');
:root {
--primary: #6366f1;
--primary-dark: #4f46e5;
--accent: #f472b6;
--surface: #0f0f23;
--surface-light: #1a1a3e;
--surface-elevated: #252552;
--text: #e2e8f0;
--text-muted: #94a3b8;
--border: #334155;
--success: #10b981;
--error: #ef4444;
--gradient-1: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
--gradient-hero: linear-gradient(135deg, #0f0f23 0%, #1a1a3e 50%, #252552 100%);
}
.gradio-container {
font-family: 'Outfit', sans-serif !important;
background: var(--gradient-hero) !important;
min-height: 100vh;
}
.main-header {
text-align: center;
padding: 2rem 1rem;
background: linear-gradient(180deg, rgba(99, 102, 241, 0.1) 0%, transparent 100%);
border-radius: 24px;
margin-bottom: 2rem;
border: 1px solid rgba(99, 102, 241, 0.2);
}
.main-header h1 {
font-size: 2.5rem;
font-weight: 700;
background: linear-gradient(135deg, #fff 0%, #a5b4fc 50%, #f472b6 100%);
-webkit-background-clip: text;
-webkit-text-fill-color: transparent;
background-clip: text;
margin-bottom: 0.5rem;
}
.main-header p {
color: var(--text-muted);
font-size: 1.1rem;
max-width: 600px;
margin: 0 auto;
}
.feature-badges {
display: flex;
gap: 1rem;
justify-content: center;
flex-wrap: wrap;
margin-top: 1.5rem;
}
.badge {
display: inline-flex;
align-items: center;
gap: 0.5rem;
padding: 0.5rem 1rem;
background: rgba(99, 102, 241, 0.15);
border: 1px solid rgba(99, 102, 241, 0.3);
border-radius: 9999px;
color: #a5b4fc;
font-size: 0.875rem;
font-weight: 500;
}
.section-header {
display: flex;
align-items: center;
gap: 0.75rem;
margin-bottom: 1rem;
padding-bottom: 0.75rem;
border-bottom: 1px solid var(--border);
}
.section-header h3 {
font-size: 1.125rem;
font-weight: 600;
color: var(--text);
margin: 0;
}
.generate-btn {
background: var(--gradient-1) !important;
border: none !important;
border-radius: 12px !important;
padding: 1rem 2rem !important;
font-size: 1.1rem !important;
font-weight: 600 !important;
color: white !important;
cursor: pointer !important;
transition: all 0.3s ease !important;
box-shadow: 0 4px 15px rgba(99, 102, 241, 0.4) !important;
}
.generate-btn:hover {
transform: translateY(-2px) !important;
box-shadow: 0 6px 20px rgba(99, 102, 241, 0.5) !important;
}
.output-image {
border-radius: 16px;
overflow: hidden;
border: 2px solid transparent;
background: linear-gradient(var(--surface-light), var(--surface-light)) padding-box,
var(--gradient-1) border-box;
}
@media (max-width: 768px) {
.main-header h1 { font-size: 1.75rem; }
.feature-badges { flex-direction: column; align-items: center; }
}
"""
# ============================================================
# 构建 Gradio 界面
# ============================================================
def create_demo():
with gr.Blocks(
title="UniPic-3 DMD",
theme=gr.themes.Base(
primary_hue="indigo",
secondary_hue="pink",
neutral_hue="slate",
font=("Outfit", "sans-serif"),
),
css=CUSTOM_CSS
) as demo:
gr.HTML("""
<div class="main-header">
<h1>🎨 UniPic-3 DMD</h1>
<p>Multi-Image Composition with Distribution-Matching Distillation</p>
<div class="feature-badges">
<span class="badge">⚡ 8-Step Fast Inference</span>
<span class="badge">🖼️ Up to 6 Images</span>
<span class="badge">🚀 12.5× Speedup</span>
</div>
</div>
""")
with gr.Row(equal_height=True):
with gr.Column(scale=1):
gr.HTML('<div class="section-header"><span>📸</span><h3>Upload Images</h3></div>')
num_images = gr.Slider(minimum=1, maximum=6, value=2, step=1,
label="Number of Images", info="Select how many images to compose")
with gr.Row():
img1 = gr.Image(type="pil", label="Image 1", visible=True)
img2 = gr.Image(type="pil", label="Image 2", visible=True)
with gr.Row():
img3 = gr.Image(type="pil", label="Image 3", visible=False)
img4 = gr.Image(type="pil", label="Image 4", visible=False)
with gr.Row():
img5 = gr.Image(type="pil", label="Image 5", visible=False)
img6 = gr.Image(type="pil", label="Image 6", visible=False)
image_inputs = [img1, img2, img3, img4, img5, img6]
num_images.change(fn=update_image_visibility, inputs=num_images, outputs=image_inputs)
gr.HTML('<div class="section-header"><span>✍️</span><h3>Editing Instruction</h3></div>')
prompt_input = gr.Textbox(
label="Prompt",
placeholder="e.g., A man from Image1 standing on a surfboard from Image2...",
lines=3,
value="Combine the reference images to generate the final result."
)
with gr.Accordion("⚙️ Advanced Settings", open=False):
cfg_scale = gr.Slider(minimum=1.0, maximum=10.0, value=4.0, step=0.5,
label="CFG Scale", info="Higher = more prompt alignment")
with gr.Row():
seed = gr.Number(value=42, label="Seed", info="For reproducibility", precision=0)
num_steps = gr.Slider(minimum=1, maximum=8, value=8, step=1,
label="Steps", info="8 recommended for DMD")
generate_btn = gr.Button("🚀 Generate Image", variant="primary", size="lg",
elem_classes=["generate-btn"])
with gr.Column(scale=1):
gr.HTML('<div class="section-header"><span>🎨</span><h3>Generated Result</h3></div>')
output_image = gr.Image(type="pil", label="Output", elem_classes=["output-image"])
status_text = gr.Textbox(
label="Status",
value="✨ Ready! First run takes ~60s to load models.",
interactive=False,
)
gr.HTML("""
<div style="margin-top: 1.5rem; padding: 1rem; background: rgba(99, 102, 241, 0.1);
border-radius: 12px; border: 1px solid rgba(99, 102, 241, 0.2);">
<p style="color: #ffffff; font-weight: 600; margin-bottom: 0.5rem;">💡 Tips</p>
<ul style="color: #ffffff; font-size: 0.9rem; margin: 0; padding-left: 1.25rem;">
<li>Reference images as "Image1", "Image2", etc.</li>
<li>First run loads models (~60s)</li>
</ul>
</div>
""")
generate_btn.click(
fn=process_images,
inputs=[*image_inputs, prompt_input, cfg_scale, seed, num_steps],
outputs=[output_image, status_text]
)
gr.HTML('<div class="section-header" style="margin-top: 2rem;"><span>📚</span><h3>Example Prompts</h3></div>')
gr.Examples(
examples=[
["A person from Image1 wearing the outfit from Image2"],
["Combine Image1 and Image2 into a single cohesive scene"],
["The object from Image1 placed in the environment from Image2"],
],
inputs=[prompt_input],
label=""
)
return demo
demo = create_demo()
if __name__ == "__main__":
demo.launch()