OrlandoHugBot commited on
Commit
cf16bb0
·
verified ·
1 Parent(s): 87f5c9f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +141 -235
app.py CHANGED
@@ -1,11 +1,11 @@
1
  """
2
  UniPic-3 DMD Multi-Image Composition
3
- Hugging Face Space - UI Persistent + GPU On-Demand Architecture
4
 
5
- 核心优化:
6
- 1. UI 常驻 - 页面始终可用,无需等待模型加载
7
- 2. GPU on-demand - 仅在推理时调用 GPU,节省资源
8
- 3. 优化的前端界面 - 现代美观的 UI 设计
9
  """
10
 
11
  import gradio as gr
@@ -16,14 +16,17 @@ import sys
16
 
17
  # Hugging Face Spaces GPU decorator
18
  try:
19
- from spaces import GPU
20
  HF_SPACES = True
21
  except ImportError:
22
  HF_SPACES = False
23
- def GPU(duration=60):
24
- def decorator(func):
25
- return func
26
- return decorator
 
 
 
27
 
28
  # Local pipeline import
29
  sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
@@ -33,18 +36,119 @@ MODEL_NAME = os.environ.get("MODEL_NAME", "Skywork/Unipic3-DMD")
33
  TRANSFORMER_PATH = os.environ.get("TRANSFORMER_PATH", "Skywork/Unipic3-DMD/ema_transformer")
34
 
35
  # ============================================================
36
- # GPU On-Demand: Model loading happens inside @GPU decorated function
37
  # ============================================================
38
 
39
- def get_device():
40
- """Get the appropriate device"""
41
- return torch.device("cuda" if torch.cuda.is_available() else "cpu")
42
 
43
- def get_dtype():
44
- """Get the appropriate dtype"""
45
- return torch.bfloat16 if torch.cuda.is_available() else torch.float32
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
 
47
- @GPU(duration=180)
 
 
 
 
 
48
  def generate_image(
49
  images: list[Image.Image],
50
  prompt: str,
@@ -53,81 +157,16 @@ def generate_image(
53
  num_steps: int
54
  ) -> Image.Image:
55
  """
56
- GPU on-demand inference function.
57
- Model is loaded fresh each call to work with ZeroGPU.
58
  """
59
- # Import dependencies inside GPU function for ZeroGPU compatibility
60
- try:
61
- from pipeline_qwenimage_edit import QwenImageEditPipeline
62
- except ImportError:
63
- from diffusers import QwenImageEditPipeline
64
-
65
- from diffusers import (
66
- FlowMatchEulerDiscreteScheduler,
67
- QwenImageTransformer2DModel,
68
- AutoencoderKLQwenImage
69
- )
70
- from transformers import AutoModel, AutoTokenizer, Qwen2VLProcessor
71
-
72
- # ZeroGPU: 必须在 @GPU 函数内部获取设备
73
- device = torch.device("cuda:0") # 明确指定 cuda:0
74
- dtype = torch.bfloat16
75
-
76
- print(f"🚀 Loading model on {device}...")
77
- print(f" CUDA available: {torch.cuda.is_available()}")
78
- print(f" CUDA device count: {torch.cuda.device_count()}")
79
-
80
- # Load scheduler (CPU, no device needed)
81
- scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(
82
- MODEL_NAME, subfolder='scheduler'
83
- )
84
-
85
- # Load tokenizer & processor (CPU, no device needed)
86
- tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, subfolder='tokenizer')
87
- processor = Qwen2VLProcessor.from_pretrained(MODEL_NAME, subfolder='processor')
88
-
89
- # Load text encoder - 直接加载到 CUDA
90
- print(" Loading text_encoder...")
91
- text_encoder = AutoModel.from_pretrained(
92
- MODEL_NAME,
93
- subfolder='text_encoder',
94
- torch_dtype=dtype,
95
- ).to(device).eval()
96
-
97
- # Load transformer
98
- print(" Loading transformer...")
99
- transformer = load_transformer(device, dtype)
100
-
101
- # Load VAE
102
- print(" Loading VAE...")
103
- vae = AutoencoderKLQwenImage.from_pretrained(
104
- MODEL_NAME,
105
- subfolder='vae',
106
- torch_dtype=dtype,
107
- ).to(device).eval()
108
-
109
- # Create pipeline
110
- pipe = QwenImageEditPipeline(
111
- scheduler=scheduler,
112
- vae=vae,
113
- text_encoder=text_encoder,
114
- tokenizer=tokenizer,
115
- processor=processor,
116
- transformer=transformer
117
- )
118
-
119
- # 注意:不需要手动设置 _execution_device
120
- # 修复后的 pipeline_qwenimage_edit.py 会直接从 text_encoder 获取设备
121
-
122
- print(f"✅ Model loaded!")
123
- print(f" text_encoder device: {next(text_encoder.parameters()).device}")
124
- print(f" transformer device: {next(transformer.parameters()).device}")
125
- print(f" vae device: {next(vae.parameters()).device}")
126
- print(f" Generating with {len(images)} image(s)...")
127
 
128
  # Generate
129
  with torch.no_grad():
130
- generator = torch.Generator(device=device).manual_seed(int(seed))
131
 
132
  if len(images) == 1:
133
  result = pipe(
@@ -152,60 +191,12 @@ def generate_image(
152
  generator=generator
153
  ).images[0]
154
 
155
- # Cleanup to free VRAM
156
- del pipe, transformer, vae, text_encoder
157
- if torch.cuda.is_available():
158
- torch.cuda.empty_cache()
159
-
160
  return result
161
 
162
 
163
- def load_transformer(device, dtype):
164
- """Load transformer with proper path handling for ZeroGPU"""
165
- from diffusers import QwenImageTransformer2DModel
166
-
167
- if os.path.exists(TRANSFORMER_PATH):
168
- # Local path
169
- if os.path.isdir(TRANSFORMER_PATH):
170
- config_path = os.path.join(TRANSFORMER_PATH, "config.json")
171
- if os.path.exists(config_path):
172
- return QwenImageTransformer2DModel.from_pretrained(
173
- TRANSFORMER_PATH,
174
- torch_dtype=dtype,
175
- use_safetensors=False # 使用 .bin 文件
176
- ).to(device).eval()
177
- else:
178
- return QwenImageTransformer2DModel.from_pretrained(
179
- TRANSFORMER_PATH,
180
- subfolder='transformer',
181
- torch_dtype=dtype,
182
- use_safetensors=False
183
- ).to(device).eval()
184
- raise ValueError(f"Invalid transformer path: {TRANSFORMER_PATH}")
185
- else:
186
- # HuggingFace repo path
187
- path_parts = TRANSFORMER_PATH.split('/')
188
- if len(path_parts) >= 3:
189
- # 路径格式: "Skywork/Unipic3-DMD/ema_transformer"
190
- repo_id = '/'.join(path_parts[:2]) # "Skywork/Unipic3-DMD"
191
- subfolder = '/'.join(path_parts[2:]) # "ema_transformer"
192
- return QwenImageTransformer2DModel.from_pretrained(
193
- repo_id,
194
- subfolder=subfolder,
195
- torch_dtype=dtype,
196
- use_safetensors=False # 使用 .bin 文件
197
- ).to(device).eval()
198
- else:
199
- return QwenImageTransformer2DModel.from_pretrained(
200
- TRANSFORMER_PATH,
201
- subfolder='transformer',
202
- torch_dtype=dtype,
203
- use_safetensors=False
204
- ).to(device).eval()
205
-
206
-
207
  # ============================================================
208
- # UI Logic (CPU-only, always available)
209
  # ============================================================
210
 
211
  def process_images(
@@ -215,12 +206,12 @@ def process_images(
215
  seed: int,
216
  num_steps: int
217
  ):
218
- """Process images - validates input then calls GPU function"""
219
 
220
- # Filter valid images
221
  images = [img for img in [img1, img2, img3, img4, img5, img6] if img is not None]
222
 
223
- # Validation
224
  if len(images) == 0:
225
  return None, "❌ Please upload at least one image"
226
 
@@ -231,10 +222,10 @@ def process_images(
231
  return None, "❌ Please enter an editing instruction"
232
 
233
  try:
234
- # Convert to RGB
235
  images = [img.convert("RGB") for img in images]
236
 
237
- # Call GPU function
238
  result = generate_image(
239
  images=images,
240
  prompt=prompt,
@@ -252,19 +243,17 @@ def process_images(
252
 
253
 
254
  def update_image_visibility(num):
255
- """Update visibility of image upload slots"""
256
  return [gr.update(visible=(i < num)) for i in range(6)]
257
 
258
 
259
  # ============================================================
260
- # Custom CSS for Beautiful UI
261
  # ============================================================
262
 
263
  CUSTOM_CSS = """
264
- /* Import distinctive fonts */
265
  @import url('https://fonts.googleapis.com/css2?family=Outfit:wght@300;400;500;600;700&family=JetBrains+Mono:wght@400;500&display=swap');
266
 
267
- /* Root variables */
268
  :root {
269
  --primary: #6366f1;
270
  --primary-dark: #4f46e5;
@@ -278,18 +267,15 @@ CUSTOM_CSS = """
278
  --success: #10b981;
279
  --error: #ef4444;
280
  --gradient-1: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
281
- --gradient-2: linear-gradient(135deg, #f093fb 0%, #f5576c 100%);
282
  --gradient-hero: linear-gradient(135deg, #0f0f23 0%, #1a1a3e 50%, #252552 100%);
283
  }
284
 
285
- /* Global styles */
286
  .gradio-container {
287
  font-family: 'Outfit', sans-serif !important;
288
  background: var(--gradient-hero) !important;
289
  min-height: 100vh;
290
  }
291
 
292
- /* Header styling */
293
  .main-header {
294
  text-align: center;
295
  padding: 2rem 1rem;
@@ -316,7 +302,6 @@ CUSTOM_CSS = """
316
  margin: 0 auto;
317
  }
318
 
319
- /* Feature badges */
320
  .feature-badges {
321
  display: flex;
322
  gap: 1rem;
@@ -338,7 +323,6 @@ CUSTOM_CSS = """
338
  font-weight: 500;
339
  }
340
 
341
- /* Section headers */
342
  .section-header {
343
  display: flex;
344
  align-items: center;
@@ -355,22 +339,6 @@ CUSTOM_CSS = """
355
  margin: 0;
356
  }
357
 
358
- /* Card styling */
359
- .card {
360
- background: var(--surface-light) !important;
361
- border: 1px solid var(--border) !important;
362
- border-radius: 16px !important;
363
- padding: 1.5rem !important;
364
- }
365
-
366
- /* Image upload grid */
367
- .image-grid {
368
- display: grid;
369
- grid-template-columns: repeat(3, 1fr);
370
- gap: 1rem;
371
- }
372
-
373
- /* Button styling */
374
  .generate-btn {
375
  background: var(--gradient-1) !important;
376
  border: none !important;
@@ -389,28 +357,6 @@ CUSTOM_CSS = """
389
  box-shadow: 0 6px 20px rgba(99, 102, 241, 0.5) !important;
390
  }
391
 
392
- /* Input styling */
393
- .gr-textbox textarea,
394
- .gr-textbox input {
395
- background: var(--surface) !important;
396
- border: 1px solid var(--border) !important;
397
- border-radius: 12px !important;
398
- color: var(--text) !important;
399
- font-family: 'Outfit', sans-serif !important;
400
- }
401
-
402
- .gr-textbox textarea:focus,
403
- .gr-textbox input:focus {
404
- border-color: var(--primary) !important;
405
- box-shadow: 0 0 0 3px rgba(99, 102, 241, 0.2) !important;
406
- }
407
-
408
- /* Slider styling */
409
- .gr-slider input[type="range"] {
410
- accent-color: var(--primary) !important;
411
- }
412
-
413
- /* Output image */
414
  .output-image {
415
  border-radius: 16px;
416
  overflow: hidden;
@@ -419,46 +365,10 @@ CUSTOM_CSS = """
419
  var(--gradient-1) border-box;
420
  }
421
 
422
- /* Status text */
423
- .status-success {
424
- color: var(--success) !important;
425
- font-weight: 500;
426
- }
427
-
428
- .status-error {
429
- color: var(--error) !important;
430
- font-weight: 500;
431
- }
432
-
433
- /* Accordion */
434
- .gr-accordion {
435
- background: var(--surface-light) !important;
436
- border: 1px solid var(--border) !important;
437
- border-radius: 12px !important;
438
- }
439
-
440
- /* Labels */
441
- label {
442
- color: var(--text) !important;
443
- font-weight: 500 !important;
444
- }
445
-
446
- /* Tooltip / info text */
447
- .gr-info {
448
- color: var(--text-muted) !important;
449
- font-size: 0.875rem !important;
450
- }
451
-
452
- /* Responsive adjustments */
453
  @media (max-width: 768px) {
454
- .image-grid {
455
- grid-template-columns: repeat(2, 1fr);
456
- }
457
-
458
  .main-header h1 {
459
  font-size: 1.75rem;
460
  }
461
-
462
  .feature-badges {
463
  flex-direction: column;
464
  align-items: center;
@@ -466,8 +376,9 @@ label {
466
  }
467
  """
468
 
 
469
  # ============================================================
470
- # Build Gradio Interface
471
  # ============================================================
472
 
473
  def create_demo():
@@ -499,7 +410,6 @@ def create_demo():
499
  # Left Column - Inputs
500
  with gr.Column(scale=1):
501
 
502
- # Image Upload Section
503
  gr.HTML('<div class="section-header"><span>📸</span><h3>Upload Images</h3></div>')
504
 
505
  num_images = gr.Slider(
@@ -531,7 +441,6 @@ def create_demo():
531
  outputs=image_inputs
532
  )
533
 
534
- # Prompt Section
535
  gr.HTML('<div class="section-header"><span>✍️</span><h3>Editing Instruction</h3></div>')
536
 
537
  prompt_input = gr.Textbox(
@@ -541,7 +450,6 @@ def create_demo():
541
  value="Combine the reference images to generate the final result."
542
  )
543
 
544
- # Advanced Settings
545
  with gr.Accordion("⚙️ Advanced Settings", open=False):
546
  cfg_scale = gr.Slider(
547
  minimum=1.0,
@@ -568,7 +476,6 @@ def create_demo():
568
  info="8 recommended for DMD"
569
  )
570
 
571
- # Generate Button
572
  generate_btn = gr.Button(
573
  "🚀 Generate Image",
574
  variant="primary",
@@ -594,7 +501,6 @@ def create_demo():
594
  show_copy_button=False
595
  )
596
 
597
- # Tips
598
  gr.HTML("""
599
  <div style="
600
  margin-top: 1.5rem;
@@ -607,7 +513,7 @@ def create_demo():
607
  <ul style="color: #ffffff; font-size: 0.9rem; margin: 0; padding-left: 1.25rem;">
608
  <li>Reference images as "Image1", "Image2", etc. in your prompt</li>
609
  <li>Use descriptive prompts for better composition</li>
610
- <li>First run may take longer due to model loading</li>
611
  </ul>
612
  </div>
613
  """)
@@ -637,7 +543,7 @@ def create_demo():
637
 
638
 
639
  # ============================================================
640
- # Launch
641
  # ============================================================
642
 
643
  demo = create_demo()
 
1
  """
2
  UniPic-3 DMD Multi-Image Composition
3
+ Hugging Face Space - ZeroGPU 优化版本
4
 
5
+ 架构说明:
6
+ 1. 模型在全局作用域加载(ZeroGPU 会拦截 CUDA 调用)
7
+ 2. 只有实际推理时才使用 @spaces.GPU 装饰器
8
+ 3. 这样避免了每次请求都重新加载模型
9
  """
10
 
11
  import gradio as gr
 
16
 
17
  # Hugging Face Spaces GPU decorator
18
  try:
19
+ import spaces
20
  HF_SPACES = True
21
  except ImportError:
22
  HF_SPACES = False
23
+ # 本地开发时的 mock
24
+ class spaces:
25
+ @staticmethod
26
+ def GPU(duration=60):
27
+ def decorator(func):
28
+ return func
29
+ return decorator
30
 
31
  # Local pipeline import
32
  sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
 
36
  TRANSFORMER_PATH = os.environ.get("TRANSFORMER_PATH", "Skywork/Unipic3-DMD/ema_transformer")
37
 
38
  # ============================================================
39
+ # 全局加载模型(ZeroGPU 会拦截 CUDA 调用)
40
  # ============================================================
41
 
42
+ print("🚀 Loading models...")
 
 
43
 
44
+ try:
45
+ from pipeline_qwenimage_edit import QwenImageEditPipeline
46
+ except ImportError:
47
+ from diffusers import QwenImageEditPipeline
48
+
49
+ from diffusers import (
50
+ FlowMatchEulerDiscreteScheduler,
51
+ QwenImageTransformer2DModel,
52
+ AutoencoderKLQwenImage
53
+ )
54
+ from transformers import AutoModel, AutoTokenizer, Qwen2VLProcessor
55
+
56
+ # 确定 dtype
57
+ dtype = torch.bfloat16
58
+
59
+ # Load scheduler (CPU)
60
+ print(" Loading scheduler...")
61
+ scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(
62
+ MODEL_NAME, subfolder='scheduler'
63
+ )
64
+
65
+ # Load tokenizer & processor (CPU)
66
+ print(" Loading tokenizer & processor...")
67
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, subfolder='tokenizer')
68
+ processor = Qwen2VLProcessor.from_pretrained(MODEL_NAME, subfolder='processor')
69
+
70
+ # Load text encoder
71
+ print(" Loading text_encoder...")
72
+ text_encoder = AutoModel.from_pretrained(
73
+ MODEL_NAME,
74
+ subfolder='text_encoder',
75
+ torch_dtype=dtype,
76
+ ).eval()
77
+
78
+ # Load transformer
79
+ print(" Loading transformer...")
80
+ def load_transformer():
81
+ """Load transformer with proper path handling"""
82
+ if os.path.exists(TRANSFORMER_PATH):
83
+ # Local path
84
+ if os.path.isdir(TRANSFORMER_PATH):
85
+ config_path = os.path.join(TRANSFORMER_PATH, "config.json")
86
+ if os.path.exists(config_path):
87
+ return QwenImageTransformer2DModel.from_pretrained(
88
+ TRANSFORMER_PATH,
89
+ torch_dtype=dtype,
90
+ use_safetensors=False
91
+ ).eval()
92
+ else:
93
+ return QwenImageTransformer2DModel.from_pretrained(
94
+ TRANSFORMER_PATH,
95
+ subfolder='transformer',
96
+ torch_dtype=dtype,
97
+ use_safetensors=False
98
+ ).eval()
99
+ raise ValueError(f"Invalid transformer path: {TRANSFORMER_PATH}")
100
+ else:
101
+ # HuggingFace repo path
102
+ path_parts = TRANSFORMER_PATH.split('/')
103
+ if len(path_parts) >= 3:
104
+ repo_id = '/'.join(path_parts[:2])
105
+ subfolder = '/'.join(path_parts[2:])
106
+ return QwenImageTransformer2DModel.from_pretrained(
107
+ repo_id,
108
+ subfolder=subfolder,
109
+ torch_dtype=dtype,
110
+ use_safetensors=False
111
+ ).eval()
112
+ else:
113
+ return QwenImageTransformer2DModel.from_pretrained(
114
+ TRANSFORMER_PATH,
115
+ subfolder='transformer',
116
+ torch_dtype=dtype,
117
+ use_safetensors=False
118
+ ).eval()
119
+
120
+ transformer = load_transformer()
121
+
122
+ # Load VAE
123
+ print(" Loading VAE...")
124
+ vae = AutoencoderKLQwenImage.from_pretrained(
125
+ MODEL_NAME,
126
+ subfolder='vae',
127
+ torch_dtype=dtype,
128
+ ).eval()
129
+
130
+ # Create pipeline
131
+ print(" Creating pipeline...")
132
+ pipe = QwenImageEditPipeline(
133
+ scheduler=scheduler,
134
+ vae=vae,
135
+ text_encoder=text_encoder,
136
+ tokenizer=tokenizer,
137
+ processor=processor,
138
+ transformer=transformer
139
+ )
140
+
141
+ # 移动到 CUDA(ZeroGPU 会拦截这个调用)
142
+ pipe.to('cuda')
143
+
144
+ print("✅ Models loaded successfully!")
145
 
146
+
147
+ # ============================================================
148
+ # GPU 推理函数(只包含实际的推理逻辑)
149
+ # ============================================================
150
+
151
+ @spaces.GPU(duration=120)
152
  def generate_image(
153
  images: list[Image.Image],
154
  prompt: str,
 
157
  num_steps: int
158
  ) -> Image.Image:
159
  """
160
+ GPU 推理函数 - 只包含实际的推理逻辑
161
+ 模型已在全局加载,这里只执行推理
162
  """
163
+ print(f"🎨 Generating with {len(images)} image(s)...")
164
+ print(f" Prompt: {prompt[:50]}...")
165
+ print(f" Steps: {num_steps}, CFG: {true_cfg_scale}, Seed: {seed}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
166
 
167
  # Generate
168
  with torch.no_grad():
169
+ generator = torch.Generator(device='cuda').manual_seed(int(seed))
170
 
171
  if len(images) == 1:
172
  result = pipe(
 
191
  generator=generator
192
  ).images[0]
193
 
194
+ print("✅ Generation complete!")
 
 
 
 
195
  return result
196
 
197
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
198
  # ============================================================
199
+ # UI 逻辑(CPU,始终可用)
200
  # ============================================================
201
 
202
  def process_images(
 
206
  seed: int,
207
  num_steps: int
208
  ):
209
+ """处理图像 - 验证输入后调用 GPU 函数"""
210
 
211
+ # 过滤有效图像
212
  images = [img for img in [img1, img2, img3, img4, img5, img6] if img is not None]
213
 
214
+ # 验证
215
  if len(images) == 0:
216
  return None, "❌ Please upload at least one image"
217
 
 
222
  return None, "❌ Please enter an editing instruction"
223
 
224
  try:
225
+ # 转换为 RGB
226
  images = [img.convert("RGB") for img in images]
227
 
228
+ # 调用 GPU 函数
229
  result = generate_image(
230
  images=images,
231
  prompt=prompt,
 
243
 
244
 
245
  def update_image_visibility(num):
246
+ """更新图像上传槽的可见性"""
247
  return [gr.update(visible=(i < num)) for i in range(6)]
248
 
249
 
250
  # ============================================================
251
+ # 自定义 CSS
252
  # ============================================================
253
 
254
  CUSTOM_CSS = """
 
255
  @import url('https://fonts.googleapis.com/css2?family=Outfit:wght@300;400;500;600;700&family=JetBrains+Mono:wght@400;500&display=swap');
256
 
 
257
  :root {
258
  --primary: #6366f1;
259
  --primary-dark: #4f46e5;
 
267
  --success: #10b981;
268
  --error: #ef4444;
269
  --gradient-1: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
 
270
  --gradient-hero: linear-gradient(135deg, #0f0f23 0%, #1a1a3e 50%, #252552 100%);
271
  }
272
 
 
273
  .gradio-container {
274
  font-family: 'Outfit', sans-serif !important;
275
  background: var(--gradient-hero) !important;
276
  min-height: 100vh;
277
  }
278
 
 
279
  .main-header {
280
  text-align: center;
281
  padding: 2rem 1rem;
 
302
  margin: 0 auto;
303
  }
304
 
 
305
  .feature-badges {
306
  display: flex;
307
  gap: 1rem;
 
323
  font-weight: 500;
324
  }
325
 
 
326
  .section-header {
327
  display: flex;
328
  align-items: center;
 
339
  margin: 0;
340
  }
341
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
342
  .generate-btn {
343
  background: var(--gradient-1) !important;
344
  border: none !important;
 
357
  box-shadow: 0 6px 20px rgba(99, 102, 241, 0.5) !important;
358
  }
359
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
360
  .output-image {
361
  border-radius: 16px;
362
  overflow: hidden;
 
365
  var(--gradient-1) border-box;
366
  }
367
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
368
  @media (max-width: 768px) {
 
 
 
 
369
  .main-header h1 {
370
  font-size: 1.75rem;
371
  }
 
372
  .feature-badges {
373
  flex-direction: column;
374
  align-items: center;
 
376
  }
377
  """
378
 
379
+
380
  # ============================================================
381
+ # 构建 Gradio 界面
382
  # ============================================================
383
 
384
  def create_demo():
 
410
  # Left Column - Inputs
411
  with gr.Column(scale=1):
412
 
 
413
  gr.HTML('<div class="section-header"><span>📸</span><h3>Upload Images</h3></div>')
414
 
415
  num_images = gr.Slider(
 
441
  outputs=image_inputs
442
  )
443
 
 
444
  gr.HTML('<div class="section-header"><span>✍️</span><h3>Editing Instruction</h3></div>')
445
 
446
  prompt_input = gr.Textbox(
 
450
  value="Combine the reference images to generate the final result."
451
  )
452
 
 
453
  with gr.Accordion("⚙️ Advanced Settings", open=False):
454
  cfg_scale = gr.Slider(
455
  minimum=1.0,
 
476
  info="8 recommended for DMD"
477
  )
478
 
 
479
  generate_btn = gr.Button(
480
  "🚀 Generate Image",
481
  variant="primary",
 
501
  show_copy_button=False
502
  )
503
 
 
504
  gr.HTML("""
505
  <div style="
506
  margin-top: 1.5rem;
 
513
  <ul style="color: #ffffff; font-size: 0.9rem; margin: 0; padding-left: 1.25rem;">
514
  <li>Reference images as "Image1", "Image2", etc. in your prompt</li>
515
  <li>Use descriptive prompts for better composition</li>
516
+ <li>First run may take longer due to model warm-up</li>
517
  </ul>
518
  </div>
519
  """)
 
543
 
544
 
545
  # ============================================================
546
+ # 启动
547
  # ============================================================
548
 
549
  demo = create_demo()