brian4dwell commited on
Commit
fbd0580
·
1 Parent(s): 4c075ec

load and save of ui settings

Browse files
.gitignore CHANGED
@@ -68,3 +68,4 @@ db.sqlite3-journal
68
  # Flask stuff:
69
  instance/
70
  .webassets-cache
 
 
68
  # Flask stuff:
69
  instance/
70
  .webassets-cache
71
+ stream3r/__pycache__/stream_session.cpython-311.pyc
app.py CHANGED
@@ -4,6 +4,7 @@
4
  # This source code is licensed under the license found in the
5
  # LICENSE file in the root directory of this source tree.
6
 
 
7
  import os
8
  import cv2
9
  import torch
@@ -15,6 +16,7 @@ import glob
15
  import gc
16
  import time
17
  import zipfile
 
18
  from stream3r.models.stream3r import STream3R
19
  from stream3r.stream_session import StreamSession
20
  from stream3r.models.components.utils.load_fn import load_and_preprocess_images
@@ -140,6 +142,33 @@ def _copy_with_unique_name(src_path: str, dst_dir: str) -> str:
140
  return dest_path
141
 
142
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
143
  # -------------------------------------------------------------------------
144
  # 1) Core model inference
145
  # -------------------------------------------------------------------------
@@ -341,7 +370,8 @@ def update_gallery_on_upload(input_video, input_images, input_zip, session_state
341
  Handle any new uploads (video, images, or zip) and render preview.
342
  """
343
  if not input_video and not input_images and not input_zip and not session_state:
344
- return None, current_target_dir, None, None, None
 
345
 
346
  target_dir, image_paths, session_loaded = handle_uploads(
347
  input_video,
@@ -356,7 +386,90 @@ def update_gallery_on_upload(input_video, input_images, input_zip, session_state
356
  else:
357
  message = "Upload complete. Click 'Reconstruct' to begin 3D processing."
358
 
359
- return None, target_dir, image_paths, message, None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
360
 
361
 
362
  def update_gallery_without_session(input_video, input_images, input_zip, current_target_dir):
@@ -391,9 +504,7 @@ def gradio_demo(
391
 
392
  # Prepare frame_filter dropdown
393
  target_dir_images = os.path.join(target_dir, "images")
394
- all_files = sorted(os.listdir(target_dir_images)) if os.path.isdir(target_dir_images) else []
395
- all_files = [f"{i}: {filename}" for i, filename in enumerate(all_files)]
396
- frame_filter_choices = ["All"] + all_files
397
 
398
  print("Running run_model...")
399
  with torch.no_grad():
@@ -403,6 +514,25 @@ def gradio_demo(
403
  prediction_save_path = os.path.join(target_dir, "predictions.npz")
404
  np.savez(prediction_save_path, **predictions)
405
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
406
  session_state_file = None
407
  if streaming:
408
  if session_cache_path is None:
@@ -417,7 +547,7 @@ def gradio_demo(
417
  # Build a GLB file name
418
  glbfile = os.path.join(
419
  target_dir,
420
- f"glbscene_{conf_thres}_{frame_filter.replace('.', '_').replace(':', '').replace(' ', '_')}_maskb{mask_black_bg}_maskw{mask_white_bg}_cam{show_cam}_sky{mask_sky}_pred{prediction_mode.replace(' ', '_')}_mode{mode}.glb",
421
  )
422
 
423
  # Convert predictions to GLB
@@ -441,7 +571,8 @@ def gradio_demo(
441
 
442
  end_time = time.time()
443
  print(f"Total time: {end_time - start_time:.2f} seconds (including IO)")
444
- log_msg = f"Reconstruction Success ({len(all_files)} frames). Waiting for visualization."
 
445
 
446
  return (
447
  glbfile,
@@ -511,7 +642,7 @@ def update_visualization(
511
  loaded = np.load(predictions_path)
512
  predictions = {key: np.array(loaded[key]) for key in key_list}
513
 
514
- sanitized_frame = frame_filter.replace('.', '_').replace(':', '').replace(' ', '_') if frame_filter else "All"
515
  glbfile = os.path.join(
516
  target_dir,
517
  f"glbscene_{conf_thres}_{sanitized_frame}_maskb{mask_black_bg}_maskw{mask_white_bg}_cam{show_cam}_sky{mask_sky}_pred{prediction_mode.replace(' ', '_')}_mode{mode_value}.glb",
@@ -702,7 +833,7 @@ with gr.Blocks(
702
  streaming = gr.Radio(
703
  [('stream', True), ('batch', False)],
704
  label="Streaming or Batch Mode",
705
- value=False,
706
  scale=1,
707
  )
708
 
@@ -710,7 +841,7 @@ with gr.Blocks(
710
  mode = gr.Radio(
711
  ["causal", "window", "full"],
712
  label="Select Processing Mode",
713
- value="causal",
714
  scale=1,
715
  )
716
 
@@ -801,7 +932,22 @@ with gr.Blocks(
801
  mode,
802
  False,
803
  )
804
- return glbfile, log_msg, target_dir, dropdown, image_paths, session_file
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
805
 
806
  gr.Markdown("Click any row to load an example.", elem_classes=["example-log"])
807
 
@@ -828,6 +974,14 @@ with gr.Blocks(
828
  frame_filter,
829
  image_gallery,
830
  session_state_output,
 
 
 
 
 
 
 
 
831
  ],
832
  fn=example_pipeline,
833
  cache_examples=False,
@@ -981,7 +1135,22 @@ with gr.Blocks(
981
  # -------------------------------------------------------------------------
982
  # Auto-update gallery whenever user uploads or changes their files
983
  # -------------------------------------------------------------------------
984
- upload_outputs = [reconstruction_output, target_dir_output, image_gallery, log_output, session_state_output]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
985
  no_session_inputs = [input_video, input_images, input_zip, target_dir_output]
986
 
987
  input_video.change(fn=update_gallery_without_session, inputs=no_session_inputs, outputs=upload_outputs)
 
4
  # This source code is licensed under the license found in the
5
  # LICENSE file in the root directory of this source tree.
6
 
7
+ import json
8
  import os
9
  import cv2
10
  import torch
 
16
  import gc
17
  import time
18
  import zipfile
19
+ from typing import Any, Dict, Optional
20
  from stream3r.models.stream3r import STream3R
21
  from stream3r.stream_session import StreamSession
22
  from stream3r.models.components.utils.load_fn import load_and_preprocess_images
 
142
  return dest_path
143
 
144
 
145
+ def load_session_settings(target_dir: str) -> Dict[str, Any]:
146
+ settings_path = os.path.join(target_dir, "session_settings.json")
147
+ if not os.path.exists(settings_path):
148
+ return {}
149
+ try:
150
+ with open(settings_path, "r", encoding="utf-8") as handle:
151
+ data = json.load(handle)
152
+ if isinstance(data, dict):
153
+ return data
154
+ except (json.JSONDecodeError, OSError) as exc:
155
+ print(f"Failed to load session settings from {settings_path}: {exc}")
156
+ return {}
157
+
158
+
159
+ def build_frame_filter_choices(target_dir_images: str) -> list[str]:
160
+ if not os.path.isdir(target_dir_images):
161
+ return ["All"]
162
+ files = sorted(os.listdir(target_dir_images))
163
+ return ["All"] + [f"{idx}: {name}" for idx, name in enumerate(files)]
164
+
165
+
166
+ def sanitize_frame_filter_label(label: Optional[str]) -> str:
167
+ if not label:
168
+ return "All"
169
+ return label.replace('.', '_').replace(':', '').replace(' ', '_')
170
+
171
+
172
  # -------------------------------------------------------------------------
173
  # 1) Core model inference
174
  # -------------------------------------------------------------------------
 
370
  Handle any new uploads (video, images, or zip) and render preview.
371
  """
372
  if not input_video and not input_images and not input_zip and not session_state:
373
+ default_updates = [gr.update()] * 9
374
+ return (None, current_target_dir, None, None, None, *default_updates)
375
 
376
  target_dir, image_paths, session_loaded = handle_uploads(
377
  input_video,
 
386
  else:
387
  message = "Upload complete. Click 'Reconstruct' to begin 3D processing."
388
 
389
+ target_dir_images = os.path.join(target_dir, "images")
390
+ frame_filter_choices = build_frame_filter_choices(target_dir_images)
391
+ frame_value = "All"
392
+ frame_update = gr.update(choices=frame_filter_choices, value=frame_value)
393
+
394
+ streaming_update = gr.update()
395
+ mode_update = gr.update()
396
+ conf_update = gr.update()
397
+ mask_black_update = gr.update()
398
+ mask_white_update = gr.update()
399
+ show_cam_update = gr.update()
400
+ mask_sky_update = gr.update()
401
+ prediction_mode_update = gr.update()
402
+
403
+ reconstruction_value = None
404
+
405
+ if session_loaded:
406
+ settings = load_session_settings(target_dir)
407
+
408
+ if settings:
409
+ if "frame_filter" in settings:
410
+ potential_value = settings.get("frame_filter", "All")
411
+ if potential_value in frame_filter_choices:
412
+ frame_value = potential_value
413
+ frame_update = gr.update(choices=frame_filter_choices, value=frame_value)
414
+
415
+ if "streaming" in settings:
416
+ streaming_update = gr.update(value=bool(settings.get("streaming", True)))
417
+
418
+ if settings.get("mode") in {"causal", "window", "full"}:
419
+ mode_update = gr.update(value=settings["mode"])
420
+
421
+ if "conf_thres" in settings:
422
+ try:
423
+ conf_update = gr.update(value=float(settings["conf_thres"]))
424
+ except (TypeError, ValueError):
425
+ pass
426
+
427
+ if "mask_black_bg" in settings:
428
+ mask_black_update = gr.update(value=bool(settings.get("mask_black_bg", False)))
429
+ if "mask_white_bg" in settings:
430
+ mask_white_update = gr.update(value=bool(settings.get("mask_white_bg", False)))
431
+ if "show_cam" in settings:
432
+ show_cam_update = gr.update(value=bool(settings.get("show_cam", True)))
433
+ if "mask_sky" in settings:
434
+ mask_sky_update = gr.update(value=bool(settings.get("mask_sky", False)))
435
+
436
+ pred_mode_value = settings.get("prediction_mode")
437
+ if pred_mode_value in {"Depthmap and Camera Branch", "Pointmap Branch"}:
438
+ prediction_mode_update = gr.update(value=pred_mode_value)
439
+
440
+ try:
441
+ conf_val = settings["conf_thres"]
442
+ mode_val = settings["mode"]
443
+ pred_mode_val = settings["prediction_mode"]
444
+ mask_black_val = bool(settings.get("mask_black_bg", False))
445
+ mask_white_val = bool(settings.get("mask_white_bg", False))
446
+ show_cam_val = bool(settings.get("show_cam", True))
447
+ mask_sky_val = bool(settings.get("mask_sky", False))
448
+ glb_candidate = os.path.join(
449
+ target_dir,
450
+ f"glbscene_{conf_val}_{sanitize_frame_filter_label(frame_value)}_maskb{mask_black_val}_maskw{mask_white_val}_cam{show_cam_val}_sky{mask_sky_val}_pred{pred_mode_val.replace(' ', '_')}_mode{mode_val}.glb",
451
+ )
452
+ if os.path.exists(glb_candidate):
453
+ reconstruction_value = glb_candidate
454
+ except (KeyError, AttributeError):
455
+ pass
456
+
457
+ return (
458
+ reconstruction_value,
459
+ target_dir,
460
+ image_paths,
461
+ message,
462
+ None,
463
+ streaming_update,
464
+ mode_update,
465
+ conf_update,
466
+ frame_update,
467
+ mask_black_update,
468
+ mask_white_update,
469
+ show_cam_update,
470
+ mask_sky_update,
471
+ prediction_mode_update,
472
+ )
473
 
474
 
475
  def update_gallery_without_session(input_video, input_images, input_zip, current_target_dir):
 
504
 
505
  # Prepare frame_filter dropdown
506
  target_dir_images = os.path.join(target_dir, "images")
507
+ frame_filter_choices = build_frame_filter_choices(target_dir_images)
 
 
508
 
509
  print("Running run_model...")
510
  with torch.no_grad():
 
514
  prediction_save_path = os.path.join(target_dir, "predictions.npz")
515
  np.savez(prediction_save_path, **predictions)
516
 
517
+ frame_filter_value = frame_filter if frame_filter is not None else "All"
518
+
519
+ session_settings = {
520
+ "streaming": bool(streaming),
521
+ "mode": mode,
522
+ "conf_thres": float(conf_thres),
523
+ "frame_filter": frame_filter_value,
524
+ "mask_black_bg": bool(mask_black_bg),
525
+ "mask_white_bg": bool(mask_white_bg),
526
+ "show_cam": bool(show_cam),
527
+ "mask_sky": bool(mask_sky),
528
+ "prediction_mode": prediction_mode,
529
+ }
530
+ try:
531
+ with open(os.path.join(target_dir, "session_settings.json"), "w", encoding="utf-8") as handle:
532
+ json.dump(session_settings, handle, indent=2)
533
+ except OSError as exc:
534
+ print(f"Failed to write session settings: {exc}")
535
+
536
  session_state_file = None
537
  if streaming:
538
  if session_cache_path is None:
 
547
  # Build a GLB file name
548
  glbfile = os.path.join(
549
  target_dir,
550
+ f"glbscene_{conf_thres}_{sanitize_frame_filter_label(frame_filter)}_maskb{mask_black_bg}_maskw{mask_white_bg}_cam{show_cam}_sky{mask_sky}_pred{prediction_mode.replace(' ', '_')}_mode{mode}.glb",
551
  )
552
 
553
  # Convert predictions to GLB
 
571
 
572
  end_time = time.time()
573
  print(f"Total time: {end_time - start_time:.2f} seconds (including IO)")
574
+ num_frames = max(0, len(frame_filter_choices) - 1)
575
+ log_msg = f"Reconstruction Success ({num_frames} frames). Waiting for visualization."
576
 
577
  return (
578
  glbfile,
 
642
  loaded = np.load(predictions_path)
643
  predictions = {key: np.array(loaded[key]) for key in key_list}
644
 
645
+ sanitized_frame = sanitize_frame_filter_label(frame_filter)
646
  glbfile = os.path.join(
647
  target_dir,
648
  f"glbscene_{conf_thres}_{sanitized_frame}_maskb{mask_black_bg}_maskw{mask_white_bg}_cam{show_cam}_sky{mask_sky}_pred{prediction_mode.replace(' ', '_')}_mode{mode_value}.glb",
 
833
  streaming = gr.Radio(
834
  [('stream', True), ('batch', False)],
835
  label="Streaming or Batch Mode",
836
+ value=True,
837
  scale=1,
838
  )
839
 
 
841
  mode = gr.Radio(
842
  ["causal", "window", "full"],
843
  label="Select Processing Mode",
844
+ value="window",
845
  scale=1,
846
  )
847
 
 
932
  mode,
933
  False,
934
  )
935
+ return (
936
+ glbfile,
937
+ log_msg,
938
+ target_dir,
939
+ dropdown,
940
+ image_paths,
941
+ session_file,
942
+ False,
943
+ mode,
944
+ conf_thres,
945
+ mask_black_bg,
946
+ mask_white_bg,
947
+ show_cam,
948
+ mask_sky,
949
+ prediction_mode,
950
+ )
951
 
952
  gr.Markdown("Click any row to load an example.", elem_classes=["example-log"])
953
 
 
974
  frame_filter,
975
  image_gallery,
976
  session_state_output,
977
+ streaming,
978
+ mode,
979
+ conf_thres,
980
+ mask_black_bg,
981
+ mask_white_bg,
982
+ show_cam,
983
+ mask_sky,
984
+ prediction_mode,
985
  ],
986
  fn=example_pipeline,
987
  cache_examples=False,
 
1135
  # -------------------------------------------------------------------------
1136
  # Auto-update gallery whenever user uploads or changes their files
1137
  # -------------------------------------------------------------------------
1138
+ upload_outputs = [
1139
+ reconstruction_output,
1140
+ target_dir_output,
1141
+ image_gallery,
1142
+ log_output,
1143
+ session_state_output,
1144
+ streaming,
1145
+ mode,
1146
+ conf_thres,
1147
+ frame_filter,
1148
+ mask_black_bg,
1149
+ mask_white_bg,
1150
+ show_cam,
1151
+ mask_sky,
1152
+ prediction_mode,
1153
+ ]
1154
  no_session_inputs = [input_video, input_images, input_zip, target_dir_output]
1155
 
1156
  input_video.change(fn=update_gallery_without_session, inputs=no_session_inputs, outputs=upload_outputs)
requirements.txt CHANGED
@@ -41,6 +41,7 @@ scipy
41
  seaborn
42
  pyglet<2
43
  huggingface-hub[torch]>=0.22
 
44
 
45
  # --------- eval --------- #
46
  accelerate
 
41
  seaborn
42
  pyglet<2
43
  huggingface-hub[torch]>=0.22
44
+ spaces
45
 
46
  # --------- eval --------- #
47
  accelerate
stream3r/__pycache__/stream_session.cpython-311.pyc CHANGED
Binary files a/stream3r/__pycache__/stream_session.cpython-311.pyc and b/stream3r/__pycache__/stream_session.cpython-311.pyc differ