hexware commited on
Commit
eeb4923
·
verified ·
1 Parent(s): 05ce8b7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +248 -168
app.py CHANGED
@@ -1,31 +1,67 @@
1
  import os
2
  import uuid
3
- import numpy as np
4
  import random
5
  import tempfile
6
  import zipfile
 
7
 
8
  import spaces
9
  import torch
10
  import gradio as gr
11
 
12
  from PIL import Image
13
- from diffusers import QwenImageLayeredPipeline
14
  from pptx import Presentation
 
15
 
16
  LOG_DIR = "/tmp/local"
17
  MAX_SEED = np.iinfo(np.int32).max
18
 
19
- # Optional HF login (works in Spaces if you set HF token as secret env var "hf")
 
 
 
20
  from huggingface_hub import login
21
  login(token=os.environ.get("hf"))
22
 
23
- dtype = torch.bfloat16
24
- device = "cuda" if torch.cuda.is_available() else "cpu"
 
 
 
 
 
 
25
 
 
 
 
26
  pipeline = QwenImageLayeredPipeline.from_pretrained(
27
- "Qwen/Qwen-Image-Layered", torch_dtype=dtype
28
- ).to(device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
 
30
 
31
  def ensure_dirname(path: str):
@@ -33,26 +69,21 @@ def ensure_dirname(path: str):
33
  os.makedirs(path, exist_ok=True)
34
 
35
 
36
- def random_str(length=8):
37
- return uuid.uuid4().hex[:length]
38
-
39
-
40
  def imagelist_to_pptx(img_files):
41
  with Image.open(img_files[0]) as img:
42
  img_width_px, img_height_px = img.size
43
 
44
  def px_to_emu(px, dpi=96):
45
  inch = px / dpi
46
- emu = inch * 914400
47
- return int(emu)
48
 
49
  prs = Presentation()
50
  prs.slide_width = px_to_emu(img_width_px)
51
  prs.slide_height = px_to_emu(img_height_px)
52
 
53
  slide = prs.slides.add_slide(prs.slide_layouts[6])
54
-
55
  left = top = 0
 
56
  for img_path in img_files:
57
  slide.shapes.add_picture(
58
  img_path,
@@ -75,171 +106,214 @@ def _clamp_int(x, default: int, lo: int, hi: int) -> int:
75
  return max(lo, min(hi, v))
76
 
77
 
78
- # Dynamic duration callable: must accept the same args as infer(). It returns seconds.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79
  def get_duration(
80
  input_image,
81
- seed=777,
82
- randomize_seed=False,
83
- prompt=None,
84
  neg_prompt=" ",
85
  true_guidance_scale=4.0,
86
  num_inference_steps=50,
87
- layer=4,
88
  cfg_norm=True,
89
  use_en_prompt=True,
90
- resolution=640,
91
- gpu_duration=1000,
92
- refine_enabled=False,
93
  refine_layer_index=1,
94
  refine_sub_layers=3,
95
  ):
96
  return _clamp_int(gpu_duration, default=1000, lo=20, hi=1500)
97
 
98
 
99
- def _normalize_input_image(input_image):
100
- if isinstance(input_image, list):
101
- input_image = input_image[0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
102
 
103
- if isinstance(input_image, str):
104
- return Image.open(input_image).convert("RGB").convert("RGBA")
105
- if isinstance(input_image, Image.Image):
106
- return input_image.convert("RGB").convert("RGBA")
107
- if isinstance(input_image, np.ndarray):
108
- return Image.fromarray(input_image).convert("RGB").convert("RGBA")
109
 
110
- raise ValueError(f"Unsupported input_image type: {type(input_image)}")
111
 
 
 
112
 
113
- def _export_images_to_pptx_and_zip(pil_images, zip_prefix="layer"):
114
- temp_files = []
115
- for img in pil_images:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
116
  tmp = tempfile.NamedTemporaryFile(suffix=".png", delete=False)
117
  img.save(tmp.name)
118
- temp_files.append(tmp.name)
119
 
120
- pptx_path = imagelist_to_pptx(temp_files)
121
 
122
  with tempfile.NamedTemporaryFile(suffix=".zip", delete=False) as tmpzip:
123
  with zipfile.ZipFile(tmpzip.name, "w", zipfile.ZIP_DEFLATED) as zipf:
124
- for i, img_path in enumerate(temp_files):
125
- zipf.write(img_path, f"{zip_prefix}_{i+1}.png")
126
  zip_path = tmpzip.name
127
 
128
- return pptx_path, zip_path
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
129
 
130
 
131
  @spaces.GPU(duration=get_duration)
132
- def infer(
133
- input_image,
134
- seed=777,
135
- randomize_seed=False,
136
- prompt=None,
 
 
137
  neg_prompt=" ",
138
  true_guidance_scale=4.0,
139
  num_inference_steps=50,
140
- layer=4,
141
  cfg_norm=True,
142
  use_en_prompt=True,
143
- resolution=640,
144
- gpu_duration=1000,
145
- refine_enabled=False,
146
- refine_layer_index=1, # 1-based for UI convenience
147
- refine_sub_layers=3,
148
  ):
149
- # Seed
 
 
150
  if randomize_seed:
151
  seed = random.randint(0, MAX_SEED)
152
 
153
- # Normalize resolution input
154
- resolution = _clamp_int(resolution, default=640, lo=640, hi=1024)
 
 
 
 
 
155
  if resolution not in (640, 1024):
156
- resolution = 640
157
 
158
- # Normalize image input
159
- pil_image = _normalize_input_image(input_image)
160
 
161
- gen_device = "cuda" if torch.cuda.is_available() else "cpu"
162
- generator = torch.Generator(device=gen_device).manual_seed(seed)
163
 
164
- # First pass inputs
165
  inputs = {
166
- "image": pil_image,
167
- "generator": generator,
168
- "true_cfg_scale": true_guidance_scale,
169
- "prompt": prompt,
170
  "negative_prompt": neg_prompt,
171
- "num_inference_steps": num_inference_steps,
172
  "num_images_per_prompt": 1,
173
- "layers": layer,
174
- "resolution": resolution,
175
- "cfg_normalize": cfg_norm,
176
- "use_en_prompt": use_en_prompt,
177
  }
178
 
179
- print("INFER INPUTS:", inputs)
180
  print("REQUESTED GPU DURATION:", gpu_duration)
 
181
 
182
  with torch.inference_mode():
183
  out = pipeline(**inputs)
184
- output_layers = out.images[0] # list[PIL.Image]
185
 
186
- # Export first pass
187
- pptx_path, zip_path = _export_images_to_pptx_and_zip(output_layers, zip_prefix="layer")
188
-
189
- # Optional: Recursive (refine one layer into sub-layers) — no separate steps/resolution/cfg
190
  refined_gallery = []
191
- refined_pptx = None
192
- refined_zip = None
193
-
194
- if refine_enabled and len(output_layers) > 0:
195
- idx0 = _clamp_int(refine_layer_index, default=1, lo=1, hi=len(output_layers)) - 1
196
- refine_sub_layers = _clamp_int(refine_sub_layers, default=3, lo=2, hi=10)
197
-
198
- selected_layer = output_layers[idx0].convert("RGBA")
199
-
200
- refined_inputs = dict(inputs) # reuse same params
201
- refined_inputs["image"] = selected_layer
202
- refined_inputs["layers"] = refine_sub_layers
203
-
204
- print("REFINE ENABLED:", True)
205
- print("REFINE LAYER INDEX (1-based):", idx0 + 1)
206
- print("REFINE SUB-LAYERS:", refine_sub_layers)
207
- print("REFINED INPUTS:", {k: v for k, v in refined_inputs.items() if k != "image"})
208
 
209
- with torch.inference_mode():
210
- refined_out = pipeline(**refined_inputs)
211
- sub_layers = refined_out.images[0]
212
 
213
- refined_gallery = sub_layers
214
- refined_pptx, refined_zip = _export_images_to_pptx_and_zip(sub_layers, zip_prefix=f"sub_layer_{idx0+1}")
 
 
 
215
 
216
- return (
217
- output_layers,
218
- pptx_path,
219
- zip_path,
220
- refined_gallery,
221
- refined_pptx,
222
- refined_zip,
223
- )
224
 
225
 
226
  ensure_dirname(LOG_DIR)
227
-
228
- examples = [
229
- "assets/test_images/1.png",
230
- "assets/test_images/2.png",
231
- "assets/test_images/3.png",
232
- "assets/test_images/4.png",
233
- "assets/test_images/5.png",
234
- "assets/test_images/6.png",
235
- "assets/test_images/7.png",
236
- "assets/test_images/8.png",
237
- "assets/test_images/9.png",
238
- "assets/test_images/10.png",
239
- "assets/test_images/11.png",
240
- "assets/test_images/12.png",
241
- "assets/test_images/13.png",
242
- ]
243
 
244
  with gr.Blocks() as demo:
245
  with gr.Column(elem_id="col-container"):
@@ -249,10 +323,14 @@ with gr.Blocks() as demo:
249
  )
250
  gr.Markdown(
251
  """
252
- The text prompt is intended to describe the overall content of the input image—including elements that may be partially occluded (e.g., you may specify the text hidden behind a foreground object). It is not designed to control the semantic content of individual layers explicitly.
 
253
  """
254
  )
255
 
 
 
 
256
  with gr.Row():
257
  with gr.Column(scale=1):
258
  input_image = gr.Image(label="Input Image", image_mode="RGBA")
@@ -260,7 +338,7 @@ The text prompt is intended to describe the overall content of the input image
260
  with gr.Accordion("Advanced Settings", open=False):
261
  prompt = gr.Textbox(
262
  label="Prompt (Optional)",
263
- placeholder="Please enter the prompt to descibe the image. (Optional)",
264
  value="",
265
  lines=2,
266
  )
@@ -271,48 +349,27 @@ The text prompt is intended to describe the overall content of the input image
271
  lines=2,
272
  )
273
 
274
- seed = gr.Slider(
275
- label="Seed",
276
- minimum=0,
277
- maximum=MAX_SEED,
278
- step=1,
279
- value=0,
280
- )
281
  randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
282
 
283
  true_guidance_scale = gr.Slider(
284
- label="True guidance scale",
285
- minimum=1.0,
286
- maximum=10.0,
287
- step=0.1,
288
- value=4.0,
289
  )
290
 
291
  num_inference_steps = gr.Slider(
292
- label="Number of inference steps",
293
- minimum=1,
294
- maximum=100,
295
- step=1,
296
- value=50,
297
  )
298
 
299
- layer = gr.Slider(
300
- label="Layers",
301
- minimum=2,
302
- maximum=10,
303
- step=1,
304
- value=7,
305
- )
306
 
 
307
  resolution = gr.Radio(
308
  label="Processing resolution",
309
  choices=[640, 1024],
310
- value=640,
311
  )
312
 
313
- cfg_norm = gr.Checkbox(
314
- label="Whether enable CFG normalization", value=True
315
- )
316
  use_en_prompt = gr.Checkbox(
317
  label="Automatic caption language if no prompt provided, True for EN, False for ZH",
318
  value=True,
@@ -325,27 +382,24 @@ The text prompt is intended to describe the overall content of the input image
325
  placeholder="e.g. 60, 120, 300, 1000, 1500",
326
  )
327
 
328
- gr.Markdown("### Advanced: Recursive decomposition")
329
- refine_enabled = gr.Checkbox(
330
- label="Refine one layer into sub-layers",
331
- value=False,
332
- )
333
  refine_layer_index = gr.Slider(
334
- label="Refine layer index (1-based)",
335
  minimum=1,
336
- maximum=10,
337
  step=1,
338
  value=1,
339
  )
340
  refine_sub_layers = gr.Slider(
341
- label="Sub-layers (for refined layer)",
342
  minimum=2,
343
  maximum=10,
344
  step=1,
345
  value=3,
346
  )
347
-
348
- run_button = gr.Button("Decompose!", variant="primary")
349
 
350
  with gr.Column(scale=2):
351
  gallery = gr.Gallery(label="Layers", columns=4, rows=1, format="png")
@@ -353,12 +407,13 @@ The text prompt is intended to describe the overall content of the input image
353
  export_file = gr.File(label="Download PPTX")
354
  export_zip_file = gr.File(label="Download ZIP")
355
 
356
- gr.Markdown("### Refined sub-layers")
357
  refined_gallery = gr.Gallery(label="Sub-layers", columns=4, rows=1, format="png")
358
  with gr.Row():
359
  refined_export_file = gr.File(label="Download refined PPTX")
360
  refined_export_zip_file = gr.File(label="Download refined ZIP")
361
 
 
362
  gr.Examples(
363
  examples=examples,
364
  inputs=[input_image],
@@ -366,18 +421,21 @@ The text prompt is intended to describe the overall content of the input image
366
  gallery,
367
  export_file,
368
  export_zip_file,
 
369
  refined_gallery,
370
  refined_export_file,
371
  refined_export_zip_file,
 
372
  ],
373
- fn=infer,
374
  examples_per_page=14,
375
  cache_examples=False,
376
  run_on_click=True,
377
  )
378
 
379
- run_button.click(
380
- fn=infer,
 
381
  inputs=[
382
  input_image,
383
  seed,
@@ -391,18 +449,40 @@ The text prompt is intended to describe the overall content of the input image
391
  use_en_prompt,
392
  resolution,
393
  gpu_duration,
394
- refine_enabled,
395
- refine_layer_index,
396
- refine_sub_layers,
397
  ],
398
  outputs=[
399
  gallery,
400
  export_file,
401
  export_zip_file,
 
402
  refined_gallery,
403
  refined_export_file,
404
  refined_export_zip_file,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
405
  ],
 
406
  )
407
 
408
  if __name__ == "__main__":
 
1
  import os
2
  import uuid
 
3
  import random
4
  import tempfile
5
  import zipfile
6
+ import numpy as np
7
 
8
  import spaces
9
  import torch
10
  import gradio as gr
11
 
12
  from PIL import Image
 
13
  from pptx import Presentation
14
+ from diffusers import QwenImageLayeredPipeline
15
 
16
  LOG_DIR = "/tmp/local"
17
  MAX_SEED = np.iinfo(np.int32).max
18
 
19
+ # Reduce allocator fragmentation (new name; old PYTORCH_CUDA_ALLOC_CONF is deprecated)
20
+ os.environ.setdefault("PYTORCH_ALLOC_CONF", "expandable_segments:True")
21
+
22
+ # Optional HF login (Spaces secret env var "hf")
23
  from huggingface_hub import login
24
  login(token=os.environ.get("hf"))
25
 
26
+ # ----------------------------
27
+ # Device / dtype (memory-safe)
28
+ # ----------------------------
29
+ has_cuda = torch.cuda.is_available()
30
+ device = "cuda" if has_cuda else ("mps" if torch.backends.mps.is_available() else "cpu")
31
+
32
+ # fp16 is typically best for VRAM; CPU uses fp32
33
+ torch_dtype = torch.float16 if device in ("cuda", "mps") else torch.float32
34
 
35
+ # ----------------------------
36
+ # Load pipeline (avoid CPU RAM spikes)
37
+ # ----------------------------
38
  pipeline = QwenImageLayeredPipeline.from_pretrained(
39
+ "Qwen/Qwen-Image-Layered",
40
+ torch_dtype=torch_dtype,
41
+ low_cpu_mem_usage=True,
42
+ )
43
+
44
+ # Memory helpers (guarded)
45
+ if hasattr(pipeline, "enable_attention_slicing"):
46
+ pipeline.enable_attention_slicing()
47
+
48
+ # This pipeline may NOT expose enable_vae_slicing(), so guard both ways
49
+ if hasattr(pipeline, "enable_vae_slicing"):
50
+ pipeline.enable_vae_slicing()
51
+ elif hasattr(pipeline, "vae") and hasattr(pipeline.vae, "enable_slicing"):
52
+ pipeline.vae.enable_slicing()
53
+
54
+ if device == "cuda":
55
+ # Best for Spaces: keep CPU RAM lower and avoid huge peak VRAM at startup
56
+ # (requires accelerate, usually present in Spaces)
57
+ try:
58
+ pipeline.enable_model_cpu_offload()
59
+ except Exception:
60
+ pipeline.to("cuda")
61
+ elif device == "mps":
62
+ pipeline.to("mps")
63
+ else:
64
+ pipeline.to("cpu")
65
 
66
 
67
  def ensure_dirname(path: str):
 
69
  os.makedirs(path, exist_ok=True)
70
 
71
 
 
 
 
 
72
  def imagelist_to_pptx(img_files):
73
  with Image.open(img_files[0]) as img:
74
  img_width_px, img_height_px = img.size
75
 
76
  def px_to_emu(px, dpi=96):
77
  inch = px / dpi
78
+ return int(inch * 914400)
 
79
 
80
  prs = Presentation()
81
  prs.slide_width = px_to_emu(img_width_px)
82
  prs.slide_height = px_to_emu(img_height_px)
83
 
84
  slide = prs.slides.add_slide(prs.slide_layouts[6])
 
85
  left = top = 0
86
+
87
  for img_path in img_files:
88
  slide.shapes.add_picture(
89
  img_path,
 
106
  return max(lo, min(hi, v))
107
 
108
 
109
+ def _safe_open_rgba(img_like):
110
+ if isinstance(img_like, list):
111
+ img_like = img_like[0]
112
+ if isinstance(img_like, str):
113
+ return Image.open(img_like).convert("RGB").convert("RGBA")
114
+ if isinstance(img_like, Image.Image):
115
+ return img_like.convert("RGB").convert("RGBA")
116
+ if isinstance(img_like, np.ndarray):
117
+ return Image.fromarray(img_like).convert("RGB").convert("RGBA")
118
+ raise ValueError(f"Unsupported input_image type: {type(img_like)}")
119
+
120
+
121
+ def _update_refine_index_ui(n_layers: int, current_idx: int | None = None):
122
+ n_layers = max(1, int(n_layers))
123
+ if current_idx is None:
124
+ current_idx = 1
125
+ current_idx = max(1, min(int(current_idx), n_layers))
126
+ return gr.update(minimum=1, maximum=n_layers, value=current_idx)
127
+
128
+
129
+ # Dynamic duration callable: must accept same args as decompose() and refine()
130
  def get_duration(
131
  input_image,
132
+ seed=0,
133
+ randomize_seed=True,
134
+ prompt="",
135
  neg_prompt=" ",
136
  true_guidance_scale=4.0,
137
  num_inference_steps=50,
138
+ layer=7,
139
  cfg_norm=True,
140
  use_en_prompt=True,
141
+ resolution=1024,
142
+ gpu_duration="1000",
 
143
  refine_layer_index=1,
144
  refine_sub_layers=3,
145
  ):
146
  return _clamp_int(gpu_duration, default=1000, lo=20, hi=1500)
147
 
148
 
149
+ @spaces.GPU(duration=get_duration)
150
+ def decompose(
151
+ input_image,
152
+ seed=0,
153
+ randomize_seed=True,
154
+ prompt="",
155
+ neg_prompt=" ",
156
+ true_guidance_scale=4.0,
157
+ num_inference_steps=50,
158
+ layer=7,
159
+ cfg_norm=True,
160
+ use_en_prompt=True,
161
+ resolution=1024,
162
+ gpu_duration="1000",
163
+ refine_layer_index=1, # passed in (so we can "clamp" it красиво)
164
+ refine_sub_layers=3, # unused here, but kept for duration signature parity
165
+ ):
166
+ if randomize_seed:
167
+ seed = random.randint(0, MAX_SEED)
168
 
169
+ resolution = _clamp_int(resolution, default=1024, lo=640, hi=1024)
170
+ if resolution not in (640, 1024):
171
+ resolution = 1024
 
 
 
172
 
173
+ pil_image = _safe_open_rgba(input_image)
174
 
175
+ # Generator on CPU works well with CPU offload too
176
+ gen = torch.Generator(device="cpu").manual_seed(seed)
177
 
178
+ inputs = {
179
+ "image": pil_image,
180
+ "generator": gen,
181
+ "true_cfg_scale": float(true_guidance_scale),
182
+ "prompt": prompt if prompt else None,
183
+ "negative_prompt": neg_prompt,
184
+ "num_inference_steps": int(num_inference_steps),
185
+ "num_images_per_prompt": 1,
186
+ "layers": int(layer),
187
+ "resolution": int(resolution),
188
+ "cfg_normalize": bool(cfg_norm),
189
+ "use_en_prompt": bool(use_en_prompt),
190
+ }
191
+
192
+ print("DECOMPOSE INPUTS:", {k: v for k, v in inputs.items() if k != "image"})
193
+ print("REQUESTED GPU DURATION:", gpu_duration)
194
+
195
+ with torch.inference_mode():
196
+ out = pipeline(**inputs)
197
+ output_images = out.images[0] # list[PIL.Image]
198
+
199
+ # Save layers for exports + for refine stage
200
+ layer_paths = []
201
+ gallery_out = []
202
+
203
+ for img in output_images:
204
+ gallery_out.append(img)
205
  tmp = tempfile.NamedTemporaryFile(suffix=".png", delete=False)
206
  img.save(tmp.name)
207
+ layer_paths.append(tmp.name)
208
 
209
+ pptx_path = imagelist_to_pptx(layer_paths)
210
 
211
  with tempfile.NamedTemporaryFile(suffix=".zip", delete=False) as tmpzip:
212
  with zipfile.ZipFile(tmpzip.name, "w", zipfile.ZIP_DEFLATED) as zipf:
213
+ for i, p in enumerate(layer_paths):
214
+ zipf.write(p, f"layer_{i+1}.png")
215
  zip_path = tmpzip.name
216
 
217
+ # Reset refined outputs on new decompose
218
+ refined_gallery = []
219
+ refined_pptx = None
220
+ refined_zip = None
221
+
222
+ # "совсем красиво": clamp current refine index to new [1..N]
223
+ refine_index_update = _update_refine_index_ui(len(layer_paths), refine_layer_index)
224
+
225
+ return (
226
+ gallery_out,
227
+ pptx_path,
228
+ zip_path,
229
+ layer_paths, # gr.State
230
+ refined_gallery,
231
+ refined_pptx,
232
+ refined_zip,
233
+ refine_index_update, # update refine slider bounds/value
234
+ )
235
 
236
 
237
  @spaces.GPU(duration=get_duration)
238
+ def refine_selected_layer(
239
+ layer_paths,
240
+ refine_layer_index=1,
241
+ refine_sub_layers=3,
242
+ seed=0,
243
+ randomize_seed=True,
244
+ prompt="",
245
  neg_prompt=" ",
246
  true_guidance_scale=4.0,
247
  num_inference_steps=50,
 
248
  cfg_norm=True,
249
  use_en_prompt=True,
250
+ resolution=1024,
251
+ gpu_duration="1000",
 
 
 
252
  ):
253
+ if not layer_paths:
254
+ return [], None, None
255
+
256
  if randomize_seed:
257
  seed = random.randint(0, MAX_SEED)
258
 
259
+ # Clamp index into existing layers
260
+ n = len(layer_paths)
261
+ idx = _clamp_int(refine_layer_index, default=1, lo=1, hi=n) - 1
262
+
263
+ sub_layers = _clamp_int(refine_sub_layers, default=3, lo=2, hi=10)
264
+
265
+ resolution = _clamp_int(resolution, default=1024, lo=640, hi=1024)
266
  if resolution not in (640, 1024):
267
+ resolution = 1024
268
 
269
+ selected_path = layer_paths[idx]
270
+ selected_layer_img = Image.open(selected_path).convert("RGBA")
271
 
272
+ gen = torch.Generator(device="cpu").manual_seed(seed)
 
273
 
 
274
  inputs = {
275
+ "image": selected_layer_img,
276
+ "generator": gen,
277
+ "true_cfg_scale": float(true_guidance_scale),
278
+ "prompt": prompt if prompt else None,
279
  "negative_prompt": neg_prompt,
280
+ "num_inference_steps": int(num_inference_steps),
281
  "num_images_per_prompt": 1,
282
+ "layers": int(sub_layers), # <-- ключевой параметр рекурсивной декомпозиции
283
+ "resolution": int(resolution), # тот же resolution (без отдельных опций для refine)
284
+ "cfg_normalize": bool(cfg_norm),
285
+ "use_en_prompt": bool(use_en_prompt),
286
  }
287
 
288
+ print("REFINE INPUTS:", {k: v for k, v in inputs.items() if k != "image"})
289
  print("REQUESTED GPU DURATION:", gpu_duration)
290
+ print(f"REFINE: base layer index={idx+1}/{n}, sub_layers={sub_layers}")
291
 
292
  with torch.inference_mode():
293
  out = pipeline(**inputs)
294
+ refined_images = out.images[0]
295
 
296
+ refined_paths = []
 
 
 
297
  refined_gallery = []
298
+ for img in refined_images:
299
+ refined_gallery.append(img)
300
+ tmp = tempfile.NamedTemporaryFile(suffix=".png", delete=False)
301
+ img.save(tmp.name)
302
+ refined_paths.append(tmp.name)
 
 
 
 
 
 
 
 
 
 
 
 
303
 
304
+ refined_pptx = imagelist_to_pptx(refined_paths)
 
 
305
 
306
+ with tempfile.NamedTemporaryFile(suffix=".zip", delete=False) as tmpzip:
307
+ with zipfile.ZipFile(tmpzip.name, "w", zipfile.ZIP_DEFLATED) as zipf:
308
+ for i, p in enumerate(refined_paths):
309
+ zipf.write(p, f"sub_layer_{i+1}.png")
310
+ refined_zip = tmpzip.name
311
 
312
+ return refined_gallery, refined_pptx, refined_zip
 
 
 
 
 
 
 
313
 
314
 
315
  ensure_dirname(LOG_DIR)
316
+ examples = [f"assets/test_images/{i}.png" for i in range(1, 14)]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
317
 
318
  with gr.Blocks() as demo:
319
  with gr.Column(elem_id="col-container"):
 
323
  )
324
  gr.Markdown(
325
  """
326
+ The text prompt is intended to describe the overall content of the input image—including elements that may be partially occluded.
327
+ It is not designed to control the semantic content of individual layers explicitly.
328
  """
329
  )
330
 
331
+ # State to store layer PNG paths from last Decompose
332
+ layer_paths_state = gr.State([])
333
+
334
  with gr.Row():
335
  with gr.Column(scale=1):
336
  input_image = gr.Image(label="Input Image", image_mode="RGBA")
 
338
  with gr.Accordion("Advanced Settings", open=False):
339
  prompt = gr.Textbox(
340
  label="Prompt (Optional)",
341
+ placeholder="Please enter the prompt to describe the image (optional)",
342
  value="",
343
  lines=2,
344
  )
 
349
  lines=2,
350
  )
351
 
352
+ seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0)
 
 
 
 
 
 
353
  randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
354
 
355
  true_guidance_scale = gr.Slider(
356
+ label="True guidance scale", minimum=1.0, maximum=10.0, step=0.1, value=4.0
 
 
 
 
357
  )
358
 
359
  num_inference_steps = gr.Slider(
360
+ label="Number of inference steps", minimum=1, maximum=100, step=1, value=50
 
 
 
 
361
  )
362
 
363
+ layer = gr.Slider(label="Layers", minimum=2, maximum=10, step=1, value=7)
 
 
 
 
 
 
364
 
365
+ # default 1024 as you asked earlier
366
  resolution = gr.Radio(
367
  label="Processing resolution",
368
  choices=[640, 1024],
369
+ value=1024,
370
  )
371
 
372
+ cfg_norm = gr.Checkbox(label="Whether enable CFG normalization", value=True)
 
 
373
  use_en_prompt = gr.Checkbox(
374
  label="Automatic caption language if no prompt provided, True for EN, False for ZH",
375
  value=True,
 
382
  placeholder="e.g. 60, 120, 300, 1000, 1500",
383
  )
384
 
385
+ decompose_btn = gr.Button("Decompose!", variant="primary")
386
+
387
+ with gr.Accordion("Refine layer (Recursive Decomposition)", open=False):
 
 
388
  refine_layer_index = gr.Slider(
389
+ label="Refine layer index (1 = first layer)",
390
  minimum=1,
391
+ maximum=7,
392
  step=1,
393
  value=1,
394
  )
395
  refine_sub_layers = gr.Slider(
396
+ label="Sub-layers (how many to split selected layer into)",
397
  minimum=2,
398
  maximum=10,
399
  step=1,
400
  value=3,
401
  )
402
+ refine_btn = gr.Button("Refine selected layer", variant="secondary")
 
403
 
404
  with gr.Column(scale=2):
405
  gallery = gr.Gallery(label="Layers", columns=4, rows=1, format="png")
 
407
  export_file = gr.File(label="Download PPTX")
408
  export_zip_file = gr.File(label="Download ZIP")
409
 
410
+ gr.Markdown("### Refined (sub-layers)")
411
  refined_gallery = gr.Gallery(label="Sub-layers", columns=4, rows=1, format="png")
412
  with gr.Row():
413
  refined_export_file = gr.File(label="Download refined PPTX")
414
  refined_export_zip_file = gr.File(label="Download refined ZIP")
415
 
416
+ # Examples run Decompose
417
  gr.Examples(
418
  examples=examples,
419
  inputs=[input_image],
 
421
  gallery,
422
  export_file,
423
  export_zip_file,
424
+ layer_paths_state,
425
  refined_gallery,
426
  refined_export_file,
427
  refined_export_zip_file,
428
+ refine_layer_index, # update slider bounds/value
429
  ],
430
+ fn=decompose,
431
  examples_per_page=14,
432
  cache_examples=False,
433
  run_on_click=True,
434
  )
435
 
436
+ # Decompose button
437
+ decompose_btn.click(
438
+ fn=decompose,
439
  inputs=[
440
  input_image,
441
  seed,
 
449
  use_en_prompt,
450
  resolution,
451
  gpu_duration,
452
+ refine_layer_index, # so we can clamp nicely after new decomposition
453
+ refine_sub_layers, # for duration signature parity
 
454
  ],
455
  outputs=[
456
  gallery,
457
  export_file,
458
  export_zip_file,
459
+ layer_paths_state,
460
  refined_gallery,
461
  refined_export_file,
462
  refined_export_zip_file,
463
+ refine_layer_index, # update slider bounds/value
464
+ ],
465
+ )
466
+
467
+ # Refine button
468
+ refine_btn.click(
469
+ fn=refine_selected_layer,
470
+ inputs=[
471
+ layer_paths_state,
472
+ refine_layer_index,
473
+ refine_sub_layers,
474
+ seed,
475
+ randomize_seed,
476
+ prompt,
477
+ neg_prompt,
478
+ true_guidance_scale,
479
+ num_inference_steps,
480
+ cfg_norm,
481
+ use_en_prompt,
482
+ resolution,
483
+ gpu_duration,
484
  ],
485
+ outputs=[refined_gallery, refined_export_file, refined_export_zip_file],
486
  )
487
 
488
  if __name__ == "__main__":