OrlandoHugBot commited on
Commit
e28e511
·
verified ·
1 Parent(s): e04ab65

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +109 -139
app.py CHANGED
@@ -1,6 +1,8 @@
1
  """
2
- UniPic-3 DMD Multi-Image Composition
3
- Hugging Face Space
 
 
4
  """
5
 
6
  import os
@@ -9,16 +11,47 @@ import torch
9
  import gradio as gr
10
  from PIL import Image
11
  from spaces import GPU
 
12
 
13
  # -----------------------------------------------------------------------------
14
- # Local imports
15
  # -----------------------------------------------------------------------------
16
- sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
 
 
17
 
18
- try:
19
- from pipeline_qwenimage_edit import QwenImageEditPipeline
20
- except ImportError:
21
- from diffusers import QwenImageEditPipeline
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
 
23
  from diffusers import (
24
  FlowMatchEulerDiscreteScheduler,
@@ -27,18 +60,18 @@ from diffusers import (
27
  )
28
  from transformers import AutoModel, AutoTokenizer, Qwen2VLProcessor
29
 
 
 
 
 
 
30
  # -----------------------------------------------------------------------------
31
  # Globals
32
  # -----------------------------------------------------------------------------
33
  pipe = None
34
 
35
- MODEL_NAME = os.environ.get("MODEL_NAME", "Skywork/Unipic3-DMD")
36
- TRANSFORMER_PATH = os.environ.get(
37
- "TRANSFORMER_PATH", "Skywork/Unipic3-DMD/ema_transformer"
38
- )
39
-
40
  # -----------------------------------------------------------------------------
41
- # Model loader (LAZY)
42
  # -----------------------------------------------------------------------------
43
  def load_model():
44
  global pipe
@@ -47,62 +80,44 @@ def load_model():
47
  return
48
 
49
  if not torch.cuda.is_available():
50
- raise RuntimeError(
51
- "❌ GPU not available. This Space is GPU-only."
52
- )
53
 
54
  device = torch.device("cuda")
55
  dtype = torch.bfloat16
56
 
57
- print("🚀 Loading UniPic-3 DMD on GPU")
58
  print("Device:", device)
59
- print("Dtype:", dtype)
60
 
61
- # Scheduler
62
  scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(
63
- MODEL_NAME, subfolder="scheduler"
64
  )
65
 
66
- # Text encoder
67
  text_encoder = AutoModel.from_pretrained(
68
- MODEL_NAME,
69
  subfolder="text_encoder",
70
  torch_dtype=dtype,
71
- ).to('cpu')
72
 
73
- # Tokenizer / Processor
74
  tokenizer = AutoTokenizer.from_pretrained(
75
- MODEL_NAME, subfolder="tokenizer"
76
  )
 
77
  processor = Qwen2VLProcessor.from_pretrained(
78
- MODEL_NAME, subfolder="processor"
79
  )
80
 
81
- # Transformer (DMD)
82
- if os.path.exists(TRANSFORMER_PATH):
83
- transformer = QwenImageTransformer2DModel.from_pretrained(
84
- TRANSFORMER_PATH,
85
- torch_dtype=dtype,
86
- use_safetensors=False,
87
- ).to(device)
88
- else:
89
- # HF repo path: Skywork/Unipic3-DMD/ema_transformer
90
- repo_id = "/".join(TRANSFORMER_PATH.split("/")[:2])
91
- subfolder = TRANSFORMER_PATH.split("/")[-1]
92
- transformer = QwenImageTransformer2DModel.from_pretrained(
93
- repo_id,
94
- subfolder=subfolder,
95
- torch_dtype=dtype,
96
- ).to(device)
97
-
98
- # VAE
99
  vae = AutoencoderKLQwenImage.from_pretrained(
100
- MODEL_NAME,
101
  subfolder="vae",
102
  torch_dtype=dtype,
103
  ).to(device)
104
 
105
- # Pipeline
106
  pipe = QwenImageEditPipeline(
107
  scheduler=scheduler,
108
  vae=vae,
@@ -120,12 +135,9 @@ def load_model():
120
  # -----------------------------------------------------------------------------
121
  # Inference
122
  # -----------------------------------------------------------------------------
123
- def process_images(
124
  img1, img2, img3, img4, img5, img6,
125
- prompt,
126
- true_cfg_scale,
127
- seed,
128
- num_steps,
129
  ):
130
  global pipe
131
 
@@ -133,112 +145,70 @@ def process_images(
133
  load_model()
134
 
135
  images = [i for i in [img1, img2, img3, img4, img5, img6] if i is not None]
136
-
137
- if len(images) == 0:
138
- return None, "❌ Please upload at least one image."
139
-
140
- if len(images) > 6:
141
- return None, "❌ Maximum 6 images allowed."
142
-
143
- if not prompt.strip():
144
- return None, "❌ Prompt cannot be empty."
145
 
146
  images = [img.convert("RGB") for img in images]
147
 
148
- generator = torch.Generator(device="cuda").manual_seed(int(seed))
149
-
150
- try:
151
- with torch.no_grad():
152
- if len(images) == 1:
153
- result = pipe(
154
- images[0],
155
- prompt=prompt,
156
- height=512,
157
- width=512,
158
- negative_prompt=" ",
159
- num_inference_steps=num_steps,
160
- true_cfg_scale=true_cfg_scale,
161
- generator=generator,
162
- ).images[0]
163
- else:
164
- result = pipe(
165
- images=images,
166
- prompt=prompt,
167
- height=512,
168
- width=512,
169
- negative_prompt=" ",
170
- num_inference_steps=num_steps,
171
- true_cfg_scale=true_cfg_scale,
172
- generator=generator,
173
- ).images[0]
174
-
175
- return result, f"✅ Generated from {len(images)} image(s)"
176
-
177
- except Exception as e:
178
- import traceback
179
- traceback.print_exc()
180
- return None, f"❌ Error: {e}"
181
 
182
 
183
  # -----------------------------------------------------------------------------
184
  # UI
185
  # -----------------------------------------------------------------------------
186
- with gr.Blocks(
187
- title="UniPic-3 DMD Multi-Image Composition",
188
- theme=gr.themes.Soft(),
189
- ) as demo:
190
 
191
- gr.Markdown(
192
- """
193
- # 🔥 UniPic-3 DMD Multi-Image Composition
 
 
194
 
195
- - **Model**: UniPic-3 DMD
196
- - **Inference**: 8-step fast generation
197
- """
198
- )
199
-
200
- with gr.Row():
201
- with gr.Column():
202
- image_inputs = [
203
- gr.Image(type="pil", label=f"Image {i+1}", visible=(i < 2))
204
- for i in range(6)
205
- ]
206
-
207
- num_images = gr.Slider(1, 6, value=2, step=1, label="Number of Images")
208
-
209
- def update_visibility(n):
210
- return [gr.update(visible=i < n) for i in range(6)]
211
-
212
- num_images.change(update_visibility, num_images, image_inputs)
213
-
214
- prompt = gr.Textbox(
215
- label="Prompt",
216
- lines=3,
217
- value="Combine the reference images to generate the final result.",
218
- )
219
-
220
- cfg = gr.Slider(1.0, 10.0, value=4.0, step=0.5, label="CFG Scale")
221
- seed = gr.Number(value=42, precision=0, label="Seed")
222
- steps = gr.Slider(1, 8, value=8, step=1, label="Steps")
223
-
224
- btn = gr.Button("🚀 Generate", variant="primary")
225
-
226
- with gr.Column():
227
- output = gr.Image(label="Output")
228
- status = gr.Textbox(label="Status", interactive=False)
229
 
230
  btn.click(
231
- process_images,
232
- inputs=[*image_inputs, prompt, cfg, seed, steps],
233
- outputs=[output, status],
234
  )
235
 
236
 
237
  # -----------------------------------------------------------------------------
238
- # Entry (IMPORTANT)
239
  # -----------------------------------------------------------------------------
240
  @GPU
241
  def main():
 
 
 
 
242
  demo.launch(ssr_mode=False)
243
 
244
 
 
1
  """
2
+ UniPic-3 DMD ZeroGPU friendly demo
3
+ - Pre-cache all weights on CPU
4
+ - GPU phase does ZERO network IO
5
+ - SSR disabled
6
  """
7
 
8
  import os
 
11
  import gradio as gr
12
  from PIL import Image
13
  from spaces import GPU
14
+ from huggingface_hub import snapshot_download
15
 
16
  # -----------------------------------------------------------------------------
17
+ # Paths
18
  # -----------------------------------------------------------------------------
19
+ MODEL_ID = "Skywork/Unipic3-DMD"
20
+ CACHE_ROOT = "./hf_cache"
21
+ LOCAL_MODEL_DIR = os.path.join(CACHE_ROOT, MODEL_ID)
22
 
23
+ # -----------------------------------------------------------------------------
24
+ # Pre-cache weights (CPU ONLY)
25
+ # -----------------------------------------------------------------------------
26
+ def precache_weights():
27
+ if os.path.exists(LOCAL_MODEL_DIR):
28
+ print("✅ Weights already cached")
29
+ return
30
+
31
+ print("📦 Pre-caching UniPic-3 DMD weights (CPU stage)...")
32
+
33
+ snapshot_download(
34
+ repo_id=MODEL_ID,
35
+ local_dir=LOCAL_MODEL_DIR,
36
+ local_dir_use_symlinks=False,
37
+ resume_download=True,
38
+ allow_patterns=[
39
+ "scheduler/*",
40
+ "text_encoder/*",
41
+ "tokenizer/*",
42
+ "processor/*",
43
+ "vae/*",
44
+ "ema_transformer/*",
45
+ ],
46
+ )
47
+
48
+ print("✅ Pre-cache complete")
49
+
50
+
51
+ # -----------------------------------------------------------------------------
52
+ # Local imports AFTER cache
53
+ # -----------------------------------------------------------------------------
54
+ sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
55
 
56
  from diffusers import (
57
  FlowMatchEulerDiscreteScheduler,
 
60
  )
61
  from transformers import AutoModel, AutoTokenizer, Qwen2VLProcessor
62
 
63
+ try:
64
+ from pipeline_qwenimage_edit import QwenImageEditPipeline
65
+ except ImportError:
66
+ from diffusers import QwenImageEditPipeline
67
+
68
  # -----------------------------------------------------------------------------
69
  # Globals
70
  # -----------------------------------------------------------------------------
71
  pipe = None
72
 
 
 
 
 
 
73
  # -----------------------------------------------------------------------------
74
+ # Load model (GPU stage, NO DOWNLOAD)
75
  # -----------------------------------------------------------------------------
76
  def load_model():
77
  global pipe
 
80
  return
81
 
82
  if not torch.cuda.is_available():
83
+ raise RuntimeError("❌ GPU not available")
 
 
84
 
85
  device = torch.device("cuda")
86
  dtype = torch.bfloat16
87
 
88
+ print("🚀 Loading UniPic-3 DMD from local cache")
89
  print("Device:", device)
 
90
 
 
91
  scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(
92
+ LOCAL_MODEL_DIR, subfolder="scheduler"
93
  )
94
 
 
95
  text_encoder = AutoModel.from_pretrained(
96
+ LOCAL_MODEL_DIR,
97
  subfolder="text_encoder",
98
  torch_dtype=dtype,
99
+ ).to(device)
100
 
 
101
  tokenizer = AutoTokenizer.from_pretrained(
102
+ LOCAL_MODEL_DIR, subfolder="tokenizer"
103
  )
104
+
105
  processor = Qwen2VLProcessor.from_pretrained(
106
+ LOCAL_MODEL_DIR, subfolder="processor"
107
  )
108
 
109
+ transformer = QwenImageTransformer2DModel.from_pretrained(
110
+ LOCAL_MODEL_DIR,
111
+ subfolder="ema_transformer",
112
+ torch_dtype=dtype,
113
+ ).to(device)
114
+
 
 
 
 
 
 
 
 
 
 
 
 
115
  vae = AutoencoderKLQwenImage.from_pretrained(
116
+ LOCAL_MODEL_DIR,
117
  subfolder="vae",
118
  torch_dtype=dtype,
119
  ).to(device)
120
 
 
121
  pipe = QwenImageEditPipeline(
122
  scheduler=scheduler,
123
  vae=vae,
 
135
  # -----------------------------------------------------------------------------
136
  # Inference
137
  # -----------------------------------------------------------------------------
138
+ def run(
139
  img1, img2, img3, img4, img5, img6,
140
+ prompt, cfg, seed, steps
 
 
 
141
  ):
142
  global pipe
143
 
 
145
  load_model()
146
 
147
  images = [i for i in [img1, img2, img3, img4, img5, img6] if i is not None]
148
+ if not images:
149
+ return None, "❌ Please upload at least one image"
 
 
 
 
 
 
 
150
 
151
  images = [img.convert("RGB") for img in images]
152
 
153
+ gen = torch.Generator(device="cuda").manual_seed(int(seed))
154
+
155
+ with torch.no_grad():
156
+ if len(images) == 1:
157
+ out = pipe(
158
+ images[0],
159
+ prompt=prompt,
160
+ height=768,
161
+ width=768,
162
+ num_inference_steps=steps,
163
+ true_cfg_scale=cfg,
164
+ generator=gen,
165
+ ).images[0]
166
+ else:
167
+ out = pipe(
168
+ images=images,
169
+ prompt=prompt,
170
+ height=768,
171
+ width=768,
172
+ num_inference_steps=steps,
173
+ true_cfg_scale=cfg,
174
+ generator=gen,
175
+ ).images[0]
176
+
177
+ return out, "✅ Done"
 
 
 
 
 
 
 
 
178
 
179
 
180
  # -----------------------------------------------------------------------------
181
  # UI
182
  # -----------------------------------------------------------------------------
183
+ with gr.Blocks(title="UniPic-3 DMD (ZeroGPU)") as demo:
184
+ gr.Markdown("# 🔥 UniPic-3 DMD (ZeroGPU + Precached)")
 
 
185
 
186
+ imgs = [gr.Image(type="pil", label=f"Image {i+1}") for i in range(6)]
187
+ prompt = gr.Textbox(label="Prompt", value="Combine the reference images.")
188
+ cfg = gr.Slider(1, 8, value=4)
189
+ seed = gr.Number(42)
190
+ steps = gr.Slider(1, 8, value=6)
191
 
192
+ btn = gr.Button("Generate")
193
+ out = gr.Image()
194
+ status = gr.Textbox()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
195
 
196
  btn.click(
197
+ run,
198
+ inputs=[*imgs, prompt, cfg, seed, steps],
199
+ outputs=[out, status],
200
  )
201
 
202
 
203
  # -----------------------------------------------------------------------------
204
+ # Entry
205
  # -----------------------------------------------------------------------------
206
  @GPU
207
  def main():
208
+ # CPU stage (no GPU time)
209
+ precache_weights()
210
+
211
+ # Start Gradio (NO SSR)
212
  demo.launch(ssr_mode=False)
213
 
214