brian4dwell commited on
Commit
4c075ec
·
1 Parent(s): 6805b8e

add saving and reloading of session

Browse files
app.py CHANGED
@@ -64,7 +64,13 @@ def extract_images_from_zip(zip_path: str, outdir: str) -> list[str]:
64
  if ext not in ALLOWED_IMG_EXT:
65
  continue
66
  # Construct final path safely
67
- dest_path = os.path.join(outdir, os.path.basename(name))
 
 
 
 
 
 
68
  # Zip-slip guard (in case filename has ../ etc.)
69
  if not _is_within_dir(outdir, dest_path):
70
  continue
@@ -74,19 +80,82 @@ def extract_images_from_zip(zip_path: str, outdir: str) -> list[str]:
74
  return extracted
75
 
76
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
  # -------------------------------------------------------------------------
78
  # 1) Core model inference
79
  # -------------------------------------------------------------------------
80
  @spaces.GPU(duration=180) # triggers ZeroGPU allocation for this call
81
- def run_model(target_dir: str, model: STream3R, mode: str="causal", streaming: bool=False) -> dict:
82
  """
83
- Run the STream3R model on images in the 'target_dir/images' folder and return predictions.
84
 
85
  Args:
86
  target_dir: Directory containing the images subfolder
87
  model: STream3R model instance
88
  mode: Processing mode ("causal", "window", or "full")
89
  streaming: If True, use StreamSession for sequential processing; if False, use batch processing
 
 
 
90
  """
91
  print(f"Processing images from {target_dir}")
92
 
@@ -113,6 +182,8 @@ def run_model(target_dir: str, model: STream3R, mode: str="causal", streaming: b
113
  print(f"Running inference in {'streaming' if streaming else 'batch'} mode...")
114
  dtype = torch.bfloat16 if torch.cuda.get_device_capability()[0] >= 8 else torch.float16
115
 
 
 
116
  with torch.no_grad():
117
  with torch.amp.autocast(dtype=dtype, device_type=device):
118
  if streaming:
@@ -123,12 +194,34 @@ def run_model(target_dir: str, model: STream3R, mode: str="causal", streaming: b
123
 
124
  session = StreamSession(model, mode=mode)
125
 
126
- # Process images one by one to simulate streaming inference
127
- for i in range(images.shape[0]):
128
- image = images[i : i + 1]
129
- predictions = session.forward_stream(image)
130
-
131
- session.clear()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
132
  else:
133
  # Use batch processing (original behavior)
134
  predictions = model(images, mode=mode)
@@ -153,19 +246,20 @@ def run_model(target_dir: str, model: STream3R, mode: str="causal", streaming: b
153
 
154
  # Clean up
155
  torch.cuda.empty_cache()
156
- return predictions
157
 
158
 
159
  # -------------------------------------------------------------------------
160
  # 2) Handle uploaded video/images --> produce target_dir + images
161
  # -------------------------------------------------------------------------
162
- def handle_uploads(input_video, input_images, input_zip=None):
163
  """
164
  Create a new 'target_dir' + 'images' subfolder.
165
  - Copies uploaded images
166
  - Optionally extracts images from a ZIP
167
  - Optionally extracts frames from a video (1 fps)
168
- Returns (target_dir, image_paths).
 
169
  """
170
  start_time = time.time()
171
  gc.collect()
@@ -173,11 +267,23 @@ def handle_uploads(input_video, input_images, input_zip=None):
173
 
174
  # Create a unique folder name
175
  timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f")
176
- target_dir = os.path.join("demo_cache", f"input_images_{timestamp}")
177
- target_dir_images = os.path.join(target_dir, "images")
 
 
 
 
 
 
 
 
 
 
 
 
 
178
 
179
- if os.path.exists(target_dir):
180
- shutil.rmtree(target_dir)
181
  os.makedirs(target_dir_images, exist_ok=True)
182
 
183
  image_paths: list[str] = []
@@ -186,9 +292,8 @@ def handle_uploads(input_video, input_images, input_zip=None):
186
  if input_images:
187
  for file_data in input_images:
188
  file_path = file_data["name"] if isinstance(file_data, dict) and "name" in file_data else file_data
189
- dst_path = os.path.join(target_dir_images, os.path.basename(file_path))
190
- shutil.copy(file_path, dst_path)
191
- image_paths.append(dst_path)
192
 
193
  # --- Handle ZIP (extract images) ---
194
  if input_zip:
@@ -203,7 +308,7 @@ def handle_uploads(input_video, input_images, input_zip=None):
203
  fps = vs.get(cv2.CAP_PROP_FPS) or 30.0
204
  frame_interval = max(1, int(fps * 1)) # 1 frame/sec
205
  count = 0
206
- video_frame_num = 0
207
  while True:
208
  gotit, frame = vs.read()
209
  if not gotit:
@@ -218,23 +323,44 @@ def handle_uploads(input_video, input_images, input_zip=None):
218
 
219
  image_paths = sorted(set(image_paths)) # de-dupe + sort
220
 
 
 
 
 
221
  end_time = time.time()
222
  print(f"Prepared {len(image_paths)} files in {target_dir_images}; took {end_time - start_time:.3f}s")
223
- return target_dir, image_paths
224
 
225
 
226
 
227
  # -------------------------------------------------------------------------
228
  # 3) Update gallery on upload
229
  # -------------------------------------------------------------------------
230
- def update_gallery_on_upload(input_video, input_images, input_zip):
231
  """
232
  Handle any new uploads (video, images, or zip) and render preview.
233
  """
234
- if not input_video and not input_images and not input_zip:
235
- return None, None, None, None
236
- target_dir, image_paths = handle_uploads(input_video, input_images, input_zip)
237
- return None, target_dir, image_paths, "Upload complete. Click 'Reconstruct' to begin 3D processing."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
238
 
239
 
240
 
@@ -271,12 +397,19 @@ def gradio_demo(
271
 
272
  print("Running run_model...")
273
  with torch.no_grad():
274
- predictions = run_model(target_dir, model, mode=mode, streaming=streaming)
275
 
276
  # Save predictions
277
  prediction_save_path = os.path.join(target_dir, "predictions.npz")
278
  np.savez(prediction_save_path, **predictions)
279
 
 
 
 
 
 
 
 
280
  # Handle None frame_filter
281
  if frame_filter is None:
282
  frame_filter = "All"
@@ -310,7 +443,12 @@ def gradio_demo(
310
  print(f"Total time: {end_time - start_time:.2f} seconds (including IO)")
311
  log_msg = f"Reconstruction Success ({len(all_files)} frames). Waiting for visualization."
312
 
313
- return glbfile, log_msg, gr.Dropdown(choices=frame_filter_choices, value=frame_filter, interactive=True)
 
 
 
 
 
314
 
315
 
316
  # -------------------------------------------------------------------------
@@ -331,7 +469,16 @@ def update_log():
331
 
332
 
333
  def update_visualization(
334
- target_dir, conf_thres, frame_filter, mask_black_bg, mask_white_bg, show_cam, mask_sky, prediction_mode, is_example
 
 
 
 
 
 
 
 
 
335
  ):
336
  """
337
  Reload saved predictions from npz, create (or reuse) the GLB for new parameters,
@@ -364,9 +511,10 @@ def update_visualization(
364
  loaded = np.load(predictions_path)
365
  predictions = {key: np.array(loaded[key]) for key in key_list}
366
 
 
367
  glbfile = os.path.join(
368
  target_dir,
369
- f"glbscene_{conf_thres}_{frame_filter.replace('.', '_').replace(':', '').replace(' ', '_')}_maskb{mask_black_bg}_maskw{mask_white_bg}_cam{show_cam}_sky{mask_sky}_pred{prediction_mode.replace(' ', '_')}_mode{mode}.glb",
370
  )
371
 
372
  if not os.path.exists(glbfile):
@@ -504,6 +652,7 @@ with gr.Blocks(
504
  input_video = gr.Video(label="Upload Video", interactive=True)
505
  input_images = gr.File(file_count="multiple", label="Upload Images", interactive=True)
506
  input_zip = gr.File(file_types=[".zip"], label="Upload ZIP of Images", interactive=True)
 
507
 
508
  image_gallery = gr.Gallery(
509
  label="Preview",
@@ -521,11 +670,22 @@ with gr.Blocks(
521
  "Please upload a video or images, then click Reconstruct.", elem_classes=["custom-log"]
522
  )
523
  reconstruction_output = gr.Model3D(height=520, zoom_speed=0.5, pan_speed=0.5)
 
524
 
525
  with gr.Row():
526
  submit_btn = gr.Button("Reconstruct", scale=1, variant="primary")
527
  clear_btn = gr.ClearButton(
528
- [input_video, input_images, reconstruction_output, log_output, target_dir_output, image_gallery],
 
 
 
 
 
 
 
 
 
 
529
  scale=1,
530
  )
531
 
@@ -626,13 +786,22 @@ with gr.Blocks(
626
  3) Return model3D + logs + new_dir + updated dropdown + gallery
627
  We do NOT return is_example. It's just an input.
628
  """
629
- target_dir, image_paths = handle_uploads(input_video, input_images)
630
  # Always use "All" for frame_filter in examples
631
  frame_filter = "All"
632
- glbfile, log_msg, dropdown = gradio_demo(
633
- target_dir, conf_thres, frame_filter, mask_black_bg, mask_white_bg, show_cam, mask_sky, prediction_mode, mode
 
 
 
 
 
 
 
 
 
634
  )
635
- return glbfile, log_msg, target_dir, dropdown, image_paths
636
 
637
  gr.Markdown("Click any row to load an example.", elem_classes=["example-log"])
638
 
@@ -652,7 +821,14 @@ with gr.Blocks(
652
  is_example,
653
  mode,
654
  ],
655
- outputs=[reconstruction_output, log_output, target_dir_output, frame_filter, image_gallery],
 
 
 
 
 
 
 
656
  fn=example_pipeline,
657
  cache_examples=False,
658
  examples_per_page=50,
@@ -681,7 +857,7 @@ with gr.Blocks(
681
  mode,
682
  streaming,
683
  ],
684
- outputs=[reconstruction_output, log_output, frame_filter],
685
  ).then(
686
  fn=lambda: "False", inputs=[], outputs=[is_example] # set is_example to "False"
687
  )
@@ -700,6 +876,7 @@ with gr.Blocks(
700
  show_cam,
701
  mask_sky,
702
  prediction_mode,
 
703
  is_example,
704
  ],
705
  [reconstruction_output, log_output],
@@ -715,6 +892,7 @@ with gr.Blocks(
715
  show_cam,
716
  mask_sky,
717
  prediction_mode,
 
718
  is_example,
719
  ],
720
  [reconstruction_output, log_output],
@@ -730,6 +908,7 @@ with gr.Blocks(
730
  show_cam,
731
  mask_sky,
732
  prediction_mode,
 
733
  is_example,
734
  ],
735
  [reconstruction_output, log_output],
@@ -745,6 +924,7 @@ with gr.Blocks(
745
  show_cam,
746
  mask_sky,
747
  prediction_mode,
 
748
  is_example,
749
  ],
750
  [reconstruction_output, log_output],
@@ -760,6 +940,7 @@ with gr.Blocks(
760
  show_cam,
761
  mask_sky,
762
  prediction_mode,
 
763
  is_example,
764
  ],
765
  [reconstruction_output, log_output],
@@ -775,6 +956,7 @@ with gr.Blocks(
775
  show_cam,
776
  mask_sky,
777
  prediction_mode,
 
778
  is_example,
779
  ],
780
  [reconstruction_output, log_output],
@@ -790,6 +972,7 @@ with gr.Blocks(
790
  show_cam,
791
  mask_sky,
792
  prediction_mode,
 
793
  is_example,
794
  ],
795
  [reconstruction_output, log_output],
@@ -798,20 +981,16 @@ with gr.Blocks(
798
  # -------------------------------------------------------------------------
799
  # Auto-update gallery whenever user uploads or changes their files
800
  # -------------------------------------------------------------------------
801
- input_video.change(
802
- fn=update_gallery_on_upload,
803
- inputs=[input_video, input_images, input_zip],
804
- outputs=[reconstruction_output, target_dir_output, image_gallery, log_output],
805
- )
806
- input_images.change(
807
- fn=update_gallery_on_upload,
808
- inputs=[input_video, input_images, input_zip],
809
- outputs=[reconstruction_output, target_dir_output, image_gallery, log_output],
810
- )
811
- input_zip.change(
812
  fn=update_gallery_on_upload,
813
- inputs=[input_video, input_images, input_zip],
814
- outputs=[reconstruction_output, target_dir_output, image_gallery, log_output],
815
  )
816
 
817
  demo.queue(max_size=20).launch(show_error=True, share=False)
 
64
  if ext not in ALLOWED_IMG_EXT:
65
  continue
66
  # Construct final path safely
67
+ base_name = os.path.basename(name)
68
+ name_root, name_ext = os.path.splitext(base_name)
69
+ dest_path = os.path.join(outdir, base_name)
70
+ counter = 1
71
+ while os.path.exists(dest_path):
72
+ dest_path = os.path.join(outdir, f"{name_root}_{counter}{name_ext}")
73
+ counter += 1
74
  # Zip-slip guard (in case filename has ../ etc.)
75
  if not _is_within_dir(outdir, dest_path):
76
  continue
 
80
  return extracted
81
 
82
 
83
+ def extract_session_state(zip_path: str, extract_root: str) -> str:
84
+ """Extract a previously saved session archive into *extract_root*.
85
+
86
+ Returns the directory that contains the restored session data.
87
+ """
88
+ if os.path.exists(extract_root):
89
+ shutil.rmtree(extract_root)
90
+ os.makedirs(extract_root, exist_ok=True)
91
+
92
+ with zipfile.ZipFile(zip_path, "r") as zf:
93
+ zf.extractall(extract_root)
94
+
95
+ entries = [os.path.join(extract_root, entry) for entry in os.listdir(extract_root)]
96
+ dirs = [entry for entry in entries if os.path.isdir(entry)]
97
+ files = [entry for entry in entries if os.path.isfile(entry)]
98
+
99
+ if len(dirs) == 1 and not files:
100
+ return dirs[0]
101
+ return extract_root
102
+
103
+
104
+ def package_session_state(target_dir: str) -> str:
105
+ """Create a zip archive containing the entire session directory."""
106
+ if not os.path.isdir(target_dir):
107
+ raise ValueError(f"Target directory does not exist: {target_dir}")
108
+
109
+ os.makedirs("demo_cache", exist_ok=True)
110
+ archive_name = f"{os.path.basename(os.path.normpath(target_dir))}_session.zip"
111
+ archive_path = os.path.join("demo_cache", archive_name)
112
+
113
+ if os.path.exists(archive_path):
114
+ os.remove(archive_path)
115
+
116
+ with zipfile.ZipFile(archive_path, "w", compression=zipfile.ZIP_DEFLATED) as zf:
117
+ for root, _, files in os.walk(target_dir):
118
+ for fname in files:
119
+ file_path = os.path.join(root, fname)
120
+ if os.path.abspath(file_path) == os.path.abspath(archive_path):
121
+ continue
122
+ arcname = os.path.join(os.path.basename(target_dir), os.path.relpath(file_path, target_dir))
123
+ zf.write(file_path, arcname)
124
+
125
+ return archive_path
126
+
127
+
128
+ def _copy_with_unique_name(src_path: str, dst_dir: str) -> str:
129
+ """Copy *src_path* into *dst_dir*, avoiding filename collisions."""
130
+ base_name = os.path.basename(src_path)
131
+ name, ext = os.path.splitext(base_name)
132
+ candidate = base_name
133
+ counter = 1
134
+ dest_path = os.path.join(dst_dir, candidate)
135
+ while os.path.exists(dest_path):
136
+ candidate = f"{name}_{counter}{ext}"
137
+ dest_path = os.path.join(dst_dir, candidate)
138
+ counter += 1
139
+ shutil.copy(src_path, dest_path)
140
+ return dest_path
141
+
142
+
143
  # -------------------------------------------------------------------------
144
  # 1) Core model inference
145
  # -------------------------------------------------------------------------
146
  @spaces.GPU(duration=180) # triggers ZeroGPU allocation for this call
147
+ def run_model(target_dir: str, model: STream3R, mode: str="causal", streaming: bool=False) -> tuple[dict, str | None]:
148
  """
149
+ Run the STream3R model on images in the 'target_dir/images' folder.
150
 
151
  Args:
152
  target_dir: Directory containing the images subfolder
153
  model: STream3R model instance
154
  mode: Processing mode ("causal", "window", or "full")
155
  streaming: If True, use StreamSession for sequential processing; if False, use batch processing
156
+ Returns:
157
+ tuple[dict, str | None]: Predictions dictionary and optional path to the saved session cache when
158
+ streaming mode is used.
159
  """
160
  print(f"Processing images from {target_dir}")
161
 
 
182
  print(f"Running inference in {'streaming' if streaming else 'batch'} mode...")
183
  dtype = torch.bfloat16 if torch.cuda.get_device_capability()[0] >= 8 else torch.float16
184
 
185
+ session_cache_path: str | None = None
186
+
187
  with torch.no_grad():
188
  with torch.amp.autocast(dtype=dtype, device_type=device):
189
  if streaming:
 
194
 
195
  session = StreamSession(model, mode=mode)
196
 
197
+ kv_cache_path = os.path.join(target_dir, "kv_cache.pt")
198
+ if os.path.exists(kv_cache_path):
199
+ print(f"Loading existing session cache from {kv_cache_path}")
200
+ session.load_cache(kv_cache_path, device=images.device)
201
+
202
+ existing_predictions = session.get_all_predictions()
203
+ existing_frames = 0
204
+ for value in existing_predictions.values():
205
+ if isinstance(value, torch.Tensor) and value.dim() >= 2:
206
+ existing_frames = max(existing_frames, value.shape[1])
207
+
208
+ total_frames = images.shape[0]
209
+ if existing_frames > total_frames:
210
+ raise ValueError(
211
+ "Session cache contains more frames than available images. Please ensure the images folder "
212
+ "matches the saved session state."
213
+ )
214
+
215
+ if existing_frames == total_frames:
216
+ print("No new frames detected; reusing cached predictions.")
217
+ else:
218
+ for i in range(existing_frames, total_frames):
219
+ image = images[i : i + 1]
220
+ session.forward_stream(image)
221
+
222
+ predictions = session.get_all_predictions()
223
+ session.save_cache(kv_cache_path)
224
+ session_cache_path = kv_cache_path
225
  else:
226
  # Use batch processing (original behavior)
227
  predictions = model(images, mode=mode)
 
246
 
247
  # Clean up
248
  torch.cuda.empty_cache()
249
+ return predictions, session_cache_path
250
 
251
 
252
  # -------------------------------------------------------------------------
253
  # 2) Handle uploaded video/images --> produce target_dir + images
254
  # -------------------------------------------------------------------------
255
+ def handle_uploads(input_video, input_images, input_zip=None, session_state=None, current_target_dir: str | None = None):
256
  """
257
  Create a new 'target_dir' + 'images' subfolder.
258
  - Copies uploaded images
259
  - Optionally extracts images from a ZIP
260
  - Optionally extracts frames from a video (1 fps)
261
+ - Optionally loads a previously saved session archive
262
+ Returns (target_dir, image_paths, session_loaded).
263
  """
264
  start_time = time.time()
265
  gc.collect()
 
267
 
268
  # Create a unique folder name
269
  timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f")
270
+ session_loaded = False
271
+
272
+ if session_state:
273
+ session_path = session_state.get("name") if isinstance(session_state, dict) and "name" in session_state else getattr(session_state, "name", None)
274
+ session_path = session_path or session_state
275
+ extract_root = os.path.join("demo_cache", f"session_{timestamp}")
276
+ target_dir = extract_session_state(session_path, extract_root)
277
+ session_loaded = True
278
+ elif current_target_dir and os.path.isdir(current_target_dir):
279
+ target_dir = current_target_dir
280
+ else:
281
+ target_dir = os.path.join("demo_cache", f"input_images_{timestamp}")
282
+ if os.path.exists(target_dir):
283
+ shutil.rmtree(target_dir)
284
+ os.makedirs(target_dir, exist_ok=True)
285
 
286
+ target_dir_images = os.path.join(target_dir, "images")
 
287
  os.makedirs(target_dir_images, exist_ok=True)
288
 
289
  image_paths: list[str] = []
 
292
  if input_images:
293
  for file_data in input_images:
294
  file_path = file_data["name"] if isinstance(file_data, dict) and "name" in file_data else file_data
295
+ copied_path = _copy_with_unique_name(file_path, target_dir_images)
296
+ image_paths.append(copied_path)
 
297
 
298
  # --- Handle ZIP (extract images) ---
299
  if input_zip:
 
308
  fps = vs.get(cv2.CAP_PROP_FPS) or 30.0
309
  frame_interval = max(1, int(fps * 1)) # 1 frame/sec
310
  count = 0
311
+ video_frame_num = len(os.listdir(target_dir_images))
312
  while True:
313
  gotit, frame = vs.read()
314
  if not gotit:
 
323
 
324
  image_paths = sorted(set(image_paths)) # de-dupe + sort
325
 
326
+ # Ensure gallery reflects existing files in the images directory
327
+ existing_images = sorted(glob.glob(os.path.join(target_dir_images, "*")))
328
+ image_paths = existing_images
329
+
330
  end_time = time.time()
331
  print(f"Prepared {len(image_paths)} files in {target_dir_images}; took {end_time - start_time:.3f}s")
332
+ return target_dir, image_paths, session_loaded
333
 
334
 
335
 
336
  # -------------------------------------------------------------------------
337
  # 3) Update gallery on upload
338
  # -------------------------------------------------------------------------
339
+ def update_gallery_on_upload(input_video, input_images, input_zip, session_state, current_target_dir):
340
  """
341
  Handle any new uploads (video, images, or zip) and render preview.
342
  """
343
+ if not input_video and not input_images and not input_zip and not session_state:
344
+ return None, current_target_dir, None, None, None
345
+
346
+ target_dir, image_paths, session_loaded = handle_uploads(
347
+ input_video,
348
+ input_images,
349
+ input_zip,
350
+ session_state=session_state,
351
+ current_target_dir=current_target_dir,
352
+ )
353
+
354
+ if session_loaded:
355
+ message = "Session state loaded. Add new frames and click 'Reconstruct' to continue."
356
+ else:
357
+ message = "Upload complete. Click 'Reconstruct' to begin 3D processing."
358
+
359
+ return None, target_dir, image_paths, message, None
360
+
361
+
362
+ def update_gallery_without_session(input_video, input_images, input_zip, current_target_dir):
363
+ return update_gallery_on_upload(input_video, input_images, input_zip, None, current_target_dir)
364
 
365
 
366
 
 
397
 
398
  print("Running run_model...")
399
  with torch.no_grad():
400
+ predictions, session_cache_path = run_model(target_dir, model, mode=mode, streaming=streaming)
401
 
402
  # Save predictions
403
  prediction_save_path = os.path.join(target_dir, "predictions.npz")
404
  np.savez(prediction_save_path, **predictions)
405
 
406
+ session_state_file = None
407
+ if streaming:
408
+ if session_cache_path is None:
409
+ session_cache_path = os.path.join(target_dir, "kv_cache.pt")
410
+ if os.path.exists(session_cache_path):
411
+ session_state_file = package_session_state(target_dir)
412
+
413
  # Handle None frame_filter
414
  if frame_filter is None:
415
  frame_filter = "All"
 
443
  print(f"Total time: {end_time - start_time:.2f} seconds (including IO)")
444
  log_msg = f"Reconstruction Success ({len(all_files)} frames). Waiting for visualization."
445
 
446
+ return (
447
+ glbfile,
448
+ log_msg,
449
+ gr.Dropdown(choices=frame_filter_choices, value=frame_filter, interactive=True),
450
+ session_state_file,
451
+ )
452
 
453
 
454
  # -------------------------------------------------------------------------
 
469
 
470
 
471
  def update_visualization(
472
+ target_dir,
473
+ conf_thres,
474
+ frame_filter,
475
+ mask_black_bg,
476
+ mask_white_bg,
477
+ show_cam,
478
+ mask_sky,
479
+ prediction_mode,
480
+ mode_value,
481
+ is_example,
482
  ):
483
  """
484
  Reload saved predictions from npz, create (or reuse) the GLB for new parameters,
 
511
  loaded = np.load(predictions_path)
512
  predictions = {key: np.array(loaded[key]) for key in key_list}
513
 
514
+ sanitized_frame = frame_filter.replace('.', '_').replace(':', '').replace(' ', '_') if frame_filter else "All"
515
  glbfile = os.path.join(
516
  target_dir,
517
+ f"glbscene_{conf_thres}_{sanitized_frame}_maskb{mask_black_bg}_maskw{mask_white_bg}_cam{show_cam}_sky{mask_sky}_pred{prediction_mode.replace(' ', '_')}_mode{mode_value}.glb",
518
  )
519
 
520
  if not os.path.exists(glbfile):
 
652
  input_video = gr.Video(label="Upload Video", interactive=True)
653
  input_images = gr.File(file_count="multiple", label="Upload Images", interactive=True)
654
  input_zip = gr.File(file_types=[".zip"], label="Upload ZIP of Images", interactive=True)
655
+ session_state_input = gr.File(file_types=[".zip"], label="Load Session State", interactive=True)
656
 
657
  image_gallery = gr.Gallery(
658
  label="Preview",
 
670
  "Please upload a video or images, then click Reconstruct.", elem_classes=["custom-log"]
671
  )
672
  reconstruction_output = gr.Model3D(height=520, zoom_speed=0.5, pan_speed=0.5)
673
+ session_state_output = gr.File(label="Download Session State", interactive=False)
674
 
675
  with gr.Row():
676
  submit_btn = gr.Button("Reconstruct", scale=1, variant="primary")
677
  clear_btn = gr.ClearButton(
678
+ [
679
+ input_video,
680
+ input_images,
681
+ input_zip,
682
+ session_state_input,
683
+ reconstruction_output,
684
+ log_output,
685
+ target_dir_output,
686
+ image_gallery,
687
+ session_state_output,
688
+ ],
689
  scale=1,
690
  )
691
 
 
786
  3) Return model3D + logs + new_dir + updated dropdown + gallery
787
  We do NOT return is_example. It's just an input.
788
  """
789
+ target_dir, image_paths, _ = handle_uploads(input_video, input_images)
790
  # Always use "All" for frame_filter in examples
791
  frame_filter = "All"
792
+ glbfile, log_msg, dropdown, session_file = gradio_demo(
793
+ target_dir,
794
+ conf_thres,
795
+ frame_filter,
796
+ mask_black_bg,
797
+ mask_white_bg,
798
+ show_cam,
799
+ mask_sky,
800
+ prediction_mode,
801
+ mode,
802
+ False,
803
  )
804
+ return glbfile, log_msg, target_dir, dropdown, image_paths, session_file
805
 
806
  gr.Markdown("Click any row to load an example.", elem_classes=["example-log"])
807
 
 
821
  is_example,
822
  mode,
823
  ],
824
+ outputs=[
825
+ reconstruction_output,
826
+ log_output,
827
+ target_dir_output,
828
+ frame_filter,
829
+ image_gallery,
830
+ session_state_output,
831
+ ],
832
  fn=example_pipeline,
833
  cache_examples=False,
834
  examples_per_page=50,
 
857
  mode,
858
  streaming,
859
  ],
860
+ outputs=[reconstruction_output, log_output, frame_filter, session_state_output],
861
  ).then(
862
  fn=lambda: "False", inputs=[], outputs=[is_example] # set is_example to "False"
863
  )
 
876
  show_cam,
877
  mask_sky,
878
  prediction_mode,
879
+ mode,
880
  is_example,
881
  ],
882
  [reconstruction_output, log_output],
 
892
  show_cam,
893
  mask_sky,
894
  prediction_mode,
895
+ mode,
896
  is_example,
897
  ],
898
  [reconstruction_output, log_output],
 
908
  show_cam,
909
  mask_sky,
910
  prediction_mode,
911
+ mode,
912
  is_example,
913
  ],
914
  [reconstruction_output, log_output],
 
924
  show_cam,
925
  mask_sky,
926
  prediction_mode,
927
+ mode,
928
  is_example,
929
  ],
930
  [reconstruction_output, log_output],
 
940
  show_cam,
941
  mask_sky,
942
  prediction_mode,
943
+ mode,
944
  is_example,
945
  ],
946
  [reconstruction_output, log_output],
 
956
  show_cam,
957
  mask_sky,
958
  prediction_mode,
959
+ mode,
960
  is_example,
961
  ],
962
  [reconstruction_output, log_output],
 
972
  show_cam,
973
  mask_sky,
974
  prediction_mode,
975
+ mode,
976
  is_example,
977
  ],
978
  [reconstruction_output, log_output],
 
981
  # -------------------------------------------------------------------------
982
  # Auto-update gallery whenever user uploads or changes their files
983
  # -------------------------------------------------------------------------
984
+ upload_outputs = [reconstruction_output, target_dir_output, image_gallery, log_output, session_state_output]
985
+ no_session_inputs = [input_video, input_images, input_zip, target_dir_output]
986
+
987
+ input_video.change(fn=update_gallery_without_session, inputs=no_session_inputs, outputs=upload_outputs)
988
+ input_images.change(fn=update_gallery_without_session, inputs=no_session_inputs, outputs=upload_outputs)
989
+ input_zip.change(fn=update_gallery_without_session, inputs=no_session_inputs, outputs=upload_outputs)
990
+ session_state_input.change(
 
 
 
 
991
  fn=update_gallery_on_upload,
992
+ inputs=[input_video, input_images, input_zip, session_state_input, target_dir_output],
993
+ outputs=upload_outputs,
994
  )
995
 
996
  demo.queue(max_size=20).launch(show_error=True, share=False)
configs/stream_session.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ {
2
+ "window_size": 25
3
+ }
stream3r/stream_session.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import os
2
  from typing import Any, Dict, Optional
3
 
@@ -9,12 +10,13 @@ class StreamSession:
9
  """
10
  A causal streaming inference session with KV cache management for STream3R.
11
  """
12
- def __init__(self, model: STream3R, mode: str):
13
  self.model = model
14
  self.mode = mode
15
  self.aggregator_kv_cache_depth = model.aggregator.depth
16
  self.camera_head_kv_cache_depth = model.camera_head.trunk_depth
17
  self.camera_head_iterations = 4
 
18
 
19
  if self.mode not in ["causal", "window"]:
20
  raise ValueError(f"Unsupported attention mode when using kv_cache: {self.mode}")
@@ -41,13 +43,12 @@ class StreamSession:
41
  self.aggregator_kv_cache_list = aggregator_kv_cache_list
42
  self.camera_head_kv_cache_list = camera_head_kv_cache_list
43
  elif self.mode == "window":
44
- window_size = 25
45
  for k in range(2):
46
  for i in range(self.aggregator_kv_cache_depth):
47
  h, w = self.predictions["depth"].shape[2], self.predictions["depth"].shape[3]
48
  P = h * w // self.model.aggregator.patch_size // self.model.aggregator.patch_size + self.model.aggregator.patch_start_idx
49
  anchor_token = aggregator_kv_cache_list[i][k][:, :, :P]
50
- window_tokens = aggregator_kv_cache_list[i][k][:, :, max(P, aggregator_kv_cache_list[i][k].size(2)-window_size*P):]
51
  self.aggregator_kv_cache_list[i][k] = torch.cat(
52
  [
53
  anchor_token,
@@ -58,7 +59,7 @@ class StreamSession:
58
  for i in range(self.camera_head_iterations):
59
  for j in range(self.camera_head_kv_cache_depth):
60
  anchor_token = camera_head_kv_cache_list[i][j][k][:, :, :1]
61
- window_tokens = camera_head_kv_cache_list[i][j][k][:, :, max(1, camera_head_kv_cache_list[i][j][k].size(2)-window_size):]
62
  self.camera_head_kv_cache_list[i][j][k] = torch.cat(
63
  [
64
  anchor_token,
@@ -112,6 +113,32 @@ class StreamSession:
112
  except StopIteration:
113
  return torch.device("cpu")
114
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
115
  def save_cache(self, file_path: str) -> None:
116
  aggregator_cache, camera_cache = self._get_cache()
117
 
@@ -121,6 +148,7 @@ class StreamSession:
121
  "aggregator_depth": self.aggregator_kv_cache_depth,
122
  "camera_head_depth": self.camera_head_kv_cache_depth,
123
  "camera_head_iterations": self.camera_head_iterations,
 
124
  "patch_size": getattr(self.model.aggregator, "patch_size", None),
125
  "patch_start_idx": getattr(self.model.aggregator, "patch_start_idx", None),
126
  },
@@ -148,6 +176,7 @@ class StreamSession:
148
  "aggregator_depth": self.aggregator_kv_cache_depth,
149
  "camera_head_depth": self.camera_head_kv_cache_depth,
150
  "camera_head_iterations": self.camera_head_iterations,
 
151
  }
152
 
153
  for key, expected_value in expected_metadata.items():
 
1
+ import json
2
  import os
3
  from typing import Any, Dict, Optional
4
 
 
10
  """
11
  A causal streaming inference session with KV cache management for STream3R.
12
  """
13
+ def __init__(self, model: STream3R, mode: str, *, window_size: Optional[int] = None, config_path: Optional[str] = None):
14
  self.model = model
15
  self.mode = mode
16
  self.aggregator_kv_cache_depth = model.aggregator.depth
17
  self.camera_head_kv_cache_depth = model.camera_head.trunk_depth
18
  self.camera_head_iterations = 4
19
+ self.window_size = self._resolve_window_size(window_size, config_path)
20
 
21
  if self.mode not in ["causal", "window"]:
22
  raise ValueError(f"Unsupported attention mode when using kv_cache: {self.mode}")
 
43
  self.aggregator_kv_cache_list = aggregator_kv_cache_list
44
  self.camera_head_kv_cache_list = camera_head_kv_cache_list
45
  elif self.mode == "window":
 
46
  for k in range(2):
47
  for i in range(self.aggregator_kv_cache_depth):
48
  h, w = self.predictions["depth"].shape[2], self.predictions["depth"].shape[3]
49
  P = h * w // self.model.aggregator.patch_size // self.model.aggregator.patch_size + self.model.aggregator.patch_start_idx
50
  anchor_token = aggregator_kv_cache_list[i][k][:, :, :P]
51
+ window_tokens = aggregator_kv_cache_list[i][k][:, :, max(P, aggregator_kv_cache_list[i][k].size(2)-self.window_size*P):]
52
  self.aggregator_kv_cache_list[i][k] = torch.cat(
53
  [
54
  anchor_token,
 
59
  for i in range(self.camera_head_iterations):
60
  for j in range(self.camera_head_kv_cache_depth):
61
  anchor_token = camera_head_kv_cache_list[i][j][k][:, :, :1]
62
+ window_tokens = camera_head_kv_cache_list[i][j][k][:, :, max(1, camera_head_kv_cache_list[i][j][k].size(2)-self.window_size):]
63
  self.camera_head_kv_cache_list[i][j][k] = torch.cat(
64
  [
65
  anchor_token,
 
113
  except StopIteration:
114
  return torch.device("cpu")
115
 
116
+ def _resolve_window_size(self, override: Optional[int], config_path: Optional[str]) -> int:
117
+ if override is not None:
118
+ return override
119
+
120
+ config_path = config_path or os.path.abspath(
121
+ os.path.join(os.path.dirname(__file__), "..", "configs", "stream_session.json")
122
+ )
123
+
124
+ default_window_size = 25
125
+
126
+ if not os.path.exists(config_path):
127
+ return default_window_size
128
+
129
+ try:
130
+ with open(config_path, "r", encoding="utf-8") as handle:
131
+ data = json.load(handle)
132
+ except (json.JSONDecodeError, OSError):
133
+ return default_window_size
134
+
135
+ window_size = data.get("window_size")
136
+
137
+ if isinstance(window_size, int) and window_size > 0:
138
+ return window_size
139
+
140
+ return default_window_size
141
+
142
  def save_cache(self, file_path: str) -> None:
143
  aggregator_cache, camera_cache = self._get_cache()
144
 
 
148
  "aggregator_depth": self.aggregator_kv_cache_depth,
149
  "camera_head_depth": self.camera_head_kv_cache_depth,
150
  "camera_head_iterations": self.camera_head_iterations,
151
+ "window_size": self.window_size,
152
  "patch_size": getattr(self.model.aggregator, "patch_size", None),
153
  "patch_start_idx": getattr(self.model.aggregator, "patch_start_idx", None),
154
  },
 
176
  "aggregator_depth": self.aggregator_kv_cache_depth,
177
  "camera_head_depth": self.camera_head_kv_cache_depth,
178
  "camera_head_iterations": self.camera_head_iterations,
179
+ "window_size": self.window_size,
180
  }
181
 
182
  for key, expected_value in expected_metadata.items():
tests/test_stream_session_cache.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import os
2
  import tempfile
3
  import unittest
@@ -101,6 +102,22 @@ else:
101
  restored_tensor = restored_session.predictions[key]
102
  self.assertTrue(torch.equal(original_tensor, restored_tensor))
103
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
104
 
105
  if __name__ == "__main__": # pragma: no cover - manual execution
106
  unittest.main()
 
1
+ import json
2
  import os
3
  import tempfile
4
  import unittest
 
102
  restored_tensor = restored_session.predictions[key]
103
  self.assertTrue(torch.equal(original_tensor, restored_tensor))
104
 
105
+ def test_window_size_from_config(self):
106
+ model = _DummyModel()
107
+ with tempfile.TemporaryDirectory() as tmpdir:
108
+ config_path = os.path.join(tmpdir, "stream_session.json")
109
+ with open(config_path, "w", encoding="utf-8") as handle:
110
+ json.dump({"window_size": 7}, handle)
111
+
112
+ session = StreamSession(model, mode="window", config_path=config_path)
113
+
114
+ self.assertEqual(session.window_size, 7)
115
+
116
+ def test_window_size_override(self):
117
+ model = _DummyModel()
118
+ session = StreamSession(model, mode="window", window_size=11)
119
+ self.assertEqual(session.window_size, 11)
120
+
121
 
122
  if __name__ == "__main__": # pragma: no cover - manual execution
123
  unittest.main()