hexware commited on
Commit
bd98a2d
·
verified ·
1 Parent(s): 230a765

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +1046 -671
app.py CHANGED
@@ -1,13 +1,15 @@
1
  import os
2
- import uuid
 
3
  import json
4
  import time
5
- import shutil
6
  import numpy as np
7
  import random
8
  import tempfile
9
  import zipfile
10
- from typing import Any, Dict, List, Optional, Tuple
 
11
 
12
  import spaces
13
  import torch
@@ -17,14 +19,23 @@ from PIL import Image
17
  from diffusers import QwenImageLayeredPipeline
18
  from pptx import Presentation
19
 
 
 
 
20
  LOG_DIR = "/tmp/local"
21
  MAX_SEED = np.iinfo(np.int32).max
22
 
23
  # Optional HF login (works in Spaces if you set HF token as secret env var "hf")
24
- from huggingface_hub import login, HfApi, hf_hub_download
25
-
26
  login(token=os.environ.get("hf"))
27
 
 
 
 
 
 
 
 
 
28
  dtype = torch.bfloat16
29
  device = "cuda" if torch.cuda.is_available() else "cpu"
30
 
@@ -32,56 +43,26 @@ pipeline = QwenImageLayeredPipeline.from_pretrained(
32
  "Qwen/Qwen-Image-Layered", torch_dtype=dtype
33
  ).to(device)
34
 
35
- # ----------------------------
36
- # Dataset repo persistence (no /data needed)
37
- # ----------------------------
38
- HF_TOKEN = os.environ.get("hf") # secret
39
- HF_DATASET_REPO = os.environ.get("HF_DATASET_REPO") # e.g. "hexware/qwen-layered-sessions"
40
- _hf_api: Optional[HfApi] = None
41
- _persist_enabled = False
42
-
43
- def _init_dataset_repo() -> Tuple[bool, str]:
44
- """
45
- Returns (enabled, message)
46
- """
47
- global _hf_api, _persist_enabled
48
- if not HF_TOKEN:
49
- _persist_enabled = False
50
- return False, "Persistence: disabled (no secret env var 'hf')."
51
- if not HF_DATASET_REPO:
52
- _persist_enabled = False
53
- return False, "Persistence: disabled (set env var HF_DATASET_REPO to enable)."
54
 
55
- try:
56
- _hf_api = HfApi(token=HF_TOKEN)
57
- # Create dataset repo if missing (private). If exists, this is no-op.
58
- # NOTE: create_repo is available via HfApi in most versions.
59
- _hf_api.create_repo(
60
- repo_id=HF_DATASET_REPO,
61
- repo_type="dataset",
62
- private=True,
63
- exist_ok=True,
64
- )
65
- _persist_enabled = True
66
- return True, f"Persistence: enabled (dataset repo: {HF_DATASET_REPO})."
67
- except Exception as e:
68
- _persist_enabled = False
69
- return False, f"Persistence: failed to init dataset repo: {type(e).__name__}: {e}"
70
 
71
- _enabled, _enabled_msg = _init_dataset_repo()
72
 
73
- # ----------------------------
74
- # Helpers
75
- # ----------------------------
76
  def ensure_dirname(path: str):
77
  if path and not os.path.exists(path):
78
  os.makedirs(path, exist_ok=True)
79
 
 
80
  def random_str(length=8):
81
  return uuid.uuid4().hex[:length]
82
 
83
- def _now_ts() -> float:
84
- return time.time()
 
 
85
 
86
  def _clamp_int(x, default: int, lo: int, hi: int) -> int:
87
  try:
@@ -90,7 +71,9 @@ def _clamp_int(x, default: int, lo: int, hi: int) -> int:
90
  v = default
91
  return max(lo, min(hi, v))
92
 
93
- def _pil_rgba(input_image) -> Image.Image:
 
 
94
  if isinstance(input_image, list):
95
  input_image = input_image[0]
96
 
@@ -102,9 +85,11 @@ def _pil_rgba(input_image) -> Image.Image:
102
  pil_image = Image.fromarray(input_image).convert("RGB").convert("RGBA")
103
  else:
104
  raise ValueError(f"Unsupported input_image type: {type(input_image)}")
 
105
  return pil_image
106
 
107
- def imagelist_to_pptx(img_files: List[str]) -> str:
 
108
  with Image.open(img_files[0]) as img:
109
  img_width_px, img_height_px = img.size
110
 
@@ -133,29 +118,288 @@ def imagelist_to_pptx(img_files: List[str]) -> str:
133
  prs.save(tmp.name)
134
  return tmp.name
135
 
136
- def export_node_layers(layers: List[Image.Image]) -> Tuple[str, str]:
137
- """
138
- Returns (pptx_path, zip_path)
139
- """
140
- temp_files: List[str] = []
141
  for img in layers:
142
  tmp = tempfile.NamedTemporaryFile(suffix=".png", delete=False)
143
  img.save(tmp.name)
144
  temp_files.append(tmp.name)
 
145
 
146
- pptx_path = imagelist_to_pptx(temp_files)
147
 
 
148
  with tempfile.NamedTemporaryFile(suffix=".zip", delete=False) as tmpzip:
149
  with zipfile.ZipFile(tmpzip.name, "w", zipfile.ZIP_DEFLATED) as zipf:
150
- for i, img_path in enumerate(temp_files):
151
  zipf.write(img_path, f"layer_{i+1}.png")
152
- zip_path = tmpzip.name
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
153
 
154
- return pptx_path, zip_path
155
 
156
- # ----------------------------
157
- # ZeroGPU duration
158
- # ----------------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
159
  def get_duration(
160
  input_image,
161
  seed=777,
@@ -172,24 +416,22 @@ def get_duration(
172
  ):
173
  return _clamp_int(gpu_duration, default=1000, lo=20, hi=1500)
174
 
175
- # ----------------------------
176
- # GPU pipeline runners
177
- # ----------------------------
178
  @spaces.GPU(duration=get_duration)
179
  def gpu_run_pipeline(
180
- input_pil_image: Image.Image,
181
- seed=777,
182
- randomize_seed=False,
183
- prompt=None,
184
- neg_prompt=" ",
185
- true_guidance_scale=4.0,
186
- num_inference_steps=50,
187
- layer=4,
188
- cfg_norm=True,
189
- use_en_prompt=True,
190
- resolution=640,
191
- gpu_duration=1000,
192
- ) -> List[Image.Image]:
193
  # Seed
194
  if randomize_seed:
195
  seed = random.randint(0, MAX_SEED)
@@ -202,7 +444,7 @@ def gpu_run_pipeline(
202
  gen_device = "cuda" if torch.cuda.is_available() else "cpu"
203
 
204
  inputs = {
205
- "image": input_pil_image,
206
  "generator": torch.Generator(device=gen_device).manual_seed(seed),
207
  "true_cfg_scale": true_guidance_scale,
208
  "prompt": prompt,
@@ -215,510 +457,553 @@ def gpu_run_pipeline(
215
  "use_en_prompt": use_en_prompt,
216
  }
217
 
218
- with torch.inference_mode():
219
- out = pipeline(**inputs)
220
 
221
- # out.images[0] => list of PIL images
222
- return out.images[0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
223
 
224
- # ----------------------------
225
- # Session / History model
226
- # ----------------------------
227
- def new_state() -> Dict[str, Any]:
228
- sid = uuid.uuid4().hex
229
- return {
230
- "session_id": sid,
231
- "nodes": {}, # node_id -> node dict
232
- "root_id": None,
233
- "current_id": None,
234
- "selected_layer_idx": 0,
235
- "last_refined_id": None,
236
- }
237
 
238
- def _node_label(node: Dict[str, Any]) -> str:
239
- name = node.get("name") or node["id"][:8]
240
- return f"{name} ({node['id'][:8]})"
241
-
242
- def _build_history_choices(st: Dict[str, Any]) -> List[Tuple[str, str]]:
243
- # returns list of (label, value=node_id)
244
- out = []
245
- for nid, node in st["nodes"].items():
246
- out.append((_node_label(node), nid))
247
- # stable order by created
248
- out.sort(key=lambda x: st["nodes"][x[1]].get("created_at", 0.0))
249
- return out
250
-
251
- def _get_node(st: Dict[str, Any], node_id: Optional[str]) -> Optional[Dict[str, Any]]:
252
- if not node_id:
253
- return None
254
- return st["nodes"].get(node_id)
255
-
256
- def _current_node(st: Dict[str, Any]) -> Optional[Dict[str, Any]]:
257
- return _get_node(st, st.get("current_id"))
258
-
259
- def _chips_text(st: Dict[str, Any], node_id: Optional[str]) -> str:
260
- node = _get_node(st, node_id)
261
- if not node:
262
- return ""
263
- chips = []
264
- if node_id == st.get("root_id"):
265
- chips.append("[root]")
266
- if node.get("parent_id"):
267
- chips.append("[parent]")
268
- children = node.get("children_ids") or []
269
- if children:
270
- chips.append(f"[children:{len(children)}]")
271
- return " ".join(chips)
272
 
273
- def _make_node(
274
- st: Dict[str, Any],
 
 
 
275
  layers: List[Image.Image],
 
276
  parent_id: Optional[str],
277
- name: Optional[str] = None,
278
- ) -> str:
279
- nid = uuid.uuid4().hex
280
- node = {
281
- "id": nid,
282
- "name": name or ("root" if parent_id is None else "refine"),
283
- "parent_id": parent_id,
284
- "children_ids": [],
285
- "created_at": _now_ts(),
286
- "layers": layers,
287
- }
288
- st["nodes"][nid] = node
289
- if parent_id:
290
- parent = st["nodes"].get(parent_id)
291
- if parent is not None:
292
- parent.setdefault("children_ids", []).append(nid)
293
- return nid
294
-
295
- def _set_current(st: Dict[str, Any], node_id: str):
296
- st["current_id"] = node_id
297
- st["selected_layer_idx"] = 0
298
-
299
- # ----------------------------
300
- # Persistence: save/load whole session as one zip in dataset repo
301
- # ----------------------------
302
- def _serialize_session_to_zip(st: Dict[str, Any]) -> str:
303
- """
304
- Create a zip file with:
305
- session.json
306
- nodes/<node_id>/layer_1.png ...
307
- Returns local zip path.
308
- """
309
- tmpdir = tempfile.mkdtemp(prefix="sess_")
310
- try:
311
- sess_meta = {
312
- "session_id": st["session_id"],
313
- "root_id": st["root_id"],
314
- "current_id": st["current_id"],
315
- "selected_layer_idx": st.get("selected_layer_idx", 0),
316
- "last_refined_id": st.get("last_refined_id"),
317
- "nodes": {},
318
- }
319
-
320
- for nid, node in st["nodes"].items():
321
- node_dir = os.path.join(tmpdir, "nodes", nid)
322
- os.makedirs(node_dir, exist_ok=True)
323
- layers: List[Image.Image] = node.get("layers") or []
324
- for i, img in enumerate(layers):
325
- img_path = os.path.join(node_dir, f"layer_{i+1}.png")
326
- img.save(img_path)
327
- sess_meta["nodes"][nid] = {
328
- "id": nid,
329
- "name": node.get("name"),
330
- "parent_id": node.get("parent_id"),
331
- "children_ids": node.get("children_ids") or [],
332
- "created_at": node.get("created_at", 0.0),
333
- "layer_count": len(layers),
334
- }
335
-
336
- meta_path = os.path.join(tmpdir, "session.json")
337
- with open(meta_path, "w", encoding="utf-8") as f:
338
- json.dump(sess_meta, f, ensure_ascii=False, indent=2)
339
-
340
- out_zip = tempfile.NamedTemporaryFile(suffix=".zip", delete=False).name
341
- with zipfile.ZipFile(out_zip, "w", zipfile.ZIP_DEFLATED) as zf:
342
- for root, _, files in os.walk(tmpdir):
343
- for fn in files:
344
- abs_path = os.path.join(root, fn)
345
- rel_path = os.path.relpath(abs_path, tmpdir)
346
- zf.write(abs_path, rel_path)
347
- return out_zip
348
- finally:
349
- shutil.rmtree(tmpdir, ignore_errors=True)
350
-
351
- def _deserialize_session_from_zip(zip_path: str) -> Dict[str, Any]:
352
- tmpdir = tempfile.mkdtemp(prefix="sess_load_")
353
- try:
354
- with zipfile.ZipFile(zip_path, "r") as zf:
355
- zf.extractall(tmpdir)
356
-
357
- meta_path = os.path.join(tmpdir, "session.json")
358
- with open(meta_path, "r", encoding="utf-8") as f:
359
- meta = json.load(f)
360
-
361
- st = new_state()
362
- st["session_id"] = meta["session_id"]
363
- st["root_id"] = meta.get("root_id")
364
- st["current_id"] = meta.get("current_id")
365
- st["selected_layer_idx"] = meta.get("selected_layer_idx", 0)
366
- st["last_refined_id"] = meta.get("last_refined_id")
367
-
368
- nodes_meta: Dict[str, Any] = meta.get("nodes", {})
369
- # First pass: create node shells
370
- for nid, nm in nodes_meta.items():
371
- st["nodes"][nid] = {
372
- "id": nid,
373
- "name": nm.get("name"),
374
- "parent_id": nm.get("parent_id"),
375
- "children_ids": nm.get("children_ids") or [],
376
- "created_at": nm.get("created_at", 0.0),
377
- "layers": [],
378
- }
379
- # Second pass: load layers images
380
- for nid, nm in nodes_meta.items():
381
- layer_count = int(nm.get("layer_count", 0))
382
- node_dir = os.path.join(tmpdir, "nodes", nid)
383
- layers: List[Image.Image] = []
384
- for i in range(layer_count):
385
- p = os.path.join(node_dir, f"layer_{i+1}.png")
386
- if os.path.exists(p):
387
- layers.append(Image.open(p).convert("RGBA"))
388
- st["nodes"][nid]["layers"] = layers
389
-
390
- return st
391
- finally:
392
- shutil.rmtree(tmpdir, ignore_errors=True)
393
-
394
- def save_session_to_hub(st: Dict[str, Any]) -> Tuple[str, str]:
395
- """
396
- Returns (status_text, session_id)
397
- """
398
- if not _persist_enabled or _hf_api is None:
399
- return "Save: disabled (set HF_DATASET_REPO and secret hf write token).", st.get("session_id", "")
400
- try:
401
- zip_path = _serialize_session_to_zip(st)
402
- path_in_repo = f"sessions/{st['session_id']}.zip"
403
- _hf_api.upload_file(
404
- path_or_fileobj=zip_path,
405
- path_in_repo=path_in_repo,
406
- repo_id=HF_DATASET_REPO,
407
- repo_type="dataset",
408
- commit_message=f"Save session {st['session_id']}",
409
- )
410
- return f"Saved to dataset repo: {path_in_repo}", st["session_id"]
411
- except Exception as e:
412
- return f"Save failed: {type(e).__name__}: {e}", st.get("session_id", "")
413
- finally:
414
- try:
415
- if "zip_path" in locals() and os.path.exists(zip_path):
416
- os.remove(zip_path)
417
- except Exception:
418
- pass
419
-
420
- def load_session_from_hub(session_id: str) -> Tuple[Optional[Dict[str, Any]], str]:
421
- if not _persist_enabled:
422
- return None, "Load: disabled (set HF_DATASET_REPO and secret hf write token)."
423
- session_id = (session_id or "").strip()
424
- if not session_id:
425
- return None, "Load: please enter a Session ID."
426
- try:
427
- filename = f"sessions/{session_id}.zip"
428
- local_zip = hf_hub_download(
429
- repo_id=HF_DATASET_REPO,
430
- repo_type="dataset",
431
- filename=filename,
432
- token=HF_TOKEN,
433
  )
434
- st = _deserialize_session_from_zip(local_zip)
435
- return st, f"Loaded session: {session_id}"
436
- except Exception as e:
437
- return None, f"Load failed: {type(e).__name__}: {e}"
438
-
439
- # ----------------------------
440
- # UI Callbacks
441
- # ----------------------------
442
- def ui_boot() -> Tuple[str, Dict[str, Any]]:
443
- ensure_dirname(LOG_DIR)
444
- st = new_state()
445
- return _enabled_msg, st
446
 
447
- def on_new_session(st: Dict[str, Any]) -> Tuple[Dict[str, Any], str, gr.Dropdown, List[Image.Image], List[Image.Image], str, str, str, Optional[str], Optional[str]]:
448
- st = new_state()
 
 
 
 
 
449
  return (
450
- st,
451
- st["session_id"],
452
- gr.Dropdown(choices=[], value=None),
 
 
 
 
 
453
  [],
454
- [],
455
- "",
456
- "",
457
- "",
458
- None,
459
- None,
460
  )
461
 
462
- def _render_from_state(st: Dict[str, Any]) -> Tuple[
463
- gr.Dropdown,
464
- List[Image.Image],
465
- List[Image.Image],
466
- gr.Number,
467
- str
468
- ]:
469
- choices = _build_history_choices(st)
470
- current = _current_node(st)
471
- layers = current["layers"] if current else []
472
- idx = st.get("selected_layer_idx", 0)
473
- if layers:
474
- idx = max(0, min(idx, len(layers) - 1))
475
- st["selected_layer_idx"] = idx
476
- chips = _chips_text(st, st.get("current_id"))
477
  return (
478
- gr.Dropdown(choices=choices, value=st.get("current_id")),
479
- layers,
480
- layers, # mini gallery mirrors current layers
481
- idx,
482
- chips,
483
  )
484
 
485
- def on_history_select(node_id: str, st: Dict[str, Any]) -> Tuple[Dict[str, Any], gr.Dropdown, List[Image.Image], List[Image.Image], gr.Number, str]:
486
- if node_id and node_id in st["nodes"]:
487
- st["current_id"] = node_id
488
- st["selected_layer_idx"] = 0
489
- dd, layers, mini, idx, chips = _render_from_state(st)
490
- return st, dd, layers, mini, idx, chips
491
 
492
- def on_layer_gallery_select(evt: gr.SelectData, st: Dict[str, Any]) -> Tuple[Dict[str, Any], gr.Number]:
493
- # evt.index is int for Gallery
494
- idx = int(evt.index) if evt and evt.index is not None else 0
495
- current = _current_node(st)
496
- if current:
497
- layers = current.get("layers") or []
498
- if layers:
499
- idx = max(0, min(idx, len(layers) - 1))
500
- else:
501
- idx = 0
502
- else:
503
- idx = 0
504
- st["selected_layer_idx"] = idx
505
- return st, idx
506
-
507
- def on_back_to_parent(st: Dict[str, Any]) -> Tuple[Dict[str, Any], gr.Dropdown, List[Image.Image], List[Image.Image], gr.Number, str]:
508
- cur = _current_node(st)
509
- if cur and cur.get("parent_id"):
510
- st["current_id"] = cur["parent_id"]
511
- st["selected_layer_idx"] = 0
512
- dd, layers, mini, idx, chips = _render_from_state(st)
513
- return st, dd, layers, mini, idx, chips
514
-
515
- def on_duplicate_node(st: Dict[str, Any]) -> Tuple[Dict[str, Any], gr.Dropdown, List[Image.Image], List[Image.Image], gr.Number, str]:
516
- cur = _current_node(st)
517
- if cur:
518
- # Duplicate current node as sibling (same parent)
519
- layers = cur.get("layers") or []
520
- parent_id = cur.get("parent_id")
521
- name = (cur.get("name") or "node") + " copy"
522
- new_id = _make_node(st, layers=layers, parent_id=parent_id, name=name)
523
- _set_current(st, new_id)
524
- if st.get("root_id") is None and parent_id is None:
525
- st["root_id"] = new_id
526
- dd, layers, mini, idx, chips = _render_from_state(st)
527
- return st, dd, layers, mini, idx, chips
528
-
529
- def on_rename_node(new_name: str, st: Dict[str, Any]) -> Tuple[Dict[str, Any], gr.Dropdown]:
530
- cur = _current_node(st)
531
- if cur:
532
- nn = (new_name or "").strip()
533
- if nn:
534
- cur["name"] = nn
535
- dd, _, _, _, _ = _render_from_state(st)
536
- return st, dd
537
-
538
- def on_export_selected(st: Dict[str, Any]) -> Tuple[Optional[str], Optional[str]]:
539
- cur = _current_node(st)
540
- if not cur:
541
- return None, None
542
- pptx_path, zip_path = export_node_layers(cur.get("layers") or [])
543
- return pptx_path, zip_path
544
-
545
- def on_save_session(st: Dict[str, Any]) -> Tuple[str, str]:
546
- status, sid = save_session_to_hub(st)
547
- return status, sid
548
-
549
- def on_load_session(session_id: str, st: Dict[str, Any]) -> Tuple[
550
- Dict[str, Any],
551
- str,
552
- gr.Dropdown,
553
- List[Image.Image],
554
- List[Image.Image],
555
- gr.Number,
556
- str
557
- ]:
558
- loaded, msg = load_session_from_hub(session_id)
559
- if loaded is None:
560
- dd, layers, mini, idx, chips = _render_from_state(st)
561
- return st, msg, dd, layers, mini, idx, chips
562
-
563
- st = loaded
564
- dd, layers, mini, idx, chips = _render_from_state(st)
565
- return st, msg, dd, layers, mini, idx, chips
566
-
567
- # GPU click handlers
568
  def on_decompose_click(
569
  input_image,
570
- seed=777,
571
- randomize_seed=False,
572
- prompt=None,
573
- neg_prompt=" ",
574
- true_guidance_scale=4.0,
575
- num_inference_steps=50,
576
- layer=4,
577
- cfg_norm=True,
578
- use_en_prompt=True,
579
- resolution=640,
580
- gpu_duration=1000,
581
- st: Optional[Dict[str, Any]] = None,
582
  ):
583
- if st is None:
584
- st = new_state()
585
-
586
- pil_image = _pil_rgba(input_image)
587
- layers_out = gpu_run_pipeline(
588
- pil_image,
589
- seed=seed,
590
- randomize_seed=randomize_seed,
 
591
  prompt=prompt,
592
  neg_prompt=neg_prompt,
593
- true_guidance_scale=true_guidance_scale,
594
- num_inference_steps=num_inference_steps,
595
- layer=layer,
596
- cfg_norm=cfg_norm,
597
- use_en_prompt=use_en_prompt,
598
- resolution=resolution,
599
- gpu_duration=gpu_duration,
600
  )
601
 
602
- # Reset session tree on Decompose (new root)
603
- sid = st.get("session_id") or uuid.uuid4().hex
604
- st = new_state()
605
- st["session_id"] = sid
 
 
 
 
 
 
 
 
 
606
 
607
- root_id = _make_node(st, layers=layers_out, parent_id=None, name="root")
608
- st["root_id"] = root_id
609
- _set_current(st, root_id)
 
 
 
 
 
610
 
611
- dd, layers, mini, idx, chips = _render_from_state(st)
 
 
 
 
612
  return (
613
- st,
614
- dd,
615
- layers,
616
- mini,
617
- idx,
618
  chips,
619
- gr.update(open=False), # refined accordion closed
620
- [], # refined gallery cleared
621
- None,
622
- None,
 
623
  )
624
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
625
  def on_refine_click(
626
- sub_layers_count: int,
627
- seed=777,
628
- randomize_seed=False,
629
- prompt=None,
630
- neg_prompt=" ",
631
- true_guidance_scale=4.0,
632
- num_inference_steps=50,
633
- cfg_norm=True,
634
- use_en_prompt=True,
635
- resolution=640,
636
- gpu_duration=1000,
637
- st: Optional[Dict[str, Any]] = None,
638
  ):
639
- if st is None:
640
- st = new_state()
641
-
642
- cur = _current_node(st)
643
- if not cur:
644
- dd, layers, mini, idx, chips = _render_from_state(st)
645
  return (
646
- st,
647
- dd,
648
- layers,
649
- mini,
650
- idx,
651
- chips,
652
- "Refine: no current node.",
653
- gr.update(open=False),
654
  [],
 
 
655
  None,
656
  None,
657
  )
658
 
659
- layers_list: List[Image.Image] = cur.get("layers") or []
660
- if not layers_list:
661
- dd, layers, mini, idx, chips = _render_from_state(st)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
662
  return (
663
- st,
664
- dd,
665
- layers,
666
- mini,
667
- idx,
668
- chips,
669
- "Refine: current node has no layers.",
670
- gr.update(open=False),
671
  [],
 
 
672
  None,
673
  None,
674
  )
675
 
676
- idx = int(st.get("selected_layer_idx", 0))
677
- idx = max(0, min(idx, len(layers_list) - 1))
678
- selected_layer = layers_list[idx].convert("RGBA")
 
 
 
 
 
 
 
 
679
 
680
- # Run pipeline again on selected layer, producing sub-layers
681
- sub_layers_count = _clamp_int(sub_layers_count, default=3, lo=2, hi=10)
 
682
 
683
- sub_layers = gpu_run_pipeline(
684
- selected_layer,
685
- seed=seed,
686
- randomize_seed=randomize_seed,
687
  prompt=prompt,
688
  neg_prompt=neg_prompt,
689
- true_guidance_scale=true_guidance_scale,
690
- num_inference_steps=num_inference_steps,
691
- layer=sub_layers_count, # <-- only change: layers = sub_layers_count
692
- cfg_norm=cfg_norm,
693
- use_en_prompt=use_en_prompt,
694
- resolution=resolution,
695
- gpu_duration=gpu_duration,
696
  )
697
 
698
- new_id = _make_node(st, layers=sub_layers, parent_id=cur["id"], name=f"refine L{idx+1}")
699
- _set_current(st, new_id)
700
- st["last_refined_id"] = new_id
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
701
 
702
- dd, layers, mini, idx2, chips = _render_from_state(st)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
703
 
704
- # Export files for current node on-demand (not automatic)
705
  return (
706
- st,
707
- dd,
708
- layers,
709
- mini,
710
- idx2,
711
  chips,
712
- f"Refined: created node {_node_label(st['nodes'][new_id])}",
713
- gr.update(open=True), # open refined accordion
714
- sub_layers, # show refined layers
715
- None,
716
- None,
717
  )
718
 
719
- # ----------------------------
720
- # App UI
721
- # ----------------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
722
  ensure_dirname(LOG_DIR)
723
 
724
  examples = [
@@ -737,31 +1022,18 @@ examples = [
737
  "assets/test_images/13.png",
738
  ]
739
 
 
740
  with gr.Blocks() as demo:
741
- st = gr.State(value=new_state())
742
 
743
  with gr.Column(elem_id="col-container"):
744
  gr.HTML(
745
  '<img src="https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-Image/layered/qwen-image-layered-logo.png" '
746
  'alt="Qwen-Image-Layered Logo" width="600" style="display: block; margin: 0 auto;">'
747
  )
748
-
749
- persist_status = gr.Markdown(_enabled_msg)
750
-
751
- with gr.Row():
752
- new_session_btn = gr.Button("New session", variant="secondary")
753
- session_id_box = gr.Textbox(label="Session ID", value="", interactive=False)
754
- save_btn = gr.Button("Save session to Dataset repo", variant="primary")
755
- save_status = gr.Textbox(label="Save/Load status", value="", interactive=False)
756
-
757
- with gr.Row():
758
- load_session_id = gr.Textbox(label="Load Session ID", value="", placeholder="paste Session ID here")
759
- load_btn = gr.Button("Load", variant="secondary")
760
-
761
  gr.Markdown(
762
  """
763
- The text prompt is intended to describe the overall content of the input image—including elements that may be partially occluded.
764
- It is not designed to control the semantic content of individual layers explicitly.
765
  """
766
  )
767
 
@@ -837,80 +1109,105 @@ It is not designed to control the semantic content of individual layers explicit
837
  placeholder="e.g. 60, 120, 300, 1000, 1500",
838
  )
839
 
840
- run_button = gr.Button("Decompose!", variant="primary")
 
 
841
 
842
- gr.Markdown("### History")
843
- history_dd = gr.Dropdown(
844
- label="Nodes",
845
- choices=[],
846
- value=None,
847
- interactive=True,
848
- )
849
- chips_md = gr.Markdown("")
850
 
851
- with gr.Row():
852
- back_btn = gr.Button("← back to parent")
853
- dup_btn = gr.Button("Duplicate node (branch)")
854
 
855
- with gr.Row():
856
- rename_inp = gr.Textbox(label="Branch name", value="", placeholder="rename current node")
857
- rename_btn = gr.Button("Rename")
 
 
 
 
 
858
 
859
- with gr.Row():
860
- export_btn = gr.Button("Export selected node")
861
- export_file = gr.File(label="Download PPTX")
862
- export_zip_file = gr.File(label="Download ZIP")
 
 
 
863
 
864
- with gr.Column(scale=2):
865
- gr.Markdown("### Layers (current node)")
866
- gallery = gr.Gallery(label="Layers", columns=4, rows=1, format="png")
867
 
868
- with gr.Accordion("Layer picker (mini, click like Photoshop)", open=True):
869
- mini_gallery = gr.Gallery(label="Pick layer to refine", columns=7, rows=1, format="png")
870
- selected_layer_idx = gr.Number(label="Selected layer index (0-based)", value=0, interactive=False)
871
 
872
- with gr.Accordion("Refine selected layer", open=True):
873
- refine_info = gr.Textbox(label="Refine status", value="", interactive=False)
 
 
 
 
 
 
874
  with gr.Row():
875
- sub_layers_count = gr.Slider(
876
- label="Sub-layers (refine)",
877
- minimum=2,
878
- maximum=10,
879
- step=1,
880
- value=3,
881
- )
882
- refine_btn = gr.Button("Refine selected layer", variant="primary")
 
 
 
 
883
 
884
- refined_acc = gr.Accordion("Refined layers (latest)", open=False)
885
- with refined_acc:
886
- refined_gallery = gr.Gallery(label="Refined layers", columns=4, rows=1, format="png")
 
 
 
 
 
887
 
888
- # Examples run Decompose
 
 
 
 
 
 
 
 
889
  gr.Examples(
890
  examples=examples,
891
  inputs=[input_image],
892
  outputs=[gallery, export_file, export_zip_file],
893
- fn=lambda img: ([], None, None),
 
894
  cache_examples=False,
895
  run_on_click=False,
896
  )
897
 
898
- # Boot / init
899
- demo.load(
900
- fn=ui_boot,
901
  inputs=[],
902
- outputs=[persist_status, st],
903
- ).then(
904
- fn=lambda st: st.get("session_id", ""),
905
- inputs=[st],
906
- outputs=[session_id_box],
907
  )
908
 
909
- # New session
910
- new_session_btn.click(
911
- fn=on_new_session,
912
- inputs=[st],
913
- outputs=[st, session_id_box, history_dd, gallery, mini_gallery, chips_md, refine_info, save_status, export_file, export_zip_file],
914
  )
915
 
916
  # Decompose
@@ -929,70 +1226,69 @@ It is not designed to control the semantic content of individual layers explicit
929
  use_en_prompt,
930
  resolution,
931
  gpu_duration,
932
- st,
933
  ],
934
  outputs=[
935
- st,
936
- history_dd,
937
  gallery,
938
- mini_gallery,
939
- selected_layer_idx,
940
- chips_md,
941
- refined_acc,
942
- refined_gallery,
943
  export_file,
944
  export_zip_file,
 
 
 
945
  ],
946
- ).then(
947
- fn=lambda st: st.get("session_id", ""),
948
- inputs=[st],
949
- outputs=[session_id_box],
950
  )
951
 
952
- # History selection
953
- history_dd.change(
954
- fn=on_history_select,
955
- inputs=[history_dd, st],
956
- outputs=[st, history_dd, gallery, mini_gallery, selected_layer_idx, chips_md],
957
- )
958
-
959
- # Mini gallery click -> choose layer index
960
- mini_gallery.select(
961
- fn=on_layer_gallery_select,
962
- inputs=[st],
963
- outputs=[st, selected_layer_idx],
964
  )
965
 
966
- # Back to parent
967
- back_btn.click(
968
- fn=on_back_to_parent,
969
- inputs=[st],
970
- outputs=[st, history_dd, gallery, mini_gallery, selected_layer_idx, chips_md],
971
  )
972
 
973
- # Duplicate node
974
- dup_btn.click(
975
- fn=on_duplicate_node,
976
- inputs=[st],
977
- outputs=[st, history_dd, gallery, mini_gallery, selected_layer_idx, chips_md],
978
- )
979
-
980
- # Rename node
981
- rename_btn.click(
982
- fn=on_rename_node,
983
- inputs=[rename_inp, st],
984
- outputs=[st, history_dd],
985
- ).then(
986
- fn=lambda: "",
987
- inputs=[],
988
- outputs=[rename_inp],
 
 
 
 
 
 
 
 
 
 
989
  )
990
 
991
- # Refine selected layer
992
- refine_btn.click(
993
- fn=on_refine_click,
994
  inputs=[
995
- sub_layers_count,
996
  seed,
997
  randomize_seed,
998
  prompt,
@@ -1003,51 +1299,130 @@ It is not designed to control the semantic content of individual layers explicit
1003
  use_en_prompt,
1004
  resolution,
1005
  gpu_duration,
1006
- st,
 
 
 
 
 
 
 
 
 
1007
  ],
 
 
 
 
 
 
1008
  outputs=[
1009
- st,
1010
- history_dd,
1011
  gallery,
1012
- mini_gallery,
1013
- selected_layer_idx,
1014
- chips_md,
1015
- refine_info,
1016
- refined_acc,
 
 
1017
  refined_gallery,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1018
  export_file,
1019
  export_zip_file,
 
 
 
1020
  ],
1021
  )
1022
 
1023
- # Export current node
1024
- export_btn.click(
1025
- fn=on_export_selected,
1026
- inputs=[st],
1027
- outputs=[export_file, export_zip_file],
 
 
 
 
 
 
 
 
 
 
 
 
1028
  )
1029
 
1030
- # Save session
1031
  save_btn.click(
1032
- fn=on_save_session,
1033
- inputs=[st],
1034
- outputs=[save_status, session_id_box],
 
 
 
 
 
 
 
1035
  )
1036
 
1037
  # Load session
1038
- load_btn.click(
1039
  fn=on_load_session,
1040
- inputs=[load_session_id, st],
1041
- outputs=[st, save_status, history_dd, gallery, mini_gallery, selected_layer_idx, chips_md],
1042
- ).then(
1043
- fn=lambda st: st.get("session_id", ""),
1044
- inputs=[st],
1045
- outputs=[session_id_box],
1046
- ).then(
1047
- fn=lambda: "",
1048
- inputs=[],
1049
- outputs=[load_session_id],
 
 
 
 
1050
  )
1051
 
 
 
 
1052
  if __name__ == "__main__":
1053
  demo.launch()
 
1
  import os
2
+ import io
3
+ import gc
4
  import json
5
  import time
6
+ import uuid
7
  import numpy as np
8
  import random
9
  import tempfile
10
  import zipfile
11
+ from dataclasses import dataclass
12
+ from typing import Dict, Any, List, Optional, Tuple
13
 
14
  import spaces
15
  import torch
 
19
  from diffusers import QwenImageLayeredPipeline
20
  from pptx import Presentation
21
 
22
+ from huggingface_hub import login, HfApi, hf_hub_download
23
+
24
+
25
  LOG_DIR = "/tmp/local"
26
  MAX_SEED = np.iinfo(np.int32).max
27
 
28
  # Optional HF login (works in Spaces if you set HF token as secret env var "hf")
 
 
29
  login(token=os.environ.get("hf"))
30
 
31
+ # Dataset persistence (optional). Example: "username/qwen-layered-sessions"
32
+ DATASET_REPO = os.environ.get("DATASET_REPO", "").strip()
33
+ DATASET_BRANCH = os.environ.get("DATASET_BRANCH", "main").strip()
34
+
35
+ # If you want to reduce allocator weirdness on some CUDA envs, you can set this.
36
+ # Keep as "setdefault" so you can override in Space Variables.
37
+ os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", os.environ.get("PYTORCH_CUDA_ALLOC_CONF", ""))
38
+
39
  dtype = torch.bfloat16
40
  device = "cuda" if torch.cuda.is_available() else "cpu"
41
 
 
43
  "Qwen/Qwen-Image-Layered", torch_dtype=dtype
44
  ).to(device)
45
 
46
+ pipeline.set_progress_bar_config(disable=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
 
48
+ try:
49
+ torch.backends.cuda.matmul.allow_tf32 = True
50
+ except Exception:
51
+ pass
 
 
 
 
 
 
 
 
 
 
 
52
 
 
53
 
 
 
 
54
  def ensure_dirname(path: str):
55
  if path and not os.path.exists(path):
56
  os.makedirs(path, exist_ok=True)
57
 
58
+
59
  def random_str(length=8):
60
  return uuid.uuid4().hex[:length]
61
 
62
+
63
+ def _now_ts() -> int:
64
+ return int(time.time())
65
+
66
 
67
  def _clamp_int(x, default: int, lo: int, hi: int) -> int:
68
  try:
 
71
  v = default
72
  return max(lo, min(hi, v))
73
 
74
+
75
+ def _safe_img_rgba(input_image):
76
+ # Normalize image input
77
  if isinstance(input_image, list):
78
  input_image = input_image[0]
79
 
 
85
  pil_image = Image.fromarray(input_image).convert("RGB").convert("RGBA")
86
  else:
87
  raise ValueError(f"Unsupported input_image type: {type(input_image)}")
88
+
89
  return pil_image
90
 
91
+
92
+ def imagelist_to_pptx(img_files):
93
  with Image.open(img_files[0]) as img:
94
  img_width_px, img_height_px = img.size
95
 
 
118
  prs.save(tmp.name)
119
  return tmp.name
120
 
121
+
122
+ def _write_layers_to_temp_pngs(layers: List[Image.Image]) -> List[str]:
123
+ temp_files = []
 
 
124
  for img in layers:
125
  tmp = tempfile.NamedTemporaryFile(suffix=".png", delete=False)
126
  img.save(tmp.name)
127
  temp_files.append(tmp.name)
128
+ return temp_files
129
 
 
130
 
131
+ def _build_zip_from_pngs(png_paths: List[str]) -> str:
132
  with tempfile.NamedTemporaryFile(suffix=".zip", delete=False) as tmpzip:
133
  with zipfile.ZipFile(tmpzip.name, "w", zipfile.ZIP_DEFLATED) as zipf:
134
+ for i, img_path in enumerate(png_paths):
135
  zipf.write(img_path, f"layer_{i+1}.png")
136
+ return tmpzip.name
137
+
138
+
139
+ def _mk_node_name(kind: str, depth: int) -> str:
140
+ if kind == "root":
141
+ return "Root"
142
+ return f"Refine d{depth}"
143
+
144
+
145
+ def _ds_enabled() -> bool:
146
+ return bool(DATASET_REPO) and bool(os.environ.get("hf"))
147
+
148
+
149
+ def _ds_api() -> HfApi:
150
+ return HfApi(token=os.environ.get("hf"))
151
+
152
+
153
+ def _ds_path(*parts: str) -> str:
154
+ return "/".join([p.strip("/") for p in parts if p is not None and p != ""])
155
+
156
+
157
+ def ds_list_sessions() -> List[str]:
158
+ if not _ds_enabled():
159
+ return []
160
+ api = _ds_api()
161
+ files = api.list_repo_files(repo_id=DATASET_REPO, repo_type="dataset", revision=DATASET_BRANCH)
162
+ sessions = set()
163
+ for f in files:
164
+ if f.startswith("sessions/") and f.endswith("/index.json"):
165
+ # sessions/<sid>/index.json
166
+ parts = f.split("/")
167
+ if len(parts) >= 3:
168
+ sessions.add(parts[1])
169
+ return sorted(list(sessions))
170
+
171
+
172
+ def ds_upload_bytes(path_in_repo: str, data: bytes, commit_message: str):
173
+ api = _ds_api()
174
+ with tempfile.NamedTemporaryFile(delete=False) as tmp:
175
+ tmp.write(data)
176
+ tmp.flush()
177
+ api.upload_file(
178
+ path_or_fileobj=tmp.name,
179
+ path_in_repo=path_in_repo,
180
+ repo_id=DATASET_REPO,
181
+ repo_type="dataset",
182
+ revision=DATASET_BRANCH,
183
+ commit_message=commit_message,
184
+ )
185
+
186
+
187
+ def ds_upload_file(local_path: str, path_in_repo: str, commit_message: str):
188
+ api = _ds_api()
189
+ api.upload_file(
190
+ path_or_fileobj=local_path,
191
+ path_in_repo=path_in_repo,
192
+ repo_id=DATASET_REPO,
193
+ repo_type="dataset",
194
+ revision=DATASET_BRANCH,
195
+ commit_message=commit_message,
196
+ )
197
+
198
+
199
+ def ds_download_json(path_in_repo: str) -> Dict[str, Any]:
200
+ local = hf_hub_download(
201
+ repo_id=DATASET_REPO,
202
+ repo_type="dataset",
203
+ revision=DATASET_BRANCH,
204
+ filename=path_in_repo,
205
+ )
206
+ with open(local, "r", encoding="utf-8") as f:
207
+ return json.load(f)
208
+
209
+
210
+ def ds_download_file(path_in_repo: str) -> str:
211
+ return hf_hub_download(
212
+ repo_id=DATASET_REPO,
213
+ repo_type="dataset",
214
+ revision=DATASET_BRANCH,
215
+ filename=path_in_repo,
216
+ )
217
 
 
218
 
219
+ @dataclass
220
+ class Node:
221
+ node_id: str
222
+ parent_id: Optional[str]
223
+ kind: str # "root" or "refine"
224
+ depth: int
225
+ name: str
226
+ created_at: int
227
+ params: Dict[str, Any]
228
+ refine_meta: Dict[str, Any]
229
+
230
+ layers: List[Image.Image] # PIL layers (in-memory)
231
+ png_paths: List[str] # local temp pngs
232
+ pptx_path: Optional[str]
233
+ zip_path: Optional[str]
234
+
235
+ children: List[str]
236
+
237
+
238
+ def new_state() -> Dict[str, Any]:
239
+ return {
240
+ "session_id": random_str(12),
241
+ "nodes": {}, # node_id -> Node
242
+ "root_id": None,
243
+ "current_id": None,
244
+ "selected_layer_idx": 0,
245
+ "autosave": False,
246
+ "dataset_repo": DATASET_REPO,
247
+ "dataset_branch": DATASET_BRANCH,
248
+ "last_source_for_redo": None, # (from_node_id, layer_idx, sub_layers, params)
249
+ }
250
+
251
+
252
+ def _node_to_brief(n: Node) -> str:
253
+ short = n.node_id[:8]
254
+ return f"{n.name} · {short} · {n.kind} · depth {n.depth}"
255
+
256
+
257
+ def _history_choices(state: Dict[str, Any]) -> List[Tuple[str, str]]:
258
+ nodes: Dict[str, Node] = state["nodes"]
259
+ if not nodes:
260
+ return []
261
+ # sort by created_at
262
+ ordered = sorted(nodes.values(), key=lambda x: x.created_at)
263
+ return [(_node_to_brief(n), n.node_id) for n in ordered]
264
+
265
+
266
+ def _chips_html(state: Dict[str, Any], node_id: Optional[str]) -> str:
267
+ if not node_id or node_id not in state["nodes"]:
268
+ return ""
269
+ n: Node = state["nodes"][node_id]
270
+ root_id = state.get("root_id")
271
+ parent = n.parent_id
272
+ children = len(n.children or [])
273
+ chips = []
274
+ if n.node_id == root_id:
275
+ chips.append("[root]")
276
+ if parent:
277
+ chips.append("[parent]")
278
+ chips.append(f"[children:{children}]")
279
+ return " ".join(chips)
280
+
281
+
282
+ def _ensure_exports(node: Node) -> Node:
283
+ if node.png_paths is None or len(node.png_paths) == 0:
284
+ node.png_paths = _write_layers_to_temp_pngs(node.layers)
285
+ if not node.pptx_path:
286
+ node.pptx_path = imagelist_to_pptx(node.png_paths)
287
+ if not node.zip_path:
288
+ node.zip_path = _build_zip_from_pngs(node.png_paths)
289
+ return node
290
+
291
+
292
+ def _persist_node_to_dataset(state: Dict[str, Any], node: Node):
293
+ if not _ds_enabled():
294
+ return
295
+
296
+ sid = state["session_id"]
297
+ nid = node.node_id
298
+ base = _ds_path("sessions", sid, "nodes", nid)
299
+
300
+ # node meta
301
+ meta = {
302
+ "node_id": node.node_id,
303
+ "parent_id": node.parent_id,
304
+ "kind": node.kind,
305
+ "depth": node.depth,
306
+ "name": node.name,
307
+ "created_at": node.created_at,
308
+ "params": node.params,
309
+ "refine_meta": node.refine_meta,
310
+ "children": node.children,
311
+ "layer_count": len(node.layers),
312
+ }
313
+ ds_upload_bytes(_ds_path(base, "node.json"), json.dumps(meta, ensure_ascii=False, indent=2).encode("utf-8"),
314
+ commit_message=f"save node meta {sid}/{nid}")
315
+
316
+ # ensure exports
317
+ node = _ensure_exports(node)
318
+
319
+ # pngs
320
+ for i, p in enumerate(node.png_paths):
321
+ ds_upload_file(p, _ds_path(base, f"layer_{i+1}.png"), commit_message=f"save layer png {sid}/{nid}")
322
+
323
+ # pptx/zip
324
+ ds_upload_file(node.pptx_path, _ds_path(base, "layers.pptx"), commit_message=f"save pptx {sid}/{nid}")
325
+ ds_upload_file(node.zip_path, _ds_path(base, "layers.zip"), commit_message=f"save zip {sid}/{nid}")
326
+
327
+ # session index
328
+ nodes: Dict[str, Node] = state["nodes"]
329
+ index = {
330
+ "session_id": sid,
331
+ "saved_at": _now_ts(),
332
+ "root_id": state["root_id"],
333
+ "current_id": state["current_id"],
334
+ "nodes": [
335
+ {
336
+ "node_id": x.node_id,
337
+ "parent_id": x.parent_id,
338
+ "kind": x.kind,
339
+ "depth": x.depth,
340
+ "name": x.name,
341
+ "created_at": x.created_at,
342
+ "layer_count": len(x.layers),
343
+ }
344
+ for x in sorted(nodes.values(), key=lambda z: z.created_at)
345
+ ],
346
+ }
347
+ ds_upload_bytes(_ds_path("sessions", sid, "index.json"),
348
+ json.dumps(index, ensure_ascii=False, indent=2).encode("utf-8"),
349
+ commit_message=f"save session index {sid}")
350
+
351
+
352
+ def _load_session_from_dataset(session_id: str) -> Dict[str, Any]:
353
+ st = new_state()
354
+ st["session_id"] = session_id
355
+
356
+ index = ds_download_json(_ds_path("sessions", session_id, "index.json"))
357
+ st["root_id"] = index.get("root_id")
358
+ st["current_id"] = index.get("current_id")
359
+
360
+ nodes: Dict[str, Node] = {}
361
+
362
+ for item in index.get("nodes", []):
363
+ nid = item["node_id"]
364
+ meta = ds_download_json(_ds_path("sessions", session_id, "nodes", nid, "node.json"))
365
+
366
+ layer_count = int(meta.get("layer_count", 0))
367
+ layers = []
368
+ png_paths = []
369
+ for i in range(layer_count):
370
+ fp = ds_download_file(_ds_path("sessions", session_id, "nodes", nid, f"layer_{i+1}.png"))
371
+ layers.append(Image.open(fp).convert("RGBA"))
372
+ png_paths.append(fp)
373
+
374
+ pptx_path = ds_download_file(_ds_path("sessions", session_id, "nodes", nid, "layers.pptx"))
375
+ zip_path = ds_download_file(_ds_path("sessions", session_id, "nodes", nid, "layers.zip"))
376
+
377
+ n = Node(
378
+ node_id=nid,
379
+ parent_id=meta.get("parent_id"),
380
+ kind=meta.get("kind", "root"),
381
+ depth=int(meta.get("depth", 0)),
382
+ name=meta.get("name", _mk_node_name(meta.get("kind", "root"), int(meta.get("depth", 0)))),
383
+ created_at=int(meta.get("created_at", _now_ts())),
384
+ params=meta.get("params", {}),
385
+ refine_meta=meta.get("refine_meta", {}),
386
+ layers=layers,
387
+ png_paths=png_paths,
388
+ pptx_path=pptx_path,
389
+ zip_path=zip_path,
390
+ children=meta.get("children", []),
391
+ )
392
+ nodes[nid] = n
393
+
394
+ st["nodes"] = nodes
395
+
396
+ if st["current_id"] is None and st["root_id"] in nodes:
397
+ st["current_id"] = st["root_id"]
398
+
399
+ return st
400
+
401
+
402
+ # Dynamic duration callable: must accept the same args as GPU function. It returns seconds.
403
  def get_duration(
404
  input_image,
405
  seed=777,
 
416
  ):
417
  return _clamp_int(gpu_duration, default=1000, lo=20, hi=1500)
418
 
419
+
 
 
420
  @spaces.GPU(duration=get_duration)
421
  def gpu_run_pipeline(
422
+ pil_image_rgba: Image.Image,
423
+ seed: int,
424
+ randomize_seed: bool,
425
+ prompt: str,
426
+ neg_prompt: str,
427
+ true_guidance_scale: float,
428
+ num_inference_steps: int,
429
+ layer: int,
430
+ cfg_norm: bool,
431
+ use_en_prompt: bool,
432
+ resolution: int,
433
+ gpu_duration: int,
434
+ ):
435
  # Seed
436
  if randomize_seed:
437
  seed = random.randint(0, MAX_SEED)
 
444
  gen_device = "cuda" if torch.cuda.is_available() else "cpu"
445
 
446
  inputs = {
447
+ "image": pil_image_rgba,
448
  "generator": torch.Generator(device=gen_device).manual_seed(seed),
449
  "true_cfg_scale": true_guidance_scale,
450
  "prompt": prompt,
 
457
  "use_en_prompt": use_en_prompt,
458
  }
459
 
460
+ print("INFER INPUTS:", {k: (str(v)[:200] if isinstance(v, str) else v) for k, v in inputs.items()})
461
+ print("REQUESTED GPU DURATION:", gpu_duration)
462
 
463
+ # Self-heal retry for rare CUDA/NVML allocator glitches on some envs
464
+ try:
465
+ with torch.inference_mode():
466
+ out = pipeline(**inputs)
467
+ output_images = out.images[0] # list of PIL images (layers)
468
+ except RuntimeError as e:
469
+ msg = str(e)
470
+ if "NVML_SUCCESS" in msg or "CUDACachingAllocator" in msg:
471
+ print("Caught allocator/NVML error, retrying once after cache cleanup:", msg)
472
+ try:
473
+ torch.cuda.empty_cache()
474
+ except Exception:
475
+ pass
476
+ gc.collect()
477
+ time.sleep(0.2)
478
+ with torch.inference_mode():
479
+ out = pipeline(**inputs)
480
+ output_images = out.images[0]
481
+ else:
482
+ raise
483
 
484
+ # Ensure RGBA
485
+ fixed = []
486
+ for im in output_images:
487
+ if not isinstance(im, Image.Image):
488
+ im = Image.fromarray(np.array(im))
489
+ fixed.append(im.convert("RGBA"))
 
 
 
 
 
 
 
490
 
491
+ try:
492
+ torch.cuda.empty_cache()
493
+ except Exception:
494
+ pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
495
 
496
+ return fixed, seed
497
+
498
+
499
+ def _create_node(
500
+ state: Dict[str, Any],
501
  layers: List[Image.Image],
502
+ kind: str,
503
  parent_id: Optional[str],
504
+ params: Dict[str, Any],
505
+ refine_meta: Dict[str, Any],
506
+ ) -> Node:
507
+ nid = random_str(24)
508
+ depth = 0
509
+ if parent_id and parent_id in state["nodes"]:
510
+ depth = state["nodes"][parent_id].depth + 1
511
+
512
+ n = Node(
513
+ node_id=nid,
514
+ parent_id=parent_id,
515
+ kind=kind,
516
+ depth=depth,
517
+ name=_mk_node_name(kind, depth),
518
+ created_at=_now_ts(),
519
+ params=params,
520
+ refine_meta=refine_meta,
521
+ layers=layers,
522
+ png_paths=[],
523
+ pptx_path=None,
524
+ zip_path=None,
525
+ children=[],
526
+ )
527
+ state["nodes"][nid] = n
528
+ if parent_id and parent_id in state["nodes"]:
529
+ state["nodes"][parent_id].children.append(nid)
530
+ if kind == "root" and state.get("root_id") is None:
531
+ state["root_id"] = nid
532
+ state["current_id"] = nid
533
+ state["selected_layer_idx"] = 0
534
+ return n
535
+
536
+
537
+ def _current_node(state: Dict[str, Any]) -> Optional[Node]:
538
+ cid = state.get("current_id")
539
+ if cid and cid in state["nodes"]:
540
+ return state["nodes"][cid]
541
+ return None
542
+
543
+
544
+ def _render_current(state: Dict[str, Any]):
545
+ n = _current_node(state)
546
+ if not n:
547
+ return (
548
+ [], # main gallery
549
+ [], # layer strip
550
+ gr.update(choices=[], value=None), # layer dropdown
551
+ gr.update(choices=_history_choices(state), value=None), # history dropdown
552
+ "", # chips
553
+ None, # pptx
554
+ None, # zip
555
+ gr.update(open=False), # refined accordion
556
+ [], # refined gallery
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
557
  )
 
 
 
 
 
 
 
 
 
 
 
 
558
 
559
+ # layer dropdown
560
+ layer_choices = [(f"Layer {i+1}", str(i)) for i in range(len(n.layers))]
561
+ dd = gr.update(choices=layer_choices, value=str(min(state.get("selected_layer_idx", 0), max(0, len(n.layers) - 1))))
562
+
563
+ # ensure exports available for selected node
564
+ n = _ensure_exports(n)
565
+
566
  return (
567
+ n.layers, # main gallery
568
+ n.layers, # strip gallery
569
+ dd, # layer dropdown
570
+ gr.update(choices=_history_choices(state), value=n.node_id),
571
+ _chips_html(state, n.node_id),
572
+ n.pptx_path,
573
+ n.zip_path,
574
+ gr.update(open=False),
575
  [],
 
 
 
 
 
 
576
  )
577
 
578
+
579
+ def on_apply_fast_profile():
580
+ # Does not change defaults; only sets UI values when clicked.
 
 
 
 
 
 
 
 
 
 
 
 
581
  return (
582
+ 30, # steps
583
+ 5, # layers
584
+ 640, # resolution
 
 
585
  )
586
 
 
 
 
 
 
 
587
 
588
+ def on_toggle_autosave(val, state):
589
+ state["autosave"] = bool(val)
590
+ return state
591
+
592
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
593
  def on_decompose_click(
594
  input_image,
595
+ seed,
596
+ randomize_seed,
597
+ prompt,
598
+ neg_prompt,
599
+ true_guidance_scale,
600
+ num_inference_steps,
601
+ layer,
602
+ cfg_norm,
603
+ use_en_prompt,
604
+ resolution,
605
+ gpu_duration,
606
+ state,
607
  ):
608
+ if state is None:
609
+ state = new_state()
610
+
611
+ pil = _safe_img_rgba(input_image)
612
+
613
+ layers_out, used_seed = gpu_run_pipeline(
614
+ pil_image_rgba=pil,
615
+ seed=int(seed),
616
+ randomize_seed=bool(randomize_seed),
617
  prompt=prompt,
618
  neg_prompt=neg_prompt,
619
+ true_guidance_scale=float(true_guidance_scale),
620
+ num_inference_steps=int(num_inference_steps),
621
+ layer=int(layer),
622
+ cfg_norm=bool(cfg_norm),
623
+ use_en_prompt=bool(use_en_prompt),
624
+ resolution=int(resolution),
625
+ gpu_duration=int(gpu_duration),
626
  )
627
 
628
+ params = {
629
+ "seed": int(used_seed),
630
+ "randomize_seed": bool(randomize_seed),
631
+ "prompt": prompt,
632
+ "neg_prompt": neg_prompt,
633
+ "true_guidance_scale": float(true_guidance_scale),
634
+ "num_inference_steps": int(num_inference_steps),
635
+ "layers": int(layer),
636
+ "cfg_norm": bool(cfg_norm),
637
+ "use_en_prompt": bool(use_en_prompt),
638
+ "resolution": int(resolution),
639
+ "gpu_duration": int(gpu_duration),
640
+ }
641
 
642
+ node = _create_node(
643
+ state=state,
644
+ layers=layers_out,
645
+ kind="root" if state.get("root_id") is None else "refine",
646
+ parent_id=None,
647
+ params=params,
648
+ refine_meta={"mode": "decompose"},
649
+ )
650
 
651
+ if state.get("autosave"):
652
+ _persist_node_to_dataset(state, node)
653
+
654
+ # reset refined UI
655
+ main_gallery, strip, layer_dd, history_dd, chips, pptx, zzip, acc, refined = _render_current(state)
656
  return (
657
+ state,
658
+ main_gallery,
659
+ strip,
660
+ layer_dd,
661
+ history_dd,
662
  chips,
663
+ pptx,
664
+ zzip,
665
+ acc,
666
+ refined,
667
+ used_seed,
668
  )
669
 
670
+
671
+ def on_layer_select_from_strip(evt: gr.SelectData, state):
672
+ # evt.index -> int
673
+ if state is None:
674
+ state = new_state()
675
+ idx = int(evt.index) if evt and evt.index is not None else 0
676
+ state["selected_layer_idx"] = idx
677
+ n = _current_node(state)
678
+ if not n:
679
+ return state, gr.update(), ""
680
+ dd = gr.update(value=str(min(idx, len(n.layers) - 1)))
681
+ return state, dd, f"Selected: Layer {idx+1}"
682
+
683
+
684
+ def on_layer_dropdown_change(val, state):
685
+ if state is None:
686
+ state = new_state()
687
+ try:
688
+ idx = int(val)
689
+ except Exception:
690
+ idx = 0
691
+ state["selected_layer_idx"] = idx
692
+ return state, f"Selected: Layer {idx+1}"
693
+
694
+
695
  def on_refine_click(
696
+ refine_sub_layers,
697
+ seed,
698
+ randomize_seed,
699
+ prompt,
700
+ neg_prompt,
701
+ true_guidance_scale,
702
+ num_inference_steps,
703
+ cfg_norm,
704
+ use_en_prompt,
705
+ resolution,
706
+ gpu_duration,
707
+ state,
708
  ):
709
+ if state is None:
710
+ state = new_state()
711
+ n = _current_node(state)
712
+ if not n:
 
 
713
  return (
714
+ state,
715
+ gr.update(open=True),
 
 
 
 
 
 
716
  [],
717
+ gr.update(choices=_history_choices(state)),
718
+ "",
719
  None,
720
  None,
721
  )
722
 
723
+ idx = int(state.get("selected_layer_idx", 0))
724
+ idx = max(0, min(idx, len(n.layers) - 1))
725
+ selected = n.layers[idx]
726
+
727
+ # refine creates new node under current as parent
728
+ layers_out, used_seed = gpu_run_pipeline(
729
+ pil_image_rgba=selected.convert("RGBA"),
730
+ seed=int(seed),
731
+ randomize_seed=bool(randomize_seed),
732
+ prompt=prompt,
733
+ neg_prompt=neg_prompt,
734
+ true_guidance_scale=float(true_guidance_scale),
735
+ num_inference_steps=int(num_inference_steps),
736
+ layer=int(refine_sub_layers),
737
+ cfg_norm=bool(cfg_norm),
738
+ use_en_prompt=bool(use_en_prompt),
739
+ resolution=int(resolution),
740
+ gpu_duration=int(gpu_duration),
741
+ )
742
+
743
+ params = {
744
+ "seed": int(used_seed),
745
+ "randomize_seed": bool(randomize_seed),
746
+ "prompt": prompt,
747
+ "neg_prompt": neg_prompt,
748
+ "true_guidance_scale": float(true_guidance_scale),
749
+ "num_inference_steps": int(num_inference_steps),
750
+ "layers": int(refine_sub_layers),
751
+ "cfg_norm": bool(cfg_norm),
752
+ "use_en_prompt": bool(use_en_prompt),
753
+ "resolution": int(resolution),
754
+ "gpu_duration": int(gpu_duration),
755
+ }
756
+ refine_meta = {
757
+ "mode": "refine",
758
+ "from_node_id": n.node_id,
759
+ "layer_idx": idx,
760
+ "sub_layers": int(refine_sub_layers),
761
+ }
762
+
763
+ node = _create_node(
764
+ state=state,
765
+ layers=layers_out,
766
+ kind="refine",
767
+ parent_id=n.node_id,
768
+ params=params,
769
+ refine_meta=refine_meta,
770
+ )
771
+
772
+ # remember for redo
773
+ state["last_source_for_redo"] = (n.node_id, idx, int(refine_sub_layers), params)
774
+
775
+ if state.get("autosave"):
776
+ _persist_node_to_dataset(state, node)
777
+
778
+ # Update UI: refined accordion opens; current node is refined node now
779
+ node = _ensure_exports(node)
780
+ chips = _chips_html(state, node.node_id)
781
+ history_dd = gr.update(choices=_history_choices(state), value=node.node_id)
782
+
783
+ return (
784
+ state,
785
+ gr.update(open=True),
786
+ node.layers,
787
+ history_dd,
788
+ chips,
789
+ node.pptx_path,
790
+ node.zip_path,
791
+ )
792
+
793
+
794
+ def on_history_select(node_id, state):
795
+ if state is None:
796
+ state = new_state()
797
+ if node_id and node_id in state["nodes"]:
798
+ state["current_id"] = node_id
799
+ state["selected_layer_idx"] = 0
800
+
801
+ main_gallery, strip, layer_dd, history_dd, chips, pptx, zzip, acc, refined = _render_current(state)
802
+ return (
803
+ state,
804
+ main_gallery,
805
+ strip,
806
+ layer_dd,
807
+ history_dd,
808
+ chips,
809
+ pptx,
810
+ zzip,
811
+ acc,
812
+ refined,
813
+ "Selected: Layer 1",
814
+ )
815
+
816
+
817
+ def on_back_to_parent(state):
818
+ if state is None:
819
+ state = new_state()
820
+ n = _current_node(state)
821
+ if n and n.parent_id and n.parent_id in state["nodes"]:
822
+ state["current_id"] = n.parent_id
823
+ state["selected_layer_idx"] = 0
824
+ return on_history_select(state.get("current_id"), state)
825
+
826
+
827
+ def on_duplicate_node(state):
828
+ if state is None:
829
+ state = new_state()
830
+ n = _current_node(state)
831
+ if not n:
832
+ return on_history_select(state.get("current_id"), state)
833
+
834
+ # clone layers
835
+ cloned_layers = [im.copy() for im in n.layers]
836
+ params = dict(n.params)
837
+ refine_meta = {"mode": "duplicate", "from_node_id": n.node_id}
838
+ newn = _create_node(
839
+ state=state,
840
+ layers=cloned_layers,
841
+ kind="refine",
842
+ parent_id=n.parent_id,
843
+ params=params,
844
+ refine_meta=refine_meta,
845
+ )
846
+ newn.name = f"{n.name} (copy)"
847
+
848
+ if state.get("autosave"):
849
+ _persist_node_to_dataset(state, newn)
850
+
851
+ return on_history_select(newn.node_id, state)
852
+
853
+
854
+ def on_rename_node(new_name, state):
855
+ if state is None:
856
+ state = new_state()
857
+ n = _current_node(state)
858
+ if n and new_name and isinstance(new_name, str):
859
+ n.name = new_name.strip()[:80] if new_name.strip() else n.name
860
+ state["nodes"][n.node_id] = n
861
+ if state.get("autosave"):
862
+ _persist_node_to_dataset(state, n)
863
+ return on_history_select(state.get("current_id"), state)
864
+
865
+
866
+ def on_redo_refine(
867
+ seed,
868
+ randomize_seed,
869
+ prompt,
870
+ neg_prompt,
871
+ true_guidance_scale,
872
+ num_inference_steps,
873
+ cfg_norm,
874
+ use_en_prompt,
875
+ resolution,
876
+ gpu_duration,
877
+ state,
878
+ ):
879
+ if state is None:
880
+ state = new_state()
881
+
882
+ info = state.get("last_source_for_redo")
883
+ if not info:
884
  return (
885
+ state,
886
+ gr.update(open=True),
 
 
 
 
 
 
887
  [],
888
+ gr.update(choices=_history_choices(state)),
889
+ _chips_html(state, state.get("current_id")),
890
  None,
891
  None,
892
  )
893
 
894
+ from_node_id, layer_idx, sub_layers, _params = info
895
+ if from_node_id not in state["nodes"]:
896
+ return (
897
+ state,
898
+ gr.update(open=True),
899
+ [],
900
+ gr.update(choices=_history_choices(state)),
901
+ _chips_html(state, state.get("current_id")),
902
+ None,
903
+ None,
904
+ )
905
 
906
+ src = state["nodes"][from_node_id]
907
+ layer_idx = max(0, min(int(layer_idx), len(src.layers) - 1))
908
+ selected = src.layers[layer_idx]
909
 
910
+ layers_out, used_seed = gpu_run_pipeline(
911
+ pil_image_rgba=selected.convert("RGBA"),
912
+ seed=int(seed),
913
+ randomize_seed=bool(randomize_seed),
914
  prompt=prompt,
915
  neg_prompt=neg_prompt,
916
+ true_guidance_scale=float(true_guidance_scale),
917
+ num_inference_steps=int(num_inference_steps),
918
+ layer=int(sub_layers),
919
+ cfg_norm=bool(cfg_norm),
920
+ use_en_prompt=bool(use_en_prompt),
921
+ resolution=int(resolution),
922
+ gpu_duration=int(gpu_duration),
923
  )
924
 
925
+ params = {
926
+ "seed": int(used_seed),
927
+ "randomize_seed": bool(randomize_seed),
928
+ "prompt": prompt,
929
+ "neg_prompt": neg_prompt,
930
+ "true_guidance_scale": float(true_guidance_scale),
931
+ "num_inference_steps": int(num_inference_steps),
932
+ "layers": int(sub_layers),
933
+ "cfg_norm": bool(cfg_norm),
934
+ "use_en_prompt": bool(use_en_prompt),
935
+ "resolution": int(resolution),
936
+ "gpu_duration": int(gpu_duration),
937
+ }
938
+
939
+ refine_meta = {
940
+ "mode": "redo_refine",
941
+ "from_node_id": from_node_id,
942
+ "layer_idx": int(layer_idx),
943
+ "sub_layers": int(sub_layers),
944
+ }
945
 
946
+ node = _create_node(
947
+ state=state,
948
+ layers=layers_out,
949
+ kind="refine",
950
+ parent_id=from_node_id,
951
+ params=params,
952
+ refine_meta=refine_meta,
953
+ )
954
+ node.name = f"Redo refine d{node.depth}"
955
+
956
+ if state.get("autosave"):
957
+ _persist_node_to_dataset(state, node)
958
+
959
+ node = _ensure_exports(node)
960
+ chips = _chips_html(state, node.node_id)
961
+ history_dd = gr.update(choices=_history_choices(state), value=node.node_id)
962
 
 
963
  return (
964
+ state,
965
+ gr.update(open=True),
966
+ node.layers,
967
+ history_dd,
 
968
  chips,
969
+ node.pptx_path,
970
+ node.zip_path,
 
 
 
971
  )
972
 
973
+
974
+ def on_save_current(state):
975
+ if state is None:
976
+ state = new_state()
977
+ n = _current_node(state)
978
+ if not n:
979
+ return "Nothing to save."
980
+ if not _ds_enabled():
981
+ return "Dataset persistence disabled. Set DATASET_REPO env var and provide hf token."
982
+ _persist_node_to_dataset(state, n)
983
+ return f"Saved node {n.node_id[:8]} to dataset session {state['session_id']}."
984
+
985
+
986
+ def on_refresh_sessions():
987
+ if not _ds_enabled():
988
+ return gr.update(choices=[], value=None), "Dataset persistence disabled."
989
+ sessions = ds_list_sessions()
990
+ choices = [(s, s) for s in sessions]
991
+ return gr.update(choices=choices, value=(choices[-1][1] if choices else None)), f"Found {len(choices)} sessions."
992
+
993
+
994
+ def on_load_session(session_id, state):
995
+ if state is None:
996
+ state = new_state()
997
+ if not session_id:
998
+ return on_history_select(state.get("current_id"), state)
999
+ if not _ds_enabled():
1000
+ return on_history_select(state.get("current_id"), state)
1001
+ st = _load_session_from_dataset(session_id)
1002
+ # preserve autosave toggle choice from current UI state (if any)
1003
+ st["autosave"] = bool(state.get("autosave", False))
1004
+ return on_history_select(st.get("current_id"), st)
1005
+
1006
+
1007
  ensure_dirname(LOG_DIR)
1008
 
1009
  examples = [
 
1022
  "assets/test_images/13.png",
1023
  ]
1024
 
1025
+
1026
  with gr.Blocks() as demo:
1027
+ state = gr.State(new_state())
1028
 
1029
  with gr.Column(elem_id="col-container"):
1030
  gr.HTML(
1031
  '<img src="https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-Image/layered/qwen-image-layered-logo.png" '
1032
  'alt="Qwen-Image-Layered Logo" width="600" style="display: block; margin: 0 auto;">'
1033
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
1034
  gr.Markdown(
1035
  """
1036
+ The text prompt is intended to describe the overall content of the input image—including elements that may be partially occluded (e.g., you may specify the text hidden behind a foreground object). It is not designed to control the semantic content of individual layers explicitly.
 
1037
  """
1038
  )
1039
 
 
1109
  placeholder="e.g. 60, 120, 300, 1000, 1500",
1110
  )
1111
 
1112
+ with gr.Row():
1113
+ run_button = gr.Button("Decompose!", variant="primary")
1114
+ fast_button = gr.Button("Apply fast profile", variant="secondary")
1115
 
1116
+ with gr.Accordion("Refine (Recursive Decomposition)", open=True):
1117
+ refine_sub_layers = gr.Slider(
1118
+ label="Sub-layers (Refine)",
1119
+ minimum=2,
1120
+ maximum=10,
1121
+ step=1,
1122
+ value=3,
1123
+ )
1124
 
1125
+ layer_pick_help = gr.Markdown("Pick a layer from the strip or dropdown below.")
 
 
1126
 
1127
+ layer_strip = gr.Gallery(
1128
+ label="Layer strip (click to select)",
1129
+ columns=6,
1130
+ rows=1,
1131
+ height=110,
1132
+ format="png",
1133
+ show_label=True,
1134
+ )
1135
 
1136
+ with gr.Row():
1137
+ layer_dropdown = gr.Dropdown(
1138
+ label="Selected layer",
1139
+ choices=[],
1140
+ value=None,
1141
+ interactive=True,
1142
+ )
1143
 
1144
+ selected_layer_label = gr.Markdown("Selected: Layer 1")
 
 
1145
 
1146
+ refine_button = gr.Button("Refine selected layer", variant="primary")
1147
+ redo_button = gr.Button(" redo refine", variant="secondary")
 
1148
 
1149
+ with gr.Accordion("History", open=True):
1150
+ history_chips = gr.Markdown("")
1151
+ history_dropdown = gr.Dropdown(
1152
+ label="Nodes",
1153
+ choices=[],
1154
+ value=None,
1155
+ interactive=True,
1156
+ )
1157
  with gr.Row():
1158
+ back_parent_btn = gr.Button("← back to parent", variant="secondary")
1159
+ duplicate_btn = gr.Button("Duplicate node (branch)", variant="secondary")
1160
+ with gr.Row():
1161
+ rename_text = gr.Textbox(label="Branch name", value="", lines=1, placeholder="Rename current node")
1162
+ rename_btn = gr.Button("Rename", variant="secondary")
1163
+
1164
+ autosave = gr.Checkbox(
1165
+ label=f"Auto-save to Dataset repo ({DATASET_REPO if DATASET_REPO else 'not set'})",
1166
+ value=False,
1167
+ )
1168
+ save_btn = gr.Button("Save current node now", variant="secondary")
1169
+ save_status = gr.Markdown("")
1170
 
1171
+ with gr.Accordion("Load saved sessions (Dataset)", open=False):
1172
+ refresh_sessions_btn = gr.Button("Refresh sessions list", variant="secondary")
1173
+ sessions_dropdown = gr.Dropdown(label="Saved sessions", choices=[], value=None, interactive=True)
1174
+ load_session_btn = gr.Button("Load session", variant="primary")
1175
+ sessions_status = gr.Markdown("")
1176
+
1177
+ with gr.Column(scale=2):
1178
+ gallery = gr.Gallery(label="Layers", columns=4, rows=2, format="png")
1179
 
1180
+ with gr.Row():
1181
+ export_file = gr.File(label="Download PPTX (selected node)")
1182
+ export_zip_file = gr.File(label="Download ZIP (selected node)")
1183
+
1184
+ refined_accordion = gr.Accordion("Refined layers", open=False)
1185
+ with refined_accordion:
1186
+ refined_gallery = gr.Gallery(label="Refined layers", columns=4, rows=2, format="png")
1187
+
1188
+ # Examples
1189
  gr.Examples(
1190
  examples=examples,
1191
  inputs=[input_image],
1192
  outputs=[gallery, export_file, export_zip_file],
1193
+ fn=lambda x: ([], None, None),
1194
+ examples_per_page=14,
1195
  cache_examples=False,
1196
  run_on_click=False,
1197
  )
1198
 
1199
+ # Fast profile just updates UI fields
1200
+ fast_button.click(
1201
+ fn=on_apply_fast_profile,
1202
  inputs=[],
1203
+ outputs=[num_inference_steps, layer, resolution],
 
 
 
 
1204
  )
1205
 
1206
+ # Autosave toggle
1207
+ autosave.change(
1208
+ fn=on_toggle_autosave,
1209
+ inputs=[autosave, state],
1210
+ outputs=[state],
1211
  )
1212
 
1213
  # Decompose
 
1226
  use_en_prompt,
1227
  resolution,
1228
  gpu_duration,
1229
+ state,
1230
  ],
1231
  outputs=[
1232
+ state,
 
1233
  gallery,
1234
+ layer_strip,
1235
+ layer_dropdown,
1236
+ history_dropdown,
1237
+ history_chips,
 
1238
  export_file,
1239
  export_zip_file,
1240
+ refined_accordion,
1241
+ refined_gallery,
1242
+ seed,
1243
  ],
 
 
 
 
1244
  )
1245
 
1246
+ # Layer selection by clicking the strip
1247
+ layer_strip.select(
1248
+ fn=on_layer_select_from_strip,
1249
+ inputs=[state],
1250
+ outputs=[state, layer_dropdown, selected_layer_label],
 
 
 
 
 
 
 
1251
  )
1252
 
1253
+ # Layer selection by dropdown
1254
+ layer_dropdown.change(
1255
+ fn=on_layer_dropdown_change,
1256
+ inputs=[layer_dropdown, state],
1257
+ outputs=[state, selected_layer_label],
1258
  )
1259
 
1260
+ # Refine
1261
+ refine_button.click(
1262
+ fn=on_refine_click,
1263
+ inputs=[
1264
+ refine_sub_layers,
1265
+ seed,
1266
+ randomize_seed,
1267
+ prompt,
1268
+ neg_prompt,
1269
+ true_guidance_scale,
1270
+ num_inference_steps,
1271
+ cfg_norm,
1272
+ use_en_prompt,
1273
+ resolution,
1274
+ gpu_duration,
1275
+ state,
1276
+ ],
1277
+ outputs=[
1278
+ state,
1279
+ refined_accordion,
1280
+ refined_gallery,
1281
+ history_dropdown,
1282
+ history_chips,
1283
+ export_file,
1284
+ export_zip_file,
1285
+ ],
1286
  )
1287
 
1288
+ # Redo refine
1289
+ redo_button.click(
1290
+ fn=on_redo_refine,
1291
  inputs=[
 
1292
  seed,
1293
  randomize_seed,
1294
  prompt,
 
1299
  use_en_prompt,
1300
  resolution,
1301
  gpu_duration,
1302
+ state,
1303
+ ],
1304
+ outputs=[
1305
+ state,
1306
+ refined_accordion,
1307
+ refined_gallery,
1308
+ history_dropdown,
1309
+ history_chips,
1310
+ export_file,
1311
+ export_zip_file,
1312
  ],
1313
+ )
1314
+
1315
+ # History select
1316
+ history_dropdown.change(
1317
+ fn=on_history_select,
1318
+ inputs=[history_dropdown, state],
1319
  outputs=[
1320
+ state,
 
1321
  gallery,
1322
+ layer_strip,
1323
+ layer_dropdown,
1324
+ history_dropdown,
1325
+ history_chips,
1326
+ export_file,
1327
+ export_zip_file,
1328
+ refined_accordion,
1329
  refined_gallery,
1330
+ selected_layer_label,
1331
+ ],
1332
+ )
1333
+
1334
+ # Back to parent
1335
+ back_parent_btn.click(
1336
+ fn=on_back_to_parent,
1337
+ inputs=[state],
1338
+ outputs=[
1339
+ state,
1340
+ gallery,
1341
+ layer_strip,
1342
+ layer_dropdown,
1343
+ history_dropdown,
1344
+ history_chips,
1345
+ export_file,
1346
+ export_zip_file,
1347
+ refined_accordion,
1348
+ refined_gallery,
1349
+ selected_layer_label,
1350
+ ],
1351
+ )
1352
+
1353
+ # Duplicate node (branch)
1354
+ duplicate_btn.click(
1355
+ fn=on_duplicate_node,
1356
+ inputs=[state],
1357
+ outputs=[
1358
+ state,
1359
+ gallery,
1360
+ layer_strip,
1361
+ layer_dropdown,
1362
+ history_dropdown,
1363
+ history_chips,
1364
  export_file,
1365
  export_zip_file,
1366
+ refined_accordion,
1367
+ refined_gallery,
1368
+ selected_layer_label,
1369
  ],
1370
  )
1371
 
1372
+ # Rename node
1373
+ rename_btn.click(
1374
+ fn=on_rename_node,
1375
+ inputs=[rename_text, state],
1376
+ outputs=[
1377
+ state,
1378
+ gallery,
1379
+ layer_strip,
1380
+ layer_dropdown,
1381
+ history_dropdown,
1382
+ history_chips,
1383
+ export_file,
1384
+ export_zip_file,
1385
+ refined_accordion,
1386
+ refined_gallery,
1387
+ selected_layer_label,
1388
+ ],
1389
  )
1390
 
1391
+ # Save node
1392
  save_btn.click(
1393
+ fn=on_save_current,
1394
+ inputs=[state],
1395
+ outputs=[save_status],
1396
+ )
1397
+
1398
+ # Refresh sessions list
1399
+ refresh_sessions_btn.click(
1400
+ fn=on_refresh_sessions,
1401
+ inputs=[],
1402
+ outputs=[sessions_dropdown, sessions_status],
1403
  )
1404
 
1405
  # Load session
1406
+ load_session_btn.click(
1407
  fn=on_load_session,
1408
+ inputs=[sessions_dropdown, state],
1409
+ outputs=[
1410
+ state,
1411
+ gallery,
1412
+ layer_strip,
1413
+ layer_dropdown,
1414
+ history_dropdown,
1415
+ history_chips,
1416
+ export_file,
1417
+ export_zip_file,
1418
+ refined_accordion,
1419
+ refined_gallery,
1420
+ selected_layer_label,
1421
+ ],
1422
  )
1423
 
1424
+ # Serialize GPU tasks; helps stability on some ZeroGPU envs
1425
+ demo.queue(concurrency_count=1, max_size=20)
1426
+
1427
  if __name__ == "__main__":
1428
  demo.launch()