OrlandoHugBot commited on
Commit
2a111d5
ยท
verified ยท
1 Parent(s): fd25dcd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +590 -219
app.py CHANGED
@@ -1,258 +1,629 @@
1
  """
2
- UniPic-3 DMD โ€“ ZeroGPU Final Architecture
3
- UI Always-On (CPU) + GPU On-Demand Inference
4
- with GPU Queue Status Indicator
 
 
 
 
5
  """
6
 
7
- import os
8
- import sys
9
- import time
10
- import threading
11
- import torch
12
  import gradio as gr
 
13
  from PIL import Image
14
- from huggingface_hub import snapshot_download
15
- from spaces import GPU
16
-
17
- # =============================================================================
18
- # Paths & Globals
19
- # =============================================================================
20
-
21
- MODEL_ID = "Skywork/Unipic3-DMD"
22
- CACHE_ROOT = "./hf_cache"
23
- LOCAL_MODEL_DIR = os.path.join(CACHE_ROOT, MODEL_ID)
24
-
25
- pipe = None
26
- model_lock = threading.Lock()
27
-
28
- # GPU state (for UI display)
29
- GPU_STATE = {
30
- "status": "idle", # idle | waiting | loading | ready | running | error
31
- "message": "UI ready. GPU not requested yet."
32
- }
33
-
34
- # =============================================================================
35
- # CPU Stage: Pre-cache weights (NO GPU)
36
- # =============================================================================
37
-
38
- def precache_weights():
39
- if os.path.exists(LOCAL_MODEL_DIR):
40
- print("โœ… Weights already cached")
41
- return
42
-
43
- print("๐Ÿ“ฆ Pre-caching UniPic-3 DMD weights (CPU stage)...")
44
-
45
- snapshot_download(
46
- repo_id=MODEL_ID,
47
- local_dir=LOCAL_MODEL_DIR,
48
- local_dir_use_symlinks=False,
49
- resume_download=True,
50
- allow_patterns=[
51
- "scheduler/*",
52
- "text_encoder/*",
53
- "tokenizer/*",
54
- "processor/*",
55
- "vae/*",
56
- "ema_transformer/*",
57
- ],
58
- )
59
-
60
- print("โœ… Pre-cache complete")
61
-
62
- # =============================================================================
63
- # Local imports AFTER cache
64
- # =============================================================================
65
-
66
- sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
67
-
68
- from diffusers import (
69
- FlowMatchEulerDiscreteScheduler,
70
- QwenImageTransformer2DModel,
71
- AutoencoderKLQwenImage,
72
- )
73
- from transformers import AutoModel, AutoTokenizer, Qwen2VLProcessor
74
 
 
75
  try:
76
- from pipeline_qwenimage_edit import QwenImageEditPipeline
 
77
  except ImportError:
78
- from diffusers import QwenImageEditPipeline
 
 
 
 
79
 
80
- # =============================================================================
81
- # GPU Stage: Model loader (NO network)
82
- # =============================================================================
83
-
84
- def load_model_on_gpu():
85
- global pipe
86
-
87
- with model_lock:
88
- if pipe is not None:
89
- return
90
-
91
- GPU_STATE["status"] = "loading"
92
- GPU_STATE["message"] = "Loading model onto GPU..."
93
-
94
- device = torch.device("cuda")
95
- dtype = torch.bfloat16
96
-
97
- scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(
98
- LOCAL_MODEL_DIR, subfolder="scheduler"
99
- )
100
-
101
- text_encoder = AutoModel.from_pretrained(
102
- LOCAL_MODEL_DIR,
103
- subfolder="text_encoder",
104
- torch_dtype=dtype,
105
- ).to(device)
106
-
107
- tokenizer = AutoTokenizer.from_pretrained(
108
- LOCAL_MODEL_DIR, subfolder="tokenizer"
109
- )
110
 
111
- processor = Qwen2VLProcessor.from_pretrained(
112
- LOCAL_MODEL_DIR, subfolder="processor"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
113
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
114
 
115
- transformer = QwenImageTransformer2DModel.from_pretrained(
116
- LOCAL_MODEL_DIR,
117
- subfolder="ema_transformer",
118
- torch_dtype=dtype,
119
- ).to(device)
120
-
121
- vae = AutoencoderKLQwenImage.from_pretrained(
122
- LOCAL_MODEL_DIR,
123
- subfolder="vae",
124
- torch_dtype=dtype,
125
- ).to(device)
126
-
127
- pipe = QwenImageEditPipeline(
128
- scheduler=scheduler,
129
- vae=vae,
130
- text_encoder=text_encoder,
131
- tokenizer=tokenizer,
132
- processor=processor,
133
- transformer=transformer,
134
- )
135
 
136
- pipe.to(device)
 
 
 
 
 
 
 
 
137
 
138
- GPU_STATE["status"] = "ready"
139
- GPU_STATE["message"] = "GPU ready. Model loaded."
 
 
 
 
 
 
 
140
 
141
- # =============================================================================
142
- # GPU On-Demand Inference (THIS is the only @GPU function)
143
- # =============================================================================
 
 
 
144
 
145
- @GPU
146
- def run_inference(
147
- img1, img2, img3, img4, img5, img6,
148
- prompt, cfg, seed, steps
149
- ):
150
- global pipe
 
 
151
 
152
- try:
153
- GPU_STATE["status"] = "waiting"
154
- GPU_STATE["message"] = "Waiting for GPU..."
 
 
 
 
 
 
 
 
 
155
 
156
- # ZeroGPU will block here until GPU is assigned
157
- if not torch.cuda.is_available():
158
- return None, "โณ Waiting for GPU, please retry."
 
 
 
 
 
 
159
 
160
- load_model_on_gpu()
 
 
 
 
 
161
 
162
- GPU_STATE["status"] = "running"
163
- GPU_STATE["message"] = "Running inference..."
 
 
 
 
 
164
 
165
- images = [i for i in [img1, img2, img3, img4, img5, img6] if i is not None]
166
- if not images:
167
- return None, "โŒ Please upload at least one image."
 
 
 
168
 
169
- images = [img.convert("RGB") for img in images]
 
 
 
 
 
 
 
 
 
 
 
 
170
 
171
- generator = torch.Generator(device="cuda").manual_seed(int(seed))
172
-
173
- with torch.no_grad():
174
- if len(images) == 1:
175
- out = pipe(
176
- images[0],
177
- prompt=prompt,
178
- height=768,
179
- width=768,
180
- num_inference_steps=steps,
181
- true_cfg_scale=cfg,
182
- generator=generator,
183
- ).images[0]
184
- else:
185
- out = pipe(
186
- images=images,
187
- prompt=prompt,
188
- height=768,
189
- width=768,
190
- num_inference_steps=steps,
191
- true_cfg_scale=cfg,
192
- generator=generator,
193
- ).images[0]
194
 
195
- GPU_STATE["status"] = "ready"
196
- GPU_STATE["message"] = "Inference complete."
 
 
 
 
 
 
 
197
 
198
- return out, "โœ… Done"
 
 
 
 
199
 
200
- except Exception as e:
201
- GPU_STATE["status"] = "error"
202
- GPU_STATE["message"] = str(e)
203
- return None, f"โŒ Error: {e}"
204
 
205
- # =============================================================================
206
- # UI Helpers (CPU)
207
- # =============================================================================
 
 
 
 
 
208
 
209
- def get_gpu_status():
210
- return f"**GPU Status:** `{GPU_STATE['status']}`\n\n{GPU_STATE['message']}"
 
 
 
211
 
212
- # =============================================================================
213
- # UI (ALWAYS CPU, ALWAYS ON)
214
- # =============================================================================
 
215
 
216
- with gr.Blocks(title="UniPic-3 DMD (ZeroGPU)") as demo:
217
- gr.Markdown("# ๐Ÿ”ฅ UniPic-3 DMD โ€“ ZeroGPU Demo")
 
 
 
 
218
 
219
- status_box = gr.Markdown(get_gpu_status())
 
 
 
 
220
 
221
- with gr.Row():
222
- with gr.Column():
223
- imgs = [gr.Image(type="pil", label=f"Image {i+1}") for i in range(6)]
224
- prompt = gr.Textbox(label="Prompt", value="Combine the reference images.")
225
- cfg = gr.Slider(1, 8, value=4, label="CFG")
226
- seed = gr.Number(42, precision=0, label="Seed")
227
- steps = gr.Slider(1, 6, value=6, label="Steps")
228
 
229
- btn = gr.Button("๐Ÿš€ Generate")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
230
 
231
- with gr.Column():
232
- out = gr.Image(label="Output")
233
- msg = gr.Textbox(label="Result")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
234
 
235
- btn.click(
236
- run_inference,
237
- inputs=[*imgs, prompt, cfg, seed, steps],
238
- outputs=[out, msg],
239
- )
240
 
241
- # Periodic status refresh (CPU only)
242
- demo.load(
243
- fn=get_gpu_status,
244
- inputs=[],
245
- outputs=status_box,
246
- every=1.0,
247
- )
248
 
249
- # =============================================================================
250
- # Entry
251
- # =============================================================================
252
 
253
  if __name__ == "__main__":
254
- # CPU phase: cache weights
255
- precache_weights()
256
-
257
- # UI always-on
258
- demo.launch(ssr_mode=False)
 
1
  """
2
+ UniPic-3 DMD Multi-Image Composition
3
+ Hugging Face Space - UI Persistent + GPU On-Demand Architecture
4
+
5
+ ๆ ธๅฟƒไผ˜ๅŒ–๏ผš
6
+ 1. UI ๅธธ้ฉป - ้กต้ขๅง‹็ปˆๅฏ็”จ๏ผŒๆ— ้œ€็ญ‰ๅพ…ๆจกๅž‹ๅŠ ่ฝฝ
7
+ 2. GPU on-demand - ไป…ๅœจๆŽจ็†ๆ—ถ่ฐƒ็”จ GPU๏ผŒ่Š‚็œ่ต„ๆบ
8
+ 3. ไผ˜ๅŒ–็š„ๅ‰็ซฏ็•Œ้ข - ็Žฐไปฃ็พŽ่ง‚็š„ UI ่ฎพ่ฎก
9
  """
10
 
 
 
 
 
 
11
  import gradio as gr
12
+ import torch
13
  from PIL import Image
14
+ import os
15
+ import sys
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
 
17
+ # Hugging Face Spaces GPU decorator
18
  try:
19
+ from spaces import GPU
20
+ HF_SPACES = True
21
  except ImportError:
22
+ HF_SPACES = False
23
+ def GPU(duration=60):
24
+ def decorator(func):
25
+ return func
26
+ return decorator
27
 
28
+ # Local pipeline import
29
+ sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
 
31
+ # Model configuration
32
+ MODEL_NAME = os.environ.get("MODEL_NAME", "Skywork/Unipic3-DMD")
33
+ TRANSFORMER_PATH = os.environ.get("TRANSFORMER_PATH", "Skywork/Unipic3-DMD/ema_transformer")
34
+
35
+ # ============================================================
36
+ # GPU On-Demand: Model loading happens inside @GPU decorated function
37
+ # ============================================================
38
+
39
+ def get_device():
40
+ """Get the appropriate device"""
41
+ return torch.device("cuda" if torch.cuda.is_available() else "cpu")
42
+
43
+ def get_dtype():
44
+ """Get the appropriate dtype"""
45
+ return torch.bfloat16 if torch.cuda.is_available() else torch.float32
46
+
47
+ @GPU(duration=120)
48
+ def generate_image(
49
+ images: list[Image.Image],
50
+ prompt: str,
51
+ true_cfg_scale: float,
52
+ seed: int,
53
+ num_steps: int
54
+ ) -> Image.Image:
55
+ """
56
+ GPU on-demand inference function.
57
+ Model is loaded fresh each call to work with ZeroGPU.
58
+ """
59
+ # Import dependencies inside GPU function for ZeroGPU compatibility
60
+ try:
61
+ from pipeline_qwenimage_edit import QwenImageEditPipeline
62
+ except ImportError:
63
+ from diffusers import QwenImageEditPipeline
64
+
65
+ from diffusers import (
66
+ FlowMatchEulerDiscreteScheduler,
67
+ QwenImageTransformer2DModel,
68
+ AutoencoderKLQwenImage
69
+ )
70
+ from transformers import AutoModel, AutoTokenizer, Qwen2VLProcessor
71
+
72
+ device = get_device()
73
+ dtype = get_dtype()
74
+
75
+ print(f"๐Ÿš€ Loading model on {device}...")
76
+
77
+ # Load scheduler
78
+ scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(
79
+ MODEL_NAME, subfolder='scheduler'
80
+ )
81
+
82
+ # Load text encoder
83
+ text_encoder = AutoModel.from_pretrained(
84
+ MODEL_NAME,
85
+ subfolder='text_encoder',
86
+ torch_dtype=dtype
87
+ ).to(device)
88
+
89
+ # Load tokenizer & processor
90
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, subfolder='tokenizer')
91
+ processor = Qwen2VLProcessor.from_pretrained(MODEL_NAME, subfolder='processor')
92
+
93
+ # Load transformer
94
+ transformer = load_transformer(device, dtype)
95
+
96
+ # Load VAE
97
+ vae = AutoencoderKLQwenImage.from_pretrained(
98
+ MODEL_NAME,
99
+ subfolder='vae',
100
+ torch_dtype=dtype
101
+ ).to(device)
102
+
103
+ # Create pipeline
104
+ pipe = QwenImageEditPipeline(
105
+ scheduler=scheduler,
106
+ vae=vae,
107
+ text_encoder=text_encoder,
108
+ tokenizer=tokenizer,
109
+ processor=processor,
110
+ transformer=transformer
111
+ )
112
+
113
+ print(f"โœ… Model loaded! Generating with {len(images)} image(s)...")
114
+
115
+ # Generate
116
+ with torch.no_grad():
117
+ generator = torch.Generator(device=device).manual_seed(int(seed))
118
+
119
+ if len(images) == 1:
120
+ result = pipe(
121
+ images[0],
122
+ prompt=prompt,
123
+ height=1024,
124
+ width=1024,
125
+ negative_prompt=' ',
126
+ num_inference_steps=num_steps,
127
+ true_cfg_scale=true_cfg_scale,
128
+ generator=generator
129
+ ).images[0]
130
+ else:
131
+ result = pipe(
132
+ images=images,
133
+ prompt=prompt,
134
+ height=1024,
135
+ width=1024,
136
+ negative_prompt=' ',
137
+ num_inference_steps=num_steps,
138
+ true_cfg_scale=true_cfg_scale,
139
+ generator=generator
140
+ ).images[0]
141
+
142
+ # Cleanup to free VRAM
143
+ del pipe, transformer, vae, text_encoder
144
+ torch.cuda.empty_cache()
145
+
146
+ return result
147
+
148
+
149
+ def load_transformer(device, dtype):
150
+ """Load transformer with proper path handling"""
151
+ from diffusers import QwenImageTransformer2DModel
152
+
153
+ if os.path.exists(TRANSFORMER_PATH):
154
+ # Local path
155
+ if os.path.isdir(TRANSFORMER_PATH):
156
+ config_path = os.path.join(TRANSFORMER_PATH, "config.json")
157
+ if os.path.exists(config_path):
158
+ return QwenImageTransformer2DModel.from_pretrained(
159
+ TRANSFORMER_PATH,
160
+ torch_dtype=dtype,
161
+ use_safetensors=False
162
+ ).to(device)
163
+ else:
164
+ return QwenImageTransformer2DModel.from_pretrained(
165
+ TRANSFORMER_PATH,
166
+ subfolder='transformer',
167
+ torch_dtype=dtype,
168
+ use_safetensors=False
169
+ ).to(device)
170
+ raise ValueError(f"Invalid transformer path: {TRANSFORMER_PATH}")
171
+ else:
172
+ # HuggingFace repo path
173
+ path_parts = TRANSFORMER_PATH.split('/')
174
+ if len(path_parts) >= 3:
175
+ repo_id = '/'.join(path_parts[:2])
176
+ subfolder = path_parts[2]
177
+ return QwenImageTransformer2DModel.from_pretrained(
178
+ repo_id,
179
+ subfolder=subfolder,
180
+ torch_dtype=dtype
181
+ ).to(device)
182
+ else:
183
+ return QwenImageTransformer2DModel.from_pretrained(
184
+ TRANSFORMER_PATH,
185
+ subfolder='transformer',
186
+ torch_dtype=dtype
187
+ ).to(device)
188
+
189
+
190
+ # ============================================================
191
+ # UI Logic (CPU-only, always available)
192
+ # ============================================================
193
+
194
+ def process_images(
195
+ img1, img2, img3, img4, img5, img6,
196
+ prompt: str,
197
+ cfg_scale: float,
198
+ seed: int,
199
+ num_steps: int
200
+ ):
201
+ """Process images - validates input then calls GPU function"""
202
+
203
+ # Filter valid images
204
+ images = [img for img in [img1, img2, img3, img4, img5, img6] if img is not None]
205
+
206
+ # Validation
207
+ if len(images) == 0:
208
+ return None, "โŒ Please upload at least one image"
209
+
210
+ if len(images) > 6:
211
+ return None, f"โŒ Maximum 6 images allowed (got {len(images)})"
212
+
213
+ if not prompt or prompt.strip() == "":
214
+ return None, "โŒ Please enter an editing instruction"
215
+
216
+ try:
217
+ # Convert to RGB
218
+ images = [img.convert("RGB") for img in images]
219
+
220
+ # Call GPU function
221
+ result = generate_image(
222
+ images=images,
223
+ prompt=prompt,
224
+ true_cfg_scale=cfg_scale,
225
+ seed=seed,
226
+ num_steps=num_steps
227
  )
228
+
229
+ return result, f"โœ… Generated from {len(images)} image(s) in {num_steps} steps"
230
+
231
+ except Exception as e:
232
+ import traceback
233
+ traceback.print_exc()
234
+ return None, f"โŒ Error: {str(e)}"
235
+
236
+
237
+ def update_image_visibility(num):
238
+ """Update visibility of image upload slots"""
239
+ return [gr.update(visible=(i < num)) for i in range(6)]
240
+
241
+
242
+ # ============================================================
243
+ # Custom CSS for Beautiful UI
244
+ # ============================================================
245
+
246
+ CUSTOM_CSS = """
247
+ /* Import distinctive fonts */
248
+ @import url('https://fonts.googleapis.com/css2?family=Outfit:wght@300;400;500;600;700&family=JetBrains+Mono:wght@400;500&display=swap');
249
+
250
+ /* Root variables */
251
+ :root {
252
+ --primary: #6366f1;
253
+ --primary-dark: #4f46e5;
254
+ --accent: #f472b6;
255
+ --surface: #0f0f23;
256
+ --surface-light: #1a1a3e;
257
+ --surface-elevated: #252552;
258
+ --text: #e2e8f0;
259
+ --text-muted: #94a3b8;
260
+ --border: #334155;
261
+ --success: #10b981;
262
+ --error: #ef4444;
263
+ --gradient-1: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
264
+ --gradient-2: linear-gradient(135deg, #f093fb 0%, #f5576c 100%);
265
+ --gradient-hero: linear-gradient(135deg, #0f0f23 0%, #1a1a3e 50%, #252552 100%);
266
+ }
267
 
268
+ /* Global styles */
269
+ .gradio-container {
270
+ font-family: 'Outfit', sans-serif !important;
271
+ background: var(--gradient-hero) !important;
272
+ min-height: 100vh;
273
+ }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
274
 
275
+ /* Header styling */
276
+ .main-header {
277
+ text-align: center;
278
+ padding: 2rem 1rem;
279
+ background: linear-gradient(180deg, rgba(99, 102, 241, 0.1) 0%, transparent 100%);
280
+ border-radius: 24px;
281
+ margin-bottom: 2rem;
282
+ border: 1px solid rgba(99, 102, 241, 0.2);
283
+ }
284
 
285
+ .main-header h1 {
286
+ font-size: 2.5rem;
287
+ font-weight: 700;
288
+ background: linear-gradient(135deg, #fff 0%, #a5b4fc 50%, #f472b6 100%);
289
+ -webkit-background-clip: text;
290
+ -webkit-text-fill-color: transparent;
291
+ background-clip: text;
292
+ margin-bottom: 0.5rem;
293
+ }
294
 
295
+ .main-header p {
296
+ color: var(--text-muted);
297
+ font-size: 1.1rem;
298
+ max-width: 600px;
299
+ margin: 0 auto;
300
+ }
301
 
302
+ /* Feature badges */
303
+ .feature-badges {
304
+ display: flex;
305
+ gap: 1rem;
306
+ justify-content: center;
307
+ flex-wrap: wrap;
308
+ margin-top: 1.5rem;
309
+ }
310
 
311
+ .badge {
312
+ display: inline-flex;
313
+ align-items: center;
314
+ gap: 0.5rem;
315
+ padding: 0.5rem 1rem;
316
+ background: rgba(99, 102, 241, 0.15);
317
+ border: 1px solid rgba(99, 102, 241, 0.3);
318
+ border-radius: 9999px;
319
+ color: #a5b4fc;
320
+ font-size: 0.875rem;
321
+ font-weight: 500;
322
+ }
323
 
324
+ /* Section headers */
325
+ .section-header {
326
+ display: flex;
327
+ align-items: center;
328
+ gap: 0.75rem;
329
+ margin-bottom: 1rem;
330
+ padding-bottom: 0.75rem;
331
+ border-bottom: 1px solid var(--border);
332
+ }
333
 
334
+ .section-header h3 {
335
+ font-size: 1.125rem;
336
+ font-weight: 600;
337
+ color: var(--text);
338
+ margin: 0;
339
+ }
340
 
341
+ /* Card styling */
342
+ .card {
343
+ background: var(--surface-light) !important;
344
+ border: 1px solid var(--border) !important;
345
+ border-radius: 16px !important;
346
+ padding: 1.5rem !important;
347
+ }
348
 
349
+ /* Image upload grid */
350
+ .image-grid {
351
+ display: grid;
352
+ grid-template-columns: repeat(3, 1fr);
353
+ gap: 1rem;
354
+ }
355
 
356
+ /* Button styling */
357
+ .generate-btn {
358
+ background: var(--gradient-1) !important;
359
+ border: none !important;
360
+ border-radius: 12px !important;
361
+ padding: 1rem 2rem !important;
362
+ font-size: 1.1rem !important;
363
+ font-weight: 600 !important;
364
+ color: white !important;
365
+ cursor: pointer !important;
366
+ transition: all 0.3s ease !important;
367
+ box-shadow: 0 4px 15px rgba(99, 102, 241, 0.4) !important;
368
+ }
369
 
370
+ .generate-btn:hover {
371
+ transform: translateY(-2px) !important;
372
+ box-shadow: 0 6px 20px rgba(99, 102, 241, 0.5) !important;
373
+ }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
374
 
375
+ /* Input styling */
376
+ .gr-textbox textarea,
377
+ .gr-textbox input {
378
+ background: var(--surface) !important;
379
+ border: 1px solid var(--border) !important;
380
+ border-radius: 12px !important;
381
+ color: var(--text) !important;
382
+ font-family: 'Outfit', sans-serif !important;
383
+ }
384
 
385
+ .gr-textbox textarea:focus,
386
+ .gr-textbox input:focus {
387
+ border-color: var(--primary) !important;
388
+ box-shadow: 0 0 0 3px rgba(99, 102, 241, 0.2) !important;
389
+ }
390
 
391
+ /* Slider styling */
392
+ .gr-slider input[type="range"] {
393
+ accent-color: var(--primary) !important;
394
+ }
395
 
396
+ /* Output image */
397
+ .output-image {
398
+ border-radius: 16px;
399
+ overflow: hidden;
400
+ border: 2px solid transparent;
401
+ background: linear-gradient(var(--surface-light), var(--surface-light)) padding-box,
402
+ var(--gradient-1) border-box;
403
+ }
404
 
405
+ /* Status text */
406
+ .status-success {
407
+ color: var(--success) !important;
408
+ font-weight: 500;
409
+ }
410
 
411
+ .status-error {
412
+ color: var(--error) !important;
413
+ font-weight: 500;
414
+ }
415
 
416
+ /* Accordion */
417
+ .gr-accordion {
418
+ background: var(--surface-light) !important;
419
+ border: 1px solid var(--border) !important;
420
+ border-radius: 12px !important;
421
+ }
422
 
423
+ /* Labels */
424
+ label {
425
+ color: var(--text) !important;
426
+ font-weight: 500 !important;
427
+ }
428
 
429
+ /* Tooltip / info text */
430
+ .gr-info {
431
+ color: var(--text-muted) !important;
432
+ font-size: 0.875rem !important;
433
+ }
 
 
434
 
435
+ /* Responsive adjustments */
436
+ @media (max-width: 768px) {
437
+ .image-grid {
438
+ grid-template-columns: repeat(2, 1fr);
439
+ }
440
+
441
+ .main-header h1 {
442
+ font-size: 1.75rem;
443
+ }
444
+
445
+ .feature-badges {
446
+ flex-direction: column;
447
+ align-items: center;
448
+ }
449
+ }
450
+ """
451
 
452
+ # ============================================================
453
+ # Build Gradio Interface
454
+ # ============================================================
455
+
456
+ def create_demo():
457
+ with gr.Blocks(
458
+ title="UniPic-3 DMD",
459
+ theme=gr.themes.Base(
460
+ primary_hue="indigo",
461
+ secondary_hue="pink",
462
+ neutral_hue="slate",
463
+ font=("Outfit", "sans-serif"),
464
+ ),
465
+ css=CUSTOM_CSS
466
+ ) as demo:
467
+
468
+ # Header
469
+ gr.HTML("""
470
+ <div class="main-header">
471
+ <h1>๐ŸŽจ UniPic-3 DMD</h1>
472
+ <p>Multi-Image Composition with Distribution-Matching Distillation</p>
473
+ <div class="feature-badges">
474
+ <span class="badge">โšก 8-Step Fast Inference</span>
475
+ <span class="badge">๐Ÿ–ผ๏ธ Up to 6 Images</span>
476
+ <span class="badge">๐Ÿš€ 12.5ร— Speedup</span>
477
+ </div>
478
+ </div>
479
+ """)
480
+
481
+ with gr.Row(equal_height=True):
482
+ # Left Column - Inputs
483
+ with gr.Column(scale=1):
484
+
485
+ # Image Upload Section
486
+ gr.HTML('<div class="section-header"><span>๐Ÿ“ธ</span><h3>Upload Images</h3></div>')
487
+
488
+ num_images = gr.Slider(
489
+ minimum=1,
490
+ maximum=6,
491
+ value=2,
492
+ step=1,
493
+ label="Number of Images",
494
+ info="Select how many images to compose"
495
+ )
496
+
497
+ with gr.Row():
498
+ img1 = gr.Image(type="pil", label="Image 1", visible=True)
499
+ img2 = gr.Image(type="pil", label="Image 2", visible=True)
500
+
501
+ with gr.Row():
502
+ img3 = gr.Image(type="pil", label="Image 3", visible=False)
503
+ img4 = gr.Image(type="pil", label="Image 4", visible=False)
504
+
505
+ with gr.Row():
506
+ img5 = gr.Image(type="pil", label="Image 5", visible=False)
507
+ img6 = gr.Image(type="pil", label="Image 6", visible=False)
508
+
509
+ image_inputs = [img1, img2, img3, img4, img5, img6]
510
+
511
+ num_images.change(
512
+ fn=update_image_visibility,
513
+ inputs=num_images,
514
+ outputs=image_inputs
515
+ )
516
+
517
+ # Prompt Section
518
+ gr.HTML('<div class="section-header"><span>โœ๏ธ</span><h3>Editing Instruction</h3></div>')
519
+
520
+ prompt_input = gr.Textbox(
521
+ label="Prompt",
522
+ placeholder="e.g., A man from Image1 standing on a surfboard from Image2, riding ocean waves under a bright blue sky.",
523
+ lines=3,
524
+ value="Combine the reference images to generate the final result."
525
+ )
526
+
527
+ # Advanced Settings
528
+ with gr.Accordion("โš™๏ธ Advanced Settings", open=False):
529
+ cfg_scale = gr.Slider(
530
+ minimum=1.0,
531
+ maximum=10.0,
532
+ value=4.0,
533
+ step=0.5,
534
+ label="CFG Scale",
535
+ info="Higher = more prompt alignment"
536
+ )
537
+
538
+ with gr.Row():
539
+ seed = gr.Number(
540
+ value=42,
541
+ label="Seed",
542
+ info="For reproducibility",
543
+ precision=0
544
+ )
545
+ num_steps = gr.Slider(
546
+ minimum=1,
547
+ maximum=8,
548
+ value=8,
549
+ step=1,
550
+ label="Steps",
551
+ info="8 recommended for DMD"
552
+ )
553
+
554
+ # Generate Button
555
+ generate_btn = gr.Button(
556
+ "๐Ÿš€ Generate Image",
557
+ variant="primary",
558
+ size="lg",
559
+ elem_classes=["generate-btn"]
560
+ )
561
+
562
+ # Right Column - Output
563
+ with gr.Column(scale=1):
564
+ gr.HTML('<div class="section-header"><span>๐ŸŽจ</span><h3>Generated Result</h3></div>')
565
+
566
+ output_image = gr.Image(
567
+ type="pil",
568
+ label="Output",
569
+ elem_classes=["output-image"],
570
+ show_download_button=True
571
+ )
572
+
573
+ status_text = gr.Textbox(
574
+ label="Status",
575
+ value="โœจ Ready! Upload images and click Generate.",
576
+ interactive=False,
577
+ show_copy_button=False
578
+ )
579
+
580
+ # Tips
581
+ gr.HTML("""
582
+ <div style="
583
+ margin-top: 1.5rem;
584
+ padding: 1rem;
585
+ background: rgba(99, 102, 241, 0.1);
586
+ border-radius: 12px;
587
+ border: 1px solid rgba(99, 102, 241, 0.2);
588
+ ">
589
+ <p style="color: #a5b4fc; font-weight: 600; margin-bottom: 0.5rem;">๐Ÿ’ก Tips</p>
590
+ <ul style="color: #94a3b8; font-size: 0.9rem; margin: 0; padding-left: 1.25rem;">
591
+ <li>Reference images as "Image1", "Image2", etc. in your prompt</li>
592
+ <li>Use descriptive prompts for better composition</li>
593
+ <li>First run may take longer due to model loading</li>
594
+ </ul>
595
+ </div>
596
+ """)
597
+
598
+ # Connect generate button
599
+ generate_btn.click(
600
+ fn=process_images,
601
+ inputs=[*image_inputs, prompt_input, cfg_scale, seed, num_steps],
602
+ outputs=[output_image, status_text]
603
+ )
604
+
605
+ # Examples
606
+ gr.HTML('<div class="section-header" style="margin-top: 2rem;"><span>๐Ÿ“š</span><h3>Example Prompts</h3></div>')
607
+
608
+ gr.Examples(
609
+ examples=[
610
+ ["A person from Image1 wearing the outfit from Image2"],
611
+ ["Combine Image1 and Image2 into a single cohesive scene"],
612
+ ["The object from Image1 placed in the environment from Image2"],
613
+ ["Create a portrait using the face from Image1 and hairstyle from Image2"],
614
+ ],
615
+ inputs=[prompt_input],
616
+ label=""
617
+ )
618
+
619
+ return demo
620
 
 
 
 
 
 
621
 
622
+ # ============================================================
623
+ # Launch
624
+ # ============================================================
 
 
 
 
625
 
626
+ demo = create_demo()
 
 
627
 
628
  if __name__ == "__main__":
629
+ demo.launch()