brian4dwell commited on
Commit
79b8fec
·
1 Parent(s): 50cb28d

top_k selection

Browse files
Files changed (3) hide show
  1. .vscode/launch.json +2 -2
  2. app.py +181 -20
  3. configs/stream_session.json +2 -1
.vscode/launch.json CHANGED
@@ -3,10 +3,10 @@
3
  "configurations": [
4
 
5
  {
6
- "name": "Python: Current File (venv)",
7
  "type": "debugpy",
8
  "request": "launch",
9
- "program": "${file}",
10
  "console": "integratedTerminal",
11
  "cwd": "${workspaceFolder}",
12
  "envFile": "${workspaceFolder}/.env",
 
3
  "configurations": [
4
 
5
  {
6
+ "name": "Python: UI (conda)",
7
  "type": "debugpy",
8
  "request": "launch",
9
+ "program": "${workspaceFolder}/app.py",
10
  "console": "integratedTerminal",
11
  "cwd": "${workspaceFolder}",
12
  "envFile": "${workspaceFolder}/.env",
app.py CHANGED
@@ -17,6 +17,7 @@ import glob
17
  import gc
18
  import time
19
  import zipfile
 
20
  from typing import Any, Dict, Optional
21
  from stream3r.models.stream3r import STream3R
22
  from stream3r.stream_session import StreamSession
@@ -153,6 +154,21 @@ def _resolve_path(file_data) -> Optional[str]:
153
  return str(file_data)
154
 
155
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
156
  def load_session_settings(target_dir: str) -> Dict[str, Any]:
157
  settings_path = os.path.join(target_dir, "session_settings.json")
158
  if not os.path.exists(settings_path):
@@ -180,6 +196,130 @@ def sanitize_frame_filter_label(label: Optional[str]) -> str:
180
  return label.replace('.', '_').replace(':', '').replace(' ', '_')
181
 
182
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
183
  # -------------------------------------------------------------------------
184
  # 1) Core model inference
185
  # -------------------------------------------------------------------------
@@ -551,22 +691,22 @@ def localize_new_image(
551
  session.clear()
552
 
553
  try:
554
- with torch.no_grad():
555
- with torch.amp.autocast(dtype=image_tensor.dtype, device_type=image_tensor.device.type):
556
- session.load_cache(kv_cache_path, device=image_tensor.device)
557
 
558
- existing_predictions = session.get_all_predictions()
559
- existing_frames = 0
560
- for value in existing_predictions.values():
561
- if isinstance(value, torch.Tensor) and value.dim() >= 2:
562
- existing_frames = max(existing_frames, value.shape[1])
563
 
564
- session.forward_stream(image_tensor)
 
565
 
566
- localized_predictions = session.get_all_predictions()
567
  except Exception as exc:
568
  session.clear()
569
- torch.cuda.empty_cache()
 
570
  return (f"Localization failed: {exc}", gr.update())
571
 
572
  def _extract_frame(tensor: torch.Tensor, index: int) -> np.ndarray:
@@ -706,7 +846,8 @@ def localize_new_image(
706
  summary_lines.append(f"Warning: failed to update GLB preview ({exc})")
707
 
708
  session.clear()
709
- torch.cuda.empty_cache()
 
710
 
711
  return ("\n".join(summary_lines), localization_glb_path if localization_glb_path else gr.update())
712
 
@@ -740,6 +881,8 @@ def gradio_demo(
740
  # Prepare frame_filter dropdown
741
  target_dir_images = os.path.join(target_dir, "images")
742
  frame_filter_choices = build_frame_filter_choices(target_dir_images)
 
 
743
 
744
  print("Running run_model...")
745
  with torch.no_grad():
@@ -749,6 +892,20 @@ def gradio_demo(
749
  prediction_save_path = os.path.join(target_dir, "predictions.npz")
750
  np.savez(prediction_save_path, **predictions)
751
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
752
  frame_filter_value = frame_filter if frame_filter is not None else "All"
753
 
754
  session_settings = {
@@ -762,6 +919,9 @@ def gradio_demo(
762
  "mask_sky": bool(mask_sky),
763
  "prediction_mode": prediction_mode,
764
  }
 
 
 
765
  try:
766
  with open(os.path.join(target_dir, "session_settings.json"), "w", encoding="utf-8") as handle:
767
  json.dump(session_settings, handle, indent=2)
@@ -1039,6 +1199,14 @@ with gr.Blocks(
1039
  session_state_output = gr.File(label="Download Session State", interactive=False)
1040
  localization_output = gr.Textbox(label="Localization Result", lines=8, interactive=False)
1041
 
 
 
 
 
 
 
 
 
1042
  with gr.Row():
1043
  submit_btn = gr.Button("Reconstruct", scale=1, variant="primary")
1044
  clear_btn = gr.ClearButton(
@@ -1047,6 +1215,7 @@ with gr.Blocks(
1047
  input_images,
1048
  input_zip,
1049
  session_state_input,
 
1050
  reconstruction_output,
1051
  log_output,
1052
  target_dir_output,
@@ -1091,14 +1260,6 @@ with gr.Blocks(
1091
  mask_black_bg = gr.Checkbox(label="Filter Black Background", value=False)
1092
  mask_white_bg = gr.Checkbox(label="Filter White Background", value=False)
1093
 
1094
- with gr.Row():
1095
- localization_image_input = gr.File(
1096
- label="Localize Single Image",
1097
- file_types=[".png", ".jpg", ".jpeg", ".bmp", ".webp"],
1098
- interactive=True,
1099
- )
1100
- localize_button = gr.Button("Localize Image", variant="secondary")
1101
-
1102
  # ---------------------- Examples section ----------------------
1103
  def build_examples_from_folder():
1104
  examples_root = "examples"
 
17
  import gc
18
  import time
19
  import zipfile
20
+ import functools
21
  from typing import Any, Dict, Optional
22
  from stream3r.models.stream3r import STream3R
23
  from stream3r.stream_session import StreamSession
 
154
  return str(file_data)
155
 
156
 
157
+ STREAM_SESSION_CONFIG_PATH = os.path.join(os.path.dirname(__file__), "configs", "stream_session.json")
158
+
159
+
160
+ @functools.lru_cache(maxsize=1)
161
+ def load_stream_session_config() -> Dict[str, Any]:
162
+ try:
163
+ with open(STREAM_SESSION_CONFIG_PATH, "r", encoding="utf-8") as handle:
164
+ data = json.load(handle)
165
+ if isinstance(data, dict):
166
+ return data
167
+ except (OSError, json.JSONDecodeError):
168
+ pass
169
+ return {}
170
+
171
+
172
  def load_session_settings(target_dir: str) -> Dict[str, Any]:
173
  settings_path = os.path.join(target_dir, "session_settings.json")
174
  if not os.path.exists(settings_path):
 
196
  return label.replace('.', '_').replace(':', '').replace(' ', '_')
197
 
198
 
199
+ def select_top_k_frames(predictions: Dict[str, np.ndarray], images_dir: str, top_k: int) -> list[Dict[str, Any]]:
200
+ if top_k <= 0:
201
+ return []
202
+
203
+ if not os.path.isdir(images_dir):
204
+ return []
205
+
206
+ image_files = sorted(
207
+ [fname for fname in os.listdir(images_dir) if not fname.startswith('.')]
208
+ )
209
+
210
+ extrinsics = predictions.get("extrinsic")
211
+ if extrinsics is None:
212
+ return []
213
+
214
+ num_frames = extrinsics.shape[0]
215
+ if num_frames == 0:
216
+ return []
217
+
218
+ top_k = min(top_k, num_frames)
219
+
220
+ def _camera_position(extr: np.ndarray) -> np.ndarray:
221
+ R = extr[:, :3]
222
+ t = extr[:, 3]
223
+ return (-R.T @ t).astype(np.float64)
224
+
225
+ positions = np.array([_camera_position(extrinsics[i]) for i in range(num_frames)])
226
+
227
+ forward_vectors = np.array([extrinsics[i][2, :3] for i in range(num_frames)])
228
+ forward_norms = np.linalg.norm(forward_vectors, axis=1, keepdims=True)
229
+ forward_vectors = np.divide(forward_vectors, forward_norms, out=np.zeros_like(forward_vectors), where=forward_norms > 0)
230
+
231
+ conf_tensor = predictions.get("world_points_conf")
232
+ if conf_tensor is None:
233
+ conf_tensor = predictions.get("depth_conf")
234
+
235
+ quality_scores = np.zeros(num_frames, dtype=np.float64)
236
+ coverage_scores = np.zeros(num_frames, dtype=np.float64)
237
+
238
+ for idx in range(num_frames):
239
+ if conf_tensor is not None:
240
+ conf = conf_tensor[idx].reshape(-1)
241
+ if conf.size:
242
+ conf = conf[~np.isnan(conf)]
243
+ if conf.size:
244
+ quality_scores[idx] = float(np.mean(conf))
245
+ high_thresh = np.percentile(conf, 75)
246
+ coverage_scores[idx] = float(np.mean(conf >= high_thresh))
247
+ continue
248
+ quality_scores[idx] = 0.0
249
+ coverage_scores[idx] = 0.0
250
+ else:
251
+ quality_scores[idx] = 1.0
252
+ coverage_scores[idx] = 1.0
253
+
254
+ max_cov = coverage_scores.max()
255
+ if max_cov > 0:
256
+ coverage_scores = coverage_scores / max_cov
257
+ else:
258
+ coverage_scores = np.ones_like(coverage_scores)
259
+
260
+ base_scores = quality_scores * (0.5 + 0.5 * coverage_scores)
261
+
262
+ indices = list(range(num_frames))
263
+ indices.sort(key=lambda idx: base_scores[idx], reverse=True)
264
+
265
+ bbox_min = positions.min(axis=0)
266
+ bbox_max = positions.max(axis=0)
267
+ scene_scale = float(np.linalg.norm(bbox_max - bbox_min))
268
+ pos_threshold = max(0.1, 0.1 * scene_scale)
269
+ ori_threshold = 15.0
270
+
271
+ selected = []
272
+ for idx in indices:
273
+ if not selected:
274
+ selected.append(idx)
275
+ else:
276
+ accept = False
277
+ min_dist = min(np.linalg.norm(positions[idx] - positions[j]) for j in selected)
278
+ max_angle = max(
279
+ np.degrees(
280
+ np.arccos(
281
+ np.clip(np.dot(forward_vectors[idx], forward_vectors[j]), -1.0, 1.0)
282
+ )
283
+ )
284
+ for j in selected
285
+ )
286
+
287
+ if min_dist >= pos_threshold or max_angle >= ori_threshold:
288
+ accept = True
289
+ elif len(selected) < max(1, top_k // 3):
290
+ accept = True
291
+
292
+ if accept:
293
+ selected.append(idx)
294
+
295
+ if len(selected) >= top_k:
296
+ break
297
+
298
+ if len(selected) < top_k:
299
+ for idx in indices:
300
+ if idx not in selected:
301
+ selected.append(idx)
302
+ if len(selected) >= top_k:
303
+ break
304
+
305
+ selected = sorted(selected[:top_k])
306
+
307
+ records = []
308
+ for idx in selected:
309
+ filename = image_files[idx] if idx < len(image_files) else f"frame_{idx:06d}"
310
+ records.append(
311
+ {
312
+ "index": int(idx),
313
+ "filename": filename,
314
+ "score": float(base_scores[idx]),
315
+ "mean_confidence": float(quality_scores[idx]),
316
+ "coverage_ratio": float(coverage_scores[idx]),
317
+ }
318
+ )
319
+
320
+ return records
321
+
322
+
323
  # -------------------------------------------------------------------------
324
  # 1) Core model inference
325
  # -------------------------------------------------------------------------
 
691
  session.clear()
692
 
693
  try:
694
+ session.load_cache(kv_cache_path, device=image_tensor.device)
 
 
695
 
696
+ existing_predictions = session.get_all_predictions()
697
+ existing_frames = 0
698
+ for value in existing_predictions.values():
699
+ if isinstance(value, torch.Tensor) and value.dim() >= 2:
700
+ existing_frames = max(existing_frames, value.shape[1])
701
 
702
+ with torch.no_grad():
703
+ session.forward_stream(image_tensor)
704
 
705
+ localized_predictions = session.get_all_predictions()
706
  except Exception as exc:
707
  session.clear()
708
+ if image_tensor.device.type == "cuda":
709
+ torch.cuda.empty_cache()
710
  return (f"Localization failed: {exc}", gr.update())
711
 
712
  def _extract_frame(tensor: torch.Tensor, index: int) -> np.ndarray:
 
846
  summary_lines.append(f"Warning: failed to update GLB preview ({exc})")
847
 
848
  session.clear()
849
+ if image_tensor.device.type == "cuda":
850
+ torch.cuda.empty_cache()
851
 
852
  return ("\n".join(summary_lines), localization_glb_path if localization_glb_path else gr.update())
853
 
 
881
  # Prepare frame_filter dropdown
882
  target_dir_images = os.path.join(target_dir, "images")
883
  frame_filter_choices = build_frame_filter_choices(target_dir_images)
884
+ config = load_stream_session_config()
885
+ top_k_frames = int(config.get("top_k_frames", 0) or 0)
886
 
887
  print("Running run_model...")
888
  with torch.no_grad():
 
892
  prediction_save_path = os.path.join(target_dir, "predictions.npz")
893
  np.savez(prediction_save_path, **predictions)
894
 
895
+ selected_frames = select_top_k_frames(predictions, target_dir_images, top_k_frames)
896
+ selected_frames_path = os.path.join(target_dir, "selected_frames.json")
897
+ if selected_frames:
898
+ try:
899
+ with open(selected_frames_path, "w", encoding="utf-8") as handle:
900
+ json.dump({"top_k": top_k_frames, "frames": selected_frames}, handle, indent=2)
901
+ except OSError as exc:
902
+ print(f"Failed to write selected frames: {exc}")
903
+ elif os.path.exists(selected_frames_path):
904
+ try:
905
+ os.remove(selected_frames_path)
906
+ except OSError:
907
+ pass
908
+
909
  frame_filter_value = frame_filter if frame_filter is not None else "All"
910
 
911
  session_settings = {
 
919
  "mask_sky": bool(mask_sky),
920
  "prediction_mode": prediction_mode,
921
  }
922
+ session_settings["top_k_frames"] = top_k_frames
923
+ if selected_frames:
924
+ session_settings["selected_frames"] = [frame["filename"] for frame in selected_frames]
925
  try:
926
  with open(os.path.join(target_dir, "session_settings.json"), "w", encoding="utf-8") as handle:
927
  json.dump(session_settings, handle, indent=2)
 
1199
  session_state_output = gr.File(label="Download Session State", interactive=False)
1200
  localization_output = gr.Textbox(label="Localization Result", lines=8, interactive=False)
1201
 
1202
+ with gr.Row():
1203
+ localization_image_input = gr.File(
1204
+ label="Localize Single Image",
1205
+ file_types=[".png", ".jpg", ".jpeg", ".bmp", ".webp"],
1206
+ interactive=True,
1207
+ )
1208
+ localize_button = gr.Button("Localize Image", variant="secondary")
1209
+
1210
  with gr.Row():
1211
  submit_btn = gr.Button("Reconstruct", scale=1, variant="primary")
1212
  clear_btn = gr.ClearButton(
 
1215
  input_images,
1216
  input_zip,
1217
  session_state_input,
1218
+ localization_image_input,
1219
  reconstruction_output,
1220
  log_output,
1221
  target_dir_output,
 
1260
  mask_black_bg = gr.Checkbox(label="Filter Black Background", value=False)
1261
  mask_white_bg = gr.Checkbox(label="Filter White Background", value=False)
1262
 
 
 
 
 
 
 
 
 
1263
  # ---------------------- Examples section ----------------------
1264
  def build_examples_from_folder():
1265
  examples_root = "examples"
configs/stream_session.json CHANGED
@@ -1,3 +1,4 @@
1
  {
2
- "window_size": 25
 
3
  }
 
1
  {
2
+ "window_size": 5,
3
+ "top_k_frames": 18
4
  }