hexware commited on
Commit
05ce8b7
·
verified ·
1 Parent(s): 5cb19bc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +126 -37
app.py CHANGED
@@ -18,7 +18,6 @@ MAX_SEED = np.iinfo(np.int32).max
18
 
19
  # Optional HF login (works in Spaces if you set HF token as secret env var "hf")
20
  from huggingface_hub import login
21
-
22
  login(token=os.environ.get("hf"))
23
 
24
  dtype = torch.bfloat16
@@ -89,12 +88,46 @@ def get_duration(
89
  cfg_norm=True,
90
  use_en_prompt=True,
91
  resolution=640,
92
- gpu_duration=1000, # <-- NEW
 
 
 
93
  ):
94
- # Allow user override via UI (text field), but keep it sane
95
  return _clamp_int(gpu_duration, default=1000, lo=20, hi=1500)
96
 
97
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98
  @spaces.GPU(duration=get_duration)
99
  def infer(
100
  input_image,
@@ -108,7 +141,10 @@ def infer(
108
  cfg_norm=True,
109
  use_en_prompt=True,
110
  resolution=640,
111
- gpu_duration=1000, # <-- NEW (must match get_duration signature)
 
 
 
112
  ):
113
  # Seed
114
  if randomize_seed:
@@ -120,30 +156,22 @@ def infer(
120
  resolution = 640
121
 
122
  # Normalize image input
123
- if isinstance(input_image, list):
124
- input_image = input_image[0]
125
-
126
- if isinstance(input_image, str):
127
- pil_image = Image.open(input_image).convert("RGB").convert("RGBA")
128
- elif isinstance(input_image, Image.Image):
129
- pil_image = input_image.convert("RGB").convert("RGBA")
130
- elif isinstance(input_image, np.ndarray):
131
- pil_image = Image.fromarray(input_image).convert("RGB").convert("RGBA")
132
- else:
133
- raise ValueError(f"Unsupported input_image type: {type(input_image)}")
134
 
135
  gen_device = "cuda" if torch.cuda.is_available() else "cpu"
 
136
 
 
137
  inputs = {
138
  "image": pil_image,
139
- "generator": torch.Generator(device=gen_device).manual_seed(seed),
140
  "true_cfg_scale": true_guidance_scale,
141
  "prompt": prompt,
142
  "negative_prompt": neg_prompt,
143
  "num_inference_steps": num_inference_steps,
144
  "num_images_per_prompt": 1,
145
  "layers": layer,
146
- "resolution": resolution, # 640 or 1024
147
  "cfg_normalize": cfg_norm,
148
  "use_en_prompt": use_en_prompt,
149
  }
@@ -153,27 +181,46 @@ def infer(
153
 
154
  with torch.inference_mode():
155
  out = pipeline(**inputs)
156
- output_images = out.images[0] # list of PIL images (layers)
157
 
158
- # Prepare gallery + export files
159
- gallery_out = []
160
- temp_files = []
161
 
162
- for img in output_images:
163
- gallery_out.append(img)
164
- tmp = tempfile.NamedTemporaryFile(suffix=".png", delete=False)
165
- img.save(tmp.name)
166
- temp_files.append(tmp.name)
167
 
168
- pptx_path = imagelist_to_pptx(temp_files)
 
 
169
 
170
- with tempfile.NamedTemporaryFile(suffix=".zip", delete=False) as tmpzip:
171
- with zipfile.ZipFile(tmpzip.name, "w", zipfile.ZIP_DEFLATED) as zipf:
172
- for i, img_path in enumerate(temp_files):
173
- zipf.write(img_path, f"layer_{i+1}.png")
174
- zip_path = tmpzip.name
175
 
176
- return gallery_out, pptx_path, zip_path
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
177
 
178
 
179
  ensure_dirname(LOG_DIR)
@@ -271,7 +318,6 @@ The text prompt is intended to describe the overall content of the input image
271
  value=True,
272
  )
273
 
274
- # NEW: text field for GPU duration override (seconds)
275
  gpu_duration = gr.Textbox(
276
  label="GPU duration override (seconds, 20..1500)",
277
  value="1000",
@@ -279,6 +325,26 @@ The text prompt is intended to describe the overall content of the input image
279
  placeholder="e.g. 60, 120, 300, 1000, 1500",
280
  )
281
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
282
  run_button = gr.Button("Decompose!", variant="primary")
283
 
284
  with gr.Column(scale=2):
@@ -287,10 +353,23 @@ The text prompt is intended to describe the overall content of the input image
287
  export_file = gr.File(label="Download PPTX")
288
  export_zip_file = gr.File(label="Download ZIP")
289
 
 
 
 
 
 
 
290
  gr.Examples(
291
  examples=examples,
292
  inputs=[input_image],
293
- outputs=[gallery, export_file, export_zip_file],
 
 
 
 
 
 
 
294
  fn=infer,
295
  examples_per_page=14,
296
  cache_examples=False,
@@ -311,9 +390,19 @@ The text prompt is intended to describe the overall content of the input image
311
  cfg_norm,
312
  use_en_prompt,
313
  resolution,
314
- gpu_duration, # <-- NEW
 
 
 
 
 
 
 
 
 
 
 
315
  ],
316
- outputs=[gallery, export_file, export_zip_file],
317
  )
318
 
319
  if __name__ == "__main__":
 
18
 
19
  # Optional HF login (works in Spaces if you set HF token as secret env var "hf")
20
  from huggingface_hub import login
 
21
  login(token=os.environ.get("hf"))
22
 
23
  dtype = torch.bfloat16
 
88
  cfg_norm=True,
89
  use_en_prompt=True,
90
  resolution=640,
91
+ gpu_duration=1000,
92
+ refine_enabled=False,
93
+ refine_layer_index=1,
94
+ refine_sub_layers=3,
95
  ):
 
96
  return _clamp_int(gpu_duration, default=1000, lo=20, hi=1500)
97
 
98
 
99
+ def _normalize_input_image(input_image):
100
+ if isinstance(input_image, list):
101
+ input_image = input_image[0]
102
+
103
+ if isinstance(input_image, str):
104
+ return Image.open(input_image).convert("RGB").convert("RGBA")
105
+ if isinstance(input_image, Image.Image):
106
+ return input_image.convert("RGB").convert("RGBA")
107
+ if isinstance(input_image, np.ndarray):
108
+ return Image.fromarray(input_image).convert("RGB").convert("RGBA")
109
+
110
+ raise ValueError(f"Unsupported input_image type: {type(input_image)}")
111
+
112
+
113
+ def _export_images_to_pptx_and_zip(pil_images, zip_prefix="layer"):
114
+ temp_files = []
115
+ for img in pil_images:
116
+ tmp = tempfile.NamedTemporaryFile(suffix=".png", delete=False)
117
+ img.save(tmp.name)
118
+ temp_files.append(tmp.name)
119
+
120
+ pptx_path = imagelist_to_pptx(temp_files)
121
+
122
+ with tempfile.NamedTemporaryFile(suffix=".zip", delete=False) as tmpzip:
123
+ with zipfile.ZipFile(tmpzip.name, "w", zipfile.ZIP_DEFLATED) as zipf:
124
+ for i, img_path in enumerate(temp_files):
125
+ zipf.write(img_path, f"{zip_prefix}_{i+1}.png")
126
+ zip_path = tmpzip.name
127
+
128
+ return pptx_path, zip_path
129
+
130
+
131
  @spaces.GPU(duration=get_duration)
132
  def infer(
133
  input_image,
 
141
  cfg_norm=True,
142
  use_en_prompt=True,
143
  resolution=640,
144
+ gpu_duration=1000,
145
+ refine_enabled=False,
146
+ refine_layer_index=1, # 1-based for UI convenience
147
+ refine_sub_layers=3,
148
  ):
149
  # Seed
150
  if randomize_seed:
 
156
  resolution = 640
157
 
158
  # Normalize image input
159
+ pil_image = _normalize_input_image(input_image)
 
 
 
 
 
 
 
 
 
 
160
 
161
  gen_device = "cuda" if torch.cuda.is_available() else "cpu"
162
+ generator = torch.Generator(device=gen_device).manual_seed(seed)
163
 
164
+ # First pass inputs
165
  inputs = {
166
  "image": pil_image,
167
+ "generator": generator,
168
  "true_cfg_scale": true_guidance_scale,
169
  "prompt": prompt,
170
  "negative_prompt": neg_prompt,
171
  "num_inference_steps": num_inference_steps,
172
  "num_images_per_prompt": 1,
173
  "layers": layer,
174
+ "resolution": resolution,
175
  "cfg_normalize": cfg_norm,
176
  "use_en_prompt": use_en_prompt,
177
  }
 
181
 
182
  with torch.inference_mode():
183
  out = pipeline(**inputs)
184
+ output_layers = out.images[0] # list[PIL.Image]
185
 
186
+ # Export first pass
187
+ pptx_path, zip_path = _export_images_to_pptx_and_zip(output_layers, zip_prefix="layer")
 
188
 
189
+ # Optional: Recursive (refine one layer into sub-layers) — no separate steps/resolution/cfg
190
+ refined_gallery = []
191
+ refined_pptx = None
192
+ refined_zip = None
 
193
 
194
+ if refine_enabled and len(output_layers) > 0:
195
+ idx0 = _clamp_int(refine_layer_index, default=1, lo=1, hi=len(output_layers)) - 1
196
+ refine_sub_layers = _clamp_int(refine_sub_layers, default=3, lo=2, hi=10)
197
 
198
+ selected_layer = output_layers[idx0].convert("RGBA")
199
+
200
+ refined_inputs = dict(inputs) # reuse same params
201
+ refined_inputs["image"] = selected_layer
202
+ refined_inputs["layers"] = refine_sub_layers
203
 
204
+ print("REFINE ENABLED:", True)
205
+ print("REFINE LAYER INDEX (1-based):", idx0 + 1)
206
+ print("REFINE SUB-LAYERS:", refine_sub_layers)
207
+ print("REFINED INPUTS:", {k: v for k, v in refined_inputs.items() if k != "image"})
208
+
209
+ with torch.inference_mode():
210
+ refined_out = pipeline(**refined_inputs)
211
+ sub_layers = refined_out.images[0]
212
+
213
+ refined_gallery = sub_layers
214
+ refined_pptx, refined_zip = _export_images_to_pptx_and_zip(sub_layers, zip_prefix=f"sub_layer_{idx0+1}")
215
+
216
+ return (
217
+ output_layers,
218
+ pptx_path,
219
+ zip_path,
220
+ refined_gallery,
221
+ refined_pptx,
222
+ refined_zip,
223
+ )
224
 
225
 
226
  ensure_dirname(LOG_DIR)
 
318
  value=True,
319
  )
320
 
 
321
  gpu_duration = gr.Textbox(
322
  label="GPU duration override (seconds, 20..1500)",
323
  value="1000",
 
325
  placeholder="e.g. 60, 120, 300, 1000, 1500",
326
  )
327
 
328
+ gr.Markdown("### Advanced: Recursive decomposition")
329
+ refine_enabled = gr.Checkbox(
330
+ label="Refine one layer into sub-layers",
331
+ value=False,
332
+ )
333
+ refine_layer_index = gr.Slider(
334
+ label="Refine layer index (1-based)",
335
+ minimum=1,
336
+ maximum=10,
337
+ step=1,
338
+ value=1,
339
+ )
340
+ refine_sub_layers = gr.Slider(
341
+ label="Sub-layers (for refined layer)",
342
+ minimum=2,
343
+ maximum=10,
344
+ step=1,
345
+ value=3,
346
+ )
347
+
348
  run_button = gr.Button("Decompose!", variant="primary")
349
 
350
  with gr.Column(scale=2):
 
353
  export_file = gr.File(label="Download PPTX")
354
  export_zip_file = gr.File(label="Download ZIP")
355
 
356
+ gr.Markdown("### Refined sub-layers")
357
+ refined_gallery = gr.Gallery(label="Sub-layers", columns=4, rows=1, format="png")
358
+ with gr.Row():
359
+ refined_export_file = gr.File(label="Download refined PPTX")
360
+ refined_export_zip_file = gr.File(label="Download refined ZIP")
361
+
362
  gr.Examples(
363
  examples=examples,
364
  inputs=[input_image],
365
+ outputs=[
366
+ gallery,
367
+ export_file,
368
+ export_zip_file,
369
+ refined_gallery,
370
+ refined_export_file,
371
+ refined_export_zip_file,
372
+ ],
373
  fn=infer,
374
  examples_per_page=14,
375
  cache_examples=False,
 
390
  cfg_norm,
391
  use_en_prompt,
392
  resolution,
393
+ gpu_duration,
394
+ refine_enabled,
395
+ refine_layer_index,
396
+ refine_sub_layers,
397
+ ],
398
+ outputs=[
399
+ gallery,
400
+ export_file,
401
+ export_zip_file,
402
+ refined_gallery,
403
+ refined_export_file,
404
+ refined_export_zip_file,
405
  ],
 
406
  )
407
 
408
  if __name__ == "__main__":