brian4dwell commited on
Commit
de1bede
·
1 Parent(s): 1aa7485

keyframe selection

Browse files
design_docs/keyframe_selection_motion_coverage.md ADDED
@@ -0,0 +1,195 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ```markdown
2
+ # Design Doc: Motion- & Coverage-Aware Key Frame Selection
3
+
4
+ **Author:** Brian Clark
5
+ **Last Updated:** 2025-11-07
6
+ **Target Components:** `_compute_selected_frames`, Stream3R inference outputs
7
+ **Goal:** Replace naive FPS sampling with a strategy that keeps only frames providing new camera poses and meaningful scene coverage, reducing point-cloud clutter and improving 2D scene graphs.
8
+
9
+ ---
10
+
11
+ ## 1. Overview
12
+
13
+ We combine two complementary signals:
14
+
15
+ 1. **Motion-aware downsampling (Option A):** ensure key frames are spaced by actual camera movement (SE(3) distance), not just time.
16
+ 2. **Coverage-driven selection (Option B):** prefer frames that contribute new high-confidence geometry after Stream3R processing.
17
+
18
+ The final key frame list is built by enforcing motion diversity first, then greedily adding frames with the largest uncovered coverage gain until we reach a target budget.
19
+
20
+ ---
21
+
22
+ ## 2. Inputs & Prerequisites
23
+
24
+ - Per-frame camera extrinsics (`extrinsic`) from Stream3R.
25
+ - Optional per-frame quality metrics (blur/confidence) from camera head.
26
+ - Stream3R `world_points` and `world_points_conf` (or post-voxel-reduction point maps) to evaluate coverage.
27
+ - Library support: NumPy + SciPy (for SE(3) distances), optional Open3D or custom KD-tree for point coverage.
28
+
29
+ ---
30
+
31
+ ## 3. Motion Metrics (Option A)
32
+
33
+ ### 3.1 Pose difference
34
+ - Compute translation delta: `||t_i - t_j||`.
35
+ - Compute rotation delta: angle of `R_i * R_j^{-1}` via `acos((trace - 1) / 2)`.
36
+ - Combine with weights (e.g., `motion = w_t * Δpos + w_r * Δrot`), with defaults `w_t=1.0`, `w_r=0.5 m/rad`.
37
+
38
+ ### 3.2 Greedy spacing (temporal pass)
39
+ 1. Initialize with first frame as key.
40
+ 2. For each subsequent frame:
41
+ - Accumulate motion distance from last key (sum of per-frame deltas).
42
+ - If distance ≥ `motion_threshold` OR time since last key ≥ `max_gap`, mark as key.
43
+ - Optional: enforce minimum gap (`min_gap_time`) to avoid bursty picks.
44
+ 3. Result: `motion_keys` – baseline set with adequate pose coverage.
45
+
46
+ ### 3.3 Quality gating (optional)
47
+ - Discard frames with low focus / brightness (if metadata available).
48
+ - Use confidence summary (mean `world_points_conf`) to veto worst frames before motion selection.
49
+
50
+ ---
51
+
52
+ ## 4. Coverage Metrics (Option B)
53
+
54
+ ### 4.1 Coverage data
55
+ - For each frame, gather the subset of point cloud indices it contributes above a confidence threshold.
56
+ - Option 1: Use raw `world_points_conf` mask per frame.
57
+ - Option 2: After voxel reduction, store voxel IDs touched by each frame (during inference loop).
58
+
59
+ ### 4.2 Greedy coverage selection
60
+ 1. Start with `coverage_keys = []`, `covered = set()`.
61
+ 2. For each candidate frame (ordered by motion selection or confidence):
62
+ - Compute `gain = new_points / total_points`, where `new_points = {points not in covered}`.
63
+ - Keep a priority queue sorted by gain (breaking ties via motion distance or confidence).
64
+ 3. While `coverage_keys` size < desired target (`top_k` or auto budget):
65
+ - Pop frame with highest gain.
66
+ - Add to `coverage_keys` and update `covered`.
67
+ - Recompute gains lazily or maintain stored values (since coverage shrinks).
68
+ 4. Merge with `motion_keys`: `selected = sorted(motion_keys ∪ coverage_keys)` preserving chronological order.
69
+
70
+ ### 4.3 Parameters
71
+ | Parameter | Purpose | Default |
72
+ |-----------|---------|---------|
73
+ | `coverage_conf_thres` | Minimum confidence per point | 0.3 |
74
+ | `top_k` | Max key frames (if >0) | Provided payload |
75
+ | `auto_budget_seconds` | If `top_k` not set, target frames per scene duration | 0.4 fps (≈12 frames for 30 s) |
76
+ | `min_gain_ratio` | Stop if marginal gain < threshold | 0.01 |
77
+
78
+ ---
79
+
80
+ ## 5. Algorithm Outline
81
+
82
+ ```text
83
+ 1. Precompute per-frame metadata:
84
+ - Motion deltas & cumulative distance
85
+ - Frame quality/confidence
86
+ - Coverage contributions (voxel IDs or hashed points)
87
+
88
+ 2. Motion pass:
89
+ motion_keys = greedy_motion_selection(frames, motion_threshold, min_gap, max_gap)
90
+
91
+ 3. Coverage pass:
92
+ candidates = frames filtered by quality & (if large scenes) downsampled using motion_keys as seeds
93
+ coverage_keys = greedy_coverage_selection(candidates, contributions, budget)
94
+
95
+ 4. Combine & finalize:
96
+ selected = sort(unique(motion_keys ∪ coverage_keys))
97
+ if len(selected) > budget: prune lowest coverage gain while keeping motion anchors
98
+ collect metadata (confidence, motion distance, coverage gain) for diagnostics
99
+
100
+ 5. Optional reinflation pass (if enabled) to restore splat density for the selected frames only.
101
+
102
+ 6. Emit diagnostics in `selected_frames.json`.
103
+ ```
104
+
105
+ ---
106
+
107
+ ## 6. Integration Points
108
+
109
+ ### 6.1 `_compute_selected_frames`
110
+ - Extend signature to accept:
111
+ - `frame_records` (already present)
112
+ - `extrinsics`, `world_points`, `world_points_conf`
113
+ - optional `confidence_summary`, `frame_timestamps`
114
+ - Return list of dicts with fields: `frame_id`, `motion_score`, `coverage_gain`, `cum_motion`, etc., so the artifact can explain the reasoning.
115
+
116
+ ### 6.2 Inference loop
117
+ - While iterating frames, record:
118
+ - Pose deltas (store to arrays for later).
119
+ - Coverage bitsets: e.g., hash voxel indices (`np.floor(world_points / voxel_size)`).
120
+ - Quality metrics (mean conf, brightness).
121
+
122
+ ### 6.3 Job artifacts
123
+ - Include selection diagnostics in `selected_frames.json`:
124
+ ```json
125
+ {
126
+ "frame_id": "...",
127
+ "motion_distance": 0.45,
128
+ "coverage_gain": 0.12,
129
+ "decision": "coverage"
130
+ }
131
+ ```
132
+ - Enables auditing the chosen frames.
133
+
134
+ ### 6.4 Two-pass pipeline hook
135
+ - Add a config flag (e.g., `STREAM3R_KEYFRAME_PREPASS`) to toggle a lightweight pre-pass.
136
+ - **Pre-pass steps:**
137
+ 1. Collect frames as usual.
138
+ 2. Run a reduced inference loop (camera head only or full Stream3R with artifact generation disabled) to gather motion and coverage metadata.
139
+ 3. Execute the key-frame selection algorithm to produce selected indices.
140
+ - **Main pass:**
141
+ 1. Filter `frame_records` to the selected indices.
142
+ 2. If the batch size is below a configured maximum, switch inference to full attention; otherwise remain in window mode.
143
+ 3. Run the full artifact pipeline (pointmaps, GLB, reinflation) on the reduced set.
144
+ 4. Persist selection diagnostics alongside artifacts.
145
+ - Provide a fallback path: if the pre-pass fails or returns too few frames, revert to the original sampling strategy so the job still succeeds.
146
+
147
+ ---
148
+
149
+ ## 7. Configuration & Defaults
150
+
151
+ | Setting | Description | Default |
152
+ |---------|-------------|---------|
153
+ | `STREAM3R_KEYFRAME_MOTION_THRESH` | Motion distance (m) to trigger new key | 0.3 |
154
+ | `STREAM3R_KEYFRAME_ROT_THRESH` | Rotation angle (rad) weight | 0.5 |
155
+ | `STREAM3R_KEYFRAME_MIN_GAP` | Minimum time gap (s) | 0.25 |
156
+ | `STREAM3R_KEYFRAME_MAX_GAP` | Max time between keys (s) | 2.0 |
157
+ | `STREAM3R_KEYFRAME_TOP_K` | Max number of key frames | 18 (overridable per payload) |
158
+ | `STREAM3R_KEYFRAME_MIN_GAIN` | Coverage gain stop threshold | 0.01 |
159
+ | `STREAM3R_KEYFRAME_CONF_THRESH` | Confidence threshold for coverage | 0.3 |
160
+
161
+ ---
162
+
163
+ ## 8. Validation Plan
164
+
165
+ 1. **Quantitative**
166
+ - Compare key frame counts vs. baseline (2 fps sampling).
167
+ - Measure point coverage retention (% of original points represented by key frames).
168
+ - Evaluate overlap with heuristic linear sampling (should be reduced).
169
+ 2. **Qualitative**
170
+ - Visual inspection: point cloud clutter reduction, better 2D scene graph clarity.
171
+ - Spot-check key-frame artifacts (diagnostic metadata) to ensure decisions align with expectations.
172
+ 3. **Performance**
173
+ - Ensure coverage computations remain efficient (hash-based; track memory usage).
174
+ - Add timing logs in `_compute_selected_frames`.
175
+
176
+ ---
177
+
178
+ ## 9. Future Extensions
179
+
180
+ - Integrate image-content heuristics (entropy, saliency) into coverage scoring.
181
+ - Multi-pass selection: first ensure 360° orientation coverage, then fill gaps.
182
+ - Adaptive budgets based on room size / path length (use total motion distance).
183
+ - Optionally, trigger reinflation of selected frames only for visualization.
184
+
185
+ ---
186
+
187
+ **Deliverables**
188
+ 1. Updated `_compute_selected_frames` with motion + coverage logic.
189
+ 2. Supporting utilities for pose distance and coverage hashing.
190
+ 3. Config hooks & optional environment variables.
191
+ 4. Tests covering edge cases (no motion, tiny coverage gains, payload `top_k` override).
192
+ 5. Documentation updates describing new behavior and tuning knobs.
193
+
194
+ ---
195
+ ```
stream3r/utils/__pycache__/visual_utils.cpython-311.pyc CHANGED
Binary files a/stream3r/utils/__pycache__/visual_utils.cpython-311.pyc and b/stream3r/utils/__pycache__/visual_utils.cpython-311.pyc differ
 
stream3r/utils/visual_utils.py CHANGED
@@ -325,6 +325,8 @@ def predictions_to_glb(
325
  reinflate_jitter_mode: str = "cube",
326
  reinflate_jitter_sigma: float = 0.35,
327
  reinflate_seed: int | None = None,
 
 
328
  ) -> trimesh.Scene:
329
  """
330
  Converts predictions to a 3D scene represented as a GLB file.
@@ -360,6 +362,8 @@ def predictions_to_glb(
360
  reinflate_jitter_mode (str): "cube" (uniform jitter) or "gaussian".
361
  reinflate_jitter_sigma (float): Jitter strength as a fraction of voxel size.
362
  reinflate_seed (Optional[int]): RNG seed for deterministic reinflation.
 
 
363
 
364
  Returns:
365
  trimesh.Scene: Processed 3D scene containing point cloud and cameras
@@ -523,6 +527,23 @@ def predictions_to_glb(
523
  colors_rgb = colors_rgb[conf_mask]
524
  conf_used = conf[conf_mask]
525
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
526
  if effective_voxel_size is not None and voxel_after_conf and vertices_3d.size:
527
  before_count = vertices_3d.shape[0]
528
  vertices_3d, colors_rgb, conf_used = voxel_reduce(
 
325
  reinflate_jitter_mode: str = "cube",
326
  reinflate_jitter_sigma: float = 0.35,
327
  reinflate_seed: int | None = None,
328
+ ceiling_percentile: float | None = None,
329
+ ceiling_margin: float = 0.05,
330
  ) -> trimesh.Scene:
331
  """
332
  Converts predictions to a 3D scene represented as a GLB file.
 
362
  reinflate_jitter_mode (str): "cube" (uniform jitter) or "gaussian".
363
  reinflate_jitter_sigma (float): Jitter strength as a fraction of voxel size.
364
  reinflate_seed (Optional[int]): RNG seed for deterministic reinflation.
365
+ ceiling_percentile (Optional[float]): Remove points above this Z percentile (0-100).
366
+ ceiling_margin (float): Margin subtracted from percentile cutoff (meters).
367
 
368
  Returns:
369
  trimesh.Scene: Processed 3D scene containing point cloud and cameras
 
527
  colors_rgb = colors_rgb[conf_mask]
528
  conf_used = conf[conf_mask]
529
 
530
+ if ceiling_percentile is not None and vertices_3d.size:
531
+ try:
532
+ percentile_value = float(ceiling_percentile)
533
+ except (TypeError, ValueError):
534
+ percentile_value = None
535
+ if percentile_value is not None and 0.0 < percentile_value < 100.0:
536
+ cutoff = float(np.percentile(vertices_3d[:, 2], percentile_value))
537
+ margin = float(max(0.0, ceiling_margin))
538
+ threshold = cutoff - margin
539
+ keep_mask = vertices_3d[:, 2] < threshold
540
+ if not np.any(keep_mask):
541
+ keep_mask = vertices_3d[:, 2] <= cutoff
542
+ if np.any(keep_mask) and np.count_nonzero(keep_mask) < vertices_3d.shape[0]:
543
+ vertices_3d = vertices_3d[keep_mask]
544
+ colors_rgb = colors_rgb[keep_mask]
545
+ conf_used = conf_used[keep_mask]
546
+
547
  if effective_voxel_size is not None and voxel_after_conf and vertices_3d.size:
548
  before_count = vertices_3d.shape[0]
549
  vertices_3d, colors_rgb, conf_used = voxel_reduce(
stream3r/worker/config.py CHANGED
@@ -118,6 +118,17 @@ class WorkerSettings:
118
  max_frames_per_job: int = 0
119
  default_job_timeout: int = 45 * 60
120
  upload_session_cache: bool = True
 
 
 
 
 
 
 
 
 
 
 
121
 
122
  @classmethod
123
  def from_env(cls) -> "WorkerSettings":
@@ -216,6 +227,39 @@ class WorkerSettings:
216
  "upload_session_cache": _env_bool(
217
  "STREAM3R_UPLOAD_CACHE", base.upload_session_cache
218
  ),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
219
  }
220
 
221
  return cls(**kwargs)
 
118
  max_frames_per_job: int = 0
119
  default_job_timeout: int = 45 * 60
120
  upload_session_cache: bool = True
121
+ keyframe_prepass_enabled: bool = True
122
+ keyframe_motion_threshold: float = 0.4
123
+ keyframe_rotation_weight: float = 0.5
124
+ keyframe_min_gap_frames: int = 2
125
+ keyframe_max_gap_frames: int = 45
126
+ keyframe_default_top_k: int = 16
127
+ keyframe_coverage_confidence: float = 0.3
128
+ keyframe_coverage_voxel_size: float = 0.05
129
+ keyframe_coverage_max_points: int = 5000
130
+ keyframe_min_gain_ratio: float = 0.01
131
+ keyframe_full_mode_max_frames: int = 16
132
 
133
  @classmethod
134
  def from_env(cls) -> "WorkerSettings":
 
227
  "upload_session_cache": _env_bool(
228
  "STREAM3R_UPLOAD_CACHE", base.upload_session_cache
229
  ),
230
+ "keyframe_prepass_enabled": _env_bool(
231
+ "STREAM3R_KEYFRAME_PREPASS", base.keyframe_prepass_enabled
232
+ ),
233
+ "keyframe_motion_threshold": float(
234
+ os.getenv("STREAM3R_KEYFRAME_MOTION_THRESH", base.keyframe_motion_threshold)
235
+ ),
236
+ "keyframe_rotation_weight": float(
237
+ os.getenv("STREAM3R_KEYFRAME_ROT_WEIGHT", base.keyframe_rotation_weight)
238
+ ),
239
+ "keyframe_min_gap_frames": _env_int(
240
+ "STREAM3R_KEYFRAME_MIN_GAP_FRAMES", base.keyframe_min_gap_frames
241
+ ),
242
+ "keyframe_max_gap_frames": _env_int(
243
+ "STREAM3R_KEYFRAME_MAX_GAP_FRAMES", base.keyframe_max_gap_frames
244
+ ),
245
+ "keyframe_default_top_k": _env_int(
246
+ "STREAM3R_KEYFRAME_TOP_K", base.keyframe_default_top_k
247
+ ),
248
+ "keyframe_coverage_confidence": float(
249
+ os.getenv("STREAM3R_KEYFRAME_CONF_THRESH", base.keyframe_coverage_confidence)
250
+ ),
251
+ "keyframe_coverage_voxel_size": float(
252
+ os.getenv("STREAM3R_KEYFRAME_VOXEL_SIZE", base.keyframe_coverage_voxel_size)
253
+ ),
254
+ "keyframe_coverage_max_points": _env_int(
255
+ "STREAM3R_KEYFRAME_MAX_POINTS", base.keyframe_coverage_max_points
256
+ ),
257
+ "keyframe_min_gain_ratio": float(
258
+ os.getenv("STREAM3R_KEYFRAME_MIN_GAIN", base.keyframe_min_gain_ratio)
259
+ ),
260
+ "keyframe_full_mode_max_frames": _env_int(
261
+ "STREAM3R_KEYFRAME_FULL_MAX_FRAMES", base.keyframe_full_mode_max_frames
262
+ ),
263
  }
264
 
265
  return cls(**kwargs)
stream3r/worker/tasks.py CHANGED
@@ -63,6 +63,13 @@ class FrameRecord:
63
  metadata: dict[str, Any] = field(default_factory=dict)
64
 
65
 
 
 
 
 
 
 
 
66
  class ProgressTracker:
67
  """Aggregates frame progress to percentage updates."""
68
 
@@ -542,6 +549,206 @@ def _write_selected_frames(
542
  return runtime.storage.upload_file(local_file, key, content_type="application/json")
543
 
544
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
545
  def _compute_selected_frames(
546
  predictions: Mapping[str, np.ndarray],
547
  frame_records: list[FrameRecord],
@@ -567,6 +774,44 @@ def _compute_selected_frames(
567
  return result
568
 
569
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
570
  def _save_scene_glb(
571
  *,
572
  runtime: WorkerRuntime,
@@ -576,6 +821,17 @@ def _save_scene_glb(
576
  payload: Mapping[str, Any],
577
  ) -> str:
578
  local_file = temp_dir / runtime.settings.scene_glb_filename
 
 
 
 
 
 
 
 
 
 
 
579
  scene = predictions_to_glb(
580
  dict(predictions),
581
  conf_thres=float(payload.get("conf_thres", 3.0)),
@@ -586,6 +842,8 @@ def _save_scene_glb(
586
  mask_sky=_as_bool(payload.get("mask_sky"), False),
587
  target_dir=str(temp_dir),
588
  prediction_mode=payload.get("prediction_mode", "Predicted Pointmap"),
 
 
589
  )
590
  scene.export(file_obj=str(local_file))
591
  key = runtime.storage.build_key(
@@ -760,6 +1018,22 @@ def _handle_pose_pointmap(
760
  "frames": core["frames"],
761
  }
762
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
763
  result_url = _upload_result_record(
764
  runtime=runtime,
765
  scene_id=scene_id,
@@ -894,6 +1168,40 @@ def _execute_job(job_type: str, payload: Mapping[str, Any], handler: JobHandler)
894
  temp_path = Path(tmp_dir)
895
  frame_records = _collect_frames(runtime, scene_id, payload, temp_path)
896
  log_progress(f"collected frames ({len(frame_records)} items)")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
897
  cache_path = temp_path / runtime.settings.session_cache_filename if streaming else None
898
 
899
  tracker = ProgressTracker(runtime, job_meta)
@@ -1022,8 +1330,13 @@ def _handle_model_build(
1022
 
1023
  artifacts = dict(core["artifacts"])
1024
 
1025
- top_k = _as_int(payload.get("top_k_frames") or payload.get("top_k"), 0)
1026
- selected_frames = _compute_selected_frames(predictions, frame_records, top_k)
 
 
 
 
 
1027
  selected_frames_url = _write_selected_frames(
1028
  runtime=runtime,
1029
  scene_id=scene_id,
 
63
  metadata: dict[str, Any] = field(default_factory=dict)
64
 
65
 
66
+ @dataclass(slots=True)
67
+ class KeyframeSelectionResult:
68
+ indices: list[int]
69
+ diagnostics: list[dict[str, Any]]
70
+ top_k: int
71
+
72
+
73
  class ProgressTracker:
74
  """Aggregates frame progress to percentage updates."""
75
 
 
549
  return runtime.storage.upload_file(local_file, key, content_type="application/json")
550
 
551
 
552
+ def _camera_poses(extrinsic: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
553
+ matrices = np.asarray(extrinsic, dtype=np.float64)
554
+ if matrices.ndim != 3 or matrices.shape[1:] != (3, 4):
555
+ raise ValueError("Extrinsic array must have shape (N, 3, 4)")
556
+ count = matrices.shape[0]
557
+ rotations = np.empty((count, 3, 3), dtype=np.float64)
558
+ translations = np.empty((count, 3), dtype=np.float64)
559
+ for idx in range(count):
560
+ mat = np.eye(4, dtype=np.float64)
561
+ mat[:3, :4] = matrices[idx]
562
+ cam_to_world = np.linalg.inv(mat)
563
+ rotations[idx] = cam_to_world[:3, :3]
564
+ translations[idx] = cam_to_world[:3, 3]
565
+ return rotations, translations
566
+
567
+
568
+ def _compute_motion_deltas(rotations: np.ndarray, translations: np.ndarray, rot_weight: float) -> np.ndarray:
569
+ count = rotations.shape[0]
570
+ deltas = np.zeros(count, dtype=np.float64)
571
+ if count <= 1:
572
+ return deltas
573
+ for idx in range(1, count):
574
+ delta_t = np.linalg.norm(translations[idx] - translations[idx - 1])
575
+ rel = rotations[idx - 1].T @ rotations[idx]
576
+ trace = np.clip((np.trace(rel) - 1.0) / 2.0, -1.0, 1.0)
577
+ delta_r = float(np.arccos(trace))
578
+ deltas[idx] = delta_t + rot_weight * delta_r
579
+ return deltas
580
+
581
+
582
+ def _hash_quantized_voxels(coords: np.ndarray) -> np.ndarray:
583
+ coords = coords.astype(np.int64, copy=False)
584
+ primes = np.array([73856093, 19349663, 83492791], dtype=np.int64)
585
+ return coords @ primes
586
+
587
+
588
+ def _frame_voxel_sets(
589
+ world_points: np.ndarray,
590
+ confidence: np.ndarray,
591
+ *,
592
+ threshold: float,
593
+ voxel_size: float,
594
+ max_points: int,
595
+ ) -> tuple[list[set[int]], int]:
596
+ rng = np.random.default_rng(42)
597
+ frames = world_points.shape[0]
598
+ voxel_sets: list[set[int]] = []
599
+ global_union: set[int] = set()
600
+ if voxel_size <= 0.0:
601
+ return [set() for _ in range(frames)], 0
602
+ for idx in range(frames):
603
+ conf_frame = confidence[idx]
604
+ mask = conf_frame >= threshold
605
+ if not np.any(mask):
606
+ voxel_sets.append(set())
607
+ continue
608
+ points = world_points[idx][mask]
609
+ if points.shape[0] > max_points:
610
+ sample_idx = rng.choice(points.shape[0], max_points, replace=False)
611
+ points = points[sample_idx]
612
+ quantized = np.floor(points / voxel_size).astype(np.int64, copy=False)
613
+ hashes = np.unique(_hash_quantized_voxels(quantized))
614
+ voxel_set = set(int(v) for v in hashes.tolist())
615
+ voxel_sets.append(voxel_set)
616
+ global_union.update(voxel_set)
617
+ return voxel_sets, len(global_union)
618
+
619
+
620
+ def _select_motion_indices(
621
+ motion_deltas: np.ndarray,
622
+ *,
623
+ threshold: float,
624
+ min_gap: int,
625
+ max_gap: int,
626
+ ) -> tuple[list[int], dict[int, dict[str, float]]]:
627
+ total_frames = motion_deltas.shape[0]
628
+ if total_frames == 0:
629
+ return [], {}
630
+ selected = [0]
631
+ diagnostics: dict[int, dict[str, float]] = {0: {"motion_delta": 0.0, "cum_motion": 0.0}}
632
+ cumulative = 0.0
633
+ gap = 0
634
+ for idx in range(1, total_frames):
635
+ delta = float(motion_deltas[idx])
636
+ cumulative += delta
637
+ gap += 1
638
+ if gap < max(1, min_gap):
639
+ continue
640
+ should_select = cumulative >= threshold
641
+ if max_gap > 0 and gap >= max_gap:
642
+ should_select = True
643
+ if should_select:
644
+ selected.append(idx)
645
+ diagnostics[idx] = {"motion_delta": delta, "cum_motion": cumulative}
646
+ cumulative = 0.0
647
+ gap = 0
648
+ if selected[-1] != total_frames - 1:
649
+ selected.append(total_frames - 1)
650
+ diagnostics.setdefault(total_frames - 1, {"motion_delta": float(motion_deltas[-1]), "cum_motion": cumulative})
651
+ return selected, diagnostics
652
+
653
+
654
+ def _select_keyframes_motion_coverage(
655
+ frame_records: list[FrameRecord],
656
+ predictions: Mapping[str, np.ndarray],
657
+ settings: WorkerSettings,
658
+ requested_top_k: int,
659
+ ) -> KeyframeSelectionResult | None:
660
+ extrinsic = np.asarray(predictions.get("extrinsic"))
661
+ if extrinsic.size == 0:
662
+ return None
663
+ rotations, translations = _camera_poses(extrinsic)
664
+ motion_deltas = _compute_motion_deltas(rotations, translations, settings.keyframe_rotation_weight)
665
+ motion_indices, motion_diag = _select_motion_indices(
666
+ motion_deltas,
667
+ threshold=settings.keyframe_motion_threshold,
668
+ min_gap=max(1, settings.keyframe_min_gap_frames),
669
+ max_gap=max(0, settings.keyframe_max_gap_frames),
670
+ )
671
+ total_frames = len(frame_records)
672
+ confidence = _pose_confidence(predictions)
673
+ world_points = predictions.get("world_points")
674
+ if world_points is None:
675
+ world_points = predictions.get("world_points_from_depth")
676
+ voxel_sets: list[set[int]] = [set() for _ in range(total_frames)]
677
+ total_voxels = 0
678
+ mean_conf = np.zeros(total_frames, dtype=np.float32)
679
+ if confidence is not None:
680
+ mean_conf = confidence.reshape(confidence.shape[0], -1).mean(axis=1)
681
+ if confidence is not None and world_points is not None:
682
+ voxel_sets, total_voxels = _frame_voxel_sets(
683
+ np.asarray(world_points),
684
+ np.asarray(confidence),
685
+ threshold=settings.keyframe_coverage_confidence,
686
+ voxel_size=settings.keyframe_coverage_voxel_size,
687
+ max_points=max(1000, settings.keyframe_coverage_max_points),
688
+ )
689
+ total_voxels = max(total_voxels, 1)
690
+ top_k = requested_top_k if requested_top_k > 0 else settings.keyframe_default_top_k
691
+ top_k = max(min(top_k, total_frames), len(motion_indices))
692
+ selected_set: set[int] = set(motion_indices)
693
+ diagnostics: dict[int, dict[str, Any]] = {}
694
+ covered: set[int] = set()
695
+ for idx in motion_indices:
696
+ gain_count = len(voxel_sets[idx] - covered) if voxel_sets[idx] else 0
697
+ gain_ratio = gain_count / total_voxels
698
+ covered.update(voxel_sets[idx])
699
+ diagnostics[idx] = {
700
+ "frame_id": frame_records[idx].frame_id,
701
+ "frame_index": frame_records[idx].index,
702
+ "reason": "motion",
703
+ "motion_delta": float(motion_deltas[idx]),
704
+ "cum_motion": float(motion_diag.get(idx, {}).get("cum_motion", 0.0)),
705
+ "coverage_gain_ratio": float(gain_ratio),
706
+ "coverage_gain_count": int(gain_count),
707
+ "mean_confidence": float(mean_conf[idx]) if confidence is not None else None,
708
+ }
709
+ if len(selected_set) < top_k and total_voxels > 0:
710
+ min_gain_ratio = settings.keyframe_min_gain_ratio
711
+ remaining = [i for i in range(total_frames) if i not in selected_set and voxel_sets[i]]
712
+ while remaining and len(selected_set) < top_k:
713
+ best_idx = -1
714
+ best_gain = -1
715
+ best_ratio = -1.0
716
+ for idx in remaining:
717
+ gain = len(voxel_sets[idx] - covered)
718
+ if gain <= 0:
719
+ continue
720
+ ratio = gain / total_voxels
721
+ if ratio > best_ratio or (np.isclose(ratio, best_ratio) and gain > best_gain):
722
+ best_idx = idx
723
+ best_gain = gain
724
+ best_ratio = ratio
725
+ if best_idx == -1 or best_ratio < min_gain_ratio:
726
+ break
727
+ selected_set.add(best_idx)
728
+ covered.update(voxel_sets[best_idx])
729
+ diagnostics[best_idx] = {
730
+ "frame_id": frame_records[best_idx].frame_id,
731
+ "frame_index": frame_records[best_idx].index,
732
+ "reason": "coverage",
733
+ "motion_delta": float(motion_deltas[best_idx]),
734
+ "cum_motion": float(motion_diag.get(best_idx, {}).get("cum_motion", 0.0)),
735
+ "coverage_gain_ratio": float(best_ratio),
736
+ "coverage_gain_count": int(best_gain),
737
+ "mean_confidence": float(mean_conf[best_idx]) if confidence is not None else None,
738
+ }
739
+ remaining.remove(best_idx)
740
+ if requested_top_k > 0 and len(selected_set) > requested_top_k:
741
+ coverage_candidates = [idx for idx in selected_set if diagnostics[idx]["reason"] == "coverage"]
742
+ coverage_candidates.sort(key=lambda idx: diagnostics[idx].get("coverage_gain_ratio", 0.0))
743
+ while len(selected_set) > requested_top_k and coverage_candidates:
744
+ drop_idx = coverage_candidates.pop(0)
745
+ selected_set.remove(drop_idx)
746
+ diagnostics.pop(drop_idx, None)
747
+ final_indices = sorted(selected_set)
748
+ final_diags = [diagnostics[idx] for idx in final_indices]
749
+ return KeyframeSelectionResult(indices=final_indices, diagnostics=final_diags, top_k=len(final_indices))
750
+
751
+
752
  def _compute_selected_frames(
753
  predictions: Mapping[str, np.ndarray],
754
  frame_records: list[FrameRecord],
 
774
  return result
775
 
776
 
777
+ def _run_keyframe_prepass(
778
+ *,
779
+ runtime: WorkerRuntime,
780
+ payload: Mapping[str, Any],
781
+ frame_records: list[FrameRecord],
782
+ mode: str,
783
+ streaming: bool,
784
+ window_size: int | None,
785
+ ) -> KeyframeSelectionResult | None:
786
+ if len(frame_records) <= 1:
787
+ return None
788
+ settings = runtime.settings
789
+ top_k_payload = _as_int(payload.get("prepass_top_k") or payload.get("top_k_frames") or payload.get("top_k"), 0)
790
+ try:
791
+ inference = run_stream3r_inference(
792
+ runtime=runtime,
793
+ image_paths=[record.path for record in frame_records],
794
+ mode=mode,
795
+ streaming=streaming,
796
+ cache_output_path=None,
797
+ progress_cb=None,
798
+ window_size=window_size if streaming and mode == "window" else None,
799
+ )
800
+ except Exception:
801
+ logger.exception("Keyframe pre-pass inference failed")
802
+ return None
803
+ try:
804
+ selection = _select_keyframes_motion_coverage(
805
+ frame_records,
806
+ inference.predictions,
807
+ settings,
808
+ requested_top_k=top_k_payload,
809
+ )
810
+ finally:
811
+ del inference
812
+ return selection
813
+
814
+
815
  def _save_scene_glb(
816
  *,
817
  runtime: WorkerRuntime,
 
821
  payload: Mapping[str, Any],
822
  ) -> str:
823
  local_file = temp_dir / runtime.settings.scene_glb_filename
824
+ ceiling_percentile = payload.get("ceiling_percentile")
825
+ try:
826
+ ceiling_percentile_value = float(ceiling_percentile) if ceiling_percentile is not None else None
827
+ except (TypeError, ValueError):
828
+ ceiling_percentile_value = None
829
+ ceiling_margin_value = payload.get("ceiling_margin")
830
+ try:
831
+ ceiling_margin_value = float(ceiling_margin_value) if ceiling_margin_value is not None else 0.05
832
+ except (TypeError, ValueError):
833
+ ceiling_margin_value = 0.05
834
+
835
  scene = predictions_to_glb(
836
  dict(predictions),
837
  conf_thres=float(payload.get("conf_thres", 3.0)),
 
842
  mask_sky=_as_bool(payload.get("mask_sky"), False),
843
  target_dir=str(temp_dir),
844
  prediction_mode=payload.get("prediction_mode", "Predicted Pointmap"),
845
+ ceiling_percentile=ceiling_percentile_value,
846
+ ceiling_margin=ceiling_margin_value,
847
  )
848
  scene.export(file_obj=str(local_file))
849
  key = runtime.storage.build_key(
 
1018
  "frames": core["frames"],
1019
  }
1020
 
1021
+ selected_frames_payload = payload.get("_selected_frames_info")
1022
+ if selected_frames_payload:
1023
+ result_payload["selected_frames"] = list(selected_frames_payload)
1024
+ try:
1025
+ selected_frames_url = _write_selected_frames(
1026
+ runtime=runtime,
1027
+ scene_id=scene_id,
1028
+ selected_frames=list(selected_frames_payload),
1029
+ top_k=_as_int(payload.get("_selected_top_k"), len(selected_frames_payload)),
1030
+ temp_dir=temp_dir,
1031
+ )
1032
+ if selected_frames_url:
1033
+ result_payload["artifacts"]["selected_frames_url"] = selected_frames_url
1034
+ except Exception:
1035
+ logger.exception("Failed to persist selected frames artifact for pose_pointmap job")
1036
+
1037
  result_url = _upload_result_record(
1038
  runtime=runtime,
1039
  scene_id=scene_id,
 
1168
  temp_path = Path(tmp_dir)
1169
  frame_records = _collect_frames(runtime, scene_id, payload, temp_path)
1170
  log_progress(f"collected frames ({len(frame_records)} items)")
1171
+ selection_result: KeyframeSelectionResult | None = None
1172
+ if runtime.settings.keyframe_prepass_enabled and len(frame_records) > 1:
1173
+ log_progress("starting keyframe pre-pass")
1174
+ try:
1175
+ selection_result = _run_keyframe_prepass(
1176
+ runtime=runtime,
1177
+ payload=payload,
1178
+ frame_records=frame_records,
1179
+ mode=mode,
1180
+ streaming=streaming,
1181
+ window_size=window_size,
1182
+ )
1183
+ except Exception:
1184
+ selection_result = None
1185
+ logger.exception("Keyframe pre-pass failed; falling back to full frame set")
1186
+ if selection_result and selection_result.indices:
1187
+ log_progress(
1188
+ f"pre-pass selected {len(selection_result.indices)} frames from {len(frame_records)}"
1189
+ )
1190
+ frame_records = [frame_records[i] for i in selection_result.indices]
1191
+ for new_idx, record in enumerate(frame_records):
1192
+ record.index = new_idx
1193
+ payload["_selected_frames_info"] = selection_result.diagnostics
1194
+ payload["_selected_top_k"] = selection_result.top_k
1195
+ payload["_selected_frame_indices"] = selection_result.indices
1196
+ if len(frame_records) <= runtime.settings.keyframe_full_mode_max_frames:
1197
+ mode = "full"
1198
+ streaming = False
1199
+ window_size = None
1200
+ payload["mode"] = mode
1201
+ payload["streaming"] = streaming
1202
+ else:
1203
+ selection_result = None
1204
+
1205
  cache_path = temp_path / runtime.settings.session_cache_filename if streaming else None
1206
 
1207
  tracker = ProgressTracker(runtime, job_meta)
 
1330
 
1331
  artifacts = dict(core["artifacts"])
1332
 
1333
+ selected_frames_payload = payload.get("_selected_frames_info")
1334
+ if selected_frames_payload:
1335
+ top_k = _as_int(payload.get("_selected_top_k"), len(selected_frames_payload))
1336
+ selected_frames = list(selected_frames_payload)
1337
+ else:
1338
+ top_k = _as_int(payload.get("top_k_frames") or payload.get("top_k"), 0)
1339
+ selected_frames = _compute_selected_frames(predictions, frame_records, top_k)
1340
  selected_frames_url = _write_selected_frames(
1341
  runtime=runtime,
1342
  scene_id=scene_id,
tests/test_voxel_reduction.py CHANGED
@@ -177,3 +177,48 @@ def test_density_filter_points_removes_isolated_samples():
177
 
178
  assert filtered_points.shape[0] < points.shape[0]
179
  assert np.all(filtered_points.max(axis=0) < 0.2)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
177
 
178
  assert filtered_points.shape[0] < points.shape[0]
179
  assert np.all(filtered_points.max(axis=0) < 0.2)
180
+
181
+
182
+ def test_predictions_to_glb_ceiling_filter():
183
+ world_points = np.array(
184
+ [
185
+ [
186
+ [[0.0, 0.0, 0.0], [0.0, 0.0, 1.5]],
187
+ [[0.0, 0.0, 1.6], [0.0, 0.0, 1.7]],
188
+ ]
189
+ ],
190
+ dtype=np.float32,
191
+ )
192
+ predictions = {
193
+ "world_points": world_points,
194
+ "world_points_conf": np.ones((1, 2, 2), dtype=np.float32),
195
+ "world_points_from_depth": world_points,
196
+ "depth_conf": np.ones((1, 2, 2), dtype=np.float32),
197
+ "images": np.ones((1, 2, 2, 3), dtype=np.float32) * 0.5,
198
+ "extrinsic": np.array(
199
+ [
200
+ [
201
+ [1.0, 0.0, 0.0, 0.0],
202
+ [0.0, 1.0, 0.0, 0.0],
203
+ [0.0, 0.0, 1.0, 0.0],
204
+ ]
205
+ ],
206
+ dtype=np.float32,
207
+ ),
208
+ }
209
+
210
+ scene = predictions_to_glb(
211
+ predictions,
212
+ conf_thres=0.0,
213
+ voxel_size=None,
214
+ o3d_denoise=False,
215
+ density_filter=False,
216
+ reinflate_enabled=False,
217
+ ceiling_percentile=90.0,
218
+ ceiling_margin=0.05,
219
+ )
220
+
221
+ assert isinstance(scene, trimesh.Scene)
222
+ point_cloud = next(iter(scene.geometry.values()))
223
+ max_z = point_cloud.vertices[:, 2].max()
224
+ assert max_z < 1.6