OrlandoHugBot commited on
Commit
9ac11e6
·
verified ·
1 Parent(s): ed53052

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +127 -197
app.py CHANGED
@@ -1,9 +1,11 @@
1
  """
2
  UniPic-3 DMD Multi-Image Composition
3
- Hugging Face Space - ZeroGPU 优化版本 V3
4
 
5
- 关键修复:完全在 @spaces.GPU 内部加载模型
6
- 参考 Qwen 官方的 app.py 实现方式
 
 
7
  """
8
 
9
  import gradio as gr
@@ -20,7 +22,6 @@ try:
20
  except ImportError:
21
  HF_SPACES = False
22
  print("⚠️ Running locally (no ZeroGPU)")
23
- # 本地开发时的 mock
24
  class spaces:
25
  @staticmethod
26
  def GPU(duration=60):
@@ -35,123 +36,47 @@ sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
35
  MODEL_NAME = os.environ.get("MODEL_NAME", "/data_genie/genie/chris/Unipic3-DMD")
36
  TRANSFORMER_PATH = os.environ.get("TRANSFORMER_PATH", "/data_genie/genie/chris/Unipic3-DMD/ema_transformer")
37
 
 
 
38
  # ============================================================
39
- # 全局变量
40
  # ============================================================
41
- pipe = None
42
- dtype = torch.bfloat16
43
 
 
44
 
45
- def load_pipeline():
46
- """
47
- 加载完整的 Pipeline
48
- 这个函数应该在 @spaces.GPU 装饰的函数内部调用
49
- """
50
- global pipe
51
-
52
- if pipe is not None:
53
- return pipe
54
-
55
- print("🚀 Loading pipeline...")
56
-
57
- try:
58
- from pipeline_qwenimage_edit import QwenImageEditPipeline
59
- except ImportError:
60
- from diffusers import QwenImageEditPipeline
61
-
62
- from diffusers import (
63
- FlowMatchEulerDiscreteScheduler,
64
- QwenImageTransformer2DModel,
65
- AutoencoderKLQwenImage
66
- )
67
- from transformers import AutoModel, AutoTokenizer, Qwen2VLProcessor
68
-
69
- device = 'cuda'
70
-
71
- # Load scheduler
72
- print(" Loading scheduler...")
73
- scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(
74
- MODEL_NAME, subfolder='scheduler'
75
- )
76
-
77
- # Load tokenizer & processor
78
- print(" Loading tokenizer & processor...")
79
- tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, subfolder='tokenizer')
80
- processor = Qwen2VLProcessor.from_pretrained(MODEL_NAME, subfolder='processor')
81
-
82
- # Load text encoder - 直接加载到 GPU
83
- print(" Loading text_encoder...")
84
- text_encoder = AutoModel.from_pretrained(
85
- MODEL_NAME,
86
- subfolder='text_encoder',
87
- torch_dtype=dtype,
88
- ).to(device).eval()
89
-
90
- # Load transformer - 直接加载到 GPU
91
- print(" Loading transformer...")
92
- if os.path.exists(TRANSFORMER_PATH):
93
- if os.path.isdir(TRANSFORMER_PATH):
94
- config_path = os.path.join(TRANSFORMER_PATH, "config.json")
95
- if os.path.exists(config_path):
96
- transformer = QwenImageTransformer2DModel.from_pretrained(
97
- TRANSFORMER_PATH,
98
- torch_dtype=dtype,
99
- use_safetensors=False
100
- ).to(device).eval()
101
- else:
102
- transformer = QwenImageTransformer2DModel.from_pretrained(
103
- TRANSFORMER_PATH,
104
- subfolder='transformer',
105
- torch_dtype=dtype,
106
- use_safetensors=False
107
- ).to(device).eval()
108
- else:
109
- path_parts = TRANSFORMER_PATH.split('/')
110
- if len(path_parts) >= 3:
111
- repo_id = '/'.join(path_parts[:2])
112
- subfolder = '/'.join(path_parts[2:])
113
- transformer = QwenImageTransformer2DModel.from_pretrained(
114
- repo_id,
115
- subfolder=subfolder,
116
- torch_dtype=dtype,
117
- use_safetensors=False
118
- ).to(device).eval()
119
- else:
120
- transformer = QwenImageTransformer2DModel.from_pretrained(
121
- TRANSFORMER_PATH,
122
- subfolder='transformer',
123
- torch_dtype=dtype,
124
- use_safetensors=False
125
- ).to(device).eval()
126
-
127
- # Load VAE - 直接加载到 GPU
128
- print(" Loading VAE...")
129
- vae = AutoencoderKLQwenImage.from_pretrained(
130
- MODEL_NAME,
131
- subfolder='vae',
132
- torch_dtype=dtype,
133
- ).to(device).eval()
134
-
135
- # Create Pipeline
136
- print(" Creating pipeline...")
137
- pipe = QwenImageEditPipeline(
138
- scheduler=scheduler,
139
- vae=vae,
140
- text_encoder=text_encoder,
141
- tokenizer=tokenizer,
142
- processor=processor,
143
- transformer=transformer
144
- )
145
-
146
- print("✅ Pipeline loaded successfully!")
147
- return pipe
148
 
149
 
150
  # ============================================================
151
  # GPU 推理函数 - 模型在这里加载
152
  # ============================================================
153
 
154
- @spaces.GPU(duration=180) # 增加时间以包含首次加载
155
  def generate_image(
156
  images: list[Image.Image],
157
  prompt: str,
@@ -161,17 +86,85 @@ def generate_image(
161
  ) -> Image.Image:
162
  """
163
  GPU 推理函数
164
- 关键:Pipeline 完全在这里加载,确保在真实 GPU 环境中初始化
165
  """
166
- global pipe
167
 
168
  print(f"🎨 Generating with {len(images)} image(s)...")
169
  print(f" Prompt: {prompt[:50]}...")
170
  print(f" Steps: {num_steps}, CFG: {true_cfg_scale}, Seed: {seed}")
171
 
172
  # 在真实 GPU 环境中加载模型(首次调用时)
173
- if pipe is None:
174
- load_pipeline()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
175
 
176
  # 验证设备
177
  print(f" [DEBUG] text_encoder device: {next(pipe.text_encoder.parameters()).device}")
@@ -222,10 +215,8 @@ def process_images(
222
  ):
223
  """处理图像 - 验证输入后调用 GPU 函数"""
224
 
225
- # 过滤有效图像
226
  images = [img for img in [img1, img2, img3, img4, img5, img6] if img is not None]
227
 
228
- # 验证
229
  if len(images) == 0:
230
  return None, "❌ Please upload at least one image"
231
 
@@ -236,10 +227,8 @@ def process_images(
236
  return None, "❌ Please enter an editing instruction"
237
 
238
  try:
239
- # 转换为 RGB
240
  images = [img.convert("RGB") for img in images]
241
 
242
- # 调用 GPU 函数
243
  result = generate_image(
244
  images=images,
245
  prompt=prompt,
@@ -257,7 +246,6 @@ def process_images(
257
 
258
 
259
  def update_image_visibility(num):
260
- """更新图像上传槽的可见性"""
261
  return [gr.update(visible=(i < num)) for i in range(6)]
262
 
263
 
@@ -367,13 +355,8 @@ CUSTOM_CSS = """
367
  var(--gradient-1) border-box;
368
  }
369
  @media (max-width: 768px) {
370
- .main-header h1 {
371
- font-size: 1.75rem;
372
- }
373
- .feature-badges {
374
- flex-direction: column;
375
- align-items: center;
376
- }
377
  }
378
  """
379
 
@@ -394,7 +377,6 @@ def create_demo():
394
  css=CUSTOM_CSS
395
  ) as demo:
396
 
397
- # Header
398
  gr.HTML("""
399
  <div class="main-header">
400
  <h1>🎨 UniPic-3 DMD</h1>
@@ -408,19 +390,11 @@ def create_demo():
408
  """)
409
 
410
  with gr.Row(equal_height=True):
411
- # Left Column - Inputs
412
  with gr.Column(scale=1):
413
-
414
  gr.HTML('<div class="section-header"><span>📸</span><h3>Upload Images</h3></div>')
415
 
416
- num_images = gr.Slider(
417
- minimum=1,
418
- maximum=6,
419
- value=2,
420
- step=1,
421
- label="Number of Images",
422
- info="Select how many images to compose"
423
- )
424
 
425
  with gr.Row():
426
  img1 = gr.Image(type="pil", label="Image 1", visible=True)
@@ -435,96 +409,57 @@ def create_demo():
435
  img6 = gr.Image(type="pil", label="Image 6", visible=False)
436
 
437
  image_inputs = [img1, img2, img3, img4, img5, img6]
438
-
439
- num_images.change(
440
- fn=update_image_visibility,
441
- inputs=num_images,
442
- outputs=image_inputs
443
- )
444
 
445
  gr.HTML('<div class="section-header"><span>✍️</span><h3>Editing Instruction</h3></div>')
446
 
447
  prompt_input = gr.Textbox(
448
  label="Prompt",
449
- placeholder="e.g., A man from Image1 standing on a surfboard from Image2, riding ocean waves under a bright blue sky.",
450
  lines=3,
451
  value="Combine the reference images to generate the final result."
452
  )
453
 
454
  with gr.Accordion("⚙️ Advanced Settings", open=False):
455
- cfg_scale = gr.Slider(
456
- minimum=1.0,
457
- maximum=10.0,
458
- value=4.0,
459
- step=0.5,
460
- label="CFG Scale",
461
- info="Higher = more prompt alignment"
462
- )
463
 
464
  with gr.Row():
465
- seed = gr.Number(
466
- value=42,
467
- label="Seed",
468
- info="For reproducibility",
469
- precision=0
470
- )
471
- num_steps = gr.Slider(
472
- minimum=1,
473
- maximum=8,
474
- value=8,
475
- step=1,
476
- label="Steps",
477
- info="8 recommended for DMD"
478
- )
479
 
480
- generate_btn = gr.Button(
481
- "🚀 Generate Image",
482
- variant="primary",
483
- size="lg",
484
- elem_classes=["generate-btn"]
485
- )
486
 
487
- # Right Column - Output
488
  with gr.Column(scale=1):
489
  gr.HTML('<div class="section-header"><span>🎨</span><h3>Generated Result</h3></div>')
490
 
491
- output_image = gr.Image(
492
- type="pil",
493
- label="Output",
494
- elem_classes=["output-image"],
495
- )
496
 
497
  status_text = gr.Textbox(
498
  label="Status",
499
- value="✨ Ready! Upload images and click Generate. First run will take longer to load the model.",
500
  interactive=False,
501
  )
502
 
503
  gr.HTML("""
504
- <div style="
505
- margin-top: 1.5rem;
506
- padding: 1rem;
507
- background: rgba(99, 102, 241, 0.1);
508
- border-radius: 12px;
509
- border: 1px solid rgba(99, 102, 241, 0.2);
510
- ">
511
  <p style="color: #ffffff; font-weight: 600; margin-bottom: 0.5rem;">💡 Tips</p>
512
  <ul style="color: #ffffff; font-size: 0.9rem; margin: 0; padding-left: 1.25rem;">
513
- <li>Reference images as "Image1", "Image2", etc. in your prompt</li>
514
- <li>Use descriptive prompts for better composition</li>
515
- <li>First run will take ~60s to load the model</li>
516
  </ul>
517
  </div>
518
  """)
519
 
520
- # Connect generate button
521
  generate_btn.click(
522
  fn=process_images,
523
  inputs=[*image_inputs, prompt_input, cfg_scale, seed, num_steps],
524
  outputs=[output_image, status_text]
525
  )
526
 
527
- # Examples
528
  gr.HTML('<div class="section-header" style="margin-top: 2rem;"><span>📚</span><h3>Example Prompts</h3></div>')
529
 
530
  gr.Examples(
@@ -532,7 +467,6 @@ def create_demo():
532
  ["A person from Image1 wearing the outfit from Image2"],
533
  ["Combine Image1 and Image2 into a single cohesive scene"],
534
  ["The object from Image1 placed in the environment from Image2"],
535
- ["Create a portrait using the face from Image1 and hairstyle from Image2"],
536
  ],
537
  inputs=[prompt_input],
538
  label=""
@@ -541,10 +475,6 @@ def create_demo():
541
  return demo
542
 
543
 
544
- # ============================================================
545
- # 启动
546
- # ============================================================
547
-
548
  demo = create_demo()
549
 
550
  if __name__ == "__main__":
 
1
  """
2
  UniPic-3 DMD Multi-Image Composition
3
+ Hugging Face Space - ZeroGPU 优化版本 V5
4
 
5
+ 关键策略:
6
+ 1. 全局只加载不需要 GPU 的组件(scheduler, tokenizer, processor)
7
+ 2. 需要 GPU 的模型在 @spaces.GPU 内部加载,显式指定 device='cuda'
8
+ 3. 不使用 device_map='auto',因为它可能在 ZeroGPU 外部被错误地分配
9
  """
10
 
11
  import gradio as gr
 
22
  except ImportError:
23
  HF_SPACES = False
24
  print("⚠️ Running locally (no ZeroGPU)")
 
25
  class spaces:
26
  @staticmethod
27
  def GPU(duration=60):
 
36
  MODEL_NAME = os.environ.get("MODEL_NAME", "/data_genie/genie/chris/Unipic3-DMD")
37
  TRANSFORMER_PATH = os.environ.get("TRANSFORMER_PATH", "/data_genie/genie/chris/Unipic3-DMD/ema_transformer")
38
 
39
+ dtype = torch.bfloat16
40
+
41
  # ============================================================
42
+ # 全局加载轻量级组件(不需要 GPU)
43
  # ============================================================
 
 
44
 
45
+ print("🚀 Loading lightweight components (CPU)...")
46
 
47
+ from diffusers import (
48
+ FlowMatchEulerDiscreteScheduler,
49
+ QwenImageTransformer2DModel,
50
+ AutoencoderKLQwenImage
51
+ )
52
+ from transformers import AutoModel, AutoTokenizer, Qwen2VLProcessor
53
+
54
+ try:
55
+ from pipeline_qwenimage_edit import QwenImageEditPipeline
56
+ except ImportError:
57
+ from diffusers import QwenImageEditPipeline
58
+
59
+ # 这些组件不需要 GPU,可以在全局加载
60
+ scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(
61
+ MODEL_NAME, subfolder='scheduler'
62
+ )
63
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, subfolder='tokenizer')
64
+ processor = Qwen2VLProcessor.from_pretrained(MODEL_NAME, subfolder='processor')
65
+
66
+ print("✅ Lightweight components loaded!")
67
+
68
+ # ============================================================
69
+ # Pipeline 状态
70
+ # ============================================================
71
+ pipe = None
72
+ _models_loaded = False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73
 
74
 
75
  # ============================================================
76
  # GPU 推理函数 - 模型在这里加载
77
  # ============================================================
78
 
79
+ @spaces.GPU(duration=180)
80
  def generate_image(
81
  images: list[Image.Image],
82
  prompt: str,
 
86
  ) -> Image.Image:
87
  """
88
  GPU 推理函数
89
+ 关键:所有需要 GPU 的模型都在这里加载,确保在真实 GPU 环境中
90
  """
91
+ global pipe, _models_loaded
92
 
93
  print(f"🎨 Generating with {len(images)} image(s)...")
94
  print(f" Prompt: {prompt[:50]}...")
95
  print(f" Steps: {num_steps}, CFG: {true_cfg_scale}, Seed: {seed}")
96
 
97
  # 在真实 GPU 环境中加载模型(首次调用时)
98
+ if not _models_loaded:
99
+ print(" [INIT] Loading models on real GPU...")
100
+
101
+ device = 'cuda'
102
+
103
+ # 加载 text_encoder 到 GPU
104
+ print(" [INIT] Loading text_encoder...")
105
+ text_encoder = AutoModel.from_pretrained(
106
+ MODEL_NAME,
107
+ subfolder='text_encoder',
108
+ torch_dtype=dtype,
109
+ ).to(device).eval()
110
+
111
+ # 加载 transformer 到 GPU
112
+ print(" [INIT] Loading transformer...")
113
+ if os.path.exists(TRANSFORMER_PATH) and os.path.isdir(TRANSFORMER_PATH):
114
+ config_path = os.path.join(TRANSFORMER_PATH, "config.json")
115
+ if os.path.exists(config_path):
116
+ transformer = QwenImageTransformer2DModel.from_pretrained(
117
+ TRANSFORMER_PATH,
118
+ torch_dtype=dtype,
119
+ use_safetensors=False
120
+ ).to(device).eval()
121
+ else:
122
+ transformer = QwenImageTransformer2DModel.from_pretrained(
123
+ TRANSFORMER_PATH,
124
+ subfolder='transformer',
125
+ torch_dtype=dtype,
126
+ use_safetensors=False
127
+ ).to(device).eval()
128
+ else:
129
+ path_parts = TRANSFORMER_PATH.split('/')
130
+ if len(path_parts) >= 3:
131
+ repo_id = '/'.join(path_parts[:2])
132
+ subfolder = '/'.join(path_parts[2:])
133
+ transformer = QwenImageTransformer2DModel.from_pretrained(
134
+ repo_id,
135
+ subfolder=subfolder,
136
+ torch_dtype=dtype,
137
+ use_safetensors=False
138
+ ).to(device).eval()
139
+ else:
140
+ transformer = QwenImageTransformer2DModel.from_pretrained(
141
+ TRANSFORMER_PATH,
142
+ subfolder='transformer',
143
+ torch_dtype=dtype,
144
+ use_safetensors=False
145
+ ).to(device).eval()
146
+
147
+ # 加载 VAE 到 GPU
148
+ print(" [INIT] Loading VAE...")
149
+ vae = AutoencoderKLQwenImage.from_pretrained(
150
+ MODEL_NAME,
151
+ subfolder='vae',
152
+ torch_dtype=dtype,
153
+ ).to(device).eval()
154
+
155
+ # 创建 Pipeline
156
+ print(" [INIT] Creating pipeline...")
157
+ pipe = QwenImageEditPipeline(
158
+ scheduler=scheduler,
159
+ vae=vae,
160
+ text_encoder=text_encoder,
161
+ tokenizer=tokenizer,
162
+ processor=processor,
163
+ transformer=transformer
164
+ )
165
+
166
+ _models_loaded = True
167
+ print(" [INIT] ✅ Models loaded successfully!")
168
 
169
  # 验证设备
170
  print(f" [DEBUG] text_encoder device: {next(pipe.text_encoder.parameters()).device}")
 
215
  ):
216
  """处理图像 - 验证输入后调用 GPU 函数"""
217
 
 
218
  images = [img for img in [img1, img2, img3, img4, img5, img6] if img is not None]
219
 
 
220
  if len(images) == 0:
221
  return None, "❌ Please upload at least one image"
222
 
 
227
  return None, "❌ Please enter an editing instruction"
228
 
229
  try:
 
230
  images = [img.convert("RGB") for img in images]
231
 
 
232
  result = generate_image(
233
  images=images,
234
  prompt=prompt,
 
246
 
247
 
248
  def update_image_visibility(num):
 
249
  return [gr.update(visible=(i < num)) for i in range(6)]
250
 
251
 
 
355
  var(--gradient-1) border-box;
356
  }
357
  @media (max-width: 768px) {
358
+ .main-header h1 { font-size: 1.75rem; }
359
+ .feature-badges { flex-direction: column; align-items: center; }
 
 
 
 
 
360
  }
361
  """
362
 
 
377
  css=CUSTOM_CSS
378
  ) as demo:
379
 
 
380
  gr.HTML("""
381
  <div class="main-header">
382
  <h1>🎨 UniPic-3 DMD</h1>
 
390
  """)
391
 
392
  with gr.Row(equal_height=True):
 
393
  with gr.Column(scale=1):
 
394
  gr.HTML('<div class="section-header"><span>📸</span><h3>Upload Images</h3></div>')
395
 
396
+ num_images = gr.Slider(minimum=1, maximum=6, value=2, step=1,
397
+ label="Number of Images", info="Select how many images to compose")
 
 
 
 
 
 
398
 
399
  with gr.Row():
400
  img1 = gr.Image(type="pil", label="Image 1", visible=True)
 
409
  img6 = gr.Image(type="pil", label="Image 6", visible=False)
410
 
411
  image_inputs = [img1, img2, img3, img4, img5, img6]
412
+ num_images.change(fn=update_image_visibility, inputs=num_images, outputs=image_inputs)
 
 
 
 
 
413
 
414
  gr.HTML('<div class="section-header"><span>✍️</span><h3>Editing Instruction</h3></div>')
415
 
416
  prompt_input = gr.Textbox(
417
  label="Prompt",
418
+ placeholder="e.g., A man from Image1 standing on a surfboard from Image2...",
419
  lines=3,
420
  value="Combine the reference images to generate the final result."
421
  )
422
 
423
  with gr.Accordion("⚙️ Advanced Settings", open=False):
424
+ cfg_scale = gr.Slider(minimum=1.0, maximum=10.0, value=4.0, step=0.5,
425
+ label="CFG Scale", info="Higher = more prompt alignment")
 
 
 
 
 
 
426
 
427
  with gr.Row():
428
+ seed = gr.Number(value=42, label="Seed", info="For reproducibility", precision=0)
429
+ num_steps = gr.Slider(minimum=1, maximum=8, value=8, step=1,
430
+ label="Steps", info="8 recommended for DMD")
 
 
 
 
 
 
 
 
 
 
 
431
 
432
+ generate_btn = gr.Button("🚀 Generate Image", variant="primary", size="lg",
433
+ elem_classes=["generate-btn"])
 
 
 
 
434
 
 
435
  with gr.Column(scale=1):
436
  gr.HTML('<div class="section-header"><span>🎨</span><h3>Generated Result</h3></div>')
437
 
438
+ output_image = gr.Image(type="pil", label="Output", elem_classes=["output-image"])
 
 
 
 
439
 
440
  status_text = gr.Textbox(
441
  label="Status",
442
+ value="✨ Ready! First run takes ~60s to load models.",
443
  interactive=False,
444
  )
445
 
446
  gr.HTML("""
447
+ <div style="margin-top: 1.5rem; padding: 1rem; background: rgba(99, 102, 241, 0.1);
448
+ border-radius: 12px; border: 1px solid rgba(99, 102, 241, 0.2);">
 
 
 
 
 
449
  <p style="color: #ffffff; font-weight: 600; margin-bottom: 0.5rem;">💡 Tips</p>
450
  <ul style="color: #ffffff; font-size: 0.9rem; margin: 0; padding-left: 1.25rem;">
451
+ <li>Reference images as "Image1", "Image2", etc.</li>
452
+ <li>First run loads models (~60s)</li>
 
453
  </ul>
454
  </div>
455
  """)
456
 
 
457
  generate_btn.click(
458
  fn=process_images,
459
  inputs=[*image_inputs, prompt_input, cfg_scale, seed, num_steps],
460
  outputs=[output_image, status_text]
461
  )
462
 
 
463
  gr.HTML('<div class="section-header" style="margin-top: 2rem;"><span>📚</span><h3>Example Prompts</h3></div>')
464
 
465
  gr.Examples(
 
467
  ["A person from Image1 wearing the outfit from Image2"],
468
  ["Combine Image1 and Image2 into a single cohesive scene"],
469
  ["The object from Image1 placed in the environment from Image2"],
 
470
  ],
471
  inputs=[prompt_input],
472
  label=""
 
475
  return demo
476
 
477
 
 
 
 
 
478
  demo = create_demo()
479
 
480
  if __name__ == "__main__":