hexware commited on
Commit
578fc7f
·
verified ·
1 Parent(s): 5b7062a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +116 -40
app.py CHANGED
@@ -226,12 +226,8 @@ def ds_list_sessions(max_sessions: int = 50) -> Tuple[List[str], str]:
226
  repo_id = ds_repo_id()
227
  try:
228
  files = api.list_repo_files(repo_id=repo_id, repo_type="dataset")
229
- # sessions/<session_id>/session.json
230
  sess = set()
231
  for p in files:
232
- if p.startswith("sessions/") and p.count("/") >= 2 and p.endswith("/session.json"):
233
- # rarely list_repo_files returns directories without trailing slash, but for files it's path
234
- pass
235
  if p.startswith("sessions/") and p.endswith("session.json"):
236
  parts = p.split("/")
237
  if len(parts) >= 3:
@@ -241,6 +237,44 @@ def ds_list_sessions(max_sessions: int = 50) -> Tuple[List[str], str]:
241
  except Exception as e:
242
  return [], f"List sessions failed: {repr(e)}"
243
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
244
 
245
  # -------------------------
246
  # Node / History model
@@ -504,7 +538,16 @@ def _persist_session_manifest(state: Dict[str, Any]) -> Tuple[bool, str]:
504
  }
505
  b = json.dumps(manifest, ensure_ascii=False, indent=2).encode("utf-8")
506
  path = f"{_session_base(session_id)}/session.json"
507
- return ds_upload_bytes(path, b, f"save session manifest {session_id}")
 
 
 
 
 
 
 
 
 
508
 
509
  def _load_session_manifest(session_id: str) -> Tuple[Optional[Dict[str, Any]], str]:
510
  b, msg = ds_download_bytes(f"{_session_base(session_id)}/session.json")
@@ -551,7 +594,6 @@ def _persistence_status_text() -> str:
551
 
552
  def on_refresh_sessions():
553
  sessions, msg = ds_list_sessions()
554
- # For dropdown choices, use sessions as both label+value
555
  return gr.update(choices=sessions, value=(sessions[0] if sessions else None)), msg
556
 
557
  def on_init_dataset():
@@ -683,7 +725,6 @@ def on_decompose_click(
683
  )
684
 
685
  def on_layer_pick_from_dropdown(state: Dict[str, Any], layer_name: str):
686
- # layer_name = "Layer k"
687
  node_id = state.get("selected_node_id")
688
  imgs = _get_node_images(state, node_id) if node_id else []
689
  n = len(imgs)
@@ -703,7 +744,6 @@ def on_layer_pick_from_gallery(state: Dict[str, Any], evt: gr.SelectData):
703
  n = len(imgs)
704
  idx = int(evt.index) if evt and evt.index is not None else 0
705
  idx = max(0, min(n - 1, idx)) if n > 0 else 0
706
- # also update dropdown value
707
  dd_value = f"Layer {idx+1}" if n > 0 else None
708
  return gr.update(value=idx), gr.update(value=dd_value), _layer_label(idx, n)
709
 
@@ -739,7 +779,7 @@ def _refine_from_source(
739
  neg_prompt=neg_prompt,
740
  true_guidance_scale=true_guidance_scale,
741
  num_inference_steps=num_inference_steps,
742
- layer=sub_layers, # << refine layers count
743
  cfg_norm=cfg_norm,
744
  use_en_prompt=use_en_prompt,
745
  resolution=resolution,
@@ -790,6 +830,7 @@ def on_refine_click(
790
  gr.update(),
791
  gr.update(),
792
  gr.update(),
 
793
  gr.update(visible=False),
794
  [],
795
  None,
@@ -824,7 +865,6 @@ def on_refine_click(
824
 
825
  sub_layers = _clamp_int(sub_layers, default=3, lo=2, hi=10)
826
 
827
- # Run refine
828
  layers_out, used_seed, settings_snapshot = _refine_from_source(
829
  state,
830
  source_node_id=source_node_id,
@@ -842,7 +882,6 @@ def on_refine_click(
842
  randomize_seed=randomize_seed,
843
  )
844
 
845
- # Create child node and select it
846
  child_name = f"refine ({state['nodes'][source_node_id]['meta']['name']}) L{idx+1}"
847
  child_id = _add_node(
848
  state,
@@ -858,37 +897,32 @@ def on_refine_click(
858
  state["selected_node_id"] = child_id
859
  state["last_refined_node_id"] = child_id
860
 
861
- # UI update
862
  n_layers = len(layers_out)
863
  layer_choices, layer_value = _build_layer_dropdown(n_layers)
864
  hist_choices = _history_choices(state)
865
  chips = _make_chips(state)
866
  selected_label = _layer_label(0, n_layers)
867
 
868
- # exports for current node (child)
869
  pptx_path, zip_path, exp_msg = _current_node_export(state, child_id)
870
 
871
  status = f"Refined into {n_layers} sub-layer(s). Seed={used_seed}. {exp_msg}"
872
-
873
- # refined block visible + show refined gallery
874
  refined_visible = gr.update(visible=True)
875
 
876
- # Also update dropdown that selects nodes (history)
877
  return (
878
  state,
879
- layers_out, # main gallery shows child node now
880
- layers_out, # mini picker gallery
881
  gr.update(choices=layer_choices, value=layer_value),
882
  gr.update(value=0),
883
  selected_label,
884
  gr.update(choices=[c[1] for c in hist_choices], value=child_id),
885
  chips,
886
  refined_visible,
887
- layers_out, # refined gallery
888
  pptx_path,
889
  zip_path,
890
  status,
891
- gr.update(), # no change to seed_used textbox here
892
  )
893
 
894
  def on_history_select(state: Dict[str, Any], node_id: str):
@@ -919,7 +953,6 @@ def on_history_select(state: Dict[str, Any], node_id: str):
919
 
920
  pptx_path, zip_path, exp_msg = _current_node_export(state, node_id)
921
 
922
- # hide refined block when user manually switches nodes (keeps UX clean)
923
  return (
924
  state,
925
  imgs,
@@ -952,7 +985,6 @@ def on_duplicate_node(state: Dict[str, Any]):
952
  new_id = _duplicate_node(state, node_id)
953
  if not new_id:
954
  return state, gr.update(), "Duplicate failed."
955
- # select duplicated
956
  return on_history_select(state, new_id)
957
 
958
  def on_rename_node(state: Dict[str, Any], new_name: str):
@@ -960,10 +992,7 @@ def on_rename_node(state: Dict[str, Any], new_name: str):
960
  if not node_id:
961
  return state, gr.update(), "No selected node."
962
  _rename_node(state, node_id, new_name)
963
- # update history dropdown labels
964
  hist_choices = _history_choices(state)
965
- # gr.Dropdown choices are values; we keep node_id list, but users see ids.
966
- # We'll show names in a markdown below.
967
  chips = _make_chips(state)
968
  return state, gr.update(choices=[c[1] for c in hist_choices], value=node_id), chips, "Renamed."
969
 
@@ -1054,7 +1083,6 @@ def on_redo_refine(
1054
  randomize_seed=randomize_seed,
1055
  )
1056
 
1057
- # create new child under same source (new branch)
1058
  child_name = f"redo refine ({state['nodes'][source_node_id]['meta']['name']}) L{int(source_layer_idx)+1}"
1059
  child_id = _add_node(
1060
  state,
@@ -1070,7 +1098,6 @@ def on_redo_refine(
1070
  state["selected_node_id"] = child_id
1071
  state["last_refined_node_id"] = child_id
1072
 
1073
- # update UI (same as refine)
1074
  n_layers = len(layers_out)
1075
  layer_choices, layer_value = _build_layer_dropdown(n_layers)
1076
  hist_choices = _history_choices(state)
@@ -1111,7 +1138,7 @@ def on_save_current(state: Dict[str, Any]):
1111
  ok2, msg2 = _persist_session_manifest(state)
1112
  if not ok2:
1113
  return msg2
1114
- return f"✅ Saved node + session manifest. {msg1}"
1115
 
1116
  def on_load_session(state: Dict[str, Any], session_id: str):
1117
  if not session_id:
@@ -1121,7 +1148,6 @@ def on_load_session(state: Dict[str, Any], session_id: str):
1121
  if manifest is None:
1122
  return state, gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), gr.update(visible=False), [], None, None, msg
1123
 
1124
- # rebuild state from manifest (lazy-load images only for selected)
1125
  new_state = _init_state()
1126
  new_state["session_id"] = manifest.get("session_id") or session_id
1127
  new_state["created_at"] = manifest.get("created_at")
@@ -1146,7 +1172,6 @@ def on_load_session(state: Dict[str, Any], session_id: str):
1146
  else:
1147
  return new_state, gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), gr.update(visible=False), [], None, None, f"Loaded manifest but failed to load node images: {msg2}"
1148
 
1149
- # Also pre-load root images if root != selected (nice UX)
1150
  root = new_state.get("root_node_id")
1151
  if root and root != sel and root in nodes_meta and not new_state["nodes"][root]["images"]:
1152
  rl = int(nodes_meta[root].get("num_layers", 0))
@@ -1180,20 +1205,14 @@ def on_load_session(state: Dict[str, Any], session_id: str):
1180
  )
1181
 
1182
  def on_history_need_images(state: Dict[str, Any], node_id: str):
1183
- """
1184
- When switching nodes after loading manifest, images may not be loaded yet.
1185
- We'll lazy-load from dataset if missing.
1186
- """
1187
  if not node_id or node_id not in state.get("nodes", {}):
1188
  return state, "Unknown node."
1189
  imgs = state["nodes"][node_id].get("images", [])
1190
  if imgs:
1191
  return state, "OK"
1192
- # try load from dataset
1193
  session_id = state.get("session_id")
1194
  if not session_id:
1195
  return state, "No session_id."
1196
- # get num_layers from meta? we stored only in manifest; not guaranteed here. fallback: try read session.json again
1197
  manifest, msg = _load_session_manifest(session_id)
1198
  if not manifest:
1199
  return state, f"Cannot load manifest: {msg}"
@@ -1207,6 +1226,45 @@ def on_history_need_images(state: Dict[str, Any], node_id: str):
1207
  state["nodes"][node_id]["images"] = imgs2
1208
  return state, "Loaded images."
1209
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1210
  # -------------------------
1211
  # Build UI
1212
  # -------------------------
@@ -1336,7 +1394,6 @@ It is not designed to control the semantic content of individual layers explicit
1336
 
1337
  with gr.Group():
1338
  gr.Markdown("### Refine (Recursive Decomposition)")
1339
- # refine params are NOT in Advanced Settings (per your UX request)
1340
  sub_layers = gr.Slider(
1341
  label="Sub-layers (Refine)",
1342
  minimum=2,
@@ -1384,7 +1441,6 @@ It is not designed to control the semantic content of individual layers explicit
1384
  export_zip = gr.File(label="Download ZIP")
1385
 
1386
  status = gr.Markdown("")
1387
-
1388
  seed_used = gr.Textbox(label="Seed used", value="", interactive=False)
1389
 
1390
  # Examples
@@ -1422,6 +1478,27 @@ It is not designed to control the semantic content of individual layers explicit
1422
  ],
1423
  )
1424
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1425
  # Decompose
1426
  btn_decompose.click(
1427
  fn=on_decompose_click,
@@ -1510,8 +1587,7 @@ It is not designed to control the semantic content of individual layers explicit
1510
 
1511
  # History select (lazy load images first)
1512
  def _history_select_with_lazy(state, node_id):
1513
- state, msg = on_history_need_images(state, node_id)
1514
- # ignore msg for now; on_history_select will error if still missing
1515
  return on_history_select(state, node_id)
1516
 
1517
  history_dd.change(
 
226
  repo_id = ds_repo_id()
227
  try:
228
  files = api.list_repo_files(repo_id=repo_id, repo_type="dataset")
 
229
  sess = set()
230
  for p in files:
 
 
 
231
  if p.startswith("sessions/") and p.endswith("session.json"):
232
  parts = p.split("/")
233
  if len(parts) >= 3:
 
237
  except Exception as e:
238
  return [], f"List sessions failed: {repr(e)}"
239
 
240
+ def ds_get_root_index() -> Tuple[Optional[Dict[str, Any]], str]:
241
+ """
242
+ Read dataset root index.json (used for "last session").
243
+ Expected keys: id or last_session_id or session_id.
244
+ """
245
+ b, msg = ds_download_bytes("index.json")
246
+ if b is None:
247
+ return None, msg
248
+ try:
249
+ obj = json.loads(b.decode("utf-8"))
250
+ if not isinstance(obj, dict):
251
+ return None, "index.json is not an object"
252
+ return obj, "OK"
253
+ except Exception as e:
254
+ return None, f"Failed to parse index.json: {repr(e)}"
255
+
256
+ def ds_get_last_session_id() -> Tuple[Optional[str], str]:
257
+ idx, msg = ds_get_root_index()
258
+ if not idx:
259
+ return None, msg
260
+ for k in ("id", "last_session_id", "session_id"):
261
+ v = idx.get(k)
262
+ if isinstance(v, str) and v.strip():
263
+ return v.strip(), "OK"
264
+ return None, "index.json missing id/last_session_id/session_id"
265
+
266
+ def ds_set_root_index(session_id: str) -> Tuple[bool, str]:
267
+ """
268
+ Update dataset root index.json so we can auto-load the last session after refresh/restart.
269
+ """
270
+ payload = {
271
+ "id": session_id,
272
+ "last_session_id": session_id,
273
+ "updated_at": time.time(),
274
+ }
275
+ b = json.dumps(payload, ensure_ascii=False, indent=2).encode("utf-8")
276
+ return ds_upload_bytes("index.json", b, f"update index.json {session_id}")
277
+
278
 
279
  # -------------------------
280
  # Node / History model
 
538
  }
539
  b = json.dumps(manifest, ensure_ascii=False, indent=2).encode("utf-8")
540
  path = f"{_session_base(session_id)}/session.json"
541
+ ok_m, msg_m = ds_upload_bytes(path, b, f"save session manifest {session_id}")
542
+ if not ok_m:
543
+ return False, msg_m
544
+
545
+ # Update root index.json (best-effort, do not fail save if it can't be updated)
546
+ ok_i, msg_i = ds_set_root_index(session_id)
547
+ if not ok_i:
548
+ return True, f"{msg_m} (warning: index.json update failed: {msg_i})"
549
+
550
+ return True, f"{msg_m} + updated index.json"
551
 
552
  def _load_session_manifest(session_id: str) -> Tuple[Optional[Dict[str, Any]], str]:
553
  b, msg = ds_download_bytes(f"{_session_base(session_id)}/session.json")
 
594
 
595
  def on_refresh_sessions():
596
  sessions, msg = ds_list_sessions()
 
597
  return gr.update(choices=sessions, value=(sessions[0] if sessions else None)), msg
598
 
599
  def on_init_dataset():
 
725
  )
726
 
727
  def on_layer_pick_from_dropdown(state: Dict[str, Any], layer_name: str):
 
728
  node_id = state.get("selected_node_id")
729
  imgs = _get_node_images(state, node_id) if node_id else []
730
  n = len(imgs)
 
744
  n = len(imgs)
745
  idx = int(evt.index) if evt and evt.index is not None else 0
746
  idx = max(0, min(n - 1, idx)) if n > 0 else 0
 
747
  dd_value = f"Layer {idx+1}" if n > 0 else None
748
  return gr.update(value=idx), gr.update(value=dd_value), _layer_label(idx, n)
749
 
 
779
  neg_prompt=neg_prompt,
780
  true_guidance_scale=true_guidance_scale,
781
  num_inference_steps=num_inference_steps,
782
+ layer=sub_layers,
783
  cfg_norm=cfg_norm,
784
  use_en_prompt=use_en_prompt,
785
  resolution=resolution,
 
830
  gr.update(),
831
  gr.update(),
832
  gr.update(),
833
+ gr.update(),
834
  gr.update(visible=False),
835
  [],
836
  None,
 
865
 
866
  sub_layers = _clamp_int(sub_layers, default=3, lo=2, hi=10)
867
 
 
868
  layers_out, used_seed, settings_snapshot = _refine_from_source(
869
  state,
870
  source_node_id=source_node_id,
 
882
  randomize_seed=randomize_seed,
883
  )
884
 
 
885
  child_name = f"refine ({state['nodes'][source_node_id]['meta']['name']}) L{idx+1}"
886
  child_id = _add_node(
887
  state,
 
897
  state["selected_node_id"] = child_id
898
  state["last_refined_node_id"] = child_id
899
 
 
900
  n_layers = len(layers_out)
901
  layer_choices, layer_value = _build_layer_dropdown(n_layers)
902
  hist_choices = _history_choices(state)
903
  chips = _make_chips(state)
904
  selected_label = _layer_label(0, n_layers)
905
 
 
906
  pptx_path, zip_path, exp_msg = _current_node_export(state, child_id)
907
 
908
  status = f"Refined into {n_layers} sub-layer(s). Seed={used_seed}. {exp_msg}"
 
 
909
  refined_visible = gr.update(visible=True)
910
 
 
911
  return (
912
  state,
913
+ layers_out,
914
+ layers_out,
915
  gr.update(choices=layer_choices, value=layer_value),
916
  gr.update(value=0),
917
  selected_label,
918
  gr.update(choices=[c[1] for c in hist_choices], value=child_id),
919
  chips,
920
  refined_visible,
921
+ layers_out,
922
  pptx_path,
923
  zip_path,
924
  status,
925
+ gr.update(),
926
  )
927
 
928
  def on_history_select(state: Dict[str, Any], node_id: str):
 
953
 
954
  pptx_path, zip_path, exp_msg = _current_node_export(state, node_id)
955
 
 
956
  return (
957
  state,
958
  imgs,
 
985
  new_id = _duplicate_node(state, node_id)
986
  if not new_id:
987
  return state, gr.update(), "Duplicate failed."
 
988
  return on_history_select(state, new_id)
989
 
990
  def on_rename_node(state: Dict[str, Any], new_name: str):
 
992
  if not node_id:
993
  return state, gr.update(), "No selected node."
994
  _rename_node(state, node_id, new_name)
 
995
  hist_choices = _history_choices(state)
 
 
996
  chips = _make_chips(state)
997
  return state, gr.update(choices=[c[1] for c in hist_choices], value=node_id), chips, "Renamed."
998
 
 
1083
  randomize_seed=randomize_seed,
1084
  )
1085
 
 
1086
  child_name = f"redo refine ({state['nodes'][source_node_id]['meta']['name']}) L{int(source_layer_idx)+1}"
1087
  child_id = _add_node(
1088
  state,
 
1098
  state["selected_node_id"] = child_id
1099
  state["last_refined_node_id"] = child_id
1100
 
 
1101
  n_layers = len(layers_out)
1102
  layer_choices, layer_value = _build_layer_dropdown(n_layers)
1103
  hist_choices = _history_choices(state)
 
1138
  ok2, msg2 = _persist_session_manifest(state)
1139
  if not ok2:
1140
  return msg2
1141
+ return f"✅ Saved node + session manifest. {msg1} | {msg2}"
1142
 
1143
  def on_load_session(state: Dict[str, Any], session_id: str):
1144
  if not session_id:
 
1148
  if manifest is None:
1149
  return state, gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), gr.update(visible=False), [], None, None, msg
1150
 
 
1151
  new_state = _init_state()
1152
  new_state["session_id"] = manifest.get("session_id") or session_id
1153
  new_state["created_at"] = manifest.get("created_at")
 
1172
  else:
1173
  return new_state, gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), gr.update(visible=False), [], None, None, f"Loaded manifest but failed to load node images: {msg2}"
1174
 
 
1175
  root = new_state.get("root_node_id")
1176
  if root and root != sel and root in nodes_meta and not new_state["nodes"][root]["images"]:
1177
  rl = int(nodes_meta[root].get("num_layers", 0))
 
1205
  )
1206
 
1207
  def on_history_need_images(state: Dict[str, Any], node_id: str):
 
 
 
 
1208
  if not node_id or node_id not in state.get("nodes", {}):
1209
  return state, "Unknown node."
1210
  imgs = state["nodes"][node_id].get("images", [])
1211
  if imgs:
1212
  return state, "OK"
 
1213
  session_id = state.get("session_id")
1214
  if not session_id:
1215
  return state, "No session_id."
 
1216
  manifest, msg = _load_session_manifest(session_id)
1217
  if not manifest:
1218
  return state, f"Cannot load manifest: {msg}"
 
1226
  state["nodes"][node_id]["images"] = imgs2
1227
  return state, "Loaded images."
1228
 
1229
+ def on_autoload_last_session(state: Dict[str, Any]):
1230
+ if not isinstance(state, dict):
1231
+ state = _init_state()
1232
+ if not ds_enabled():
1233
+ return (
1234
+ state,
1235
+ gr.update(),
1236
+ gr.update(),
1237
+ gr.update(),
1238
+ gr.update(),
1239
+ gr.update(),
1240
+ gr.update(),
1241
+ gr.update(),
1242
+ gr.update(visible=False),
1243
+ [],
1244
+ None,
1245
+ None,
1246
+ "Dataset persistence disabled (no autoload).",
1247
+ )
1248
+ sid, msg = ds_get_last_session_id()
1249
+ if not sid:
1250
+ return (
1251
+ state,
1252
+ gr.update(),
1253
+ gr.update(),
1254
+ gr.update(),
1255
+ gr.update(),
1256
+ gr.update(),
1257
+ gr.update(),
1258
+ gr.update(),
1259
+ gr.update(visible=False),
1260
+ [],
1261
+ None,
1262
+ None,
1263
+ f"No last session to autoload: {msg}",
1264
+ )
1265
+ return on_load_session(state, sid)
1266
+
1267
+
1268
  # -------------------------
1269
  # Build UI
1270
  # -------------------------
 
1394
 
1395
  with gr.Group():
1396
  gr.Markdown("### Refine (Recursive Decomposition)")
 
1397
  sub_layers = gr.Slider(
1398
  label="Sub-layers (Refine)",
1399
  minimum=2,
 
1441
  export_zip = gr.File(label="Download ZIP")
1442
 
1443
  status = gr.Markdown("")
 
1444
  seed_used = gr.Textbox(label="Seed used", value="", interactive=False)
1445
 
1446
  # Examples
 
1478
  ],
1479
  )
1480
 
1481
+ # Auto-load last session (reads root index.json; supports id/last_session_id/session_id)
1482
+ demo.load(
1483
+ fn=on_autoload_last_session,
1484
+ inputs=[state],
1485
+ outputs=[
1486
+ state,
1487
+ gallery,
1488
+ layer_picker,
1489
+ layer_dropdown,
1490
+ selected_layer_idx,
1491
+ selected_layer_label,
1492
+ history_dd,
1493
+ chips_md,
1494
+ refined_block,
1495
+ refined_gallery,
1496
+ export_pptx,
1497
+ export_zip,
1498
+ status,
1499
+ ],
1500
+ )
1501
+
1502
  # Decompose
1503
  btn_decompose.click(
1504
  fn=on_decompose_click,
 
1587
 
1588
  # History select (lazy load images first)
1589
  def _history_select_with_lazy(state, node_id):
1590
+ state, _ = on_history_need_images(state, node_id)
 
1591
  return on_history_select(state, node_id)
1592
 
1593
  history_dd.change(