OrlandoHugBot commited on
Commit
3638b57
·
verified ·
1 Parent(s): 095827d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +166 -257
app.py CHANGED
@@ -1,338 +1,247 @@
1
  """
2
- Gradio Demo for UniPic-3 DMD Multi-Image Composition
3
- Hugging Face Space compatible version
4
-
5
- Upload up to 6 images and generate a composed result using DMD model with 8-step inference.
6
  """
7
 
8
- import gradio as gr
 
9
  import torch
 
10
  from PIL import Image
11
- import os
12
  from spaces import GPU
13
 
14
- # Use local pipeline to ensure compatibility
15
- import sys
 
16
  sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
 
17
  try:
18
  from pipeline_qwenimage_edit import QwenImageEditPipeline
19
  except ImportError:
20
- # Fallback to diffusers if local not available
21
- try:
22
- from diffusers import QwenImageEditPipeline
23
- except ImportError:
24
- raise ImportError(
25
- "QwenImageEditPipeline not found. Please ensure pipeline_qwenimage_edit.py "
26
- "is in the same directory or diffusers is installed."
27
- )
28
 
29
- from diffusers import FlowMatchEulerDiscreteScheduler, QwenImageTransformer2DModel, AutoencoderKLQwenImage
 
 
 
 
30
  from transformers import AutoModel, AutoTokenizer, Qwen2VLProcessor
31
 
32
-
33
- # Global pipeline
 
34
  pipe = None
35
 
36
- # Model paths (can be set via environment variables)
37
-
38
  MODEL_NAME = os.environ.get("MODEL_NAME", "Skywork/Unipic3-DMD")
39
- default_transformer = "Skywork/Unipic3-DMD"
40
- TRANSFORMER_PATH = os.environ.get("TRANSFORMER_PATH", "Skywork/Unipic3-DMD/ema_transformer")
 
41
 
 
 
 
42
  def load_model():
43
- """Load the DMD model and pipeline"""
44
  global pipe
45
-
46
  if pipe is not None:
47
- return pipe
48
-
49
- print(f"Loading model from {TRANSFORMER_PATH}...")
50
-
51
- # Load scheduler
 
 
 
 
 
 
 
 
 
 
52
  scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(
53
- pretrained_model_name_or_path=MODEL_NAME, subfolder='scheduler'
54
  )
55
-
56
- # Load text encoder
57
  text_encoder = AutoModel.from_pretrained(
58
- pretrained_model_name_or_path=MODEL_NAME, subfolder='text_encoder',
59
- device_map='auto', torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32
60
- )
61
-
62
- # Load tokenizer and processor
 
63
  tokenizer = AutoTokenizer.from_pretrained(
64
- pretrained_model_name_or_path=MODEL_NAME, subfolder='tokenizer',
65
  )
66
  processor = Qwen2VLProcessor.from_pretrained(
67
- pretrained_model_name_or_path=MODEL_NAME, subfolder='processor',
68
  )
69
-
70
- # Load transformer (DMD model)
71
- # Handle both local paths and HuggingFace repo paths
72
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
73
- print("device:", device)
74
-
75
  if os.path.exists(TRANSFORMER_PATH):
76
- # Local path - load directly to device (avoid device_map issues with .bin files)
77
- if os.path.isdir(TRANSFORMER_PATH):
78
- # Check if it's a direct transformer directory or has subfolder
79
- if os.path.exists(os.path.join(TRANSFORMER_PATH, "config.json")):
80
- transformer = QwenImageTransformer2DModel.from_pretrained(
81
- pretrained_model_name_or_path=TRANSFORMER_PATH,
82
- torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
83
- use_safetensors=False # Use .bin file
84
- ).to(device)
85
- else:
86
- transformer = QwenImageTransformer2DModel.from_pretrained(
87
- pretrained_model_name_or_path=TRANSFORMER_PATH,
88
- subfolder='transformer',
89
- torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
90
- use_safetensors=False
91
- ).to(device)
92
- else:
93
- raise ValueError(f"Transformer path does not exist: {TRANSFORMER_PATH}")
94
- else:
95
- # HuggingFace repo path
96
- # Handle paths like "Skywork/Unipic3-DMD/ema_transformer"
97
- path_parts = TRANSFORMER_PATH.split('/')
98
- if len(path_parts) >= 3:
99
- # Has subfolder: "Skywork/Unipic3-DMD/ema_transformer"
100
- repo_id = '/'.join(path_parts[:2]) # "Skywork/Unipic3-DMD"
101
- subfolder = path_parts[2] # "ema_transformer"
102
- transformer = QwenImageTransformer2DModel.from_pretrained(
103
- pretrained_model_name_or_path=repo_id,
104
- subfolder=subfolder,
105
- device_map='auto',
106
- torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32
107
- )
108
- elif len(path_parts) == 2:
109
- # Just repo: "Skywork/Unipic3-DMD"
110
- transformer = QwenImageTransformer2DModel.from_pretrained(
111
- pretrained_model_name_or_path=TRANSFORMER_PATH,
112
- subfolder='transformer',
113
- device_map='auto',
114
- torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32
115
- )
116
- else:
117
- # Single name, assume it's a repo ID
118
- transformer = QwenImageTransformer2DModel.from_pretrained(
119
- pretrained_model_name_or_path=TRANSFORMER_PATH,
120
- subfolder='transformer',
121
- device_map='auto',
122
- torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32
123
- )
124
-
125
- # Load VAE
126
- # Get device from transformer (handle both .device and device_map cases)
127
- if hasattr(transformer, 'device'):
128
- vae_device = transformer.device
129
- elif hasattr(transformer, 'hf_device_map'):
130
- # If using device_map, get the first device
131
- vae_device = device
132
  else:
133
- vae_device = device
134
-
 
 
 
 
 
 
 
 
135
  vae = AutoencoderKLQwenImage.from_pretrained(
136
- pretrained_model_name_or_path=MODEL_NAME,
137
- subfolder='vae',
138
- torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
139
- ).to(vae_device)
140
-
141
- # Create pipeline
142
  pipe = QwenImageEditPipeline(
143
  scheduler=scheduler,
144
  vae=vae,
145
  text_encoder=text_encoder,
146
  tokenizer=tokenizer,
147
  processor=processor,
148
- transformer=transformer
149
  )
150
-
151
- print("Model loaded successfully!")
152
- # Don't return pipe for demo.load() - it expects no return value
153
 
 
154
 
 
 
 
 
 
 
155
  def process_images(
156
  img1, img2, img3, img4, img5, img6,
157
- prompt: str,
158
- true_cfg_scale: float = 4.0,
159
- seed: int = 42,
160
- num_steps: int = 8
161
- ) -> tuple:
162
- """Process multiple images and generate composed result"""
163
  global pipe
164
-
165
- # Ensure model is loaded (should be loaded by demo.load() on startup)
166
  if pipe is None:
167
- return None, "⏳ Model is still loading, please wait a moment and try again..."
168
-
169
- # Filter out None images
170
- images = [img for img in [img1, img2, img3, img4, img5, img6] if img is not None]
171
-
172
- # Validate inputs
173
  if len(images) == 0:
174
- return None, "❌ Error: Please upload at least one image."
175
-
176
  if len(images) > 6:
177
- return None, f"❌ Error: Maximum 6 images allowed. You uploaded {len(images)} images."
178
-
179
- if not prompt or prompt.strip() == "":
180
- return None, "❌ Error: Please enter an editing instruction."
181
-
 
 
 
 
182
  try:
183
- # Convert to RGB
184
- images = [img.convert("RGB") for img in images]
185
-
186
- print(f"Processing {len(images)} images with prompt: '{prompt}'")
187
- print(f"Steps: {num_steps}, CFG Scale: {true_cfg_scale}, Seed: {seed}")
188
-
189
- # Generate image
190
- # Note: images can be passed as first positional argument or as keyword argument
191
  with torch.no_grad():
192
- # Try positional argument first (as shown in pipeline examples)
193
  if len(images) == 1:
194
- # Single image: pass as first positional argument
195
  result = pipe(
196
  images[0],
197
  prompt=prompt,
198
  height=1024,
199
  width=1024,
200
- negative_prompt=' ',
201
  num_inference_steps=num_steps,
202
  true_cfg_scale=true_cfg_scale,
203
- generator=torch.manual_seed(int(seed))
204
  ).images[0]
205
  else:
206
- # Multiple images: pass as keyword argument
207
  result = pipe(
208
  images=images,
209
  prompt=prompt,
210
  height=1024,
211
  width=1024,
212
- negative_prompt=' ',
213
  num_inference_steps=num_steps,
214
  true_cfg_scale=true_cfg_scale,
215
- generator=torch.manual_seed(int(seed))
216
  ).images[0]
217
-
218
- return result, f"✅ Success! Generated from {len(images)} image(s) in {num_steps} steps."
219
-
220
  except Exception as e:
221
- error_msg = f"❌ Error: {str(e)}"
222
- print(error_msg)
223
  import traceback
224
  traceback.print_exc()
225
- return None, error_msg
226
-
227
-
228
- # Create Gradio interface
229
- with gr.Blocks(title="UniPic-3 DMD Multi-Image Composition", theme=gr.themes.Soft()) as demo:
230
- gr.Markdown("""
231
- # 🔥 UniPic-3 DMD Multi-Image Composition
232
-
233
- Upload up to **6 images** and provide an editing instruction to generate a composed result.
234
-
235
- **Model**: DMD (Distribution-Matching Distillation) - **8-step fast inference (12.5× speedup)**
236
-
237
- **Features**:
238
- - Support 1-6 input images
239
- - Fast 8-step inference
240
- - High-quality multi-image composition
241
- """)
242
-
 
 
 
243
  with gr.Row():
244
- with gr.Column(scale=1):
245
- gr.Markdown("### 📸 Upload Images (1-6 images)")
246
  image_inputs = [
247
  gr.Image(type="pil", label=f"Image {i+1}", visible=(i < 2))
248
  for i in range(6)
249
  ]
250
-
251
- num_images = gr.Slider(
252
- minimum=1,
253
- maximum=6,
254
- value=2,
255
- step=1,
256
- label="Number of Images",
257
- info="Select how many images you want to upload"
258
- )
259
-
260
- def update_image_visibility(num):
261
- return [gr.update(visible=(i < num)) for i in range(6)]
262
-
263
- num_images.change(
264
- fn=update_image_visibility,
265
- inputs=num_images,
266
- outputs=image_inputs
267
- )
268
-
269
- gr.Markdown("### ✍️ Editing Instruction")
270
- prompt_input = gr.Textbox(
271
  label="Prompt",
272
- placeholder="e.g., A man from Image1 is standing on a surfboard from Image2, riding the ocean waves under a bright blue sky.",
273
  lines=3,
274
- value="Combine the reference images to generate the final result."
275
- )
276
-
277
- with gr.Accordion("⚙️ Advanced Settings", open=False):
278
- cfg_scale = gr.Slider(
279
- minimum=1.0,
280
- maximum=10.0,
281
- value=4.0,
282
- step=0.5,
283
- label="CFG Scale",
284
- info="Higher values make the output more aligned with the prompt"
285
- )
286
- seed = gr.Number(
287
- value=42,
288
- label="Seed",
289
- info="Random seed for reproducibility",
290
- precision=0
291
- )
292
- num_steps = gr.Slider(
293
- minimum=1,
294
- maximum=8,
295
- value=8,
296
- step=1,
297
- label="Inference Steps",
298
- info="Number of denoising steps (8 is recommended for DMD)"
299
- )
300
-
301
- generate_btn = gr.Button("🚀 Generate", variant="primary", size="lg")
302
-
303
- with gr.Column(scale=1):
304
- gr.Markdown("### 🎨 Generated Result")
305
- output_image = gr.Image(type="pil", label="Output Image")
306
- status_text = gr.Textbox(
307
- label="Status",
308
- value="Ready. Upload images and enter a prompt, then click Generate.",
309
- interactive=False
310
  )
311
-
312
- # Load model on startup
313
- def load_model_wrapper():
314
- """Wrapper to load model without returning value"""
315
- load_model()
316
- return None
317
-
318
- demo.load(
319
- fn=load_model_wrapper,
320
- inputs=[],
321
- outputs=[],
322
- show_progress=True
323
- )
324
-
325
- # Generate button
326
- generate_btn.click(
327
- fn=process_images,
328
- inputs=[*image_inputs, prompt_input, cfg_scale, seed, num_steps],
329
- outputs=[output_image, status_text]
330
  )
331
 
332
- # @GPU
333
- # def main():
334
- # demo.launch()
335
-
336
- if __name__ == "__main__":
 
337
  demo.launch()
338
 
 
 
 
 
1
  """
2
+ UniPic-3 DMD Multi-Image Composition
3
+ Hugging Face Space
 
 
4
  """
5
 
6
+ import os
7
+ import sys
8
  import torch
9
+ import gradio as gr
10
  from PIL import Image
 
11
  from spaces import GPU
12
 
13
+ # -----------------------------------------------------------------------------
14
+ # Local imports
15
+ # -----------------------------------------------------------------------------
16
  sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
17
+
18
  try:
19
  from pipeline_qwenimage_edit import QwenImageEditPipeline
20
  except ImportError:
21
+ from diffusers import QwenImageEditPipeline
 
 
 
 
 
 
 
22
 
23
+ from diffusers import (
24
+ FlowMatchEulerDiscreteScheduler,
25
+ QwenImageTransformer2DModel,
26
+ AutoencoderKLQwenImage,
27
+ )
28
  from transformers import AutoModel, AutoTokenizer, Qwen2VLProcessor
29
 
30
+ # -----------------------------------------------------------------------------
31
+ # Globals
32
+ # -----------------------------------------------------------------------------
33
  pipe = None
34
 
 
 
35
  MODEL_NAME = os.environ.get("MODEL_NAME", "Skywork/Unipic3-DMD")
36
+ TRANSFORMER_PATH = os.environ.get(
37
+ "TRANSFORMER_PATH", "Skywork/Unipic3-DMD/ema_transformer"
38
+ )
39
 
40
+ # -----------------------------------------------------------------------------
41
+ # Model loader (LAZY)
42
+ # -----------------------------------------------------------------------------
43
  def load_model():
 
44
  global pipe
45
+
46
  if pipe is not None:
47
+ return
48
+
49
+ if not torch.cuda.is_available():
50
+ raise RuntimeError(
51
+ "❌ GPU not available. This Space is GPU-only."
52
+ )
53
+
54
+ device = torch.device("cuda")
55
+ dtype = torch.bfloat16
56
+
57
+ print("🚀 Loading UniPic-3 DMD on GPU")
58
+ print("Device:", device)
59
+ print("Dtype:", dtype)
60
+
61
+ # Scheduler
62
  scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(
63
+ MODEL_NAME, subfolder="scheduler"
64
  )
65
+
66
+ # Text encoder
67
  text_encoder = AutoModel.from_pretrained(
68
+ MODEL_NAME,
69
+ subfolder="text_encoder",
70
+ torch_dtype=dtype,
71
+ ).to(device)
72
+
73
+ # Tokenizer / Processor
74
  tokenizer = AutoTokenizer.from_pretrained(
75
+ MODEL_NAME, subfolder="tokenizer"
76
  )
77
  processor = Qwen2VLProcessor.from_pretrained(
78
+ MODEL_NAME, subfolder="processor"
79
  )
80
+
81
+ # Transformer (DMD)
 
 
 
 
82
  if os.path.exists(TRANSFORMER_PATH):
83
+ transformer = QwenImageTransformer2DModel.from_pretrained(
84
+ TRANSFORMER_PATH,
85
+ torch_dtype=dtype,
86
+ use_safetensors=False,
87
+ ).to(device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88
  else:
89
+ # HF repo path: Skywork/Unipic3-DMD/ema_transformer
90
+ repo_id = "/".join(TRANSFORMER_PATH.split("/")[:2])
91
+ subfolder = TRANSFORMER_PATH.split("/")[-1]
92
+ transformer = QwenImageTransformer2DModel.from_pretrained(
93
+ repo_id,
94
+ subfolder=subfolder,
95
+ torch_dtype=dtype,
96
+ ).to(device)
97
+
98
+ # VAE
99
  vae = AutoencoderKLQwenImage.from_pretrained(
100
+ MODEL_NAME,
101
+ subfolder="vae",
102
+ torch_dtype=dtype,
103
+ ).to(device)
104
+
105
+ # Pipeline
106
  pipe = QwenImageEditPipeline(
107
  scheduler=scheduler,
108
  vae=vae,
109
  text_encoder=text_encoder,
110
  tokenizer=tokenizer,
111
  processor=processor,
112
+ transformer=transformer,
113
  )
 
 
 
114
 
115
+ pipe.to(device)
116
 
117
+ print("✅ Model loaded successfully")
118
+
119
+
120
+ # -----------------------------------------------------------------------------
121
+ # Inference
122
+ # -----------------------------------------------------------------------------
123
  def process_images(
124
  img1, img2, img3, img4, img5, img6,
125
+ prompt,
126
+ true_cfg_scale,
127
+ seed,
128
+ num_steps,
129
+ ):
 
130
  global pipe
131
+
 
132
  if pipe is None:
133
+ load_model()
134
+
135
+ images = [i for i in [img1, img2, img3, img4, img5, img6] if i is not None]
136
+
 
 
137
  if len(images) == 0:
138
+ return None, "❌ Please upload at least one image."
139
+
140
  if len(images) > 6:
141
+ return None, "❌ Maximum 6 images allowed."
142
+
143
+ if not prompt.strip():
144
+ return None, "❌ Prompt cannot be empty."
145
+
146
+ images = [img.convert("RGB") for img in images]
147
+
148
+ generator = torch.Generator(device="cuda").manual_seed(int(seed))
149
+
150
  try:
 
 
 
 
 
 
 
 
151
  with torch.no_grad():
 
152
  if len(images) == 1:
 
153
  result = pipe(
154
  images[0],
155
  prompt=prompt,
156
  height=1024,
157
  width=1024,
158
+ negative_prompt=" ",
159
  num_inference_steps=num_steps,
160
  true_cfg_scale=true_cfg_scale,
161
+ generator=generator,
162
  ).images[0]
163
  else:
 
164
  result = pipe(
165
  images=images,
166
  prompt=prompt,
167
  height=1024,
168
  width=1024,
169
+ negative_prompt=" ",
170
  num_inference_steps=num_steps,
171
  true_cfg_scale=true_cfg_scale,
172
+ generator=generator,
173
  ).images[0]
174
+
175
+ return result, f"✅ Generated from {len(images)} image(s)"
176
+
177
  except Exception as e:
 
 
178
  import traceback
179
  traceback.print_exc()
180
+ return None, f"❌ Error: {e}"
181
+
182
+
183
+ # -----------------------------------------------------------------------------
184
+ # UI
185
+ # -----------------------------------------------------------------------------
186
+ with gr.Blocks(
187
+ title="UniPic-3 DMD Multi-Image Composition",
188
+ theme=gr.themes.Soft(),
189
+ ) as demo:
190
+
191
+ gr.Markdown(
192
+ """
193
+ # 🔥 UniPic-3 DMD Multi-Image Composition
194
+
195
+ - **Model**: UniPic-3 DMD
196
+ - **Inference**: 8-step fast generation
197
+ - **GPU-only Hugging Face Space**
198
+ """
199
+ )
200
+
201
  with gr.Row():
202
+ with gr.Column():
 
203
  image_inputs = [
204
  gr.Image(type="pil", label=f"Image {i+1}", visible=(i < 2))
205
  for i in range(6)
206
  ]
207
+
208
+ num_images = gr.Slider(1, 6, value=2, step=1, label="Number of Images")
209
+
210
+ def update_visibility(n):
211
+ return [gr.update(visible=i < n) for i in range(6)]
212
+
213
+ num_images.change(update_visibility, num_images, image_inputs)
214
+
215
+ prompt = gr.Textbox(
 
 
 
 
 
 
 
 
 
 
 
 
216
  label="Prompt",
 
217
  lines=3,
218
+ value="Combine the reference images to generate the final result.",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
219
  )
220
+
221
+ cfg = gr.Slider(1.0, 10.0, value=4.0, step=0.5, label="CFG Scale")
222
+ seed = gr.Number(value=42, precision=0, label="Seed")
223
+ steps = gr.Slider(1, 8, value=8, step=1, label="Steps")
224
+
225
+ btn = gr.Button("🚀 Generate", variant="primary")
226
+
227
+ with gr.Column():
228
+ output = gr.Image(label="Output")
229
+ status = gr.Textbox(label="Status", interactive=False)
230
+
231
+ btn.click(
232
+ process_images,
233
+ inputs=[*image_inputs, prompt, cfg, seed, steps],
234
+ outputs=[output, status],
 
 
 
 
235
  )
236
 
237
+
238
+ # -----------------------------------------------------------------------------
239
+ # Entry (IMPORTANT)
240
+ # -----------------------------------------------------------------------------
241
+ @GPU
242
+ def main():
243
  demo.launch()
244
 
245
+
246
+ if __name__ == "__main__":
247
+ main()