Cc commited on
Commit
e340a84
·
1 Parent(s): 4d8122b
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. README.md +35 -9
  2. app.py +5 -0
  3. configs/longstream_infer.yaml +84 -0
  4. demo_gradio.py +332 -0
  5. longstream/.DS_Store +0 -0
  6. longstream/__init__.py +1 -0
  7. longstream/core/__init__.py +0 -0
  8. longstream/core/cli.py +213 -0
  9. longstream/core/infer.py +451 -0
  10. longstream/core/model.py +69 -0
  11. longstream/data/__init__.py +3 -0
  12. longstream/data/dataloader.py +422 -0
  13. longstream/demo/__init__.py +11 -0
  14. longstream/demo/backend.py +495 -0
  15. longstream/demo/common.py +84 -0
  16. longstream/demo/export.py +85 -0
  17. longstream/demo/geometry.py +211 -0
  18. longstream/demo/viewer.py +134 -0
  19. longstream/eval/__init__.py +3 -0
  20. longstream/eval/evaluate.py +551 -0
  21. longstream/eval/io.py +156 -0
  22. longstream/eval/metrics.py +116 -0
  23. longstream/io/__init__.py +0 -0
  24. longstream/io/save_images.py +38 -0
  25. longstream/io/save_points.py +71 -0
  26. longstream/io/save_poses_txt.py +43 -0
  27. longstream/models/__init__.py +3 -0
  28. longstream/models/longstream.py +370 -0
  29. longstream/streaming/__init__.py +0 -0
  30. longstream/streaming/keyframe_selector.py +80 -0
  31. longstream/streaming/refresh.py +217 -0
  32. longstream/streaming/stream_session.py +294 -0
  33. longstream/utils/__init__.py +0 -0
  34. longstream/utils/camera.py +50 -0
  35. longstream/utils/depth.py +36 -0
  36. longstream/utils/hub.py +42 -0
  37. longstream/utils/sky_mask.py +100 -0
  38. longstream/utils/vendor/__init__.py +2 -0
  39. longstream/utils/vendor/croco/LICENSE +52 -0
  40. longstream/utils/vendor/croco/NOTICE +21 -0
  41. longstream/utils/vendor/croco/README.MD +124 -0
  42. longstream/utils/vendor/croco/assets/arch.jpg +0 -0
  43. longstream/utils/vendor/croco/croco-stereo-flow-demo.ipynb +182 -0
  44. longstream/utils/vendor/croco/datasets/__init__.py +2 -0
  45. longstream/utils/vendor/croco/datasets/crops/README.MD +104 -0
  46. longstream/utils/vendor/croco/datasets/crops/extract_crops_from_images.py +175 -0
  47. longstream/utils/vendor/croco/datasets/habitat_sim/README.MD +76 -0
  48. longstream/utils/vendor/croco/datasets/habitat_sim/__init__.py +2 -0
  49. longstream/utils/vendor/croco/datasets/habitat_sim/generate_from_metadata.py +121 -0
  50. longstream/utils/vendor/croco/datasets/habitat_sim/generate_from_metadata_files.py +34 -0
README.md CHANGED
@@ -1,14 +1,40 @@
1
  ---
2
- title: LongStream
3
- emoji: 📊
4
- colorFrom: red
5
- colorTo: purple
6
  sdk: gradio
7
- sdk_version: 6.9.0
8
  app_file: app.py
9
- pinned: false
10
- license: mit
11
- short_description: Demo of LongStream
12
  ---
13
 
14
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: LongStream Demo
 
 
 
3
  sdk: gradio
4
+ sdk_version: 5.44.0
5
  app_file: app.py
6
+ python_version: "3.10"
7
+ startup_duration_timeout: 1h
 
8
  ---
9
 
10
+ # LongStream Demo
11
+
12
+ This repository is the Hugging Face Space package for LongStream.
13
+
14
+ Project page: `https://3dagentworld.github.io/longstream/`
15
+
16
+ ## Space Settings
17
+
18
+ Set these variables in the Space settings before the first run:
19
+
20
+ - `LONGSTREAM_HF_REPO=NicolasCC/LongStream`
21
+ - `LONGSTREAM_HF_FILE=50_longstream.pt`
22
+ - `LONGSTREAM_HF_LOCAL_DIR=checkpoints`
23
+
24
+ Optional:
25
+
26
+ - `LONGSTREAM_HF_REVISION=v0.1.0`
27
+ - `HF_TOKEN=<token>` if the model repo is private
28
+
29
+ ## Entrypoints
30
+
31
+ - `app.py`: stable demo
32
+
33
+
34
+ ## Included Files
35
+
36
+ - `demo_gradio.py`
37
+ - `demo_gradio_interactive.py`
38
+ - `longstream/`
39
+ - `configs/longstream_infer.yaml`
40
+
app.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ from demo_gradio import main
2
+
3
+
4
+ if __name__ == "__main__":
5
+ main()
configs/longstream_infer.yaml ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ device: cuda
2
+
3
+ model:
4
+ checkpoint: checkpoints/50_longstream.pt
5
+ strict_load: false
6
+ hf:
7
+ repo_id: null
8
+ filename: null
9
+ revision: null
10
+ local_dir: checkpoints
11
+ longstream_cfg:
12
+ img_size: 518
13
+ patch_size: 14
14
+ embed_dim: 1024
15
+ window_size: 48
16
+ use_role_embedding: false
17
+ enable_scale_token: true
18
+ disable_keyframe_distinction: true
19
+ use_segment_mask: false
20
+ enable_camera_head: false
21
+ freeze: none
22
+ use_rel_pose_head: true
23
+ rel_pose_head_cfg:
24
+ enabled: true
25
+ keyframe_mode: fixed
26
+ keyframe_stride: 8
27
+ reference_source: pred
28
+ detach_reference: false
29
+ trunk_depth: 4
30
+ pose_mode: SE3
31
+ num_heads: 16
32
+ mlp_ratio: 4
33
+ init_values: 0.01
34
+ trans_act: linear
35
+ quat_act: linear
36
+ use_pair_cross_attn: false
37
+ xattn_temperature: 1.0
38
+ use_precat: false
39
+ use_kf_role_embed: false
40
+ kf_role_embed_init_std: 0.02
41
+ fl_act: relu
42
+ use_global_scale: false
43
+ reinit_camera_head: false
44
+
45
+ inference:
46
+ mode: batch_refresh
47
+ streaming_mode: causal
48
+ window_size: 48
49
+ keyframe_mode: fixed
50
+ keyframe_stride: 8
51
+ refresh: 4
52
+ rel_pose_head_cfg:
53
+ num_iterations: 4
54
+
55
+ data:
56
+ format: generalizable
57
+ data_roots_file: data_roots.txt
58
+ camera: null
59
+ img_path: "path/to/your/image/directory"
60
+ stride: 1
61
+ max_frames: null
62
+ size: 518
63
+ crop: false
64
+ patch_size: 14
65
+
66
+ output:
67
+ root: outputs
68
+ save_videos: true
69
+ save_points: true
70
+ save_frame_points: true
71
+ save_depth: true
72
+ save_images: true
73
+ mask_sky: true
74
+ max_full_pointcloud_points: 2000000
75
+ max_frame_pointcloud_points: 200000
76
+ skyseg_path: skyseg.onnx
77
+
78
+ evaluation:
79
+ align_scale: true
80
+ depth_rel_delta_threshold: 1.25
81
+ point_f1_threshold: 0.25
82
+ point_eval_max_points: 100000
83
+ point_eval_voxel_size: null
84
+ point_eval_oversample_factor: 4
demo_gradio.py ADDED
@@ -0,0 +1,332 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import gradio as gr
4
+
5
+ from longstream.demo import BRANCH_OPTIONS, create_demo_session, load_metadata
6
+ from longstream.demo.backend import load_frame_previews
7
+ from longstream.demo.export import export_glb
8
+ from longstream.demo.viewer import build_interactive_figure
9
+
10
+
11
+ DEFAULT_KEYFRAME_STRIDE = 8
12
+ DEFAULT_REFRESH = 3
13
+ DEFAULT_WINDOW_SIZE = 48
14
+ DEFAULT_CHECKPOINT = os.getenv("LONGSTREAM_CHECKPOINT", "checkpoints/50_longstream.pt")
15
+
16
+
17
+ def _run_stable_demo(
18
+ image_dir,
19
+ uploaded_files,
20
+ uploaded_video,
21
+ checkpoint,
22
+ device,
23
+ mode,
24
+ streaming_mode,
25
+ refresh,
26
+ window_size,
27
+ compute_sky,
28
+ branch_label,
29
+ show_cameras,
30
+ mask_sky,
31
+ camera_scale,
32
+ point_size,
33
+ opacity,
34
+ preview_max_points,
35
+ glb_max_points,
36
+ ):
37
+ if not image_dir and not uploaded_files and not uploaded_video:
38
+ raise gr.Error("Provide an image folder, upload images, or upload a video.")
39
+ session_dir = create_demo_session(
40
+ image_dir=image_dir or "",
41
+ uploaded_files=uploaded_files,
42
+ uploaded_video=uploaded_video,
43
+ checkpoint=checkpoint,
44
+ device=device,
45
+ mode=mode,
46
+ streaming_mode=streaming_mode,
47
+ keyframe_stride=DEFAULT_KEYFRAME_STRIDE,
48
+ refresh=int(refresh),
49
+ window_size=int(window_size),
50
+ compute_sky=bool(compute_sky),
51
+ )
52
+ fig = build_interactive_figure(
53
+ session_dir=session_dir,
54
+ branch=branch_label,
55
+ display_mode="All Frames",
56
+ frame_index=0,
57
+ point_size=float(point_size),
58
+ opacity=float(opacity),
59
+ preview_max_points=int(preview_max_points),
60
+ show_cameras=bool(show_cameras),
61
+ camera_scale=float(camera_scale),
62
+ mask_sky=bool(mask_sky),
63
+ )
64
+ glb_path = export_glb(
65
+ session_dir=session_dir,
66
+ branch=branch_label,
67
+ display_mode="All Frames",
68
+ frame_index=0,
69
+ mask_sky=bool(mask_sky),
70
+ show_cameras=bool(show_cameras),
71
+ camera_scale=float(camera_scale),
72
+ max_points=int(glb_max_points),
73
+ )
74
+ rgb, depth, frame_label = load_frame_previews(session_dir, 0)
75
+ meta = load_metadata(session_dir)
76
+ slider = gr.update(
77
+ minimum=0,
78
+ maximum=max(meta["num_frames"] - 1, 0),
79
+ value=0,
80
+ step=1,
81
+ interactive=meta["num_frames"] > 1,
82
+ )
83
+ sky_msg = ""
84
+ if meta.get("has_sky_masks"):
85
+ removed = float(meta.get("sky_removed_ratio") or 0.0) * 100.0
86
+ sky_msg = f" | sky_removed={removed:.1f}%"
87
+ status = f"Ready: {meta['num_frames']} frames | branch={branch_label}{sky_msg}"
88
+ return (
89
+ fig,
90
+ glb_path,
91
+ session_dir,
92
+ rgb,
93
+ depth,
94
+ frame_label,
95
+ slider,
96
+ status,
97
+ )
98
+
99
+
100
+ def _update_stable_scene(
101
+ session_dir,
102
+ branch_label,
103
+ show_cameras,
104
+ mask_sky,
105
+ camera_scale,
106
+ point_size,
107
+ opacity,
108
+ preview_max_points,
109
+ glb_max_points,
110
+ ):
111
+ if not session_dir or not os.path.isdir(session_dir):
112
+ return None, None, "Run reconstruction first."
113
+ fig = build_interactive_figure(
114
+ session_dir=session_dir,
115
+ branch=branch_label,
116
+ display_mode="All Frames",
117
+ frame_index=0,
118
+ point_size=float(point_size),
119
+ opacity=float(opacity),
120
+ preview_max_points=int(preview_max_points),
121
+ show_cameras=bool(show_cameras),
122
+ camera_scale=float(camera_scale),
123
+ mask_sky=bool(mask_sky),
124
+ )
125
+ glb_path = export_glb(
126
+ session_dir=session_dir,
127
+ branch=branch_label,
128
+ display_mode="All Frames",
129
+ frame_index=0,
130
+ mask_sky=bool(mask_sky),
131
+ show_cameras=bool(show_cameras),
132
+ camera_scale=float(camera_scale),
133
+ max_points=int(glb_max_points),
134
+ )
135
+ meta = load_metadata(session_dir)
136
+ sky_msg = ""
137
+ if meta.get("has_sky_masks"):
138
+ removed = float(meta.get("sky_removed_ratio") or 0.0) * 100.0
139
+ sky_msg = f" | sky_removed={removed:.1f}%"
140
+ return fig, glb_path, f"Updated preview: {branch_label}{sky_msg}"
141
+
142
+
143
+ def _update_frame_preview(session_dir, frame_index):
144
+ if not session_dir or not os.path.isdir(session_dir):
145
+ return None, None, ""
146
+ rgb, depth, label = load_frame_previews(session_dir, int(frame_index))
147
+ return rgb, depth, label
148
+
149
+
150
+ def main():
151
+ with gr.Blocks(title="LongStream Demo") as demo:
152
+ session_dir = gr.Textbox(visible=False)
153
+
154
+ gr.Markdown("# LongStream Demo")
155
+
156
+ with gr.Row():
157
+ image_dir = gr.Textbox(
158
+ label="Image Folder", placeholder="/path/to/sequence"
159
+ )
160
+ uploaded_files = gr.File(
161
+ label="Upload Images", file_count="multiple", file_types=["image"]
162
+ )
163
+ uploaded_video = gr.File(
164
+ label="Upload Video", file_count="single", file_types=["video"]
165
+ )
166
+
167
+ with gr.Row():
168
+ checkpoint = gr.Textbox(label="Checkpoint", value=DEFAULT_CHECKPOINT)
169
+ device = gr.Dropdown(label="Device", choices=["cuda", "cpu"], value="cuda")
170
+
171
+ with gr.Accordion("Inference", open=False):
172
+ with gr.Row():
173
+ mode = gr.Dropdown(
174
+ label="Mode",
175
+ choices=["streaming_refresh", "batch_refresh"],
176
+ value="batch_refresh",
177
+ )
178
+ streaming_mode = gr.Dropdown(
179
+ label="Streaming Mode", choices=["causal", "window"], value="causal"
180
+ )
181
+ with gr.Row():
182
+ refresh = gr.Slider(
183
+ label="Refresh", minimum=2, maximum=9, step=1, value=DEFAULT_REFRESH
184
+ )
185
+ window_size = gr.Slider(
186
+ label="Window Size",
187
+ minimum=1,
188
+ maximum=64,
189
+ step=1,
190
+ value=DEFAULT_WINDOW_SIZE,
191
+ )
192
+ compute_sky = gr.Checkbox(label="Compute Sky Masks", value=True)
193
+
194
+ with gr.Accordion("GLB Settings", open=True):
195
+ with gr.Row():
196
+ branch_label = gr.Dropdown(
197
+ label="Point Cloud Branch",
198
+ choices=BRANCH_OPTIONS,
199
+ value="Point Head + Pose",
200
+ )
201
+ show_cameras = gr.Checkbox(label="Show Cameras", value=True)
202
+ mask_sky = gr.Checkbox(label="Mask Sky", value=True)
203
+ with gr.Row():
204
+ point_size = gr.Slider(
205
+ label="Point Size",
206
+ minimum=0.05,
207
+ maximum=2.0,
208
+ step=0.05,
209
+ value=0.3,
210
+ )
211
+ opacity = gr.Slider(
212
+ label="Opacity",
213
+ minimum=0.1,
214
+ maximum=1.0,
215
+ step=0.05,
216
+ value=0.75,
217
+ )
218
+ preview_max_points = gr.Slider(
219
+ label="Preview Max Points",
220
+ minimum=5000,
221
+ maximum=1000000,
222
+ step=10000,
223
+ value=100000,
224
+ )
225
+ with gr.Row():
226
+ camera_scale = gr.Slider(
227
+ label="Camera Scale",
228
+ minimum=0.001,
229
+ maximum=0.05,
230
+ step=0.001,
231
+ value=0.01,
232
+ )
233
+ glb_max_points = gr.Slider(
234
+ label="GLB Max Points",
235
+ minimum=20000,
236
+ maximum=1000000,
237
+ step=10000,
238
+ value=400000,
239
+ )
240
+
241
+ run_btn = gr.Button("Run Stable Demo", variant="primary")
242
+ status = gr.Markdown("Provide input images, then run reconstruction.")
243
+
244
+ plot = gr.Plot(label="Scene Preview")
245
+
246
+ glb_file = gr.File(label="Download GLB")
247
+
248
+ with gr.Row():
249
+ frame_slider = gr.Slider(
250
+ label="Preview Frame",
251
+ minimum=0,
252
+ maximum=0,
253
+ step=1,
254
+ value=0,
255
+ interactive=False,
256
+ )
257
+ frame_label = gr.Textbox(label="Frame")
258
+ with gr.Row():
259
+ rgb_preview = gr.Image(label="RGB", type="numpy")
260
+ depth_preview = gr.Image(label="Depth Plasma", type="numpy")
261
+
262
+ run_btn.click(
263
+ _run_stable_demo,
264
+ inputs=[
265
+ image_dir,
266
+ uploaded_files,
267
+ uploaded_video,
268
+ checkpoint,
269
+ device,
270
+ mode,
271
+ streaming_mode,
272
+ refresh,
273
+ window_size,
274
+ compute_sky,
275
+ branch_label,
276
+ show_cameras,
277
+ mask_sky,
278
+ camera_scale,
279
+ point_size,
280
+ opacity,
281
+ preview_max_points,
282
+ glb_max_points,
283
+ ],
284
+ outputs=[
285
+ plot,
286
+ glb_file,
287
+ session_dir,
288
+ rgb_preview,
289
+ depth_preview,
290
+ frame_label,
291
+ frame_slider,
292
+ status,
293
+ ],
294
+ )
295
+
296
+ for component in [
297
+ branch_label,
298
+ show_cameras,
299
+ mask_sky,
300
+ camera_scale,
301
+ point_size,
302
+ opacity,
303
+ preview_max_points,
304
+ glb_max_points,
305
+ ]:
306
+ component.change(
307
+ _update_stable_scene,
308
+ inputs=[
309
+ session_dir,
310
+ branch_label,
311
+ show_cameras,
312
+ mask_sky,
313
+ camera_scale,
314
+ point_size,
315
+ opacity,
316
+ preview_max_points,
317
+ glb_max_points,
318
+ ],
319
+ outputs=[plot, glb_file, status],
320
+ )
321
+
322
+ frame_slider.change(
323
+ _update_frame_preview,
324
+ inputs=[session_dir, frame_slider],
325
+ outputs=[rgb_preview, depth_preview, frame_label],
326
+ )
327
+
328
+ demo.launch()
329
+
330
+
331
+ if __name__ == "__main__":
332
+ main()
longstream/.DS_Store ADDED
Binary file (6.15 kB). View file
 
longstream/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ __all__ = []
longstream/core/__init__.py ADDED
File without changes
longstream/core/cli.py ADDED
@@ -0,0 +1,213 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import sys
4
+
5
+ import yaml
6
+
7
+
8
+ def default_config_path() -> str:
9
+ return os.path.join(
10
+ os.path.dirname(os.path.dirname(os.path.dirname(__file__))),
11
+ "configs",
12
+ "longstream_infer.yaml",
13
+ )
14
+
15
+
16
+ def add_runtime_arguments(parser):
17
+ parser.add_argument(
18
+ "--config",
19
+ default=default_config_path(),
20
+ help="Path to longstream config yaml.",
21
+ )
22
+ parser.add_argument(
23
+ "--dataset",
24
+ default=None,
25
+ help="Optional dataset hint. Generic format works without it.",
26
+ )
27
+ parser.add_argument("--img-path", default=None)
28
+ parser.add_argument(
29
+ "--seq-list",
30
+ default=None,
31
+ help="Comma-separated sequence names. Default: auto-detect all sequences.",
32
+ )
33
+ parser.add_argument("--format", default=None, help="generalizable")
34
+ parser.add_argument("--data-roots-file", default=None)
35
+ parser.add_argument("--camera", default=None)
36
+ parser.add_argument("--output-root", default=None)
37
+ parser.add_argument("--device", default=None)
38
+ parser.add_argument("--checkpoint", default=None)
39
+ parser.add_argument("--hf-repo", default=None)
40
+ parser.add_argument("--hf-file", default=None)
41
+ parser.add_argument(
42
+ "--mode", default=None, help="batch_refresh | streaming_refresh"
43
+ )
44
+ parser.add_argument("--streaming-mode", default=None, help="causal | window")
45
+ parser.add_argument("--window-size", type=int, default=None)
46
+ parser.add_argument("--keyframe-stride", type=int, default=None)
47
+ parser.add_argument(
48
+ "--refresh",
49
+ type=int,
50
+ default=None,
51
+ help="Number of keyframes per refresh span, inclusive of both ends and including the segment start keyframe.",
52
+ )
53
+ parser.add_argument(
54
+ "--keyframes-per-batch",
55
+ dest="keyframes_per_batch_legacy",
56
+ type=int,
57
+ default=None,
58
+ help=argparse.SUPPRESS,
59
+ )
60
+ parser.add_argument("--max-frames", type=int, default=None)
61
+ parser.add_argument("--depth-rel-delta-threshold", type=float, default=None)
62
+ parser.add_argument("--point-f1-threshold", type=float, default=None)
63
+ parser.add_argument("--eval-max-points", type=int, default=None)
64
+ parser.add_argument("--eval-voxel-size", type=float, default=None)
65
+ parser.add_argument("--max-full-pointcloud-points", type=int, default=None)
66
+ parser.add_argument("--max-frame-pointcloud-points", type=int, default=None)
67
+ parser.add_argument("--save-frame-points", action="store_true")
68
+ parser.add_argument("--no-save-frame-points", action="store_true")
69
+ parser.add_argument("--no-align-scale", action="store_true")
70
+ parser.add_argument("--mask-sky", action="store_true")
71
+ parser.add_argument("--no-mask-sky", action="store_true")
72
+ return parser
73
+
74
+
75
+ def parse_runtime_args(parser):
76
+ argv = [arg for arg in sys.argv[1:] if arg.strip()]
77
+ return parser.parse_args(argv)
78
+
79
+
80
+ def load_config_with_overrides(args):
81
+ with open(args.config, "r") as f:
82
+ cfg = yaml.safe_load(f) or {}
83
+ cfg.setdefault("model", {})
84
+
85
+ if args.device is not None:
86
+ cfg["device"] = args.device
87
+
88
+ if args.output_root is not None:
89
+ cfg.setdefault("output", {})
90
+ cfg["output"]["root"] = args.output_root
91
+
92
+ if args.dataset is not None:
93
+ cfg.setdefault("data", {})
94
+ cfg["data"]["dataset"] = args.dataset
95
+
96
+ if args.img_path is not None:
97
+ cfg.setdefault("data", {})
98
+ cfg["data"]["img_path"] = args.img_path
99
+
100
+ if args.seq_list is not None:
101
+ seqs = [s.strip() for s in args.seq_list.split(",") if s.strip()]
102
+ cfg.setdefault("data", {})
103
+ cfg["data"]["seq_list"] = seqs
104
+
105
+ if args.format is not None:
106
+ cfg.setdefault("data", {})
107
+ cfg["data"]["format"] = args.format
108
+
109
+ if args.data_roots_file is not None:
110
+ cfg.setdefault("data", {})
111
+ cfg["data"]["data_roots_file"] = args.data_roots_file
112
+
113
+ if args.camera is not None:
114
+ cfg.setdefault("data", {})
115
+ cfg["data"]["camera"] = args.camera
116
+
117
+ if args.max_frames is not None:
118
+ cfg.setdefault("data", {})
119
+ cfg["data"]["max_frames"] = args.max_frames
120
+
121
+ if args.checkpoint is not None:
122
+ cfg.setdefault("model", {})
123
+ cfg["model"]["checkpoint"] = args.checkpoint
124
+
125
+ if args.hf_repo is not None or args.hf_file is not None:
126
+ cfg.setdefault("model", {})
127
+ cfg["model"].setdefault("hf", {})
128
+ if args.hf_repo is not None:
129
+ cfg["model"]["hf"]["repo_id"] = args.hf_repo
130
+ if args.hf_file is not None:
131
+ cfg["model"]["hf"]["filename"] = args.hf_file
132
+ if cfg["model"].get("checkpoint") is None:
133
+ cfg["model"]["checkpoint"] = None
134
+
135
+ if args.mode is not None:
136
+ cfg.setdefault("inference", {})
137
+ cfg["inference"]["mode"] = args.mode
138
+
139
+ if args.streaming_mode is not None:
140
+ cfg.setdefault("inference", {})
141
+ cfg["inference"]["streaming_mode"] = args.streaming_mode
142
+
143
+ if args.window_size is not None:
144
+ cfg.setdefault("inference", {})
145
+ cfg["inference"]["window_size"] = args.window_size
146
+ cfg["model"].setdefault("longstream_cfg", {})
147
+ cfg["model"]["longstream_cfg"]["window_size"] = args.window_size
148
+
149
+ if args.keyframe_stride is not None:
150
+ cfg.setdefault("inference", {})
151
+ cfg["inference"]["keyframe_stride"] = args.keyframe_stride
152
+ cfg["model"].setdefault("longstream_cfg", {})
153
+ cfg["model"]["longstream_cfg"].setdefault("rel_pose_head_cfg", {})
154
+ cfg["model"]["longstream_cfg"]["rel_pose_head_cfg"][
155
+ "keyframe_stride"
156
+ ] = args.keyframe_stride
157
+
158
+ refresh = args.refresh
159
+ if refresh is None and args.keyframes_per_batch_legacy is not None:
160
+ refresh = args.keyframes_per_batch_legacy + 1
161
+ if refresh is not None:
162
+ cfg.setdefault("inference", {})
163
+ cfg["inference"]["refresh"] = refresh
164
+
165
+ if args.depth_rel_delta_threshold is not None:
166
+ cfg.setdefault("evaluation", {})
167
+ cfg["evaluation"]["depth_rel_delta_threshold"] = args.depth_rel_delta_threshold
168
+
169
+ if args.point_f1_threshold is not None:
170
+ cfg.setdefault("evaluation", {})
171
+ cfg["evaluation"]["point_f1_threshold"] = args.point_f1_threshold
172
+
173
+ if args.eval_max_points is not None:
174
+ cfg.setdefault("evaluation", {})
175
+ cfg["evaluation"]["point_eval_max_points"] = args.eval_max_points
176
+
177
+ if args.eval_voxel_size is not None:
178
+ cfg.setdefault("evaluation", {})
179
+ cfg["evaluation"]["point_eval_voxel_size"] = args.eval_voxel_size
180
+
181
+ if args.max_full_pointcloud_points is not None:
182
+ cfg.setdefault("output", {})
183
+ cfg["output"]["max_full_pointcloud_points"] = args.max_full_pointcloud_points
184
+
185
+ if args.max_frame_pointcloud_points is not None:
186
+ cfg.setdefault("output", {})
187
+ cfg["output"]["max_frame_pointcloud_points"] = args.max_frame_pointcloud_points
188
+
189
+ if args.save_frame_points:
190
+ cfg.setdefault("output", {})
191
+ cfg["output"]["save_frame_points"] = True
192
+ if args.no_save_frame_points:
193
+ cfg.setdefault("output", {})
194
+ cfg["output"]["save_frame_points"] = False
195
+
196
+ if args.no_align_scale:
197
+ cfg.setdefault("evaluation", {})
198
+ cfg["evaluation"]["align_scale"] = False
199
+
200
+ if args.mask_sky:
201
+ cfg.setdefault("output", {})
202
+ cfg["output"]["mask_sky"] = True
203
+ if args.no_mask_sky:
204
+ cfg.setdefault("output", {})
205
+ cfg["output"]["mask_sky"] = False
206
+
207
+ infer_cfg = cfg.setdefault("inference", {})
208
+ if "refresh" not in infer_cfg and "keyframes_per_batch" in infer_cfg:
209
+ infer_cfg["refresh"] = int(infer_cfg["keyframes_per_batch"]) + 1
210
+
211
+ cfg.setdefault("data", {})
212
+ cfg["data"]["format"] = "generalizable"
213
+ return cfg
longstream/core/infer.py ADDED
@@ -0,0 +1,451 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import yaml
4
+ import cv2
5
+ import numpy as np
6
+ import torch
7
+ from PIL import Image
8
+
9
+ from longstream.core.model import LongStreamModel
10
+ from longstream.data.dataloader import LongStreamDataLoader
11
+ from longstream.streaming.keyframe_selector import KeyframeSelector
12
+ from longstream.streaming.refresh import run_batch_refresh, run_streaming_refresh
13
+ from longstream.utils.vendor.models.components.utils.pose_enc import (
14
+ pose_encoding_to_extri_intri,
15
+ )
16
+ from longstream.utils.camera import compose_abs_from_rel
17
+ from longstream.utils.depth import colorize_depth, unproject_depth_to_points
18
+ from longstream.utils.sky_mask import compute_sky_mask
19
+ from longstream.io.save_points import save_pointcloud
20
+ from longstream.io.save_poses_txt import save_w2c_txt, save_intri_txt, save_rel_pose_txt
21
+ from longstream.io.save_images import save_image_sequence, save_video
22
+
23
+
24
+ def _to_uint8_rgb(images):
25
+ imgs = images.detach().cpu().numpy()
26
+ imgs = np.clip(imgs, 0.0, 1.0)
27
+ imgs = (imgs * 255.0).astype(np.uint8)
28
+ return imgs
29
+
30
+
31
+ def _ensure_dir(path):
32
+ os.makedirs(path, exist_ok=True)
33
+
34
+
35
+ def _apply_sky_mask(depth, mask):
36
+ if mask is None:
37
+ return depth
38
+ m = (mask > 0).astype(np.float32)
39
+ return depth * m
40
+
41
+
42
+ def _camera_points_to_world(points, extri):
43
+ pts = np.asarray(points, dtype=np.float64).reshape(-1, 3)
44
+ R = np.asarray(extri[:3, :3], dtype=np.float64)
45
+ t = np.asarray(extri[:3, 3], dtype=np.float64)
46
+ world = (R.T @ (pts.T - t[:, None])).T
47
+ return world.astype(np.float32, copy=False)
48
+
49
+
50
+ def _mask_points_and_colors(points, colors, mask):
51
+ pts = points.reshape(-1, 3)
52
+ cols = None if colors is None else colors.reshape(-1, 3)
53
+ if mask is None:
54
+ return pts, cols
55
+ valid = mask.reshape(-1) > 0
56
+ pts = pts[valid]
57
+ if cols is not None:
58
+ cols = cols[valid]
59
+ return pts, cols
60
+
61
+
62
+ def _resize_long_edge(arr, long_edge_size, interpolation):
63
+ h, w = arr.shape[:2]
64
+ scale = float(long_edge_size) / float(max(h, w))
65
+ new_w = int(round(w * scale))
66
+ new_h = int(round(h * scale))
67
+ return cv2.resize(arr, (new_w, new_h), interpolation=interpolation)
68
+
69
+
70
+ def _prepare_mask_for_model(
71
+ mask, size, crop, patch_size, target_shape, square_ok=False
72
+ ):
73
+ if mask is None:
74
+ return None
75
+ long_edge = (
76
+ round(size * max(mask.shape[1] / mask.shape[0], mask.shape[0] / mask.shape[1]))
77
+ if size == 224
78
+ else size
79
+ )
80
+ mask = _resize_long_edge(mask, long_edge, cv2.INTER_NEAREST)
81
+
82
+ h, w = mask.shape[:2]
83
+ cx, cy = w // 2, h // 2
84
+ if size == 224:
85
+ half = min(cx, cy)
86
+ target_w = 2 * half
87
+ target_h = 2 * half
88
+ if crop:
89
+ mask = mask[cy - half : cy + half, cx - half : cx + half]
90
+ else:
91
+ mask = cv2.resize(
92
+ mask, (target_w, target_h), interpolation=cv2.INTER_NEAREST
93
+ )
94
+ else:
95
+ halfw = ((2 * cx) // patch_size) * (patch_size // 2)
96
+ halfh = ((2 * cy) // patch_size) * (patch_size // 2)
97
+ if not square_ok and w == h:
98
+ halfh = int(3 * halfw / 4)
99
+ target_w = 2 * halfw
100
+ target_h = 2 * halfh
101
+ if crop:
102
+ mask = mask[cy - halfh : cy + halfh, cx - halfw : cx + halfw]
103
+ else:
104
+ mask = cv2.resize(
105
+ mask, (target_w, target_h), interpolation=cv2.INTER_NEAREST
106
+ )
107
+
108
+ if mask.shape[:2] != tuple(target_shape):
109
+ mask = cv2.resize(
110
+ mask, (target_shape[1], target_shape[0]), interpolation=cv2.INTER_NEAREST
111
+ )
112
+ return mask
113
+
114
+
115
+ def _save_full_pointcloud(path, point_chunks, color_chunks, max_points=None, seed=0):
116
+ if not point_chunks:
117
+ return
118
+ points = np.concatenate(point_chunks, axis=0)
119
+ colors = None
120
+ if color_chunks and len(color_chunks) == len(point_chunks):
121
+ colors = np.concatenate(color_chunks, axis=0)
122
+ if max_points is not None and len(points) > max_points:
123
+ rng = np.random.default_rng(seed)
124
+ keep = rng.choice(len(points), size=max_points, replace=False)
125
+ points = points[keep]
126
+ if colors is not None:
127
+ colors = colors[keep]
128
+ np.save(os.path.splitext(path)[0] + ".npy", points.astype(np.float32, copy=False))
129
+ save_pointcloud(path, points, colors=colors, max_points=None, seed=seed)
130
+
131
+
132
+ def run_inference_cfg(cfg: dict):
133
+ device = cfg.get("device", "cuda" if torch.cuda.is_available() else "cpu")
134
+ device_type = torch.device(device).type
135
+ model_cfg = cfg.get("model", {})
136
+ data_cfg = cfg.get("data", {})
137
+ infer_cfg = cfg.get("inference", {})
138
+ output_cfg = cfg.get("output", {})
139
+
140
+ print(f"[longstream] device={device}", flush=True)
141
+ model = LongStreamModel(model_cfg).to(device)
142
+ model.eval()
143
+ print("[longstream] model ready", flush=True)
144
+
145
+ loader = LongStreamDataLoader(data_cfg)
146
+
147
+ keyframe_stride = int(infer_cfg.get("keyframe_stride", 8))
148
+ keyframe_mode = infer_cfg.get("keyframe_mode", "fixed")
149
+ refresh = int(
150
+ infer_cfg.get("refresh", int(infer_cfg.get("keyframes_per_batch", 3)) + 1)
151
+ )
152
+ if refresh < 2:
153
+ raise ValueError(
154
+ "refresh must be >= 2 because it counts both keyframe endpoints"
155
+ )
156
+ mode = infer_cfg.get("mode", "streaming_refresh")
157
+ if mode == "streaming":
158
+ mode = "streaming_refresh"
159
+ streaming_mode = infer_cfg.get("streaming_mode", "causal")
160
+ window_size = int(infer_cfg.get("window_size", 5))
161
+
162
+ selector = KeyframeSelector(
163
+ min_interval=keyframe_stride,
164
+ max_interval=keyframe_stride,
165
+ force_first=True,
166
+ mode="random" if keyframe_mode == "random" else "fixed",
167
+ )
168
+
169
+ out_root = output_cfg.get("root", "outputs")
170
+ _ensure_dir(out_root)
171
+ save_videos = bool(output_cfg.get("save_videos", True))
172
+ save_points = bool(output_cfg.get("save_points", True))
173
+ save_frame_points = bool(output_cfg.get("save_frame_points", True))
174
+ save_depth = bool(output_cfg.get("save_depth", True))
175
+ save_images = bool(output_cfg.get("save_images", True))
176
+ mask_sky = bool(output_cfg.get("mask_sky", True))
177
+ max_full_pointcloud_points = output_cfg.get("max_full_pointcloud_points", None)
178
+ if max_full_pointcloud_points is not None:
179
+ max_full_pointcloud_points = int(max_full_pointcloud_points)
180
+ max_frame_pointcloud_points = output_cfg.get("max_frame_pointcloud_points", None)
181
+ if max_frame_pointcloud_points is not None:
182
+ max_frame_pointcloud_points = int(max_frame_pointcloud_points)
183
+ skyseg_path = output_cfg.get(
184
+ "skyseg_path",
185
+ os.path.join(os.path.dirname(__file__), "..", "..", "skyseg.onnx"),
186
+ )
187
+
188
+ with torch.no_grad():
189
+ for seq in loader:
190
+ images = seq.images
191
+ B, S, C, H, W = images.shape
192
+ print(
193
+ f"[longstream] sequence {seq.name}: inference start ({S} frames)",
194
+ flush=True,
195
+ )
196
+
197
+ is_keyframe, keyframe_indices = selector.select_keyframes(
198
+ S, B, images.device
199
+ )
200
+
201
+ rel_pose_cfg = infer_cfg.get("rel_pose_head_cfg", {"num_iterations": 4})
202
+
203
+ if mode == "batch_refresh":
204
+ outputs = run_batch_refresh(
205
+ model,
206
+ images,
207
+ is_keyframe,
208
+ keyframe_indices,
209
+ streaming_mode,
210
+ keyframe_stride,
211
+ refresh,
212
+ rel_pose_cfg,
213
+ )
214
+ elif mode == "streaming_refresh":
215
+ outputs = run_streaming_refresh(
216
+ model,
217
+ images,
218
+ is_keyframe,
219
+ keyframe_indices,
220
+ streaming_mode,
221
+ window_size,
222
+ refresh,
223
+ rel_pose_cfg,
224
+ )
225
+ else:
226
+ raise ValueError(f"Unsupported inference mode: {mode}")
227
+ print(f"[longstream] sequence {seq.name}: inference done", flush=True)
228
+ if device_type == "cuda":
229
+ torch.cuda.empty_cache()
230
+
231
+ seq_dir = os.path.join(out_root, seq.name)
232
+ _ensure_dir(seq_dir)
233
+
234
+ frame_ids = list(range(S))
235
+ rgb = _to_uint8_rgb(images[0].permute(0, 2, 3, 1))
236
+
237
+ if "rel_pose_enc" in outputs:
238
+ rel_pose_enc = outputs["rel_pose_enc"][0]
239
+ abs_pose_enc = compose_abs_from_rel(rel_pose_enc, keyframe_indices[0])
240
+ extri, intri = pose_encoding_to_extri_intri(
241
+ abs_pose_enc[None], image_size_hw=(H, W)
242
+ )
243
+ extri_np = extri[0].detach().cpu().numpy()
244
+ intri_np = intri[0].detach().cpu().numpy()
245
+
246
+ pose_dir = os.path.join(seq_dir, "poses")
247
+ _ensure_dir(pose_dir)
248
+ save_w2c_txt(
249
+ os.path.join(pose_dir, "abs_pose.txt"), extri_np, frame_ids
250
+ )
251
+ save_intri_txt(os.path.join(pose_dir, "intri.txt"), intri_np, frame_ids)
252
+ save_rel_pose_txt(
253
+ os.path.join(pose_dir, "rel_pose.txt"), rel_pose_enc, frame_ids
254
+ )
255
+ elif "pose_enc" in outputs:
256
+ pose_enc = outputs["pose_enc"][0]
257
+ extri, intri = pose_encoding_to_extri_intri(
258
+ pose_enc[None], image_size_hw=(H, W)
259
+ )
260
+ extri_np = extri[0].detach().cpu().numpy()
261
+ intri_np = intri[0].detach().cpu().numpy()
262
+
263
+ pose_dir = os.path.join(seq_dir, "poses")
264
+ _ensure_dir(pose_dir)
265
+ save_w2c_txt(
266
+ os.path.join(pose_dir, "abs_pose.txt"), extri_np, frame_ids
267
+ )
268
+ save_intri_txt(os.path.join(pose_dir, "intri.txt"), intri_np, frame_ids)
269
+
270
+ if save_images:
271
+ print(f"[longstream] sequence {seq.name}: saving rgb", flush=True)
272
+ rgb_dir = os.path.join(seq_dir, "images", "rgb")
273
+ save_image_sequence(rgb_dir, list(rgb))
274
+ if save_videos:
275
+ save_video(
276
+ os.path.join(seq_dir, "images", "rgb.mp4"),
277
+ os.path.join(rgb_dir, "frame_*.png"),
278
+ )
279
+
280
+ sky_masks = None
281
+ if mask_sky:
282
+ raw_sky_masks = compute_sky_mask(
283
+ seq.image_paths, skyseg_path, os.path.join(seq_dir, "sky_masks")
284
+ )
285
+ if raw_sky_masks is not None:
286
+ sky_masks = [
287
+ _prepare_mask_for_model(
288
+ mask,
289
+ size=int(data_cfg.get("size", 518)),
290
+ crop=bool(data_cfg.get("crop", False)),
291
+ patch_size=int(data_cfg.get("patch_size", 14)),
292
+ target_shape=(H, W),
293
+ )
294
+ for mask in raw_sky_masks
295
+ ]
296
+
297
+ if save_depth and "depth" in outputs:
298
+ print(f"[longstream] sequence {seq.name}: saving depth", flush=True)
299
+ depth = outputs["depth"][0, :, :, :, 0].detach().cpu().numpy()
300
+ depth_dir = os.path.join(seq_dir, "depth", "dpt")
301
+ _ensure_dir(depth_dir)
302
+ color_dir = os.path.join(seq_dir, "depth", "dpt_plasma")
303
+ _ensure_dir(color_dir)
304
+
305
+ color_frames = []
306
+ for i in range(S):
307
+ d = depth[i]
308
+ if sky_masks is not None and sky_masks[i] is not None:
309
+ d = _apply_sky_mask(d, sky_masks[i])
310
+ np.save(os.path.join(depth_dir, f"frame_{i:06d}.npy"), d)
311
+ colored = colorize_depth(d, cmap="plasma")
312
+ Image.fromarray(colored).save(
313
+ os.path.join(color_dir, f"frame_{i:06d}.png")
314
+ )
315
+ color_frames.append(colored)
316
+ if save_videos:
317
+ save_video(
318
+ os.path.join(seq_dir, "depth", "dpt_plasma.mp4"),
319
+ os.path.join(color_dir, "frame_*.png"),
320
+ )
321
+
322
+ if save_points:
323
+ print(
324
+ f"[longstream] sequence {seq.name}: saving point clouds", flush=True
325
+ )
326
+ if "world_points" in outputs:
327
+ if "rel_pose_enc" in outputs:
328
+ abs_pose_enc = compose_abs_from_rel(
329
+ outputs["rel_pose_enc"][0], keyframe_indices[0]
330
+ )
331
+ extri, intri = pose_encoding_to_extri_intri(
332
+ abs_pose_enc[None], image_size_hw=(H, W)
333
+ )
334
+ else:
335
+ extri, intri = pose_encoding_to_extri_intri(
336
+ outputs["pose_enc"][0][None], image_size_hw=(H, W)
337
+ )
338
+ extri = extri[0]
339
+ intri = intri[0]
340
+
341
+ pts_dir = os.path.join(seq_dir, "points", "point_head")
342
+ _ensure_dir(pts_dir)
343
+ pts = outputs["world_points"][0].detach().cpu().numpy()
344
+ full_pts = []
345
+ full_cols = []
346
+ for i in range(S):
347
+ pts_world = _camera_points_to_world(
348
+ pts[i], extri[i].detach().cpu().numpy()
349
+ )
350
+ pts_world = pts_world.reshape(pts[i].shape)
351
+ pts_i, cols_i = _mask_points_and_colors(
352
+ pts_world,
353
+ rgb[i],
354
+ None if sky_masks is None else sky_masks[i],
355
+ )
356
+ if save_frame_points:
357
+ save_pointcloud(
358
+ os.path.join(pts_dir, f"frame_{i:06d}.ply"),
359
+ pts_i,
360
+ colors=cols_i,
361
+ max_points=max_frame_pointcloud_points,
362
+ seed=i,
363
+ )
364
+ if len(pts_i):
365
+ full_pts.append(pts_i)
366
+ full_cols.append(cols_i)
367
+ _save_full_pointcloud(
368
+ os.path.join(seq_dir, "points", "point_head_full.ply"),
369
+ full_pts,
370
+ full_cols,
371
+ max_points=max_full_pointcloud_points,
372
+ seed=0,
373
+ )
374
+
375
+ if "depth" in outputs and (
376
+ "rel_pose_enc" in outputs or "pose_enc" in outputs
377
+ ):
378
+ depth = outputs["depth"][0, :, :, :, 0]
379
+ if "rel_pose_enc" in outputs:
380
+ abs_pose_enc = compose_abs_from_rel(
381
+ outputs["rel_pose_enc"][0], keyframe_indices[0]
382
+ )
383
+ extri, intri = pose_encoding_to_extri_intri(
384
+ abs_pose_enc[None], image_size_hw=(H, W)
385
+ )
386
+ else:
387
+ extri, intri = pose_encoding_to_extri_intri(
388
+ outputs["pose_enc"][0][None], image_size_hw=(H, W)
389
+ )
390
+
391
+ extri = extri[0]
392
+ intri = intri[0]
393
+ dpt_pts_dir = os.path.join(seq_dir, "points", "dpt_unproj")
394
+ _ensure_dir(dpt_pts_dir)
395
+ full_pts = []
396
+ full_cols = []
397
+
398
+ for i in range(S):
399
+ d = depth[i]
400
+ pts_cam = unproject_depth_to_points(d[None], intri[i : i + 1])[
401
+ 0
402
+ ]
403
+ R = extri[i, :3, :3]
404
+ t = extri[i, :3, 3]
405
+ pts_world = (
406
+ R.t() @ (pts_cam.reshape(-1, 3).t() - t[:, None])
407
+ ).t()
408
+ pts_world = pts_world.cpu().numpy().reshape(-1, 3)
409
+ pts_i, cols_i = _mask_points_and_colors(
410
+ pts_world,
411
+ rgb[i],
412
+ None if sky_masks is None else sky_masks[i],
413
+ )
414
+ if save_frame_points:
415
+ save_pointcloud(
416
+ os.path.join(dpt_pts_dir, f"frame_{i:06d}.ply"),
417
+ pts_i,
418
+ colors=cols_i,
419
+ max_points=max_frame_pointcloud_points,
420
+ seed=i,
421
+ )
422
+ if len(pts_i):
423
+ full_pts.append(pts_i)
424
+ full_cols.append(cols_i)
425
+ _save_full_pointcloud(
426
+ os.path.join(seq_dir, "points", "dpt_unproj_full.ply"),
427
+ full_pts,
428
+ full_cols,
429
+ max_points=max_full_pointcloud_points,
430
+ seed=1,
431
+ )
432
+ del outputs
433
+ if device_type == "cuda":
434
+ torch.cuda.empty_cache()
435
+
436
+
437
+ def run_inference(config_path: str):
438
+ with open(config_path, "r") as f:
439
+ cfg = yaml.safe_load(f)
440
+ run_inference_cfg(cfg)
441
+
442
+
443
+ def main():
444
+ parser = argparse.ArgumentParser()
445
+ parser.add_argument("--config", required=True)
446
+ args = parser.parse_args()
447
+ run_inference(args.config)
448
+
449
+
450
+ if __name__ == "__main__":
451
+ main()
longstream/core/model.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ from typing import Dict, Any
4
+
5
+ from longstream.models.longstream import LongStream
6
+ from longstream.utils.hub import resolve_checkpoint_path
7
+
8
+
9
+ class LongStreamModel(torch.nn.Module):
10
+ def __init__(self, cfg: Dict[str, Any] | None):
11
+ super().__init__()
12
+ cfg = cfg or {}
13
+
14
+ ckpt_path = resolve_checkpoint_path(
15
+ cfg.get("checkpoint", None), cfg.get("hf", None)
16
+ )
17
+
18
+ stream_cfg = dict(cfg.get("longstream_cfg", {}) or {})
19
+ rel_pose_cfg = stream_cfg.pop(
20
+ "rel_pose_head_cfg", cfg.get("rel_pose_head_cfg", None)
21
+ )
22
+ use_rel_pose_head = bool(stream_cfg.pop("use_rel_pose_head", False))
23
+ if use_rel_pose_head and rel_pose_cfg is not None:
24
+ stream_cfg["rel_pose_head_cfg"] = rel_pose_cfg
25
+ self.longstream = LongStream(**stream_cfg)
26
+
27
+ if ckpt_path:
28
+ self.load_checkpoint(ckpt_path, strict=bool(cfg.get("strict_load", True)))
29
+
30
+ def load_checkpoint(self, ckpt_path: str, strict: bool = True):
31
+ if not os.path.exists(ckpt_path):
32
+ raise FileNotFoundError(ckpt_path)
33
+ ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=False)
34
+ if isinstance(ckpt, dict):
35
+ if "model" in ckpt and isinstance(ckpt["model"], dict):
36
+ state = ckpt["model"]
37
+ elif "state_dict" in ckpt and isinstance(ckpt["state_dict"], dict):
38
+ state = ckpt["state_dict"]
39
+ else:
40
+ state = ckpt
41
+ else:
42
+ raise TypeError("Unsupported checkpoint format")
43
+
44
+ if state:
45
+ first_key = next(iter(state.keys()))
46
+ if first_key.startswith("sampler.longstream."):
47
+ state = {k.replace("sampler.", "", 1): v for k, v in state.items()}
48
+
49
+ missing, unexpected = self.load_state_dict(state, strict=False)
50
+ if missing or unexpected:
51
+ msg = f"checkpoint mismatch: missing={len(missing)} unexpected={len(unexpected)}"
52
+ if strict:
53
+ raise RuntimeError(msg)
54
+ print(msg)
55
+
56
+ def forward(self, *args, **kwargs):
57
+ return self.longstream(*args, **kwargs)
58
+
59
+ @property
60
+ def aggregator(self):
61
+ return self.longstream.aggregator
62
+
63
+ @property
64
+ def camera_head(self):
65
+ return getattr(self.longstream, "camera_head", None)
66
+
67
+ @property
68
+ def rel_pose_head(self):
69
+ return getattr(self.longstream, "rel_pose_head", None)
longstream/data/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .dataloader import LongStreamDataLoader, LongStreamSequence, LongStreamSequenceInfo
2
+
3
+ __all__ = ["LongStreamDataLoader", "LongStreamSequence", "LongStreamSequenceInfo"]
longstream/data/dataloader.py ADDED
@@ -0,0 +1,422 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import glob
3
+ from dataclasses import dataclass
4
+ from typing import List, Dict, Any, Iterator, Optional, Tuple
5
+
6
+ import torch
7
+
8
+ from longstream.utils.vendor.dust3r.utils.image import load_images_for_eval
9
+
10
+ dataset_metadata: Dict[str, Dict[str, Any]] = {
11
+ "davis": {
12
+ "img_path": "data/davis/DAVIS/JPEGImages/480p",
13
+ "mask_path": "data/davis/DAVIS/masked_images/480p",
14
+ "dir_path_func": lambda img_path, seq: os.path.join(img_path, seq),
15
+ "gt_traj_func": lambda img_path, anno_path, seq: None,
16
+ "traj_format": None,
17
+ "seq_list": None,
18
+ "full_seq": True,
19
+ "mask_path_seq_func": lambda mask_path, seq: os.path.join(mask_path, seq),
20
+ "skip_condition": None,
21
+ "process_func": None,
22
+ },
23
+ "kitti": {
24
+ "img_path": "data/kitti/sequences",
25
+ "anno_path": "data/kitti/poses",
26
+ "mask_path": None,
27
+ "dir_path_func": lambda img_path, seq: os.path.join(img_path, seq, "image_2"),
28
+ "gt_traj_func": lambda img_path, anno_path, seq: os.path.join(
29
+ anno_path, f"{seq}.txt"
30
+ )
31
+ if os.path.exists(os.path.join(anno_path, f"{seq}.txt"))
32
+ else None,
33
+ "traj_format": "kitti",
34
+ "seq_list": ["00", "01", "02", "03", "04", "05", "06", "07", "08", "09", "10"],
35
+ "full_seq": True,
36
+ "mask_path_seq_func": lambda mask_path, seq: None,
37
+ "skip_condition": None,
38
+ "process_func": None,
39
+ },
40
+ "bonn": {
41
+ "img_path": "data/bonn/rgbd_bonn_dataset",
42
+ "mask_path": None,
43
+ "dir_path_func": lambda img_path, seq: os.path.join(
44
+ img_path, f"rgbd_bonn_{seq}", "rgb_110"
45
+ ),
46
+ "gt_traj_func": lambda img_path, anno_path, seq: os.path.join(
47
+ img_path, f"rgbd_bonn_{seq}", "groundtruth_110.txt"
48
+ ),
49
+ "traj_format": "tum",
50
+ "seq_list": ["balloon2", "crowd2", "crowd3", "person_tracking2", "synchronous"],
51
+ "full_seq": False,
52
+ "mask_path_seq_func": lambda mask_path, seq: None,
53
+ "skip_condition": None,
54
+ "process_func": None,
55
+ },
56
+ "nyu": {
57
+ "img_path": "data/nyu-v2/val/nyu_images",
58
+ "mask_path": None,
59
+ "process_func": None,
60
+ },
61
+ "scannet": {
62
+ "img_path": "data/scannetv2",
63
+ "mask_path": None,
64
+ "dir_path_func": lambda img_path, seq: os.path.join(img_path, seq, "color_90"),
65
+ "gt_traj_func": lambda img_path, anno_path, seq: os.path.join(
66
+ img_path, seq, "pose_90.txt"
67
+ ),
68
+ "traj_format": "replica",
69
+ "seq_list": None,
70
+ "full_seq": True,
71
+ "mask_path_seq_func": lambda mask_path, seq: None,
72
+ "skip_condition": None,
73
+ "process_func": None,
74
+ },
75
+ "tum": {
76
+ "img_path": "data/tum",
77
+ "mask_path": None,
78
+ "dir_path_func": lambda img_path, seq: os.path.join(img_path, seq, "rgb_90"),
79
+ "gt_traj_func": lambda img_path, anno_path, seq: os.path.join(
80
+ img_path, seq, "groundtruth_90.txt"
81
+ ),
82
+ "traj_format": "tum",
83
+ "seq_list": None,
84
+ "full_seq": True,
85
+ "mask_path_seq_func": lambda mask_path, seq: None,
86
+ "skip_condition": None,
87
+ "process_func": None,
88
+ },
89
+ "sintel": {
90
+ "img_path": "data/sintel/training/final",
91
+ "anno_path": "data/sintel/training/camdata_left",
92
+ "mask_path": None,
93
+ "dir_path_func": lambda img_path, seq: os.path.join(img_path, seq),
94
+ "gt_traj_func": lambda img_path, anno_path, seq: os.path.join(anno_path, seq),
95
+ "traj_format": None,
96
+ "seq_list": [
97
+ "alley_2",
98
+ "ambush_4",
99
+ "ambush_5",
100
+ "ambush_6",
101
+ "cave_2",
102
+ "cave_4",
103
+ "market_2",
104
+ "market_5",
105
+ "market_6",
106
+ "shaman_3",
107
+ "sleeping_1",
108
+ "sleeping_2",
109
+ "temple_2",
110
+ "temple_3",
111
+ ],
112
+ "full_seq": False,
113
+ "mask_path_seq_func": lambda mask_path, seq: None,
114
+ "skip_condition": None,
115
+ "process_func": None,
116
+ },
117
+ "waymo": {
118
+ "img_path": "/horizon-bucket/saturn_v_4dlabel/004_vision/01_users/tao02.xie/datasets/scatt3r_evaluation/waymo_open_dataset_v1_4_3",
119
+ "anno_path": None,
120
+ "mask_path": None,
121
+ "dir_path_func": lambda img_path, seq: os.path.join(
122
+ img_path,
123
+ seq.split("_cam")[0] if "_cam" in seq else seq,
124
+ "images",
125
+ seq.split("_cam")[1] if "_cam" in seq else "00",
126
+ ),
127
+ "gt_traj_func": lambda img_path, anno_path, seq: os.path.join(
128
+ img_path,
129
+ seq.split("_cam")[0] if "_cam" in seq else seq,
130
+ "cameras",
131
+ seq.split("_cam")[1] if "_cam" in seq else "00",
132
+ "extri.yml",
133
+ ),
134
+ "traj_format": "waymo",
135
+ "seq_list": None,
136
+ "full_seq": True,
137
+ "mask_path_seq_func": lambda mask_path, seq: None,
138
+ "skip_condition": None,
139
+ "process_func": None,
140
+ },
141
+ }
142
+
143
+
144
+ @dataclass
145
+ class LongStreamSequenceInfo:
146
+ name: str
147
+ scene_root: str
148
+ image_dir: str
149
+ image_paths: List[str]
150
+ camera: Optional[str]
151
+
152
+
153
+ class LongStreamSequence:
154
+ def __init__(
155
+ self,
156
+ name: str,
157
+ images: torch.Tensor,
158
+ image_paths: List[str],
159
+ scene_root: Optional[str] = None,
160
+ image_dir: Optional[str] = None,
161
+ camera: Optional[str] = None,
162
+ ):
163
+ self.name = name
164
+ self.images = images
165
+ self.image_paths = image_paths
166
+ self.scene_root = scene_root
167
+ self.image_dir = image_dir
168
+ self.camera = camera
169
+
170
+
171
+ def _read_list_file(path: str) -> List[str]:
172
+ with open(path, "r") as f:
173
+ lines = []
174
+ for line in f.readlines():
175
+ line = line.strip()
176
+ if not line:
177
+ continue
178
+ if line.startswith("#"):
179
+ continue
180
+ lines.append(line)
181
+ return lines
182
+
183
+
184
+ def _is_generalizable_scene_root(path: str) -> bool:
185
+ return os.path.isdir(os.path.join(path, "images"))
186
+
187
+
188
+ def _direct_image_files(dir_path: str) -> List[str]:
189
+ filelist = sorted(glob.glob(os.path.join(dir_path, "*.png")))
190
+ if not filelist:
191
+ filelist = sorted(glob.glob(os.path.join(dir_path, "*.jpg")))
192
+ if not filelist:
193
+ filelist = sorted(glob.glob(os.path.join(dir_path, "*.jpeg")))
194
+ return filelist
195
+
196
+
197
+ class LongStreamDataLoader:
198
+ def __init__(self, cfg: Dict[str, Any]):
199
+ self.cfg = cfg
200
+ self.dataset = cfg.get("dataset", None)
201
+ meta = dataset_metadata.get(self.dataset, {})
202
+ self.img_path = cfg.get("img_path", meta.get("img_path"))
203
+ self.mask_path = cfg.get("mask_path", meta.get("mask_path"))
204
+ self.dir_path_func = meta.get("dir_path_func", lambda p, s: os.path.join(p, s))
205
+ self.mask_path_seq_func = meta.get("mask_path_seq_func", lambda p, s: None)
206
+ self.full_seq = bool(cfg.get("full_seq", meta.get("full_seq", True)))
207
+ self.seq_list = cfg.get("seq_list", None)
208
+ self.stride = int(cfg.get("stride", 1))
209
+ self.max_frames = cfg.get("max_frames", None)
210
+ self.size = int(cfg.get("size", 518))
211
+ self.crop = bool(cfg.get("crop", False))
212
+ self.patch_size = int(cfg.get("patch_size", 14))
213
+ self.format = cfg.get("format", "auto")
214
+ self.data_roots_file = cfg.get("data_roots_file", None)
215
+ self.split = cfg.get("split", None)
216
+ self.camera = cfg.get("camera", None)
217
+
218
+ def _infer_format(self) -> str:
219
+ if self.format in ["relpose", "generalizable"]:
220
+ return self.format
221
+ if self.img_path is None:
222
+ return "relpose"
223
+ if _is_generalizable_scene_root(self.img_path):
224
+ return "generalizable"
225
+ default_list = self.data_roots_file or "data_roots.txt"
226
+ if os.path.exists(os.path.join(self.img_path, default_list)):
227
+ return "generalizable"
228
+ return "relpose"
229
+
230
+ def _resolve_seq_list_generalizable(self) -> List[str]:
231
+ if self.seq_list is not None:
232
+ return list(self.seq_list)
233
+ if self.img_path is None or not os.path.isdir(self.img_path):
234
+ return []
235
+
236
+ if _is_generalizable_scene_root(self.img_path):
237
+ return [self.img_path]
238
+
239
+ candidates = []
240
+ if isinstance(self.data_roots_file, str) and self.data_roots_file:
241
+ candidates.append(self.data_roots_file)
242
+ if isinstance(self.split, str) and self.split:
243
+ split_name = self.split.lower()
244
+ if split_name in ["val", "valid", "validate"]:
245
+ split_name = "validate"
246
+ candidates.append(f"{split_name}_data_roots.txt")
247
+ candidates.append("data_roots.txt")
248
+ candidates.append("train_data_roots.txt")
249
+ candidates.append("validate_data_roots.txt")
250
+
251
+ for fname in candidates:
252
+ path = os.path.join(self.img_path, fname)
253
+ if os.path.exists(path):
254
+ return _read_list_file(path)
255
+
256
+ img_dirs = sorted(
257
+ glob.glob(os.path.join(self.img_path, "**", "images"), recursive=True)
258
+ )
259
+ scene_roots = [os.path.dirname(p) for p in img_dirs]
260
+
261
+ rels = []
262
+ for p in scene_roots:
263
+ try:
264
+ rels.append(os.path.relpath(p, self.img_path))
265
+ except ValueError:
266
+ rels.append(p)
267
+ return sorted(set(rels))
268
+
269
+ def _resolve_seq_list_relpose(self) -> List[str]:
270
+ if self.seq_list is not None:
271
+ return list(self.seq_list)
272
+ meta = dataset_metadata.get(self.dataset, {})
273
+ if self.full_seq:
274
+ if self.img_path is None or not os.path.isdir(self.img_path):
275
+ return []
276
+ seqs = [
277
+ s
278
+ for s in os.listdir(self.img_path)
279
+ if os.path.isdir(os.path.join(self.img_path, s))
280
+ ]
281
+ return sorted(seqs)
282
+ seqs = meta.get("seq_list", []) or []
283
+ return list(seqs)
284
+
285
+ def _resolve_seq_list(self) -> List[str]:
286
+ fmt = self._infer_format()
287
+ if fmt == "generalizable":
288
+ return self._resolve_seq_list_generalizable()
289
+ return self._resolve_seq_list_relpose()
290
+
291
+ def _resolve_scene_root(self, seq_entry: str) -> Tuple[str, str]:
292
+ if os.path.isabs(seq_entry) or os.path.sep in seq_entry:
293
+ scene_root = seq_entry
294
+ name = os.path.basename(os.path.normpath(seq_entry))
295
+ else:
296
+ scene_root = os.path.join(self.img_path, seq_entry)
297
+ name = seq_entry
298
+ return name, scene_root
299
+
300
+ def _resolve_image_dir_generalizable(self, scene_root: str) -> Optional[str]:
301
+ images_root = os.path.join(scene_root, "images")
302
+ if not os.path.isdir(images_root):
303
+ return None
304
+
305
+ if isinstance(self.camera, str) and self.camera:
306
+ cam_dir = os.path.join(images_root, self.camera)
307
+ if os.path.isdir(cam_dir):
308
+ return cam_dir
309
+
310
+ if _direct_image_files(images_root):
311
+ return images_root
312
+
313
+ cams = [
314
+ d
315
+ for d in os.listdir(images_root)
316
+ if os.path.isdir(os.path.join(images_root, d))
317
+ ]
318
+ if not cams:
319
+ return None
320
+ cams = sorted(cams)
321
+
322
+ frame_dirs = []
323
+ for name in cams:
324
+ child_dir = os.path.join(images_root, name)
325
+ child_images = _direct_image_files(child_dir)
326
+ if child_images:
327
+ frame_dirs.append((name, len(child_images)))
328
+
329
+ if (
330
+ len(cams) > 10
331
+ and len(frame_dirs) == len(cams)
332
+ and max(count for _, count in frame_dirs) == 1
333
+ ):
334
+ return images_root
335
+
336
+ return os.path.join(images_root, cams[0])
337
+
338
+ def _camera_from_image_dir(self, image_dir: str) -> Optional[str]:
339
+ parent = os.path.basename(os.path.dirname(image_dir))
340
+ if parent != "images":
341
+ return None
342
+ return os.path.basename(image_dir)
343
+
344
+ def _collect_filelist(self, dir_path: str) -> List[str]:
345
+ filelist = _direct_image_files(dir_path)
346
+ if not filelist:
347
+ nested = []
348
+ child_dirs = sorted(
349
+ d for d in glob.glob(os.path.join(dir_path, "*")) if os.path.isdir(d)
350
+ )
351
+ for child_dir in child_dirs:
352
+ child_images = _direct_image_files(child_dir)
353
+ if child_images:
354
+ nested.append(child_images[0])
355
+ filelist = nested
356
+ if self.stride > 1:
357
+ filelist = filelist[:: self.stride]
358
+ if self.max_frames is not None:
359
+ filelist = filelist[: self.max_frames]
360
+ return filelist
361
+
362
+ def _load_images(self, filelist: List[str]) -> torch.Tensor:
363
+ views = load_images_for_eval(
364
+ filelist,
365
+ size=self.size,
366
+ verbose=False,
367
+ crop=self.crop,
368
+ patch_size=self.patch_size,
369
+ )
370
+ imgs = torch.cat([view["img"] for view in views], dim=0)
371
+ images = imgs.unsqueeze(0)
372
+ images = (images + 1.0) / 2.0
373
+ return images
374
+
375
+ def iter_sequence_infos(self) -> Iterator[LongStreamSequenceInfo]:
376
+ fmt = self._infer_format()
377
+ seqs = self._resolve_seq_list()
378
+ for seq_entry in seqs:
379
+ if fmt == "generalizable":
380
+ seq, scene_root = self._resolve_scene_root(seq_entry)
381
+ dir_path = self._resolve_image_dir_generalizable(scene_root)
382
+ if dir_path is None or not os.path.isdir(dir_path):
383
+ continue
384
+ camera = self._camera_from_image_dir(dir_path)
385
+ else:
386
+ seq = seq_entry
387
+ scene_root = os.path.join(self.img_path, seq)
388
+ dir_path = self.dir_path_func(self.img_path, seq)
389
+ if not os.path.isdir(dir_path):
390
+ continue
391
+ camera = None
392
+
393
+ filelist = self._collect_filelist(dir_path)
394
+ if not filelist:
395
+ continue
396
+ yield LongStreamSequenceInfo(
397
+ name=seq,
398
+ scene_root=scene_root,
399
+ image_dir=dir_path,
400
+ image_paths=filelist,
401
+ camera=camera,
402
+ )
403
+
404
+ def __iter__(self) -> Iterator[LongStreamSequence]:
405
+ for info in self.iter_sequence_infos():
406
+ print(
407
+ f"[longstream] loading sequence {info.name}: {len(info.image_paths)} frames",
408
+ flush=True,
409
+ )
410
+ images = self._load_images(info.image_paths)
411
+ print(
412
+ f"[longstream] loaded sequence {info.name}: {tuple(images.shape)}",
413
+ flush=True,
414
+ )
415
+ yield LongStreamSequence(
416
+ info.name,
417
+ images,
418
+ info.image_paths,
419
+ scene_root=info.scene_root,
420
+ image_dir=info.image_dir,
421
+ camera=info.camera,
422
+ )
longstream/demo/__init__.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .backend import create_demo_session, load_frame_previews
2
+ from .common import BRANCH_OPTIONS, DISPLAY_MODE_OPTIONS, branch_key, load_metadata
3
+
4
+ __all__ = [
5
+ "BRANCH_OPTIONS",
6
+ "DISPLAY_MODE_OPTIONS",
7
+ "branch_key",
8
+ "create_demo_session",
9
+ "load_frame_previews",
10
+ "load_metadata",
11
+ ]
longstream/demo/backend.py ADDED
@@ -0,0 +1,495 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ import re
4
+ import shutil
5
+ import tempfile
6
+ from datetime import datetime
7
+ from typing import Iterable, List, Optional, Tuple
8
+
9
+ import cv2
10
+ import numpy as np
11
+ import torch
12
+ import yaml
13
+
14
+ from longstream.core.cli import default_config_path
15
+ from longstream.core.model import LongStreamModel
16
+ from longstream.streaming.keyframe_selector import KeyframeSelector
17
+ from longstream.streaming.refresh import run_batch_refresh, run_streaming_refresh
18
+ from longstream.utils.camera import compose_abs_from_rel
19
+ from longstream.utils.depth import colorize_depth
20
+ from longstream.utils.hub import resolve_checkpoint_path
21
+ from longstream.utils.sky_mask import compute_sky_mask
22
+ from longstream.utils.vendor.dust3r.utils.image import load_images_for_eval
23
+ from longstream.utils.vendor.models.components.utils.pose_enc import (
24
+ pose_encoding_to_extri_intri,
25
+ )
26
+
27
+ from .common import load_metadata, session_file
28
+
29
+ _IMAGE_EXTS = (".png", ".jpg", ".jpeg", ".bmp", ".webp")
30
+ _MODEL_CACHE = {}
31
+
32
+
33
+ def _resolve_file_path(item) -> str:
34
+ if item is None:
35
+ return ""
36
+ if isinstance(item, str):
37
+ return item
38
+ if isinstance(item, dict) and "name" in item:
39
+ return item["name"]
40
+ if hasattr(item, "name"):
41
+ return item.name
42
+ return str(item)
43
+
44
+
45
+ def _natural_sort_key(path: str):
46
+ name = os.path.basename(path)
47
+ stem, _ = os.path.splitext(name)
48
+ parts = re.split(r"(\d+)", stem)
49
+ key = []
50
+ for part in parts:
51
+ if not part:
52
+ continue
53
+ if part.isdigit():
54
+ key.append((0, int(part)))
55
+ else:
56
+ key.append((1, part.lower()))
57
+ return key, name.lower()
58
+
59
+
60
+ def _sorted_image_paths(image_dir: str) -> List[str]:
61
+ files = []
62
+ for name in os.listdir(image_dir):
63
+ if name.lower().endswith(_IMAGE_EXTS):
64
+ files.append(os.path.join(image_dir, name))
65
+ return sorted(files, key=_natural_sort_key)
66
+
67
+
68
+ def _session_root() -> str:
69
+ root = os.path.join(tempfile.gettempdir(), "longstream_demo_sessions")
70
+ os.makedirs(root, exist_ok=True)
71
+ return root
72
+
73
+
74
+ def _new_session_dir() -> str:
75
+ stamp = datetime.utcnow().strftime("%Y%m%d_%H%M%S_%f")
76
+ return tempfile.mkdtemp(prefix=f"longstream_{stamp}_", dir=_session_root())
77
+
78
+
79
+ def _copy_uploaded_images(uploaded_files: Iterable, session_dir: str) -> List[str]:
80
+ input_dir = os.path.join(session_dir, "input_images")
81
+ os.makedirs(input_dir, exist_ok=True)
82
+ copied = []
83
+ sources = sorted(
84
+ (_resolve_file_path(x) for x in uploaded_files if x),
85
+ key=_natural_sort_key,
86
+ )
87
+ for src in sources:
88
+ if not src or not os.path.isfile(src):
89
+ continue
90
+ dst = os.path.join(input_dir, os.path.basename(src))
91
+ shutil.copy2(src, dst)
92
+ copied.append(dst)
93
+ return copied
94
+
95
+
96
+ def _extract_uploaded_video(uploaded_video, session_dir: str) -> List[str]:
97
+ src = _resolve_file_path(uploaded_video)
98
+ if not src:
99
+ return []
100
+ if not os.path.isfile(src):
101
+ raise FileNotFoundError(src)
102
+
103
+ input_dir = os.path.join(session_dir, "input_images")
104
+ os.makedirs(input_dir, exist_ok=True)
105
+ cap = cv2.VideoCapture(src)
106
+ if not cap.isOpened():
107
+ raise ValueError(f"unable to open video: {src}")
108
+
109
+ image_paths = []
110
+ frame_id = 0
111
+ while True:
112
+ ok, frame = cap.read()
113
+ if not ok:
114
+ break
115
+ dst = os.path.join(input_dir, f"{frame_id:06d}.png")
116
+ if not cv2.imwrite(dst, frame):
117
+ cap.release()
118
+ raise ValueError(f"failed to write extracted frame: {dst}")
119
+ image_paths.append(dst)
120
+ frame_id += 1
121
+ cap.release()
122
+
123
+ if not image_paths:
124
+ raise ValueError(f"no frames extracted from video: {src}")
125
+ return image_paths
126
+
127
+
128
+ def _resize_long_edge(arr, long_edge_size, interpolation):
129
+ h, w = arr.shape[:2]
130
+ scale = float(long_edge_size) / float(max(h, w))
131
+ new_w = int(round(w * scale))
132
+ new_h = int(round(h * scale))
133
+ return cv2.resize(arr, (new_w, new_h), interpolation=interpolation)
134
+
135
+
136
+ def _prepare_mask_for_model(
137
+ mask, size, crop, patch_size, target_shape, square_ok=False
138
+ ):
139
+ if mask is None:
140
+ return None
141
+ h0, w0 = mask.shape[:2]
142
+ long_edge = round(size * max(w0 / h0, h0 / w0)) if size == 224 else size
143
+ mask = _resize_long_edge(mask, long_edge, cv2.INTER_NEAREST)
144
+
145
+ h, w = mask.shape[:2]
146
+ cx, cy = w // 2, h // 2
147
+ if size == 224:
148
+ half = min(cx, cy)
149
+ if crop:
150
+ mask = mask[cy - half : cy + half, cx - half : cx + half]
151
+ else:
152
+ mask = cv2.resize(
153
+ mask, (2 * half, 2 * half), interpolation=cv2.INTER_NEAREST
154
+ )
155
+ else:
156
+ halfw = ((2 * cx) // patch_size) * (patch_size // 2)
157
+ halfh = ((2 * cy) // patch_size) * (patch_size // 2)
158
+ if not square_ok and w == h:
159
+ halfh = int(3 * halfw / 4)
160
+ if crop:
161
+ mask = mask[cy - halfh : cy + halfh, cx - halfw : cx + halfw]
162
+ else:
163
+ mask = cv2.resize(
164
+ mask, (2 * halfw, 2 * halfh), interpolation=cv2.INTER_NEAREST
165
+ )
166
+
167
+ if mask.shape[:2] != tuple(target_shape):
168
+ mask = cv2.resize(
169
+ mask, (target_shape[1], target_shape[0]), interpolation=cv2.INTER_NEAREST
170
+ )
171
+ return mask.astype(np.uint8, copy=False)
172
+
173
+
174
+ def _load_base_config(config_path: Optional[str] = None) -> dict:
175
+ path = config_path or default_config_path()
176
+ with open(path, "r") as f:
177
+ return yaml.safe_load(f) or {}
178
+
179
+
180
+ def _resolve_demo_checkpoint(checkpoint: str) -> str:
181
+ local_candidates = []
182
+ for candidate in [checkpoint, os.getenv("LONGSTREAM_CHECKPOINT", "")]:
183
+ if isinstance(candidate, str) and candidate:
184
+ local_candidates.append(candidate)
185
+
186
+ for candidate in local_candidates:
187
+ if os.path.exists(candidate):
188
+ return os.path.abspath(candidate)
189
+
190
+ hf_cfg = {
191
+ "repo_id": os.getenv("LONGSTREAM_HF_REPO"),
192
+ "filename": os.getenv("LONGSTREAM_HF_FILE"),
193
+ "revision": os.getenv("LONGSTREAM_HF_REVISION"),
194
+ "local_dir": os.getenv("LONGSTREAM_HF_LOCAL_DIR", "checkpoints"),
195
+ }
196
+ resolved = resolve_checkpoint_path(None, hf_cfg)
197
+ if resolved and os.path.exists(resolved):
198
+ return os.path.abspath(resolved)
199
+
200
+ if hf_cfg["repo_id"] and hf_cfg["filename"]:
201
+ raise FileNotFoundError(
202
+ "checkpoint not found locally and Hugging Face resolution failed: "
203
+ f"repo_id={hf_cfg['repo_id']} filename={hf_cfg['filename']}"
204
+ )
205
+
206
+ searched = ", ".join(local_candidates) if local_candidates else "<none>"
207
+ raise FileNotFoundError(
208
+ "checkpoint not found. "
209
+ f"searched local paths: {searched}. "
210
+ "You can also set LONGSTREAM_HF_REPO and LONGSTREAM_HF_FILE."
211
+ )
212
+
213
+
214
+ def _model_device(device: str) -> str:
215
+ if device == "cuda" and not torch.cuda.is_available():
216
+ return "cpu"
217
+ return device
218
+
219
+
220
+ def _cache_key(checkpoint: str, device: str, model_cfg: dict) -> Tuple[str, str, str]:
221
+ rel_cfg = json.dumps(model_cfg.get("longstream_cfg", {}), sort_keys=True)
222
+ return checkpoint, device, rel_cfg
223
+
224
+
225
+ def get_or_load_model(checkpoint: str, device: str, model_cfg: dict) -> LongStreamModel:
226
+ device = _model_device(device)
227
+ cfg = json.loads(json.dumps(model_cfg))
228
+ cfg["checkpoint"] = checkpoint
229
+ key = _cache_key(checkpoint, device, cfg)
230
+ model = _MODEL_CACHE.get(key)
231
+ if model is None:
232
+ model = LongStreamModel(cfg).to(device)
233
+ model.eval()
234
+ _MODEL_CACHE.clear()
235
+ _MODEL_CACHE[key] = model
236
+ return model
237
+
238
+
239
+ def _load_images(
240
+ image_paths: List[str], size: int, crop: bool, patch_size: int
241
+ ) -> torch.Tensor:
242
+ views = load_images_for_eval(
243
+ image_paths, size=size, verbose=False, crop=crop, patch_size=patch_size
244
+ )
245
+ imgs = torch.cat([view["img"] for view in views], dim=0)
246
+ images = (imgs.unsqueeze(0) + 1.0) / 2.0
247
+ return images
248
+
249
+
250
+ def _select_keyframes(images: torch.Tensor, keyframe_stride: int, keyframe_mode: str):
251
+ selector = KeyframeSelector(
252
+ min_interval=keyframe_stride,
253
+ max_interval=keyframe_stride,
254
+ force_first=True,
255
+ mode="random" if keyframe_mode == "random" else "fixed",
256
+ )
257
+ return selector.select_keyframes(images.shape[1], images.shape[0], images.device)
258
+
259
+
260
+ def _run_model(images: torch.Tensor, model: LongStreamModel, infer_cfg: dict):
261
+ keyframe_stride = int(infer_cfg.get("keyframe_stride", 8))
262
+ keyframe_mode = infer_cfg.get("keyframe_mode", "fixed")
263
+ refresh = int(infer_cfg.get("refresh", 4))
264
+ mode = infer_cfg.get("mode", "streaming_refresh")
265
+ streaming_mode = infer_cfg.get("streaming_mode", "causal")
266
+ window_size = int(infer_cfg.get("window_size", 48))
267
+ rel_pose_cfg = infer_cfg.get("rel_pose_head_cfg", {"num_iterations": 4})
268
+
269
+ is_keyframe, keyframe_indices = _select_keyframes(
270
+ images, keyframe_stride, keyframe_mode
271
+ )
272
+ if mode == "batch_refresh":
273
+ outputs = run_batch_refresh(
274
+ model,
275
+ images,
276
+ is_keyframe,
277
+ keyframe_indices,
278
+ streaming_mode,
279
+ keyframe_stride,
280
+ refresh,
281
+ rel_pose_cfg,
282
+ )
283
+ elif mode == "streaming_refresh":
284
+ outputs = run_streaming_refresh(
285
+ model,
286
+ images,
287
+ is_keyframe,
288
+ keyframe_indices,
289
+ streaming_mode,
290
+ window_size,
291
+ refresh,
292
+ rel_pose_cfg,
293
+ )
294
+ else:
295
+ raise ValueError(f"Unsupported demo inference mode: {mode}")
296
+ return outputs, keyframe_indices
297
+
298
+
299
+ def _compute_pose_outputs(
300
+ outputs: dict, keyframe_indices: torch.Tensor, image_hw: Tuple[int, int]
301
+ ):
302
+ if "rel_pose_enc" in outputs:
303
+ rel_pose_enc = outputs["rel_pose_enc"][0]
304
+ abs_pose_enc = compose_abs_from_rel(rel_pose_enc, keyframe_indices[0])
305
+ extri, intri = pose_encoding_to_extri_intri(
306
+ abs_pose_enc[None], image_size_hw=image_hw
307
+ )
308
+ return (
309
+ rel_pose_enc.detach().cpu().numpy(),
310
+ extri[0].detach().cpu().numpy(),
311
+ intri[0].detach().cpu().numpy(),
312
+ )
313
+ if "pose_enc" in outputs:
314
+ pose_enc = outputs["pose_enc"][0]
315
+ extri, intri = pose_encoding_to_extri_intri(
316
+ pose_enc[None], image_size_hw=image_hw
317
+ )
318
+ return None, extri[0].detach().cpu().numpy(), intri[0].detach().cpu().numpy()
319
+ raise RuntimeError("Model outputs contain neither rel_pose_enc nor pose_enc")
320
+
321
+
322
+ def _compute_sky_masks(
323
+ image_paths: List[str],
324
+ target_shape: Tuple[int, int],
325
+ data_cfg: dict,
326
+ skyseg_path: str,
327
+ session_dir: str,
328
+ ):
329
+ raw_masks = compute_sky_mask(
330
+ image_paths, skyseg_path, os.path.join(session_dir, "sky_masks_raw")
331
+ )
332
+ if raw_masks is None:
333
+ return None
334
+ masks = []
335
+ for mask in raw_masks:
336
+ masks.append(
337
+ _prepare_mask_for_model(
338
+ mask,
339
+ size=int(data_cfg.get("size", 518)),
340
+ crop=bool(data_cfg.get("crop", False)),
341
+ patch_size=int(data_cfg.get("patch_size", 14)),
342
+ target_shape=target_shape,
343
+ )
344
+ )
345
+ return np.stack(masks, axis=0)
346
+
347
+
348
+ def create_demo_session(
349
+ image_dir: str,
350
+ uploaded_files,
351
+ uploaded_video,
352
+ checkpoint: str,
353
+ device: str,
354
+ mode: str,
355
+ streaming_mode: str,
356
+ keyframe_stride: int,
357
+ refresh: int,
358
+ window_size: int,
359
+ compute_sky: bool,
360
+ config_path: Optional[str] = None,
361
+ ) -> str:
362
+ checkpoint = _resolve_demo_checkpoint(checkpoint)
363
+
364
+ session_dir = _new_session_dir()
365
+ base_cfg = _load_base_config(config_path)
366
+ data_cfg = dict(base_cfg.get("data", {}))
367
+ model_cfg = dict(base_cfg.get("model", {}))
368
+ infer_cfg = dict(base_cfg.get("inference", {}))
369
+
370
+ if image_dir:
371
+ image_dir = os.path.abspath(image_dir)
372
+ if not os.path.isdir(image_dir):
373
+ raise FileNotFoundError(f"image_dir not found: {image_dir}")
374
+ image_paths = _sorted_image_paths(image_dir)
375
+ input_root = image_dir
376
+ elif uploaded_video:
377
+ image_paths = _extract_uploaded_video(uploaded_video, session_dir)
378
+ input_root = _resolve_file_path(uploaded_video)
379
+ else:
380
+ image_paths = _copy_uploaded_images(uploaded_files or [], session_dir)
381
+ input_root = os.path.dirname(image_paths[0]) if image_paths else ""
382
+
383
+ if not image_paths:
384
+ raise ValueError("No input images found")
385
+
386
+ data_cfg["size"] = int(data_cfg.get("size", 518))
387
+ data_cfg["crop"] = bool(data_cfg.get("crop", False))
388
+ data_cfg["patch_size"] = int(data_cfg.get("patch_size", 14))
389
+
390
+ device = _model_device(device)
391
+ model = get_or_load_model(checkpoint, device, model_cfg)
392
+
393
+ images = _load_images(
394
+ image_paths, data_cfg["size"], data_cfg["crop"], data_cfg["patch_size"]
395
+ )
396
+ infer_cfg.update(
397
+ {
398
+ "mode": mode,
399
+ "streaming_mode": streaming_mode,
400
+ "keyframe_stride": int(keyframe_stride),
401
+ "refresh": int(refresh),
402
+ "window_size": int(window_size),
403
+ }
404
+ )
405
+
406
+ with torch.no_grad():
407
+ outputs, keyframe_indices = _run_model(images, model, infer_cfg)
408
+ h, w = images.shape[-2:]
409
+ rel_pose_enc, extri, intri = _compute_pose_outputs(
410
+ outputs, keyframe_indices, (h, w)
411
+ )
412
+ point_head = (
413
+ outputs["world_points"][0]
414
+ .detach()
415
+ .cpu()
416
+ .numpy()
417
+ .astype(np.float32, copy=False)
418
+ )
419
+ depth = (
420
+ outputs["depth"][0, :, :, :, 0]
421
+ .detach()
422
+ .cpu()
423
+ .numpy()
424
+ .astype(np.float32, copy=False)
425
+ )
426
+
427
+ if device == "cuda":
428
+ torch.cuda.empty_cache()
429
+
430
+ images_uint8 = np.clip(
431
+ images[0].permute(0, 2, 3, 1).cpu().numpy() * 255.0, 0, 255
432
+ ).astype(np.uint8)
433
+ sky_masks = None
434
+ if compute_sky:
435
+ skyseg_path = os.path.join(
436
+ os.path.dirname(os.path.dirname(os.path.dirname(__file__))), "skyseg.onnx"
437
+ )
438
+ sky_masks = _compute_sky_masks(
439
+ image_paths, (h, w), data_cfg, skyseg_path, session_dir
440
+ )
441
+
442
+ np.save(session_file(session_dir, "images.npy"), images_uint8)
443
+ np.save(session_file(session_dir, "depth.npy"), depth)
444
+ np.save(session_file(session_dir, "point_head.npy"), point_head)
445
+ np.save(session_file(session_dir, "w2c.npy"), extri)
446
+ np.save(session_file(session_dir, "intri.npy"), intri)
447
+ if rel_pose_enc is not None:
448
+ np.save(
449
+ session_file(session_dir, "rel_pose_enc.npy"),
450
+ rel_pose_enc.astype(np.float32, copy=False),
451
+ )
452
+ if sky_masks is not None:
453
+ np.save(
454
+ session_file(session_dir, "sky_masks.npy"),
455
+ sky_masks.astype(np.uint8, copy=False),
456
+ )
457
+
458
+ sky_removed_ratio = None
459
+ if sky_masks is not None:
460
+ sky_removed_ratio = float(1.0 - (sky_masks > 0).mean())
461
+
462
+ metadata = {
463
+ "session_dir": session_dir,
464
+ "created_at": datetime.utcnow().isoformat() + "Z",
465
+ "checkpoint": os.path.abspath(checkpoint),
466
+ "device": device,
467
+ "mode": mode,
468
+ "streaming_mode": streaming_mode,
469
+ "keyframe_stride": int(keyframe_stride),
470
+ "refresh": int(refresh),
471
+ "window_size": int(window_size),
472
+ "num_frames": int(images_uint8.shape[0]),
473
+ "height": int(images_uint8.shape[1]),
474
+ "width": int(images_uint8.shape[2]),
475
+ "input_root": input_root,
476
+ "image_paths": image_paths,
477
+ "has_sky_masks": bool(sky_masks is not None),
478
+ "sky_removed_ratio": sky_removed_ratio,
479
+ }
480
+ with open(session_file(session_dir, "metadata.json"), "w") as f:
481
+ json.dump(metadata, f, indent=2)
482
+
483
+ del outputs
484
+ return session_dir
485
+
486
+
487
+ def load_frame_previews(session_dir: str, frame_index: int):
488
+ meta = load_metadata(session_dir)
489
+ frame_index = int(np.clip(frame_index, 0, meta["num_frames"] - 1))
490
+ images = np.load(session_file(session_dir, "images.npy"), mmap_mode="r")
491
+ depth = np.load(session_file(session_dir, "depth.npy"), mmap_mode="r")
492
+ rgb = np.array(images[frame_index])
493
+ depth_color = colorize_depth(np.array(depth[frame_index]), cmap="plasma")
494
+ label = f"Frame {frame_index + 1}/{meta['num_frames']}"
495
+ return rgb, depth_color, label
longstream/demo/common.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ from typing import List
4
+
5
+ import numpy as np
6
+
7
+ BRANCH_OPTIONS = [
8
+ "Point Head + Pose",
9
+ "Depth Projection + Pose",
10
+ ]
11
+ BRANCH_TO_KEY = {
12
+ "Point Head + Pose": "point_head",
13
+ "Depth Projection + Pose": "depth_projection",
14
+ }
15
+ DISPLAY_MODE_OPTIONS = [
16
+ "Current Frame",
17
+ "Accumulate to Frame",
18
+ "All Frames",
19
+ ]
20
+
21
+
22
+ def branch_key(label: str) -> str:
23
+ return BRANCH_TO_KEY.get(label, "point_head")
24
+
25
+
26
+ def session_file(session_dir: str, name: str) -> str:
27
+ return os.path.join(session_dir, name)
28
+
29
+
30
+ def load_metadata(session_dir: str) -> dict:
31
+ with open(session_file(session_dir, "metadata.json"), "r") as f:
32
+ return json.load(f)
33
+
34
+
35
+ def selected_frame_indices(
36
+ num_frames: int, frame_index: int, display_mode: str
37
+ ) -> List[int]:
38
+ if num_frames <= 0:
39
+ return []
40
+ frame_index = int(np.clip(frame_index, 0, num_frames - 1))
41
+ if display_mode == "Current Frame":
42
+ return [frame_index]
43
+ if display_mode == "Accumulate to Frame":
44
+ return list(range(frame_index + 1))
45
+ return list(range(num_frames))
46
+
47
+
48
+ def as_4x4(w2c):
49
+ w2c = np.asarray(w2c, dtype=np.float64)
50
+ if w2c.shape == (4, 4):
51
+ return w2c
52
+ out = np.eye(4, dtype=np.float64)
53
+ out[:3, :4] = w2c
54
+ return out
55
+
56
+
57
+ _VIEW_ROT = np.array(
58
+ [
59
+ [1.0, 0.0, 0.0],
60
+ [0.0, 0.0, 1.0],
61
+ [0.0, -1.0, 0.0],
62
+ ],
63
+ dtype=np.float64,
64
+ )
65
+
66
+
67
+ def world_to_view(points):
68
+ points = np.asarray(points, dtype=np.float64)
69
+ return points @ _VIEW_ROT.T
70
+
71
+
72
+ def camera_center_from_w2c(w2c):
73
+ c2w = np.linalg.inv(as_4x4(w2c))
74
+ return c2w[:3, 3]
75
+
76
+
77
+ def c2w_in_view_space(w2c, origin_shift=None):
78
+ c2w = np.linalg.inv(as_4x4(w2c))
79
+ out = np.eye(4, dtype=np.float64)
80
+ out[:3, :3] = _VIEW_ROT @ c2w[:3, :3]
81
+ out[:3, 3] = world_to_view(c2w[:3, 3][None])[0]
82
+ if origin_shift is not None:
83
+ out[:3, 3] -= np.asarray(origin_shift, dtype=np.float64)
84
+ return out
longstream/demo/export.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import numpy as np
4
+
5
+ from .geometry import camera_geometry, collect_points
6
+
7
+ _CAMERA_COLORS = np.array(
8
+ [
9
+ [239, 68, 68, 255],
10
+ [14, 165, 233, 255],
11
+ [34, 197, 94, 255],
12
+ [245, 158, 11, 255],
13
+ ],
14
+ dtype=np.uint8,
15
+ )
16
+
17
+
18
+ def _camera_mesh(center, corners, color):
19
+ import trimesh
20
+
21
+ vertices = np.vstack([center[None], corners]).astype(np.float32)
22
+ faces = np.array(
23
+ [
24
+ [0, 1, 2],
25
+ [0, 2, 3],
26
+ [0, 3, 4],
27
+ [0, 4, 1],
28
+ [1, 2, 3],
29
+ [1, 3, 4],
30
+ ],
31
+ dtype=np.int64,
32
+ )
33
+ mesh = trimesh.Trimesh(vertices=vertices, faces=faces, process=False)
34
+ mesh.visual.face_colors = np.tile(color[None], (faces.shape[0], 1))
35
+ return mesh
36
+
37
+
38
+ def export_glb(
39
+ session_dir: str,
40
+ branch: str,
41
+ display_mode: str,
42
+ frame_index: int,
43
+ mask_sky: bool,
44
+ show_cameras: bool,
45
+ camera_scale: float,
46
+ max_points: int,
47
+ ) -> str:
48
+ import trimesh
49
+
50
+ points, colors, _ = collect_points(
51
+ session_dir=session_dir,
52
+ branch=branch,
53
+ display_mode=display_mode,
54
+ frame_index=frame_index,
55
+ mask_sky=mask_sky,
56
+ max_points=max_points,
57
+ seed=13,
58
+ )
59
+ if len(points) == 0:
60
+ raise ValueError("No valid points to export")
61
+
62
+ scene = trimesh.Scene()
63
+ scene.add_geometry(trimesh.PointCloud(vertices=points, colors=colors))
64
+
65
+ if show_cameras:
66
+ _, frustums, _ = camera_geometry(
67
+ session_dir=session_dir,
68
+ display_mode=display_mode,
69
+ frame_index=frame_index,
70
+ camera_scale_ratio=camera_scale,
71
+ points_hint=points,
72
+ )
73
+ for idx, (center, corners) in enumerate(frustums):
74
+ scene.add_geometry(
75
+ _camera_mesh(center, corners, _CAMERA_COLORS[idx % len(_CAMERA_COLORS)])
76
+ )
77
+
78
+ export_dir = os.path.join(session_dir, "exports")
79
+ os.makedirs(export_dir, exist_ok=True)
80
+ branch_slug = branch.lower().replace(" + ", "_").replace(" ", "_")
81
+ mode_slug = display_mode.replace(" ", "_").lower()
82
+ filename = f"{branch_slug}_{mode_slug}_{frame_index:04d}_sky{int(mask_sky)}_cam{int(show_cameras)}.glb"
83
+ path = os.path.join(export_dir, filename)
84
+ scene.export(path)
85
+ return path
longstream/demo/geometry.py ADDED
@@ -0,0 +1,211 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import List, Optional, Tuple
3
+
4
+ import numpy as np
5
+
6
+ from .common import (
7
+ branch_key,
8
+ c2w_in_view_space,
9
+ load_metadata,
10
+ selected_frame_indices,
11
+ session_file,
12
+ world_to_view,
13
+ )
14
+
15
+
16
+ def _origin_shift(w2c_all) -> np.ndarray:
17
+ first = c2w_in_view_space(w2c_all[0])
18
+ return first[:3, 3].copy()
19
+
20
+
21
+ def _sample_flat_indices(
22
+ valid_indices: np.ndarray, budget: Optional[int], rng: np.random.Generator
23
+ ) -> np.ndarray:
24
+ if budget is None or budget <= 0 or valid_indices.size <= budget:
25
+ return valid_indices
26
+ keep = rng.choice(valid_indices.size, size=int(budget), replace=False)
27
+ return valid_indices[keep]
28
+
29
+
30
+ def _depth_points_from_flat(depth, intri, w2c, flat_indices):
31
+ h, w = depth.shape
32
+ ys = flat_indices // w
33
+ xs = flat_indices % w
34
+ z = depth.reshape(-1)[flat_indices].astype(np.float64)
35
+ fx = float(intri[0, 0])
36
+ fy = float(intri[1, 1])
37
+ cx = float(intri[0, 2])
38
+ cy = float(intri[1, 2])
39
+ x = (xs.astype(np.float64) - cx) * z / max(fx, 1e-12)
40
+ y = (ys.astype(np.float64) - cy) * z / max(fy, 1e-12)
41
+ pts_cam = np.stack([x, y, z], axis=1)
42
+ R = w2c[:3, :3].astype(np.float64)
43
+ t = w2c[:3, 3].astype(np.float64)
44
+ return (R.T @ (pts_cam.T - t[:, None])).T.astype(np.float32, copy=False)
45
+
46
+
47
+ def _camera_points_to_world(points, w2c):
48
+ pts = np.asarray(points, dtype=np.float64).reshape(-1, 3)
49
+ R = w2c[:3, :3].astype(np.float64)
50
+ t = w2c[:3, 3].astype(np.float64)
51
+ return (R.T @ (pts.T - t[:, None])).T.astype(np.float32, copy=False)
52
+
53
+
54
+ def collect_points(
55
+ session_dir: str,
56
+ branch: str,
57
+ display_mode: str,
58
+ frame_index: int,
59
+ mask_sky: bool,
60
+ max_points: Optional[int],
61
+ seed: int = 0,
62
+ ):
63
+ branch = branch_key(branch)
64
+ meta = load_metadata(session_dir)
65
+ frame_ids = selected_frame_indices(meta["num_frames"], frame_index, display_mode)
66
+ if not frame_ids:
67
+ return (
68
+ np.empty((0, 3), dtype=np.float32),
69
+ np.empty((0, 3), dtype=np.uint8),
70
+ np.zeros(3, dtype=np.float64),
71
+ )
72
+
73
+ images = np.load(session_file(session_dir, "images.npy"), mmap_mode="r")
74
+ w2c = np.load(session_file(session_dir, "w2c.npy"), mmap_mode="r")
75
+ origin_shift = _origin_shift(w2c)
76
+ sky = None
77
+ if mask_sky and os.path.exists(session_file(session_dir, "sky_masks.npy")):
78
+ sky = np.load(session_file(session_dir, "sky_masks.npy"), mmap_mode="r")
79
+
80
+ if branch == "point_head":
81
+ point_head = np.load(session_file(session_dir, "point_head.npy"), mmap_mode="r")
82
+ source = point_head
83
+ depth = None
84
+ intri = None
85
+ else:
86
+ source = None
87
+ depth = np.load(session_file(session_dir, "depth.npy"), mmap_mode="r")
88
+ intri = np.load(session_file(session_dir, "intri.npy"), mmap_mode="r")
89
+
90
+ per_frame_budget = None
91
+ if max_points is not None and max_points > 0:
92
+ per_frame_budget = max(int(max_points) // max(len(frame_ids), 1), 1)
93
+
94
+ rng = np.random.default_rng(seed)
95
+ points = []
96
+ colors = []
97
+ for idx in frame_ids:
98
+ rgb_flat = images[idx].reshape(-1, 3)
99
+ if branch == "point_head":
100
+ pts_map = source[idx]
101
+ valid = np.isfinite(pts_map).all(axis=-1).reshape(-1)
102
+ if sky is not None:
103
+ valid &= sky[idx].reshape(-1) > 0
104
+ flat = np.flatnonzero(valid)
105
+ if flat.size == 0:
106
+ continue
107
+ flat = _sample_flat_indices(flat, per_frame_budget, rng)
108
+ pts_cam = pts_map.reshape(-1, 3)[flat]
109
+ pts_world = _camera_points_to_world(pts_cam, w2c[idx])
110
+ else:
111
+ depth_i = depth[idx]
112
+ valid = (np.isfinite(depth_i) & (depth_i > 0)).reshape(-1)
113
+ if sky is not None:
114
+ valid &= sky[idx].reshape(-1) > 0
115
+ flat = np.flatnonzero(valid)
116
+ if flat.size == 0:
117
+ continue
118
+ flat = _sample_flat_indices(flat, per_frame_budget, rng)
119
+ pts_world = _depth_points_from_flat(depth_i, intri[idx], w2c[idx], flat)
120
+
121
+ pts_view = world_to_view(pts_world) - origin_shift[None]
122
+ points.append(pts_view.astype(np.float32, copy=False))
123
+ colors.append(rgb_flat[flat].astype(np.uint8, copy=False))
124
+
125
+ if not points:
126
+ return (
127
+ np.empty((0, 3), dtype=np.float32),
128
+ np.empty((0, 3), dtype=np.uint8),
129
+ origin_shift,
130
+ )
131
+ return np.concatenate(points, axis=0), np.concatenate(colors, axis=0), origin_shift
132
+
133
+
134
+ def _frustum_corners_camera(intri, image_hw, depth_scale):
135
+ h, w = image_hw
136
+ fx = float(intri[0, 0])
137
+ fy = float(intri[1, 1])
138
+ cx = float(intri[0, 2])
139
+ cy = float(intri[1, 2])
140
+ corners = np.array(
141
+ [
142
+ [
143
+ (0.0 - cx) * depth_scale / max(fx, 1e-12),
144
+ (0.0 - cy) * depth_scale / max(fy, 1e-12),
145
+ depth_scale,
146
+ ],
147
+ [
148
+ ((w - 1.0) - cx) * depth_scale / max(fx, 1e-12),
149
+ (0.0 - cy) * depth_scale / max(fy, 1e-12),
150
+ depth_scale,
151
+ ],
152
+ [
153
+ ((w - 1.0) - cx) * depth_scale / max(fx, 1e-12),
154
+ ((h - 1.0) - cy) * depth_scale / max(fy, 1e-12),
155
+ depth_scale,
156
+ ],
157
+ [
158
+ (0.0 - cx) * depth_scale / max(fx, 1e-12),
159
+ ((h - 1.0) - cy) * depth_scale / max(fy, 1e-12),
160
+ depth_scale,
161
+ ],
162
+ ],
163
+ dtype=np.float64,
164
+ )
165
+ return corners
166
+
167
+
168
+ def camera_geometry(
169
+ session_dir: str,
170
+ display_mode: str,
171
+ frame_index: int,
172
+ camera_scale_ratio: float,
173
+ points_hint=None,
174
+ ):
175
+ meta = load_metadata(session_dir)
176
+ frame_ids = selected_frame_indices(meta["num_frames"], frame_index, display_mode)
177
+ w2c = np.load(session_file(session_dir, "w2c.npy"), mmap_mode="r")
178
+ intri = np.load(session_file(session_dir, "intri.npy"), mmap_mode="r")
179
+ origin_shift = _origin_shift(w2c)
180
+
181
+ center_points = np.array(
182
+ [c2w_in_view_space(w2c[idx], origin_shift)[:3, 3] for idx in frame_ids],
183
+ dtype=np.float64,
184
+ )
185
+ center_extent = 1.0
186
+ if len(center_points) > 1:
187
+ center_extent = float(
188
+ np.linalg.norm(center_points.max(axis=0) - center_points.min(axis=0))
189
+ )
190
+
191
+ point_extent = 0.0
192
+ if points_hint is not None and len(points_hint) > 0:
193
+ lo = np.percentile(points_hint, 5, axis=0)
194
+ hi = np.percentile(points_hint, 95, axis=0)
195
+ point_extent = float(np.linalg.norm(hi - lo))
196
+
197
+ extent = max(center_extent, point_extent, 1.0)
198
+ depth_scale = extent * float(camera_scale_ratio)
199
+
200
+ centers = []
201
+ frustums = []
202
+ for idx in frame_ids:
203
+ c2w_view = c2w_in_view_space(w2c[idx], origin_shift)
204
+ center = c2w_view[:3, 3]
205
+ corners_cam = _frustum_corners_camera(
206
+ intri[idx], (meta["height"], meta["width"]), depth_scale
207
+ )
208
+ corners_world = (c2w_view[:3, :3] @ corners_cam.T).T + center[None]
209
+ centers.append(center)
210
+ frustums.append((center, corners_world))
211
+ return np.asarray(centers, dtype=np.float64), frustums, origin_shift
longstream/demo/viewer.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import plotly.graph_objects as go
3
+
4
+ from longstream.demo.backend import load_frame_previews
5
+
6
+ from .common import load_metadata
7
+ from .geometry import camera_geometry, collect_points
8
+
9
+
10
+ def _empty_figure(message: str):
11
+ fig = go.Figure()
12
+ fig.add_annotation(
13
+ text=message, x=0.5, y=0.5, xref="paper", yref="paper", showarrow=False
14
+ )
15
+ fig.update_layout(
16
+ template="plotly_white",
17
+ margin=dict(l=0, r=0, t=40, b=0),
18
+ scene=dict(aspectmode="data"),
19
+ )
20
+ return fig
21
+
22
+
23
+ def _camera_lines(frustums):
24
+ xs, ys, zs = [], [], []
25
+ for center, corners in frustums:
26
+ order = [(0, 1), (1, 2), (2, 3), (3, 0)]
27
+ for a, b in order:
28
+ xs.extend([corners[a, 0], corners[b, 0], None])
29
+ ys.extend([corners[a, 1], corners[b, 1], None])
30
+ zs.extend([corners[a, 2], corners[b, 2], None])
31
+ for corner in corners:
32
+ xs.extend([center[0], corner[0], None])
33
+ ys.extend([center[1], corner[1], None])
34
+ zs.extend([center[2], corner[2], None])
35
+ return xs, ys, zs
36
+
37
+
38
+ def build_interactive_figure(
39
+ session_dir: str,
40
+ branch: str,
41
+ display_mode: str,
42
+ frame_index: int,
43
+ point_size: float,
44
+ opacity: float,
45
+ preview_max_points: int,
46
+ show_cameras: bool,
47
+ camera_scale: float,
48
+ mask_sky: bool,
49
+ ):
50
+ meta = load_metadata(session_dir)
51
+ points, colors, _ = collect_points(
52
+ session_dir=session_dir,
53
+ branch=branch,
54
+ display_mode=display_mode,
55
+ frame_index=frame_index,
56
+ mask_sky=mask_sky,
57
+ max_points=preview_max_points,
58
+ seed=frame_index,
59
+ )
60
+ if len(points) == 0:
61
+ return _empty_figure("No valid points for the current selection")
62
+
63
+ fig = go.Figure()
64
+ fig.add_trace(
65
+ go.Scatter3d(
66
+ x=points[:, 0],
67
+ y=points[:, 1],
68
+ z=points[:, 2],
69
+ mode="markers",
70
+ marker=dict(
71
+ size=float(point_size),
72
+ color=[f"rgb({r},{g},{b})" for r, g, b in colors],
73
+ opacity=float(opacity),
74
+ ),
75
+ hoverinfo="skip",
76
+ name="points",
77
+ )
78
+ )
79
+
80
+ if show_cameras:
81
+ centers, frustums, _ = camera_geometry(
82
+ session_dir=session_dir,
83
+ display_mode=display_mode,
84
+ frame_index=frame_index,
85
+ camera_scale_ratio=camera_scale,
86
+ points_hint=points,
87
+ )
88
+ if len(centers) > 0:
89
+ fig.add_trace(
90
+ go.Scatter3d(
91
+ x=centers[:, 0],
92
+ y=centers[:, 1],
93
+ z=centers[:, 2],
94
+ mode="lines",
95
+ line=dict(color="#16a34a", width=2),
96
+ name="trajectory",
97
+ hoverinfo="skip",
98
+ )
99
+ )
100
+ xs, ys, zs = _camera_lines(frustums)
101
+ fig.add_trace(
102
+ go.Scatter3d(
103
+ x=xs,
104
+ y=ys,
105
+ z=zs,
106
+ mode="lines",
107
+ line=dict(color="#22c55e", width=1.5),
108
+ name="cameras",
109
+ hoverinfo="skip",
110
+ )
111
+ )
112
+
113
+ fig.update_layout(
114
+ template="plotly_white",
115
+ margin=dict(l=0, r=0, t=40, b=0),
116
+ scene=dict(
117
+ aspectmode="data",
118
+ xaxis_title="x_right",
119
+ yaxis_title="z_forward",
120
+ zaxis_title="y_up",
121
+ bgcolor="#f8fafc",
122
+ camera=dict(
123
+ up=dict(x=0.0, y=0.0, z=1.0),
124
+ eye=dict(x=-1.0, y=-1.8, z=0.9),
125
+ ),
126
+ ),
127
+ legend=dict(orientation="h", yanchor="bottom", y=1.02, xanchor="left", x=0.0),
128
+ )
129
+ return fig
130
+
131
+
132
+ def build_frame_outputs(session_dir: str, frame_index: int):
133
+ rgb, depth, label = load_frame_previews(session_dir, frame_index)
134
+ return rgb, depth, label
longstream/eval/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .evaluate import evaluate_predictions_cfg
2
+
3
+ __all__ = ["evaluate_predictions_cfg"]
longstream/eval/evaluate.py ADDED
@@ -0,0 +1,551 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+
4
+ import cv2
5
+ import numpy as np
6
+ import matplotlib
7
+
8
+ matplotlib.use("Agg")
9
+ import matplotlib.pyplot as plt
10
+
11
+ from longstream.data import LongStreamDataLoader
12
+ from longstream.eval.io import (
13
+ frame_stems,
14
+ read_depth,
15
+ read_opencv_camera_yml,
16
+ read_pointcloud_xyz,
17
+ read_pred_w2c_txt,
18
+ )
19
+ from longstream.eval.metrics import ate_rmse, chamfer_and_f1, transform_points
20
+ from longstream.utils.sky_mask import sky_mask_filename
21
+
22
+
23
+ def _ensure_dir(path):
24
+ os.makedirs(path, exist_ok=True)
25
+
26
+
27
+ def _sequence_output_dir(output_root, seq_name):
28
+ return os.path.join(output_root, seq_name)
29
+
30
+
31
+ def _sequence_metrics_path(output_root, seq_name):
32
+ return os.path.join(output_root, "metrics", f"{seq_name}.json")
33
+
34
+
35
+ def _sequence_plot_path(output_root, seq_name):
36
+ return os.path.join(output_root, "plots", f"{seq_name}_traj_3d.png")
37
+
38
+
39
+ def _world_xyz_to_plot_xyz(xyz):
40
+ xyz = np.asarray(xyz, dtype=np.float64)
41
+ return np.stack([xyz[:, 0], xyz[:, 2], -xyz[:, 1]], axis=-1)
42
+
43
+
44
+ def _set_equal_3d_axes(ax, xyz):
45
+ mins = xyz.min(axis=0)
46
+ maxs = xyz.max(axis=0)
47
+ center = 0.5 * (mins + maxs)
48
+ radius = 0.5 * np.max(np.maximum(maxs - mins, 1e-6))
49
+ ax.set_xlim(center[0] - radius, center[0] + radius)
50
+ ax.set_ylim(center[1] - radius, center[1] + radius)
51
+ ax.set_zlim(center[2] - radius, center[2] + radius)
52
+
53
+
54
+ def _load_gt_pose_data(seq_info):
55
+ if seq_info.camera is not None:
56
+ cam_dir = os.path.join(seq_info.scene_root, "cameras", seq_info.camera)
57
+ extri_path = os.path.join(cam_dir, "extri.yml")
58
+ intri_path = os.path.join(cam_dir, "intri.yml")
59
+ if os.path.exists(extri_path):
60
+ extri, intri, image_sizes = read_opencv_camera_yml(extri_path, intri_path)
61
+ return extri, intri, image_sizes
62
+
63
+ extri_path = os.path.join(seq_info.scene_root, "extri.yml")
64
+ intri_path = os.path.join(seq_info.scene_root, "intri.yml")
65
+ if not os.path.exists(extri_path):
66
+ return None, None, None
67
+ extri, intri, image_sizes = read_opencv_camera_yml(extri_path, intri_path)
68
+ return extri, intri, image_sizes
69
+
70
+
71
+ def _resolve_gt_depth_root(seq_info):
72
+ if seq_info.camera is not None:
73
+ camera_depth_root = os.path.join(seq_info.scene_root, "depths", seq_info.camera)
74
+ if os.path.isdir(camera_depth_root):
75
+ return camera_depth_root
76
+ depth_root = os.path.join(seq_info.scene_root, "depths")
77
+ if os.path.isdir(depth_root):
78
+ return depth_root
79
+ return None
80
+
81
+
82
+ def _resolve_gt_depth_path(seq_info, depth_root, image_path, stem):
83
+ rel_path = os.path.relpath(image_path, seq_info.image_dir)
84
+ rel_stem = os.path.splitext(rel_path)[0]
85
+ file_stem = os.path.splitext(os.path.basename(image_path))[0]
86
+ candidates = [
87
+ os.path.join(depth_root, f"{stem}.exr"),
88
+ os.path.join(depth_root, rel_stem + ".exr"),
89
+ os.path.join(depth_root, stem, f"{file_stem}.exr"),
90
+ ]
91
+ for candidate in candidates:
92
+ if os.path.exists(candidate):
93
+ return candidate
94
+ return None
95
+
96
+
97
+ def _resize_long_edge(arr, long_edge_size, interpolation):
98
+ h, w = arr.shape[:2]
99
+ scale = float(long_edge_size) / float(max(h, w))
100
+ new_w = int(round(w * scale))
101
+ new_h = int(round(h * scale))
102
+ return cv2.resize(arr, (new_w, new_h), interpolation=interpolation)
103
+
104
+
105
+ def _prepare_map_for_eval(
106
+ arr, size, crop, patch_size, target_shape, interpolation, square_ok=False
107
+ ):
108
+ h0, w0 = arr.shape[:2]
109
+ long_edge = round(size * max(w0 / h0, h0 / w0)) if size == 224 else size
110
+ arr = _resize_long_edge(arr, long_edge, interpolation)
111
+
112
+ h, w = arr.shape[:2]
113
+ cx, cy = w // 2, h // 2
114
+
115
+ if size == 224:
116
+ half = min(cx, cy)
117
+ target_w = 2 * half
118
+ target_h = 2 * half
119
+ if crop:
120
+ arr = arr[cy - half : cy + half, cx - half : cx + half]
121
+ else:
122
+ arr = cv2.resize(arr, (target_w, target_h), interpolation=interpolation)
123
+ else:
124
+ halfw = ((2 * cx) // patch_size) * (patch_size // 2)
125
+ halfh = ((2 * cy) // patch_size) * (patch_size // 2)
126
+ if not square_ok and w == h:
127
+ halfh = int(3 * halfw / 4)
128
+ target_w = 2 * halfw
129
+ target_h = 2 * halfh
130
+ if crop:
131
+ arr = arr[cy - halfh : cy + halfh, cx - halfw : cx + halfw]
132
+ else:
133
+ arr = cv2.resize(arr, (target_w, target_h), interpolation=interpolation)
134
+
135
+ if arr.shape[:2] != tuple(target_shape):
136
+ arr = cv2.resize(
137
+ arr, (target_shape[1], target_shape[0]), interpolation=interpolation
138
+ )
139
+ return arr
140
+
141
+
142
+ def _sky_mask_path(seq_dir, image_path):
143
+ return os.path.join(seq_dir, "sky_masks", sky_mask_filename(image_path))
144
+
145
+
146
+ def _sample_frame_points(points, max_points, rng):
147
+ if max_points is None or len(points) <= max_points:
148
+ return points
149
+ keep = rng.choice(len(points), size=max_points, replace=False)
150
+ return points[keep]
151
+
152
+
153
+ def _depth_to_world_points(depth, intri, extri, valid_mask):
154
+ ys, xs = np.nonzero(valid_mask)
155
+ if ys.size == 0:
156
+ return np.empty((0, 3), dtype=np.float32)
157
+
158
+ z = depth[ys, xs].astype(np.float64)
159
+ fx = float(intri[0, 0])
160
+ fy = float(intri[1, 1])
161
+ cx = float(intri[0, 2])
162
+ cy = float(intri[1, 2])
163
+
164
+ x = (xs.astype(np.float64) - cx) * z / max(fx, 1e-12)
165
+ y = (ys.astype(np.float64) - cy) * z / max(fy, 1e-12)
166
+ pts_cam = np.stack([x, y, z], axis=1)
167
+
168
+ R = extri[:3, :3]
169
+ t = extri[:3, 3]
170
+ pts_world = (R.T @ (pts_cam.T - t[:, None])).T
171
+ return pts_world.astype(np.float32, copy=False)
172
+
173
+
174
+ def _load_gt_pointcloud(seq_info, seq_dir, gt_extri, gt_intri, eval_cfg):
175
+ if not gt_extri or not gt_intri:
176
+ return None
177
+
178
+ gt_dir = _resolve_gt_depth_root(seq_info)
179
+ if gt_dir is None:
180
+ return None
181
+
182
+ eval_max_points = int(eval_cfg.get("point_eval_max_points", 100000))
183
+ oversample_factor = int(eval_cfg.get("point_eval_oversample_factor", 4))
184
+ per_frame_budget = max(
185
+ (eval_max_points * oversample_factor) // max(len(seq_info.image_paths), 1), 1
186
+ )
187
+ rng = np.random.default_rng(0)
188
+ chunks = []
189
+
190
+ for image_path, stem in zip(
191
+ seq_info.image_paths, frame_stems(seq_info.image_paths)
192
+ ):
193
+ depth_path = _resolve_gt_depth_path(seq_info, gt_dir, image_path, stem)
194
+ if depth_path is None or stem not in gt_extri or stem not in gt_intri:
195
+ continue
196
+
197
+ depth = read_depth(depth_path)
198
+ valid = np.isfinite(depth) & (depth > 0)
199
+ if not np.any(valid):
200
+ continue
201
+
202
+ sky_path = _sky_mask_path(seq_dir, image_path)
203
+ if os.path.exists(sky_path):
204
+ sky_mask = cv2.imread(sky_path, cv2.IMREAD_GRAYSCALE)
205
+ if sky_mask is not None:
206
+ if sky_mask.shape[:2] != depth.shape[:2]:
207
+ sky_mask = cv2.resize(
208
+ sky_mask,
209
+ (depth.shape[1], depth.shape[0]),
210
+ interpolation=cv2.INTER_NEAREST,
211
+ )
212
+ valid &= sky_mask > 0
213
+ if not np.any(valid):
214
+ continue
215
+
216
+ pts_world = _depth_to_world_points(depth, gt_intri[stem], gt_extri[stem], valid)
217
+ if len(pts_world) == 0:
218
+ continue
219
+ chunks.append(_sample_frame_points(pts_world, per_frame_budget, rng))
220
+
221
+ if not chunks:
222
+ return None
223
+ return np.concatenate(chunks, axis=0)
224
+
225
+
226
+ def _evaluate_pointclouds(seq_info, seq_dir, eval_cfg, pose_align, gt_cloud):
227
+ if pose_align is None or gt_cloud is None:
228
+ return None
229
+
230
+ scale, R, t = pose_align
231
+ point_paths = {
232
+ "point_head": [
233
+ os.path.join(seq_dir, "points", "point_head_full.npy"),
234
+ os.path.join(seq_dir, "points", "point_head_full.npz"),
235
+ os.path.join(seq_dir, "points", "point_head_full.ply"),
236
+ ],
237
+ "dpt_unproj": [
238
+ os.path.join(seq_dir, "points", "dpt_unproj_full.npy"),
239
+ os.path.join(seq_dir, "points", "dpt_unproj_full.npz"),
240
+ os.path.join(seq_dir, "points", "dpt_unproj_full.ply"),
241
+ ],
242
+ }
243
+ threshold = float(eval_cfg.get("point_f1_threshold", 0.25))
244
+ max_points = int(eval_cfg.get("point_eval_max_points", 100000))
245
+ voxel_size = eval_cfg.get("point_eval_voxel_size", None)
246
+ voxel_size = None if voxel_size in (None, "", "null") else float(voxel_size)
247
+
248
+ metrics_by_branch = {}
249
+ for branch, candidates in point_paths.items():
250
+ path = next(
251
+ (candidate for candidate in candidates if os.path.exists(candidate)), None
252
+ )
253
+ if path is None:
254
+ continue
255
+ pred_cloud = read_pointcloud_xyz(path)
256
+ pred_cloud = transform_points(pred_cloud, scale, R, t)
257
+ metrics = chamfer_and_f1(
258
+ pred_cloud,
259
+ gt_cloud,
260
+ threshold=threshold,
261
+ max_points=max_points,
262
+ voxel_size=voxel_size,
263
+ seed=0 if branch == "point_head" else 1,
264
+ )
265
+ if metrics is not None:
266
+ metrics_by_branch[branch] = metrics
267
+ return metrics_by_branch or None
268
+
269
+
270
+ def _evaluate_video_dpt(seq_info, seq_dir, eval_cfg, data_cfg):
271
+ pred_dir = os.path.join(seq_dir, "depth", "dpt")
272
+ gt_dir = _resolve_gt_depth_root(seq_info)
273
+ if not os.path.isdir(pred_dir) or gt_dir is None:
274
+ return None
275
+
276
+ size = int(data_cfg.get("size", 518))
277
+ crop = bool(data_cfg.get("crop", False))
278
+ patch_size = int(data_cfg.get("patch_size", 14))
279
+ rel_delta_threshold = float(eval_cfg.get("depth_rel_delta_threshold", 1.25))
280
+
281
+ abs_rel_sum = 0.0
282
+ rel_delta_hits = 0
283
+ valid_pixels = 0
284
+ evaluated_frames = 0
285
+
286
+ stems = frame_stems(seq_info.image_paths)
287
+ for frame_id, stem in enumerate(stems):
288
+ pred_path = os.path.join(pred_dir, f"frame_{frame_id:06d}.npy")
289
+ gt_path = _resolve_gt_depth_path(
290
+ seq_info, gt_dir, seq_info.image_paths[frame_id], stem
291
+ )
292
+ if not os.path.exists(pred_path) or gt_path is None:
293
+ continue
294
+
295
+ pred = np.load(pred_path).astype(np.float32)
296
+ gt = read_depth(gt_path)
297
+ gt = _prepare_map_for_eval(
298
+ gt,
299
+ size=size,
300
+ crop=crop,
301
+ patch_size=patch_size,
302
+ target_shape=pred.shape,
303
+ interpolation=cv2.INTER_NEAREST,
304
+ )
305
+
306
+ valid = np.isfinite(gt) & (gt > 0)
307
+ if not np.any(valid):
308
+ continue
309
+
310
+ sky_mask_path = _sky_mask_path(seq_dir, seq_info.image_paths[frame_id])
311
+ if os.path.exists(sky_mask_path):
312
+ sky_mask = cv2.imread(sky_mask_path, cv2.IMREAD_GRAYSCALE)
313
+ if sky_mask is not None:
314
+ sky_mask = _prepare_map_for_eval(
315
+ sky_mask,
316
+ size=size,
317
+ crop=crop,
318
+ patch_size=patch_size,
319
+ target_shape=pred.shape,
320
+ interpolation=cv2.INTER_NEAREST,
321
+ )
322
+ valid &= sky_mask > 0
323
+
324
+ valid &= np.isfinite(pred)
325
+ if not np.any(valid):
326
+ continue
327
+
328
+ pred_valid = pred[valid].astype(np.float64)
329
+ gt_valid = gt[valid].astype(np.float64)
330
+ pred_safe = np.clip(pred_valid, 1e-6, None)
331
+ gt_safe = np.clip(gt_valid, 1e-6, None)
332
+
333
+ abs_rel_sum += np.sum(np.abs(pred_valid - gt_valid) / gt_safe)
334
+ rel_ratio = np.maximum(gt_safe / pred_safe, pred_safe / gt_safe)
335
+ rel_delta_hits += int(np.sum(rel_ratio < rel_delta_threshold))
336
+ valid_pixels += int(gt_valid.size)
337
+ evaluated_frames += 1
338
+
339
+ if valid_pixels == 0:
340
+ return None
341
+
342
+ return {
343
+ "abs_rel": float(abs_rel_sum / valid_pixels),
344
+ "rel_delta": float(rel_delta_hits / valid_pixels),
345
+ "rel_delta_threshold": rel_delta_threshold,
346
+ "num_valid_pixels": int(valid_pixels),
347
+ "num_frames": int(evaluated_frames),
348
+ }
349
+
350
+
351
+ def _extract_pose_pairs(seq_info, pred_pose_path, gt_extri):
352
+ frame_ids, pred_w2c = read_pred_w2c_txt(pred_pose_path)
353
+ if not pred_w2c:
354
+ return None
355
+
356
+ stems = frame_stems(seq_info.image_paths)
357
+ pred_xyz = []
358
+ gt_xyz = []
359
+
360
+ for frame_id, pred_mat in zip(frame_ids, pred_w2c):
361
+ if frame_id < 0 or frame_id >= len(stems):
362
+ continue
363
+ stem = stems[frame_id]
364
+ if stem not in gt_extri:
365
+ continue
366
+ pred_c2w = np.linalg.inv(pred_mat)
367
+ gt_c2w = np.linalg.inv(gt_extri[stem])
368
+ pred_xyz.append(pred_c2w[:3, 3])
369
+ gt_xyz.append(gt_c2w[:3, 3])
370
+
371
+ if len(pred_xyz) < 3:
372
+ return None
373
+ return np.asarray(pred_xyz, dtype=np.float64), np.asarray(gt_xyz, dtype=np.float64)
374
+
375
+
376
+ def _save_traj_plot_3d(path, pred_xyz, gt_xyz):
377
+ _ensure_dir(os.path.dirname(path))
378
+ pred_plot = _world_xyz_to_plot_xyz(pred_xyz)
379
+ gt_plot = _world_xyz_to_plot_xyz(gt_xyz)
380
+ origin = gt_plot[:1]
381
+ pred_plot = pred_plot - origin
382
+ gt_plot = gt_plot - origin
383
+ all_plot = np.concatenate([pred_plot, gt_plot], axis=0)
384
+
385
+ fig = plt.figure(figsize=(7, 6))
386
+ ax = fig.add_subplot(111, projection="3d")
387
+ ax.plot(
388
+ gt_plot[:, 0],
389
+ gt_plot[:, 1],
390
+ gt_plot[:, 2],
391
+ label="gt",
392
+ linewidth=2.0,
393
+ color="#1f77b4",
394
+ )
395
+ ax.plot(
396
+ pred_plot[:, 0],
397
+ pred_plot[:, 1],
398
+ pred_plot[:, 2],
399
+ label="pred",
400
+ linewidth=2.0,
401
+ color="#d62728",
402
+ )
403
+ _set_equal_3d_axes(ax, all_plot)
404
+ ax.view_init(elev=24, azim=-118)
405
+ ax.set_xlabel("x_right")
406
+ ax.set_ylabel("z_forward")
407
+ ax.set_zlabel("y_up")
408
+ ax.legend(loc="best")
409
+ ax.set_title("Trajectory 3D (Sim3-aligned view)")
410
+ fig.tight_layout()
411
+ fig.savefig(path, dpi=180)
412
+ plt.close(fig)
413
+
414
+
415
+ def evaluate_sequence(seq_info, output_root, eval_cfg, data_cfg):
416
+ seq_dir = _sequence_output_dir(output_root, seq_info.name)
417
+ result = {
418
+ "sequence": seq_info.name,
419
+ "output_dir": seq_dir,
420
+ "has_gt": False,
421
+ "has_gt_pose": False,
422
+ "has_gt_depth": False,
423
+ }
424
+
425
+ gt_extri, gt_intri, _ = _load_gt_pose_data(seq_info)
426
+ pose_align = None
427
+ if gt_extri:
428
+ result["has_gt"] = True
429
+ result["has_gt_pose"] = True
430
+
431
+ pred_pose_path = os.path.join(seq_dir, "poses", "abs_pose.txt")
432
+ pairs = _extract_pose_pairs(seq_info, pred_pose_path, gt_extri)
433
+ if pairs is not None:
434
+ pred_xyz, gt_xyz = pairs
435
+ pose_metrics = ate_rmse(
436
+ pred_xyz, gt_xyz, align_scale=bool(eval_cfg.get("align_scale", True))
437
+ )
438
+ sim3_scale = float(pose_metrics.get("sim3_scale", 1.0))
439
+ pred_xyz_aligned = transform_points(
440
+ pred_xyz,
441
+ sim3_scale,
442
+ np.asarray(pose_metrics["sim3_rotation"], dtype=np.float64),
443
+ np.asarray(pose_metrics["sim3_translation"], dtype=np.float64),
444
+ )
445
+ pose_align = (
446
+ sim3_scale,
447
+ np.asarray(pose_metrics["sim3_rotation"], dtype=np.float64),
448
+ np.asarray(pose_metrics["sim3_translation"], dtype=np.float64),
449
+ )
450
+ plot_path = _sequence_plot_path(output_root, seq_info.name)
451
+ _save_traj_plot_3d(plot_path, pred_xyz_aligned, gt_xyz)
452
+ pose_metrics.pop("sim3_scale", None)
453
+ pose_metrics["traj_3d_plot"] = plot_path
454
+ result["pose"] = pose_metrics
455
+
456
+ video_dpt_metrics = _evaluate_video_dpt(seq_info, seq_dir, eval_cfg, data_cfg)
457
+ if video_dpt_metrics is not None:
458
+ result["has_gt"] = True
459
+ result["has_gt_depth"] = True
460
+ result["video_dpt"] = video_dpt_metrics
461
+
462
+ gt_cloud = _load_gt_pointcloud(seq_info, seq_dir, gt_extri, gt_intri, eval_cfg)
463
+ pointcloud_metrics = _evaluate_pointclouds(
464
+ seq_info, seq_dir, eval_cfg, pose_align, gt_cloud
465
+ )
466
+ if pointcloud_metrics is not None:
467
+ result["has_gt"] = True
468
+ result["has_gt_depth"] = True
469
+ result["pointcloud"] = pointcloud_metrics
470
+
471
+ if not result["has_gt"]:
472
+ result["skipped"] = "missing_gt"
473
+
474
+ return result
475
+
476
+
477
+ def _mean_metric(sequence_results, group_name, metric_name):
478
+ values = []
479
+ for item in sequence_results:
480
+ group = item
481
+ for key in group_name.split("."):
482
+ if not isinstance(group, dict):
483
+ group = None
484
+ break
485
+ group = group.get(key)
486
+ if not isinstance(group, dict):
487
+ continue
488
+ if metric_name in group:
489
+ values.append(float(group[metric_name]))
490
+ if not values:
491
+ return None
492
+ return float(np.mean(values))
493
+
494
+
495
+ def evaluate_predictions_cfg(cfg):
496
+ data_cfg = dict(cfg.get("data", {}))
497
+ data_cfg["format"] = "generalizable"
498
+ output_cfg = cfg.get("output", {})
499
+ eval_cfg = cfg.get("evaluation", {})
500
+ output_root = output_cfg.get("root", "outputs")
501
+ _ensure_dir(output_root)
502
+
503
+ loader = LongStreamDataLoader(data_cfg)
504
+ sequence_results = []
505
+ for seq_info in loader.iter_sequence_infos():
506
+ print(f"[longstream] eval {seq_info.name}: start", flush=True)
507
+ metrics = evaluate_sequence(seq_info, output_root, eval_cfg, data_cfg)
508
+ sequence_results.append(metrics)
509
+ metrics_path = _sequence_metrics_path(output_root, seq_info.name)
510
+ _ensure_dir(os.path.dirname(metrics_path))
511
+ with open(metrics_path, "w") as f:
512
+ json.dump(metrics, f, indent=2)
513
+ print(f"[longstream] eval {seq_info.name}: wrote {metrics_path}", flush=True)
514
+
515
+ summary = {
516
+ "num_sequences": len(sequence_results),
517
+ "num_sequences_with_gt": sum(1 for x in sequence_results if x.get("has_gt")),
518
+ "num_sequences_with_pose_gt": sum(
519
+ 1 for x in sequence_results if x.get("has_gt_pose")
520
+ ),
521
+ "num_sequences_with_depth_gt": sum(
522
+ 1 for x in sequence_results if x.get("has_gt_depth")
523
+ ),
524
+ "ate_mean": _mean_metric(sequence_results, "pose", "ate_mean"),
525
+ "ate_rmse_mean": _mean_metric(sequence_results, "pose", "ate_rmse"),
526
+ "video_dpt_abs_rel_mean": _mean_metric(
527
+ sequence_results, "video_dpt", "abs_rel"
528
+ ),
529
+ "video_dpt_rel_delta_mean": _mean_metric(
530
+ sequence_results, "video_dpt", "rel_delta"
531
+ ),
532
+ "point_head_cd_mean": _mean_metric(
533
+ sequence_results, "pointcloud.point_head", "cd"
534
+ ),
535
+ "point_head_f1_mean": _mean_metric(
536
+ sequence_results, "pointcloud.point_head", "f1"
537
+ ),
538
+ "dpt_unproj_cd_mean": _mean_metric(
539
+ sequence_results, "pointcloud.dpt_unproj", "cd"
540
+ ),
541
+ "dpt_unproj_f1_mean": _mean_metric(
542
+ sequence_results, "pointcloud.dpt_unproj", "f1"
543
+ ),
544
+ "sequences": sequence_results,
545
+ }
546
+
547
+ summary_path = os.path.join(output_root, "summary.json")
548
+ with open(summary_path, "w") as f:
549
+ json.dump(summary, f, indent=2)
550
+ print(f"[longstream] eval: wrote {summary_path}", flush=True)
551
+ return summary
longstream/eval/io.py ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ os.environ.setdefault("OPENCV_IO_ENABLE_OPENEXR", "1")
4
+
5
+ import cv2
6
+ import numpy as np
7
+
8
+
9
+ def frame_stems(image_paths):
10
+ stems = [os.path.splitext(os.path.basename(p))[0] for p in image_paths]
11
+ if len(set(stems)) == len(stems):
12
+ return stems
13
+ parents = [os.path.basename(os.path.dirname(p)) for p in image_paths]
14
+ if len(set(parents)) == len(parents):
15
+ return parents
16
+ return stems
17
+
18
+
19
+ def read_pred_w2c_txt(path):
20
+ frames = []
21
+ poses = []
22
+ if not os.path.exists(path):
23
+ return frames, poses
24
+ with open(path, "r") as f:
25
+ for line in f:
26
+ line = line.strip()
27
+ if not line or line.startswith("#"):
28
+ continue
29
+ vals = [float(x) for x in line.split()]
30
+ if len(vals) != 13:
31
+ continue
32
+ frame = int(vals[0])
33
+ mat = np.eye(4, dtype=np.float64)
34
+ mat[:3, :3] = np.asarray(vals[1:10], dtype=np.float64).reshape(3, 3)
35
+ mat[:3, 3] = np.asarray(vals[10:13], dtype=np.float64)
36
+ frames.append(frame)
37
+ poses.append(mat)
38
+ return frames, poses
39
+
40
+
41
+ def read_opencv_camera_yml(extri_path, intri_path=None):
42
+ if not os.path.exists(extri_path):
43
+ return {}, {}, {}
44
+
45
+ fs_extri = cv2.FileStorage(extri_path, cv2.FILE_STORAGE_READ)
46
+ names_node = fs_extri.getNode("names")
47
+ names = []
48
+ for i in range(names_node.size()):
49
+ names.append(names_node.at(i).string())
50
+
51
+ extri = {}
52
+ for name in names:
53
+ rot = fs_extri.getNode(f"Rot_{name}").mat()
54
+ t = fs_extri.getNode(f"T_{name}").mat()
55
+ if rot is None or t is None:
56
+ continue
57
+ mat = np.eye(4, dtype=np.float64)
58
+ mat[:3, :3] = np.asarray(rot, dtype=np.float64)
59
+ mat[:3, 3] = np.asarray(t, dtype=np.float64).reshape(3)
60
+ extri[name] = mat
61
+ fs_extri.release()
62
+
63
+ intri = {}
64
+ image_sizes = {}
65
+ if intri_path is not None and os.path.exists(intri_path):
66
+ fs_intri = cv2.FileStorage(intri_path, cv2.FILE_STORAGE_READ)
67
+ for name in names:
68
+ K = fs_intri.getNode(f"K_{name}").mat()
69
+ if K is None:
70
+ continue
71
+ intri[name] = np.asarray(K, dtype=np.float64)
72
+ h_node = fs_intri.getNode(f"H_{name}")
73
+ w_node = fs_intri.getNode(f"W_{name}")
74
+ if not h_node.empty() and not w_node.empty():
75
+ image_sizes[name] = (int(h_node.real()), int(w_node.real()))
76
+ fs_intri.release()
77
+
78
+ return extri, intri, image_sizes
79
+
80
+
81
+ def read_depth(path):
82
+ depth = cv2.imread(path, cv2.IMREAD_ANYDEPTH)
83
+ if depth is None:
84
+ raise FileNotFoundError(path)
85
+ return depth.astype(np.float32)
86
+
87
+
88
+ def read_ply_xyz(path):
89
+ if not os.path.exists(path):
90
+ raise FileNotFoundError(path)
91
+
92
+ header = []
93
+ with open(path, "rb") as f:
94
+ while True:
95
+ line = f.readline()
96
+ if not line:
97
+ raise ValueError(f"Invalid PLY header: {path}")
98
+ text = line.decode("ascii").strip()
99
+ header.append(text)
100
+ if text == "end_header":
101
+ break
102
+
103
+ if "format binary_little_endian 1.0" not in header:
104
+ raise ValueError(f"Unsupported PLY format: {path}")
105
+
106
+ vertex_count = None
107
+ property_specs = []
108
+ in_vertex_block = False
109
+ for line in header:
110
+ if line.startswith("element vertex "):
111
+ vertex_count = int(line.split()[-1])
112
+ in_vertex_block = True
113
+ continue
114
+ if line.startswith("element ") and not line.startswith("element vertex "):
115
+ in_vertex_block = False
116
+ if in_vertex_block and line.startswith("property "):
117
+ _, dtype_name, prop_name = line.split()
118
+ property_specs.append((dtype_name, prop_name))
119
+
120
+ if vertex_count is None:
121
+ raise ValueError(f"Missing vertex count in PLY: {path}")
122
+
123
+ dtype_map = {
124
+ "float": "<f4",
125
+ "float32": "<f4",
126
+ "uchar": "u1",
127
+ "uint8": "u1",
128
+ }
129
+ vertex_dtype = []
130
+ for dtype_name, prop_name in property_specs:
131
+ if dtype_name not in dtype_map:
132
+ raise ValueError(
133
+ f"Unsupported PLY property type {dtype_name} in {path}"
134
+ )
135
+ vertex_dtype.append((prop_name, dtype_map[dtype_name]))
136
+
137
+ data = np.fromfile(f, dtype=np.dtype(vertex_dtype), count=vertex_count)
138
+ return np.stack([data["x"], data["y"], data["z"]], axis=1).astype(
139
+ np.float32, copy=False
140
+ )
141
+
142
+
143
+ def read_pointcloud_xyz(path):
144
+ ext = os.path.splitext(path)[1].lower()
145
+ if ext == ".npy":
146
+ data = np.load(path)
147
+ return np.asarray(data, dtype=np.float32).reshape(-1, 3)
148
+ if ext == ".npz":
149
+ data = np.load(path)
150
+ if "points" in data:
151
+ points = data["points"]
152
+ else:
153
+ first_key = next(iter(data.files))
154
+ points = data[first_key]
155
+ return np.asarray(points, dtype=np.float32).reshape(-1, 3)
156
+ return read_ply_xyz(path)
longstream/eval/metrics.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from scipy.spatial import cKDTree
3
+
4
+
5
+ def similarity_align(src, dst, with_scale=True):
6
+ src = np.asarray(src, dtype=np.float64)
7
+ dst = np.asarray(dst, dtype=np.float64)
8
+ if src.shape != dst.shape or src.ndim != 2 or src.shape[1] != 3:
9
+ raise ValueError("Expected Nx3 source and target point sets")
10
+ if len(src) < 3:
11
+ return 1.0, np.eye(3), np.zeros(3)
12
+
13
+ src_mean = src.mean(axis=0)
14
+ dst_mean = dst.mean(axis=0)
15
+ src_centered = src - src_mean
16
+ dst_centered = dst - dst_mean
17
+
18
+ cov = (dst_centered.T @ src_centered) / len(src)
19
+ U, D, Vt = np.linalg.svd(cov)
20
+ S = np.eye(3)
21
+ if np.linalg.det(U @ Vt) < 0:
22
+ S[-1, -1] = -1.0
23
+ R = U @ S @ Vt
24
+
25
+ if with_scale:
26
+ var = np.mean(np.sum(src_centered ** 2, axis=1))
27
+ scale = float(np.trace(np.diag(D) @ S) / max(var, 1e-12))
28
+ else:
29
+ scale = 1.0
30
+ t = dst_mean - scale * (R @ src_mean)
31
+ return scale, R, t
32
+
33
+
34
+ def transform_points(points, scale, R, t):
35
+ return (scale * (R @ points.T)).T + t[None]
36
+
37
+
38
+ def ate_rmse(pred_xyz, gt_xyz, align_scale=True):
39
+ scale, R, t = similarity_align(pred_xyz, gt_xyz, with_scale=align_scale)
40
+ pred_aligned = transform_points(pred_xyz, scale, R, t)
41
+ err = np.linalg.norm(pred_aligned - gt_xyz, axis=1)
42
+ return {
43
+ "ate_rmse": float(np.sqrt(np.mean(err ** 2))),
44
+ "ate_mean": float(np.mean(err)),
45
+ "ate_median": float(np.median(err)),
46
+ "num_pose_pairs": int(len(err)),
47
+ "align_scale": bool(align_scale),
48
+ "sim3_scale": float(scale),
49
+ "sim3_rotation": R.tolist(),
50
+ "sim3_translation": t.tolist(),
51
+ }
52
+
53
+
54
+ def _voxel_downsample(points, voxel_size):
55
+ if voxel_size is None:
56
+ return points
57
+ voxel_size = float(voxel_size)
58
+ if voxel_size <= 0 or len(points) == 0:
59
+ return points
60
+ coords = np.floor(points / voxel_size).astype(np.int64)
61
+ _, keep = np.unique(coords, axis=0, return_index=True)
62
+ keep.sort()
63
+ return points[keep]
64
+
65
+
66
+ def _sample_points(points, max_points, seed):
67
+ if max_points is None or len(points) <= int(max_points):
68
+ return points
69
+ rng = np.random.default_rng(seed)
70
+ keep = rng.choice(len(points), size=int(max_points), replace=False)
71
+ return points[keep]
72
+
73
+
74
+ def prepare_pointcloud(points, max_points=None, voxel_size=None, seed=0):
75
+ points = np.asarray(points, dtype=np.float64).reshape(-1, 3)
76
+ if len(points) == 0:
77
+ return points
78
+ valid = np.isfinite(points).all(axis=1)
79
+ points = points[valid]
80
+ points = _voxel_downsample(points, voxel_size)
81
+ points = _sample_points(points, max_points, seed)
82
+ return points
83
+
84
+
85
+ def chamfer_and_f1(
86
+ pred_points, gt_points, threshold=0.25, max_points=None, voxel_size=None, seed=0
87
+ ):
88
+ pred = prepare_pointcloud(
89
+ pred_points, max_points=max_points, voxel_size=voxel_size, seed=seed
90
+ )
91
+ gt = prepare_pointcloud(
92
+ gt_points, max_points=max_points, voxel_size=voxel_size, seed=seed + 1
93
+ )
94
+ if len(pred) == 0 or len(gt) == 0:
95
+ return None
96
+
97
+ pred_tree = cKDTree(pred)
98
+ gt_tree = cKDTree(gt)
99
+ dist_pred_to_gt, _ = gt_tree.query(pred, k=1)
100
+ dist_gt_to_pred, _ = pred_tree.query(gt, k=1)
101
+
102
+ acc = float(np.mean(dist_pred_to_gt))
103
+ comp = float(np.mean(dist_gt_to_pred))
104
+ precision = float(np.mean(dist_pred_to_gt < threshold))
105
+ recall = float(np.mean(dist_gt_to_pred < threshold))
106
+ denom = precision + recall
107
+ f1 = 0.0 if denom <= 0 else float(2.0 * precision * recall / denom)
108
+ return {
109
+ "cd": float(acc + comp),
110
+ "acc": acc,
111
+ "comp": comp,
112
+ "f1": f1,
113
+ "f1_threshold": float(threshold),
114
+ "num_pred_points": int(len(pred)),
115
+ "num_gt_points": int(len(gt)),
116
+ }
longstream/io/__init__.py ADDED
File without changes
longstream/io/save_images.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import subprocess
3
+ from typing import List
4
+
5
+ import numpy as np
6
+ from PIL import Image
7
+
8
+
9
+ def save_image_sequence(
10
+ path, images: List[np.ndarray], prefix: str = "frame", ext: str = "png"
11
+ ):
12
+ os.makedirs(path, exist_ok=True)
13
+ for i, img in enumerate(images):
14
+ out_path = os.path.join(path, f"{prefix}_{i:06d}.{ext}")
15
+ Image.fromarray(img).save(out_path)
16
+
17
+
18
+ def save_video(output_path, pattern, fps=30):
19
+ os.makedirs(os.path.dirname(output_path), exist_ok=True)
20
+ cmd = [
21
+ "ffmpeg",
22
+ "-hide_banner",
23
+ "-loglevel",
24
+ "error",
25
+ "-y",
26
+ "-framerate",
27
+ str(fps),
28
+ "-pattern_type",
29
+ "glob",
30
+ "-i",
31
+ pattern,
32
+ "-c:v",
33
+ "libx264",
34
+ "-pix_fmt",
35
+ "yuv420p",
36
+ output_path,
37
+ ]
38
+ subprocess.run(cmd, check=True)
longstream/io/save_points.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+
4
+
5
+ def _maybe_downsample(points, colors=None, max_points=None, seed=0):
6
+ pts = np.asarray(points).reshape(-1, 3)
7
+ cols = None if colors is None else np.asarray(colors).reshape(-1, 3)
8
+ if max_points is None or pts.shape[0] <= int(max_points):
9
+ return pts, cols
10
+ rng = np.random.default_rng(seed)
11
+ keep = rng.choice(pts.shape[0], size=int(max_points), replace=False)
12
+ pts = pts[keep]
13
+ if cols is not None:
14
+ cols = cols[keep]
15
+ return pts, cols
16
+
17
+
18
+ def save_pointcloud(path, points, colors=None, max_points=None, seed=0):
19
+ os.makedirs(os.path.dirname(path), exist_ok=True)
20
+ pts, cols = _maybe_downsample(
21
+ points, colors=colors, max_points=max_points, seed=seed
22
+ )
23
+ pts = pts.astype(np.float32, copy=False)
24
+ if colors is not None:
25
+ if cols.max() <= 1.0:
26
+ cols = (cols * 255.0).astype(np.uint8)
27
+ else:
28
+ cols = cols.astype(np.uint8)
29
+ has_color = True
30
+ else:
31
+ cols = None
32
+ has_color = False
33
+
34
+ with open(path, "wb") as f:
35
+ f.write(b"ply\n")
36
+ f.write(b"format binary_little_endian 1.0\n")
37
+ f.write(f"element vertex {pts.shape[0]}\n".encode("ascii"))
38
+ f.write(b"property float x\n")
39
+ f.write(b"property float y\n")
40
+ f.write(b"property float z\n")
41
+ if has_color:
42
+ f.write(b"property uchar red\n")
43
+ f.write(b"property uchar green\n")
44
+ f.write(b"property uchar blue\n")
45
+ f.write(b"end_header\n")
46
+ if has_color:
47
+ vertex_dtype = np.dtype(
48
+ [
49
+ ("x", "<f4"),
50
+ ("y", "<f4"),
51
+ ("z", "<f4"),
52
+ ("red", "u1"),
53
+ ("green", "u1"),
54
+ ("blue", "u1"),
55
+ ]
56
+ )
57
+ vertex_data = np.empty(pts.shape[0], dtype=vertex_dtype)
58
+ vertex_data["x"] = pts[:, 0]
59
+ vertex_data["y"] = pts[:, 1]
60
+ vertex_data["z"] = pts[:, 2]
61
+ vertex_data["red"] = cols[:, 0]
62
+ vertex_data["green"] = cols[:, 1]
63
+ vertex_data["blue"] = cols[:, 2]
64
+ vertex_data.tofile(f)
65
+ else:
66
+ vertex_dtype = np.dtype([("x", "<f4"), ("y", "<f4"), ("z", "<f4")])
67
+ vertex_data = np.empty(pts.shape[0], dtype=vertex_dtype)
68
+ vertex_data["x"] = pts[:, 0]
69
+ vertex_data["y"] = pts[:, 1]
70
+ vertex_data["z"] = pts[:, 2]
71
+ vertex_data.tofile(f)
longstream/io/save_poses_txt.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+
4
+
5
+ def _ensure_dir(path):
6
+ os.makedirs(os.path.dirname(path), exist_ok=True)
7
+
8
+
9
+ def save_w2c_txt(path, extri, frames):
10
+ _ensure_dir(path)
11
+ with open(path, "w") as f:
12
+ f.write("# w2c\n")
13
+ for i, frame in enumerate(frames):
14
+ mat = extri[i]
15
+ r = mat[:3, :3].reshape(-1)
16
+ t = mat[:3, 3].reshape(-1)
17
+ vals = [frame] + r.tolist() + t.tolist()
18
+ f.write(" ".join([str(v) for v in vals]) + "\n")
19
+
20
+
21
+ def save_intri_txt(path, intri, frames):
22
+ _ensure_dir(path)
23
+ with open(path, "w") as f:
24
+ f.write("# fx fy cx cy\n")
25
+ for i, frame in enumerate(frames):
26
+ k = intri[i]
27
+ fx = float(k[0, 0])
28
+ fy = float(k[1, 1])
29
+ cx = float(k[0, 2])
30
+ cy = float(k[1, 2])
31
+ f.write(f"{frame} {fx} {fy} {cx} {cy}\n")
32
+
33
+
34
+ def save_rel_pose_txt(path, rel_pose_enc, frames):
35
+ _ensure_dir(path)
36
+ arr = rel_pose_enc
37
+ if hasattr(arr, "detach"):
38
+ arr = arr.detach().cpu().numpy()
39
+ with open(path, "w") as f:
40
+ f.write("# tx ty tz qx qy qz qw fov_h fov_w\n")
41
+ for i, frame in enumerate(frames):
42
+ vals = [frame] + arr[i].tolist()
43
+ f.write(" ".join([str(v) for v in vals]) + "\n")
longstream/models/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from longstream.models.longstream import LongStream
2
+
3
+ __all__ = ["LongStream"]
longstream/models/longstream.py ADDED
@@ -0,0 +1,370 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Tuple, List, Optional, Dict
2
+ import torch
3
+ import torch.nn as nn
4
+ from huggingface_hub import PyTorchModelHubMixin
5
+
6
+ from longstream.utils.vendor.dust3r.utils.misc import freeze_all_params
7
+ from longstream.utils.vendor.models.components.aggregator.streamaggregator import (
8
+ STreamAggregator,
9
+ )
10
+ from longstream.utils.vendor.models.components.heads.camera_head import (
11
+ CameraHead,
12
+ RelPoseHead,
13
+ )
14
+ from longstream.utils.vendor.models.components.heads.dpt_head import DPTHead
15
+
16
+
17
+ class LongStream(nn.Module, PyTorchModelHubMixin):
18
+ def __init__(
19
+ self,
20
+ img_size=518,
21
+ patch_size=14,
22
+ embed_dim=1024,
23
+ freeze="none",
24
+ rel_pose_head_cfg=None,
25
+ use_role_embedding=True,
26
+ enable_scale_token=False,
27
+ scale_token_config=None,
28
+ disable_keyframe_distinction=False,
29
+ enable_camera_head=True,
30
+ use_segment_mask=False,
31
+ use_3d_rope=False,
32
+ rope_freq=100,
33
+ window_size=5000,
34
+ ):
35
+ super().__init__()
36
+
37
+ self.img_size = img_size
38
+ self.patch_size = patch_size
39
+ self.embed_dim = embed_dim
40
+ self.enable_scale_token = enable_scale_token
41
+ self.enable_camera_head = enable_camera_head
42
+ self.window_size = window_size
43
+
44
+ self.aggregator = STreamAggregator(
45
+ img_size=img_size,
46
+ patch_size=patch_size,
47
+ embed_dim=embed_dim,
48
+ use_role_embedding=use_role_embedding,
49
+ disable_keyframe_distinction=disable_keyframe_distinction,
50
+ use_segment_mask=use_segment_mask,
51
+ use_3d_rope=use_3d_rope,
52
+ rope_freq=rope_freq,
53
+ window_size=window_size,
54
+ )
55
+
56
+ if self.enable_camera_head:
57
+ self.camera_head = CameraHead(dim_in=2 * embed_dim, window_size=window_size)
58
+ else:
59
+ self.camera_head = None
60
+ self.point_head = DPTHead(
61
+ dim_in=2 * embed_dim,
62
+ output_dim=4,
63
+ activation="inv_log",
64
+ conf_activation="expp1",
65
+ )
66
+ self.depth_head = DPTHead(
67
+ dim_in=2 * embed_dim,
68
+ output_dim=2,
69
+ activation="exp",
70
+ conf_activation="expp1",
71
+ )
72
+
73
+ self.rel_pose_head = None
74
+ self.reinit_camera_head_when_rel_enabled = False
75
+
76
+ if rel_pose_head_cfg is not None:
77
+ enable = rel_pose_head_cfg.get("enabled", True)
78
+ if enable:
79
+
80
+ head_cfg = {
81
+ "dim_in": 2 * embed_dim,
82
+ "trunk_depth": rel_pose_head_cfg.get("trunk_depth", 4),
83
+ "pose_mode": rel_pose_head_cfg.get("pose_mode", "SE3"),
84
+ "num_heads": rel_pose_head_cfg.get("num_heads", 16),
85
+ "mlp_ratio": rel_pose_head_cfg.get("mlp_ratio", 4),
86
+ "init_values": rel_pose_head_cfg.get("init_values", 0.01),
87
+ "trans_act": rel_pose_head_cfg.get("trans_act", "linear"),
88
+ "quat_act": rel_pose_head_cfg.get("quat_act", "linear"),
89
+ "fl_act": rel_pose_head_cfg.get("fl_act", "relu"),
90
+ "use_global_scale": rel_pose_head_cfg.get(
91
+ "use_global_scale", False
92
+ ),
93
+ "use_pair_cross_attn": rel_pose_head_cfg.get(
94
+ "use_pair_cross_attn", False
95
+ ),
96
+ "detach_reference": rel_pose_head_cfg.get(
97
+ "detach_reference", False
98
+ ),
99
+ "xattn_temperature": rel_pose_head_cfg.get(
100
+ "xattn_temperature", 1.0
101
+ ),
102
+ "use_precat": rel_pose_head_cfg.get("use_precat", False),
103
+ "use_kf_role_embed": rel_pose_head_cfg.get(
104
+ "use_kf_role_embed", True
105
+ ),
106
+ "kf_role_embed_init_std": rel_pose_head_cfg.get(
107
+ "kf_role_embed_init_std", 0.02
108
+ ),
109
+ "window_size": window_size,
110
+ }
111
+ self.rel_pose_head = RelPoseHead(**head_cfg)
112
+
113
+ self.reinit_camera_head_when_rel_enabled = rel_pose_head_cfg.get(
114
+ "reinit_camera_head", False
115
+ )
116
+
117
+ if self.reinit_camera_head_when_rel_enabled:
118
+ pass
119
+
120
+ if self.enable_scale_token:
121
+ self._init_scale_components(scale_token_config or {})
122
+
123
+ self.set_freeze(freeze)
124
+
125
+ def reinitialize_camera_head(self):
126
+ """
127
+ Reinitialize camera_head with fresh weights.
128
+
129
+ This is useful when:
130
+ 1. Loading a pretrained checkpoint that has camera_head weights
131
+ 2. But we want to train camera_head from scratch with new settings (e.g., quaternion normalization)
132
+
133
+ This method should be called AFTER checkpoint loading.
134
+ """
135
+
136
+ old_camera_head = self.camera_head
137
+ dim_in = old_camera_head.token_norm.normalized_shape[0]
138
+
139
+ self.camera_head = CameraHead(dim_in=dim_in)
140
+
141
+ device = next(old_camera_head.parameters()).device
142
+ self.camera_head = self.camera_head.to(device)
143
+
144
+ def _init_scale_components(self, config):
145
+ self.scale_token = nn.Parameter(torch.zeros(self.embed_dim))
146
+ torch.nn.init.trunc_normal_(self.scale_token, std=0.02)
147
+
148
+ self.scale_head = nn.Sequential(
149
+ nn.Linear(2 * self.embed_dim, 256),
150
+ nn.ReLU(),
151
+ nn.Linear(256, 128),
152
+ nn.ReLU(),
153
+ nn.Linear(128, 1),
154
+ )
155
+
156
+ for m in self.scale_head.modules():
157
+ if isinstance(m, nn.Linear):
158
+ nn.init.xavier_uniform_(m.weight, gain=1.0)
159
+ if m.bias is not None:
160
+ nn.init.constant_(m.bias, 0.0)
161
+
162
+ import math
163
+
164
+ nn.init.constant_(self.scale_head[-1].bias, math.log(30.0))
165
+
166
+ def set_freeze(self, freeze):
167
+ self.freeze = freeze
168
+
169
+ to_be_frozen = {
170
+ "none": [],
171
+ "encoder": [self.aggregator.patch_embed],
172
+ }
173
+ freeze_all_params(to_be_frozen[freeze])
174
+
175
+ def forward(
176
+ self,
177
+ images: torch.Tensor,
178
+ mode: str = "causal",
179
+ aggregator_kv_cache_list: Optional[List[List[torch.Tensor]]] = None,
180
+ camera_head_kv_cache_list: Optional[List[List[List[torch.Tensor]]]] = None,
181
+ rel_pose_inputs: Optional[Dict] = None,
182
+ is_keyframe: Optional[torch.Tensor] = None,
183
+ ):
184
+
185
+ if len(images.shape) == 4:
186
+ images = images.unsqueeze(0)
187
+
188
+ batch_size = images.shape[0]
189
+
190
+ additional_tokens = None
191
+ if self.enable_scale_token:
192
+
193
+ scale_token_base = self.scale_token.unsqueeze(0).repeat(batch_size, 1)
194
+ additional_tokens = scale_token_base.unsqueeze(-1)
195
+
196
+ keyframe_indices = None
197
+ if rel_pose_inputs is not None and "keyframe_indices" in rel_pose_inputs:
198
+ keyframe_indices = rel_pose_inputs["keyframe_indices"]
199
+
200
+ if aggregator_kv_cache_list is not None:
201
+ (
202
+ aggregated_tokens_list,
203
+ patch_start_idx,
204
+ aggregator_kv_cache_list,
205
+ _,
206
+ ) = self.aggregator(
207
+ images,
208
+ mode=mode,
209
+ kv_cache_list=aggregator_kv_cache_list,
210
+ is_keyframe=is_keyframe,
211
+ keyframe_indices=keyframe_indices,
212
+ additional_tokens=additional_tokens,
213
+ reorder_keyframes_first=False,
214
+ )
215
+ else:
216
+ aggregated_tokens_list, patch_start_idx, _ = self.aggregator(
217
+ images,
218
+ mode=mode,
219
+ is_keyframe=is_keyframe,
220
+ keyframe_indices=keyframe_indices,
221
+ additional_tokens=additional_tokens,
222
+ reorder_keyframes_first=False,
223
+ )
224
+
225
+ predictions = {}
226
+
227
+ predicted_scale_factor = None
228
+ if self.enable_scale_token and additional_tokens is not None:
229
+
230
+ if len(aggregated_tokens_list) > 0:
231
+ last_layer_features = aggregated_tokens_list[-1]
232
+
233
+ scale_token_idx = patch_start_idx - 1
234
+ scale_token_output_features = last_layer_features[
235
+ :, :, scale_token_idx, :
236
+ ]
237
+
238
+ scale_token_output_features = scale_token_output_features.mean(dim=1)
239
+
240
+ scale_logits = self.scale_head(scale_token_output_features).squeeze(-1)
241
+
242
+ predicted_scale_factor = torch.exp(scale_logits)
243
+
244
+ predictions["predicted_scale_factor"] = predicted_scale_factor
245
+ predictions["scale_token_features"] = scale_token_output_features
246
+
247
+ if self.enable_camera_head and self.camera_head is not None:
248
+ if camera_head_kv_cache_list is not None:
249
+ pose_enc_list, camera_head_kv_cache_list = self.camera_head(
250
+ aggregated_tokens_list,
251
+ mode=mode,
252
+ kv_cache_list=camera_head_kv_cache_list,
253
+ )
254
+ else:
255
+ pose_enc_list = self.camera_head(aggregated_tokens_list, mode=mode)
256
+
257
+ final_pose_enc = pose_enc_list[-1]
258
+ if self.enable_scale_token and predicted_scale_factor is not None:
259
+ scale = predicted_scale_factor.view(-1, 1, 1)
260
+
261
+ scaled_t = final_pose_enc[..., :3] * scale
262
+ scaled_pose_enc = torch.cat([scaled_t, final_pose_enc[..., 3:]], dim=-1)
263
+ predictions["pose_enc"] = scaled_pose_enc
264
+ else:
265
+ predictions["pose_enc"] = final_pose_enc
266
+
267
+ if self.training:
268
+
269
+ if self.enable_scale_token and predicted_scale_factor is not None:
270
+ scale = predicted_scale_factor.view(-1, 1, 1)
271
+ scaled_pose_enc_list = []
272
+ for pose_enc in pose_enc_list:
273
+
274
+ scaled_t = pose_enc[..., :3] * scale
275
+ scaled_pose_enc = torch.cat(
276
+ [scaled_t, pose_enc[..., 3:]], dim=-1
277
+ )
278
+ scaled_pose_enc_list.append(scaled_pose_enc)
279
+ predictions["pose_enc_list"] = scaled_pose_enc_list
280
+ else:
281
+ predictions["pose_enc_list"] = pose_enc_list
282
+
283
+ if self.rel_pose_head is not None and rel_pose_inputs is not None:
284
+
285
+ rel_kwargs = dict(
286
+ aggregated_tokens_list=aggregated_tokens_list,
287
+ keyframe_indices=rel_pose_inputs.get("keyframe_indices"),
288
+ is_keyframe=rel_pose_inputs.get("is_keyframe", is_keyframe),
289
+ num_iterations=rel_pose_inputs.get("num_iterations", 4),
290
+ mode=mode,
291
+ kv_cache_list=rel_pose_inputs.get("kv_cache_list"),
292
+ )
293
+
294
+ rel_kwargs = {k: v for k, v in rel_kwargs.items() if v is not None}
295
+
296
+ rel_result = self.rel_pose_head(**rel_kwargs)
297
+
298
+ if isinstance(rel_result, dict):
299
+
300
+ pose_enc = rel_result["pose_enc"]
301
+ if pose_enc.dtype != torch.float32:
302
+ pose_enc = pose_enc.float()
303
+
304
+ if self.enable_scale_token and predicted_scale_factor is not None:
305
+ scale = predicted_scale_factor.view(-1, 1, 1)
306
+
307
+ scaled_t = pose_enc[..., :3] * scale
308
+ scaled_rel_pose_enc = torch.cat(
309
+ [scaled_t, pose_enc[..., 3:]], dim=-1
310
+ )
311
+ predictions["rel_pose_enc"] = scaled_rel_pose_enc
312
+
313
+ if "pose_enc_list" in rel_result:
314
+ scaled_pose_enc_list = []
315
+ for iter_pose in rel_result["pose_enc_list"]:
316
+ scaled_t = iter_pose[..., :3] * scale
317
+ scaled_iter_pose = torch.cat(
318
+ [scaled_t, iter_pose[..., 3:]], dim=-1
319
+ )
320
+ scaled_pose_enc_list.append(scaled_iter_pose)
321
+ predictions["rel_pose_enc_list"] = scaled_pose_enc_list
322
+ else:
323
+ predictions["rel_pose_enc"] = pose_enc
324
+
325
+ if "pose_enc_list" in rel_result:
326
+ predictions["rel_pose_enc_list"] = rel_result["pose_enc_list"]
327
+
328
+ predictions["is_keyframe"] = rel_result.get("is_keyframe")
329
+ predictions["keyframe_indices"] = rel_result.get("keyframe_indices")
330
+
331
+ if "global_scale" in rel_result:
332
+ predictions["global_scale"] = rel_result["global_scale"]
333
+
334
+ if "kv_cache_list" in rel_result:
335
+ predictions["rel_pose_kv_cache_list"] = rel_result["kv_cache_list"]
336
+
337
+ if self.point_head is not None:
338
+ pts3d, pts3d_conf = self.point_head(
339
+ aggregated_tokens_list, images=images, patch_start_idx=patch_start_idx
340
+ )
341
+
342
+ if self.enable_scale_token and predicted_scale_factor is not None:
343
+ scale = predicted_scale_factor.view(-1, 1, 1, 1, 1)
344
+ predictions["world_points"] = pts3d * scale
345
+ else:
346
+ predictions["world_points"] = pts3d
347
+ predictions["world_points_conf"] = pts3d_conf
348
+
349
+ if self.depth_head is not None:
350
+ depth, depth_conf = self.depth_head(
351
+ aggregated_tokens_list, images=images, patch_start_idx=patch_start_idx
352
+ )
353
+
354
+ if self.enable_scale_token and predicted_scale_factor is not None:
355
+ scale = predicted_scale_factor.view(-1, 1, 1, 1, 1)
356
+ predictions["depth"] = depth * scale
357
+ else:
358
+ predictions["depth"] = depth
359
+ predictions["depth_conf"] = depth_conf
360
+
361
+ if aggregator_kv_cache_list is not None:
362
+ predictions["aggregator_kv_cache_list"] = aggregator_kv_cache_list
363
+
364
+ if camera_head_kv_cache_list is not None:
365
+ predictions["camera_head_kv_cache_list"] = camera_head_kv_cache_list
366
+
367
+ if not self.training:
368
+ predictions["images"] = images
369
+
370
+ return predictions
longstream/streaming/__init__.py ADDED
File without changes
longstream/streaming/keyframe_selector.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ import torch
3
+ from typing import Optional, Tuple
4
+
5
+
6
+ class KeyframeSelector:
7
+ def __init__(
8
+ self,
9
+ min_interval: int = 8,
10
+ max_interval: int = 8,
11
+ force_first: bool = True,
12
+ motion_threshold: Optional[float] = None,
13
+ mode: str = "fixed",
14
+ ):
15
+ self.min_interval = int(min_interval)
16
+ self.max_interval = int(max_interval)
17
+ self.force_first = bool(force_first)
18
+ self.motion_threshold = motion_threshold
19
+ self.mode = mode
20
+
21
+ def select_keyframes(
22
+ self,
23
+ sequence_length: int,
24
+ batch_size: int = 1,
25
+ device: Optional[torch.device] = None,
26
+ poses: Optional[torch.Tensor] = None,
27
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
28
+ device = device or torch.device("cpu")
29
+ is_keyframe = torch.zeros(
30
+ batch_size, sequence_length, dtype=torch.bool, device=device
31
+ )
32
+ keyframe_indices = torch.zeros(
33
+ batch_size, sequence_length, dtype=torch.long, device=device
34
+ )
35
+
36
+ for b in range(batch_size):
37
+ last_keyframe_idx = 0
38
+ next_keyframe_target = None
39
+
40
+ if self.force_first or sequence_length == 1:
41
+ is_keyframe[b, 0] = True
42
+ keyframe_indices[b, 0] = 0
43
+ if self.mode == "random":
44
+ interval = random.randint(self.min_interval, self.max_interval)
45
+ next_keyframe_target = interval
46
+
47
+ for s in range(1, sequence_length):
48
+ keyframe_indices[b, s] = last_keyframe_idx
49
+ frames_since_last = s - last_keyframe_idx
50
+
51
+ if self.mode == "random" and next_keyframe_target is not None:
52
+ if s >= next_keyframe_target:
53
+ is_keyframe[b, s] = True
54
+ last_keyframe_idx = s
55
+ interval = random.randint(self.min_interval, self.max_interval)
56
+ next_keyframe_target = s + interval
57
+ elif frames_since_last >= self.max_interval:
58
+ is_keyframe[b, s] = True
59
+ last_keyframe_idx = s
60
+ if self.mode == "random":
61
+ interval = random.randint(self.min_interval, self.max_interval)
62
+ next_keyframe_target = s + interval
63
+ elif (
64
+ frames_since_last >= self.min_interval
65
+ and poses is not None
66
+ and self.motion_threshold is not None
67
+ ):
68
+ motion = torch.norm(
69
+ poses[b, s, :3] - poses[b, last_keyframe_idx, :3]
70
+ ).item()
71
+ if motion > self.motion_threshold:
72
+ is_keyframe[b, s] = True
73
+ last_keyframe_idx = s
74
+ if self.mode == "random":
75
+ interval = random.randint(
76
+ self.min_interval, self.max_interval
77
+ )
78
+ next_keyframe_target = s + interval
79
+
80
+ return is_keyframe, keyframe_indices
longstream/streaming/refresh.py ADDED
@@ -0,0 +1,217 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from typing import Dict, Any, List
3
+
4
+ from longstream.streaming.stream_session import StreamSession
5
+
6
+ _SEQUENCE_OUTPUT_KEYS = {
7
+ "pose_enc",
8
+ "rel_pose_enc",
9
+ "world_points",
10
+ "world_points_conf",
11
+ "depth",
12
+ "depth_conf",
13
+ }
14
+ _SCALAR_OUTPUT_KEYS = {
15
+ "predicted_scale_factor",
16
+ "global_scale",
17
+ }
18
+
19
+
20
+ def _refresh_intervals(refresh: int) -> int:
21
+ refresh = int(refresh)
22
+ if refresh < 2:
23
+ raise ValueError("refresh must be >= 2")
24
+ return refresh - 1
25
+
26
+
27
+ def _model_device(model) -> torch.device:
28
+ return next(model.parameters()).device
29
+
30
+
31
+ def _move_scalar_to_cpu(value: Any) -> Any:
32
+ if isinstance(value, torch.Tensor):
33
+ return value.detach().cpu()
34
+ return value
35
+
36
+
37
+ def _append_batch_output(
38
+ stitched_tensors: Dict[str, List[torch.Tensor]],
39
+ stitched_scalars: Dict[str, Any],
40
+ output: Dict[str, Any],
41
+ actual_frames: int,
42
+ slice_start: int,
43
+ ) -> None:
44
+ for key in _SEQUENCE_OUTPUT_KEYS:
45
+ value = output.get(key)
46
+ if not isinstance(value, torch.Tensor):
47
+ continue
48
+ if value.ndim < 2 or value.shape[1] != actual_frames:
49
+ continue
50
+ stitched_tensors.setdefault(key, []).append(
51
+ value[:, slice_start:].detach().cpu()
52
+ )
53
+
54
+ for key in _SCALAR_OUTPUT_KEYS:
55
+ if key in output:
56
+ stitched_scalars[key] = _move_scalar_to_cpu(output[key])
57
+
58
+
59
+ def _finalize_stitched_batches(
60
+ stitched_tensors: Dict[str, List[torch.Tensor]],
61
+ stitched_scalars: Dict[str, Any],
62
+ ) -> Dict[str, Any]:
63
+ stitched_output: Dict[str, Any] = {}
64
+ for key, chunks in stitched_tensors.items():
65
+ if not chunks:
66
+ continue
67
+ stitched_output[key] = (
68
+ chunks[0] if len(chunks) == 1 else torch.cat(chunks, dim=1)
69
+ )
70
+ stitched_output.update(stitched_scalars)
71
+ return stitched_output
72
+
73
+
74
+ def run_batch_refresh(
75
+ model,
76
+ images,
77
+ is_keyframe,
78
+ keyframe_indices,
79
+ mode: str,
80
+ keyframe_stride: int,
81
+ refresh: int,
82
+ rel_pose_cfg,
83
+ ):
84
+ B, S = images.shape[:2]
85
+ device = _model_device(model)
86
+ refresh_intervals = _refresh_intervals(refresh)
87
+ frames_per_batch = refresh_intervals * keyframe_stride + 1
88
+ step_frames = refresh_intervals * keyframe_stride
89
+
90
+ stitched_tensors: Dict[str, List[torch.Tensor]] = {}
91
+ stitched_scalars: Dict[str, Any] = {}
92
+ num_batches = (S + step_frames - 1) // step_frames
93
+ for batch_idx in range(num_batches):
94
+ start_frame = batch_idx * step_frames
95
+ end_frame = min(start_frame + frames_per_batch, S)
96
+ batch_images = images[:, start_frame:end_frame].to(device, non_blocking=True)
97
+ batch_is_keyframe = (
98
+ is_keyframe[:, start_frame:end_frame].clone()
99
+ if is_keyframe is not None
100
+ else None
101
+ )
102
+ batch_keyframe_indices = (
103
+ keyframe_indices[:, start_frame:end_frame].clone()
104
+ if keyframe_indices is not None
105
+ else None
106
+ )
107
+
108
+ if batch_idx > 0 and batch_is_keyframe is not None:
109
+ batch_is_keyframe[:, 0] = True
110
+ if batch_keyframe_indices is not None:
111
+ batch_keyframe_indices[:, 0] = start_frame
112
+
113
+ if batch_keyframe_indices is not None:
114
+ batch_keyframe_indices = batch_keyframe_indices - start_frame
115
+ batch_keyframe_indices = torch.clamp(
116
+ batch_keyframe_indices, 0, end_frame - start_frame - 1
117
+ )
118
+
119
+ batch_rel_pose_inputs = None
120
+ if rel_pose_cfg is not None and batch_is_keyframe is not None:
121
+ batch_is_keyframe = batch_is_keyframe.to(device, non_blocking=True)
122
+ if batch_keyframe_indices is not None:
123
+ batch_keyframe_indices = batch_keyframe_indices.to(
124
+ device, non_blocking=True
125
+ )
126
+ batch_rel_pose_inputs = {
127
+ "is_keyframe": batch_is_keyframe,
128
+ "keyframe_indices": batch_keyframe_indices,
129
+ "num_iterations": rel_pose_cfg.get("num_iterations", 4),
130
+ }
131
+ elif batch_is_keyframe is not None:
132
+ batch_is_keyframe = batch_is_keyframe.to(device, non_blocking=True)
133
+
134
+ batch_output = model(
135
+ images=batch_images,
136
+ mode=mode,
137
+ rel_pose_inputs=batch_rel_pose_inputs,
138
+ is_keyframe=batch_is_keyframe,
139
+ )
140
+
141
+ _append_batch_output(
142
+ stitched_tensors,
143
+ stitched_scalars,
144
+ batch_output,
145
+ actual_frames=end_frame - start_frame,
146
+ slice_start=0 if batch_idx == 0 else 1,
147
+ )
148
+ del batch_output
149
+ del batch_images
150
+ del batch_is_keyframe
151
+ del batch_keyframe_indices
152
+
153
+ return _finalize_stitched_batches(stitched_tensors, stitched_scalars)
154
+
155
+
156
+ def run_streaming_refresh(
157
+ model,
158
+ images,
159
+ is_keyframe,
160
+ keyframe_indices,
161
+ mode: str,
162
+ window_size: int,
163
+ refresh: int,
164
+ rel_pose_cfg,
165
+ ):
166
+ B, S = images.shape[:2]
167
+ device = _model_device(model)
168
+ refresh_intervals = _refresh_intervals(refresh)
169
+ session = StreamSession(model, mode=mode, window_size=window_size)
170
+ keyframe_count = 0
171
+ segment_start = 0
172
+ for s in range(S):
173
+ frame_images = images[:, s : s + 1].to(device, non_blocking=True)
174
+ is_keyframe_s = (
175
+ is_keyframe[:, s : s + 1].to(device, non_blocking=True)
176
+ if is_keyframe is not None
177
+ else None
178
+ )
179
+ if keyframe_indices is not None:
180
+ keyframe_indices_s = keyframe_indices[:, s : s + 1].clone() - segment_start
181
+ keyframe_indices_s = torch.clamp(keyframe_indices_s, min=0)
182
+ keyframe_indices_s = keyframe_indices_s.to(device, non_blocking=True)
183
+ else:
184
+ keyframe_indices_s = None
185
+ session.forward_stream(
186
+ frame_images,
187
+ is_keyframe=is_keyframe_s,
188
+ keyframe_indices=keyframe_indices_s,
189
+ record=True,
190
+ )
191
+ if is_keyframe_s is None or not bool(is_keyframe_s.item()) or s <= 0:
192
+ del frame_images
193
+ if is_keyframe_s is not None:
194
+ del is_keyframe_s
195
+ if keyframe_indices_s is not None:
196
+ del keyframe_indices_s
197
+ continue
198
+ keyframe_count += 1
199
+ if keyframe_count % refresh_intervals == 0:
200
+ session.clear_cache_only()
201
+ segment_start = s
202
+ if keyframe_indices_s is not None:
203
+ keyframe_indices_self = torch.zeros_like(keyframe_indices_s)
204
+ else:
205
+ keyframe_indices_self = None
206
+ session.forward_stream(
207
+ frame_images,
208
+ is_keyframe=is_keyframe_s,
209
+ keyframe_indices=keyframe_indices_self,
210
+ record=False,
211
+ )
212
+ del frame_images
213
+ if is_keyframe_s is not None:
214
+ del is_keyframe_s
215
+ if keyframe_indices_s is not None:
216
+ del keyframe_indices_s
217
+ return session.get_all_predictions()
longstream/streaming/stream_session.py ADDED
@@ -0,0 +1,294 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+ class StreamSession:
5
+ def __init__(
6
+ self,
7
+ model,
8
+ mode: str,
9
+ window_size: int = 5,
10
+ keep_first_frame_anchor: bool = True,
11
+ ):
12
+ self.model = model
13
+ self.core_model = getattr(model, "longstream", model)
14
+ self.mode = mode
15
+ self.window_size = window_size
16
+ self.keep_first_frame_anchor = keep_first_frame_anchor
17
+
18
+ if self.mode not in ["causal", "window"]:
19
+ raise ValueError(f"Unsupported attention mode: {self.mode}")
20
+
21
+ self.aggregator_kv_cache_depth = self.core_model.aggregator.depth
22
+ self.use_camera_head = self.core_model.camera_head is not None
23
+ if self.use_camera_head:
24
+ self.camera_head_kv_cache_depth = self.core_model.camera_head.trunk_depth
25
+ self.camera_head_iterations = 4
26
+ else:
27
+ self.camera_head_kv_cache_depth = 0
28
+ self.camera_head_iterations = 0
29
+
30
+ self.use_rel_pose_head = (
31
+ hasattr(self.core_model, "rel_pose_head")
32
+ and self.core_model.rel_pose_head is not None
33
+ )
34
+ if self.use_rel_pose_head:
35
+ self.rel_pose_head_trunk_depth = self.core_model.rel_pose_head.trunk_depth
36
+ self.rel_pose_head_iterations = 4
37
+
38
+ self.clear()
39
+
40
+ def _clear_predictions(self):
41
+ self.sequence_predictions = {}
42
+ self.scalar_predictions = {}
43
+
44
+ def _update_predictions(self, predictions):
45
+ sequence_keys = [
46
+ "pose_enc",
47
+ "rel_pose_enc",
48
+ "world_points",
49
+ "world_points_conf",
50
+ "depth",
51
+ "depth_conf",
52
+ ]
53
+ scalar_keys = ["predicted_scale_factor", "global_scale"]
54
+
55
+ for k in sequence_keys:
56
+ if k in predictions:
57
+ self.sequence_predictions.setdefault(k, []).append(
58
+ predictions[k].detach().cpu()
59
+ )
60
+
61
+ for k in scalar_keys:
62
+ if k in predictions:
63
+ value = predictions[k]
64
+ self.scalar_predictions[k] = (
65
+ value.detach().cpu() if isinstance(value, torch.Tensor) else value
66
+ )
67
+
68
+ def _clear_cache(self):
69
+ self.aggregator_kv_cache_list = [
70
+ [None, None] for _ in range(self.aggregator_kv_cache_depth)
71
+ ]
72
+ if self.use_camera_head:
73
+ self.camera_head_kv_cache_list = [
74
+ [[None, None] for _ in range(self.camera_head_kv_cache_depth)]
75
+ for _ in range(self.camera_head_iterations)
76
+ ]
77
+ else:
78
+ self.camera_head_kv_cache_list = None
79
+ if self.use_rel_pose_head:
80
+ self.rel_pose_kv_cache_list = [
81
+ [[None, None] for _ in range(self.rel_pose_head_trunk_depth)]
82
+ for _ in range(self.rel_pose_head_iterations)
83
+ ]
84
+ else:
85
+ self.rel_pose_kv_cache_list = None
86
+
87
+ def _update_cache(
88
+ self, aggregator_kv_cache_list, camera_head_kv_cache_list, frame_hw
89
+ ):
90
+ if self.mode == "causal":
91
+ self.aggregator_kv_cache_list = aggregator_kv_cache_list
92
+ if self.use_camera_head:
93
+ self.camera_head_kv_cache_list = camera_head_kv_cache_list
94
+ return
95
+
96
+ if self.mode == "window":
97
+ h, w = frame_hw
98
+ P = (
99
+ h
100
+ * w
101
+ // self.core_model.aggregator.patch_size
102
+ // self.core_model.aggregator.patch_size
103
+ + self.core_model.aggregator.patch_start_idx
104
+ )
105
+
106
+ for k in range(2):
107
+ for i in range(self.aggregator_kv_cache_depth):
108
+ cache_size = aggregator_kv_cache_list[i][k].size(2)
109
+ if self.keep_first_frame_anchor:
110
+ if cache_size <= P:
111
+ self.aggregator_kv_cache_list[i][
112
+ k
113
+ ] = aggregator_kv_cache_list[i][k].contiguous()
114
+ elif cache_size <= self.window_size * P:
115
+ self.aggregator_kv_cache_list[i][
116
+ k
117
+ ] = aggregator_kv_cache_list[i][k].contiguous()
118
+ else:
119
+ anchor = aggregator_kv_cache_list[i][k][:, :, :P]
120
+ recent_start = cache_size - (self.window_size - 1) * P
121
+ recent = aggregator_kv_cache_list[i][k][:, :, recent_start:]
122
+ self.aggregator_kv_cache_list[i][k] = torch.cat(
123
+ [anchor, recent], dim=2
124
+ ).contiguous()
125
+ else:
126
+ start_idx = max(0, cache_size - self.window_size * P)
127
+ self.aggregator_kv_cache_list[i][k] = aggregator_kv_cache_list[
128
+ i
129
+ ][k][:, :, start_idx:].contiguous()
130
+
131
+ if camera_head_kv_cache_list is not None:
132
+ for k in range(2):
133
+ for i in range(self.camera_head_iterations):
134
+ for j in range(self.camera_head_kv_cache_depth):
135
+ cache_size = camera_head_kv_cache_list[i][j][k].size(2)
136
+ if self.keep_first_frame_anchor:
137
+ if cache_size <= 1:
138
+ self.camera_head_kv_cache_list[i][j][
139
+ k
140
+ ] = camera_head_kv_cache_list[i][j][k].contiguous()
141
+ elif cache_size <= self.window_size:
142
+ self.camera_head_kv_cache_list[i][j][
143
+ k
144
+ ] = camera_head_kv_cache_list[i][j][k].contiguous()
145
+ else:
146
+ anchor = camera_head_kv_cache_list[i][j][k][
147
+ :, :, :1
148
+ ]
149
+ recent_start = cache_size - (self.window_size - 1)
150
+ recent = camera_head_kv_cache_list[i][j][k][
151
+ :, :, recent_start:
152
+ ]
153
+ self.camera_head_kv_cache_list[i][j][k] = torch.cat(
154
+ [anchor, recent], dim=2
155
+ ).contiguous()
156
+ else:
157
+ start_idx = max(0, cache_size - self.window_size)
158
+ self.camera_head_kv_cache_list[i][j][
159
+ k
160
+ ] = camera_head_kv_cache_list[i][j][k][
161
+ :, :, start_idx:
162
+ ].contiguous()
163
+ return
164
+
165
+ raise ValueError(f"Unsupported attention mode: {self.mode}")
166
+
167
+ def _get_cache(self):
168
+ return self.aggregator_kv_cache_list, self.camera_head_kv_cache_list
169
+
170
+ def get_all_predictions(self):
171
+ predictions = {}
172
+ for key, chunks in self.sequence_predictions.items():
173
+ if not chunks:
174
+ continue
175
+ predictions[key] = (
176
+ chunks[0] if len(chunks) == 1 else torch.cat(chunks, dim=1)
177
+ )
178
+ predictions.update(self.scalar_predictions)
179
+ return predictions
180
+
181
+ def get_last_prediction(self):
182
+ last_predictions = {}
183
+ keys_to_extract = [
184
+ "pose_enc",
185
+ "rel_pose_enc",
186
+ "world_points",
187
+ "world_points_conf",
188
+ "depth",
189
+ "depth_conf",
190
+ "predicted_scale_factor",
191
+ ]
192
+ for k in keys_to_extract:
193
+ if k in self.sequence_predictions and self.sequence_predictions[k]:
194
+ last_predictions[k] = self.sequence_predictions[k][-1][:, -1:]
195
+ elif k in self.scalar_predictions:
196
+ last_predictions[k] = self.scalar_predictions[k]
197
+ return last_predictions
198
+
199
+ def clear(self):
200
+ self._clear_predictions()
201
+ self._clear_cache()
202
+ if self.use_rel_pose_head:
203
+ if hasattr(self.core_model.rel_pose_head, "_keyframe_tokens_cache"):
204
+ self.core_model.rel_pose_head._keyframe_tokens_cache = {}
205
+ if hasattr(self.core_model.rel_pose_head, "_current_frame_id"):
206
+ self.core_model.rel_pose_head._current_frame_id = 0
207
+ if hasattr(self.core_model.rel_pose_head, "_frame_info"):
208
+ self.core_model.rel_pose_head._frame_info = []
209
+
210
+ def clear_cache_only(self):
211
+ self._clear_cache()
212
+ if self.use_rel_pose_head:
213
+ if hasattr(self.core_model.rel_pose_head, "_keyframe_tokens_cache"):
214
+ self.core_model.rel_pose_head._keyframe_tokens_cache = {}
215
+ if hasattr(self.core_model.rel_pose_head, "_current_frame_id"):
216
+ self.core_model.rel_pose_head._current_frame_id = 0
217
+ if hasattr(self.core_model.rel_pose_head, "_frame_info"):
218
+ self.core_model.rel_pose_head._frame_info = []
219
+
220
+ def forward_stream(
221
+ self, images, is_keyframe=None, keyframe_indices=None, record: bool = True
222
+ ):
223
+ aggregator_kv_cache_list, camera_head_kv_cache_list = self._get_cache()
224
+
225
+ rel_pose_inputs = None
226
+ if (
227
+ self.use_rel_pose_head
228
+ and is_keyframe is not None
229
+ and keyframe_indices is not None
230
+ ):
231
+ rel_pose_inputs = {
232
+ "is_keyframe": is_keyframe,
233
+ "keyframe_indices": keyframe_indices,
234
+ "kv_cache_list": self.rel_pose_kv_cache_list,
235
+ }
236
+
237
+ outputs = self.model(
238
+ images=images,
239
+ mode=self.mode,
240
+ aggregator_kv_cache_list=aggregator_kv_cache_list,
241
+ camera_head_kv_cache_list=camera_head_kv_cache_list,
242
+ rel_pose_inputs=rel_pose_inputs,
243
+ is_keyframe=is_keyframe,
244
+ )
245
+
246
+ if record:
247
+ self._update_predictions(outputs)
248
+
249
+ camera_head_kv_cache_list = outputs.get("camera_head_kv_cache_list", None)
250
+ depth_hw = (
251
+ outputs["depth"].shape[2:4] if "depth" in outputs else images.shape[-2:]
252
+ )
253
+ self._update_cache(
254
+ outputs["aggregator_kv_cache_list"], camera_head_kv_cache_list, depth_hw
255
+ )
256
+
257
+ if self.use_rel_pose_head and "rel_pose_kv_cache_list" in outputs:
258
+ rel_pose_kv_cache = outputs["rel_pose_kv_cache_list"]
259
+ if self.mode == "causal":
260
+ self.rel_pose_kv_cache_list = rel_pose_kv_cache
261
+ elif self.mode == "window":
262
+ for k in range(2):
263
+ for i in range(self.rel_pose_head_iterations):
264
+ for j in range(self.rel_pose_head_trunk_depth):
265
+ if rel_pose_kv_cache[i][j][k] is None:
266
+ continue
267
+ cache_len = rel_pose_kv_cache[i][j][k].size(2)
268
+ if self.keep_first_frame_anchor:
269
+ if cache_len <= 1:
270
+ self.rel_pose_kv_cache_list[i][j][
271
+ k
272
+ ] = rel_pose_kv_cache[i][j][k].contiguous()
273
+ elif cache_len <= self.window_size:
274
+ self.rel_pose_kv_cache_list[i][j][
275
+ k
276
+ ] = rel_pose_kv_cache[i][j][k].contiguous()
277
+ else:
278
+ anchor = rel_pose_kv_cache[i][j][k][:, :, :1]
279
+ recent_start = cache_len - (self.window_size - 1)
280
+ recent = rel_pose_kv_cache[i][j][k][
281
+ :, :, recent_start:
282
+ ]
283
+ self.rel_pose_kv_cache_list[i][j][k] = torch.cat(
284
+ [anchor, recent], dim=2
285
+ ).contiguous()
286
+ else:
287
+ start_idx = max(0, cache_len - self.window_size)
288
+ self.rel_pose_kv_cache_list[i][j][
289
+ k
290
+ ] = rel_pose_kv_cache[i][j][k][
291
+ :, :, start_idx:
292
+ ].contiguous()
293
+
294
+ return outputs
longstream/utils/__init__.py ADDED
File without changes
longstream/utils/camera.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from longstream.utils.vendor.models.components.utils.rotation import (
3
+ quat_to_mat,
4
+ mat_to_quat,
5
+ )
6
+
7
+
8
+ def compose_abs_from_rel(
9
+ rel_pose_enc: torch.Tensor, keyframe_indices: torch.Tensor
10
+ ) -> torch.Tensor:
11
+ squeeze_batch = False
12
+ if rel_pose_enc.ndim == 2:
13
+ rel_pose_enc = rel_pose_enc.unsqueeze(0)
14
+ squeeze_batch = True
15
+ if keyframe_indices.ndim == 1:
16
+ keyframe_indices = keyframe_indices.unsqueeze(0)
17
+ if rel_pose_enc.ndim != 3 or keyframe_indices.ndim != 2:
18
+ raise ValueError(
19
+ f"Expected rel_pose_enc [B,S,D] or [S,D] and keyframe_indices [B,S] or [S], "
20
+ f"got {tuple(rel_pose_enc.shape)} and {tuple(keyframe_indices.shape)}"
21
+ )
22
+
23
+ B, S, _ = rel_pose_enc.shape
24
+ device = rel_pose_enc.device
25
+ dtype = rel_pose_enc.dtype
26
+
27
+ rel_t = rel_pose_enc[..., :3]
28
+ rel_q = rel_pose_enc[..., 3:7]
29
+ rel_f = rel_pose_enc[..., 7:9]
30
+ rel_R = quat_to_mat(rel_q.reshape(-1, 4)).reshape(B, S, 3, 3)
31
+
32
+ abs_R = torch.zeros(B, S, 3, 3, device=device, dtype=dtype)
33
+ abs_t = torch.zeros(B, S, 3, device=device, dtype=dtype)
34
+ abs_f = torch.zeros(B, S, 2, device=device, dtype=dtype)
35
+
36
+ for b in range(B):
37
+ abs_R[b, 0] = rel_R[b, 0]
38
+ abs_t[b, 0] = rel_t[b, 0]
39
+ abs_f[b, 0] = rel_f[b, 0]
40
+ for s in range(1, S):
41
+ ref_idx = int(keyframe_indices[b, s].item())
42
+ abs_R[b, s] = rel_R[b, s] @ abs_R[b, ref_idx]
43
+ abs_t[b, s] = rel_t[b, s] + rel_R[b, s] @ abs_t[b, ref_idx]
44
+ abs_f[b, s] = rel_f[b, s]
45
+
46
+ abs_q = mat_to_quat(abs_R.reshape(-1, 3, 3)).reshape(B, S, 4)
47
+ abs_pose_enc = torch.cat([abs_t, abs_q, abs_f], dim=-1)
48
+ if squeeze_batch:
49
+ return abs_pose_enc[0]
50
+ return abs_pose_enc
longstream/utils/depth.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import matplotlib.cm as cm
4
+
5
+
6
+ def colorize_depth(depth: torch.Tensor, cmap: str = "plasma") -> np.ndarray:
7
+ if torch.is_tensor(depth):
8
+ depth_np = depth.detach().cpu().numpy()
9
+ else:
10
+ depth_np = depth
11
+ d_min = np.nanmin(depth_np)
12
+ d_max = np.nanmax(depth_np)
13
+ if d_max - d_min < 1e-6:
14
+ d_max = d_min + 1e-6
15
+ norm = (depth_np - d_min) / (d_max - d_min)
16
+ norm = np.clip(norm, 0.0, 1.0)
17
+ mapper = cm.get_cmap(cmap)
18
+ colored = mapper(norm)[..., :3]
19
+ return (colored * 255.0).astype(np.uint8)
20
+
21
+
22
+ def unproject_depth_to_points(depth: torch.Tensor, intri: torch.Tensor) -> torch.Tensor:
23
+ B, H, W = depth.shape
24
+ fx = intri[:, 0, 0].view(B, 1, 1)
25
+ fy = intri[:, 1, 1].view(B, 1, 1)
26
+ cx = intri[:, 0, 2].view(B, 1, 1)
27
+ cy = intri[:, 1, 2].view(B, 1, 1)
28
+
29
+ ys = torch.arange(H, device=depth.device).view(1, H, 1).float()
30
+ xs = torch.arange(W, device=depth.device).view(1, 1, W).float()
31
+
32
+ x = (xs - cx) * depth / fx
33
+ y = (ys - cy) * depth / fy
34
+ z = depth
35
+ pts = torch.stack([x, y, z], dim=-1)
36
+ return pts
longstream/utils/hub.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from dataclasses import dataclass
3
+ from typing import Optional
4
+
5
+
6
+ @dataclass
7
+ class HFSpec:
8
+ repo_id: str
9
+ filename: str
10
+ revision: Optional[str] = None
11
+ local_dir: str = "checkpoints"
12
+
13
+
14
+ def _is_nonempty_str(x) -> bool:
15
+ return isinstance(x, str) and len(x) > 0
16
+
17
+
18
+ def resolve_checkpoint_path(
19
+ checkpoint: Optional[str], hf: Optional[dict]
20
+ ) -> Optional[str]:
21
+ if _is_nonempty_str(checkpoint):
22
+ return checkpoint
23
+ if not isinstance(hf, dict):
24
+ return None
25
+
26
+ repo_id = hf.get("repo_id")
27
+ filename = hf.get("filename")
28
+ revision = hf.get("revision", None)
29
+ local_dir = hf.get("local_dir", "checkpoints")
30
+
31
+ if not _is_nonempty_str(repo_id) or not _is_nonempty_str(filename):
32
+ return None
33
+
34
+ try:
35
+ from huggingface_hub import hf_hub_download
36
+ except Exception as e:
37
+ raise RuntimeError("huggingface_hub is required for auto-download") from e
38
+
39
+ os.makedirs(local_dir, exist_ok=True)
40
+ return hf_hub_download(
41
+ repo_id=repo_id, filename=filename, revision=revision, local_dir=local_dir
42
+ )
longstream/utils/sky_mask.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import copy
3
+ import cv2
4
+ import numpy as np
5
+ import shutil
6
+ import urllib.request
7
+
8
+ try:
9
+ import onnxruntime
10
+ except Exception:
11
+ onnxruntime = None
12
+
13
+ SKYSEG_URL = "https://huggingface.co/JianyuanWang/skyseg/resolve/main/skyseg.onnx"
14
+ SKYSEG_THRESHOLD = 0.5
15
+
16
+
17
+ def run_skyseg(session, input_size, image):
18
+ temp_image = copy.deepcopy(image)
19
+ resize_image = cv2.resize(temp_image, dsize=(input_size[0], input_size[1]))
20
+ x = cv2.cvtColor(resize_image, cv2.COLOR_BGR2RGB)
21
+ x = np.array(x, dtype=np.float32)
22
+ mean = [0.485, 0.456, 0.406]
23
+ std = [0.229, 0.224, 0.225]
24
+ x = (x / 255 - mean) / std
25
+ x = x.transpose(2, 0, 1)
26
+ x = x.reshape(-1, 3, input_size[0], input_size[1]).astype("float32")
27
+ input_name = session.get_inputs()[0].name
28
+ result_map = session.run(None, {input_name: x})[0]
29
+ return result_map[0, 0]
30
+
31
+
32
+ def _normalize_skyseg_output(result_map):
33
+ result_map = np.asarray(result_map, dtype=np.float32)
34
+ if result_map.size == 0:
35
+ return result_map
36
+ finite = np.isfinite(result_map)
37
+ if not np.any(finite):
38
+ return np.zeros_like(result_map, dtype=np.float32)
39
+ result_map = np.nan_to_num(result_map, nan=0.0, posinf=1.0, neginf=0.0)
40
+ max_value = float(result_map.max())
41
+ min_value = float(result_map.min())
42
+ if min_value >= 0.0 and max_value > 1.5:
43
+ result_map = result_map / 255.0
44
+ return np.clip(result_map, 0.0, 1.0)
45
+
46
+
47
+ def sky_mask_filename(image_path):
48
+ parent = os.path.basename(os.path.dirname(image_path))
49
+ name = os.path.basename(image_path)
50
+ if parent:
51
+ return f"{parent}__{name}"
52
+ return name
53
+
54
+
55
+ def segment_sky(image_path, session, mask_filename=None):
56
+ image = cv2.imread(image_path)
57
+ if image is None:
58
+ return None
59
+ result_map = run_skyseg(session, [320, 320], image)
60
+ result_map_original = cv2.resize(result_map, (image.shape[1], image.shape[0]))
61
+ result_map_original = _normalize_skyseg_output(result_map_original)
62
+ output_mask = np.zeros(result_map_original.shape, dtype=np.uint8)
63
+ output_mask[result_map_original < SKYSEG_THRESHOLD] = 255
64
+ if mask_filename is not None:
65
+ os.makedirs(os.path.dirname(mask_filename), exist_ok=True)
66
+ cv2.imwrite(mask_filename, output_mask)
67
+ return output_mask
68
+
69
+
70
+ def compute_sky_mask(image_paths, model_path: str, target_dir: str = None):
71
+ if onnxruntime is None:
72
+ return None
73
+ if not os.path.exists(model_path):
74
+ os.makedirs(os.path.dirname(os.path.abspath(model_path)), exist_ok=True)
75
+ try:
76
+ print(f"[longstream] downloading skyseg.onnx to {model_path}", flush=True)
77
+ with urllib.request.urlopen(SKYSEG_URL) as src, open(
78
+ model_path, "wb"
79
+ ) as dst:
80
+ shutil.copyfileobj(src, dst)
81
+ except Exception as exc:
82
+ print(f"[longstream] failed to download skyseg.onnx: {exc}", flush=True)
83
+ return None
84
+ if not os.path.exists(model_path):
85
+ return None
86
+ session = onnxruntime.InferenceSession(model_path)
87
+ masks = []
88
+ for image_path in image_paths:
89
+ mask_filepath = None
90
+ if target_dir is not None:
91
+ name = sky_mask_filename(image_path)
92
+ mask_filepath = os.path.join(target_dir, name)
93
+ if os.path.exists(mask_filepath):
94
+ sky_mask = cv2.imread(mask_filepath, cv2.IMREAD_GRAYSCALE)
95
+ else:
96
+ sky_mask = segment_sky(image_path, session, mask_filepath)
97
+ else:
98
+ sky_mask = segment_sky(image_path, session, None)
99
+ masks.append(sky_mask)
100
+ return masks
longstream/utils/vendor/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+
2
+
longstream/utils/vendor/croco/LICENSE ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ CroCo, Copyright (c) 2022-present Naver Corporation, is licensed under the Creative Commons Attribution-NonCommercial-ShareAlike 4.0 license.
2
+
3
+ A summary of the CC BY-NC-SA 4.0 license is located here:
4
+ https://creativecommons.org/licenses/by-nc-sa/4.0/
5
+
6
+ The CC BY-NC-SA 4.0 license is located here:
7
+ https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode
8
+
9
+
10
+ SEE NOTICE BELOW WITH RESPECT TO THE FILE: models/pos_embed.py, models/blocks.py
11
+
12
+ ***************************
13
+
14
+ NOTICE WITH RESPECT TO THE FILE: models/pos_embed.py
15
+
16
+ This software is being redistributed in a modifiled form. The original form is available here:
17
+
18
+ https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py
19
+
20
+ This software in this file incorporates parts of the following software available here:
21
+
22
+ Transformer: https://github.com/tensorflow/models/blob/master/official/legacy/transformer/model_utils.py
23
+ available under the following license: https://github.com/tensorflow/models/blob/master/LICENSE
24
+
25
+ MoCo v3: https://github.com/facebookresearch/moco-v3
26
+ available under the following license: https://github.com/facebookresearch/moco-v3/blob/main/LICENSE
27
+
28
+ DeiT: https://github.com/facebookresearch/deit
29
+ available under the following license: https://github.com/facebookresearch/deit/blob/main/LICENSE
30
+
31
+
32
+ ORIGINAL COPYRIGHT NOTICE AND PERMISSION NOTICE AVAILABLE HERE IS REPRODUCE BELOW:
33
+
34
+ https://github.com/facebookresearch/mae/blob/main/LICENSE
35
+
36
+ Attribution-NonCommercial 4.0 International
37
+
38
+ ***************************
39
+
40
+ NOTICE WITH RESPECT TO THE FILE: models/blocks.py
41
+
42
+ This software is being redistributed in a modifiled form. The original form is available here:
43
+
44
+ https://github.com/rwightman/pytorch-image-models
45
+
46
+ ORIGINAL COPYRIGHT NOTICE AND PERMISSION NOTICE AVAILABLE HERE IS REPRODUCE BELOW:
47
+
48
+ https://github.com/rwightman/pytorch-image-models/blob/master/LICENSE
49
+
50
+ Apache License
51
+ Version 2.0, January 2004
52
+ http://www.apache.org/licenses/
longstream/utils/vendor/croco/NOTICE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ CroCo
2
+ Copyright 2022-present NAVER Corp.
3
+
4
+ This project contains subcomponents with separate copyright notices and license terms.
5
+ Your use of the source code for these subcomponents is subject to the terms and conditions of the following licenses.
6
+
7
+ ====
8
+
9
+ facebookresearch/mae
10
+ https://github.com/facebookresearch/mae
11
+
12
+ Attribution-NonCommercial 4.0 International
13
+
14
+ ====
15
+
16
+ rwightman/pytorch-image-models
17
+ https://github.com/rwightman/pytorch-image-models
18
+
19
+ Apache License
20
+ Version 2.0, January 2004
21
+ http://www.apache.org/licenses/
longstream/utils/vendor/croco/README.MD ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # CroCo + CroCo v2 / CroCo-Stereo / CroCo-Flow
2
+
3
+ [[`CroCo arXiv`](https://arxiv.org/abs/2210.10716)] [[`CroCo v2 arXiv`](https://arxiv.org/abs/2211.10408)] [[`project page and demo`](https://croco.europe.naverlabs.com/)]
4
+
5
+ This repository contains the code for our CroCo model presented in our NeurIPS'22 paper [CroCo: Self-Supervised Pre-training for 3D Vision Tasks by Cross-View Completion](https://openreview.net/pdf?id=wZEfHUM5ri) and its follow-up extension published at ICCV'23 [Improved Cross-view Completion Pre-training for Stereo Matching and Optical Flow](https://openaccess.thecvf.com/content/ICCV2023/html/Weinzaepfel_CroCo_v2_Improved_Cross-view_Completion_Pre-training_for_Stereo_Matching_and_ICCV_2023_paper.html), refered to as CroCo v2:
6
+
7
+ ![image](assets/arch.jpg)
8
+
9
+ ```bibtex
10
+ @inproceedings{croco,
11
+ title={{CroCo: Self-Supervised Pre-training for 3D Vision Tasks by Cross-View Completion}},
12
+ author={{Weinzaepfel, Philippe and Leroy, Vincent and Lucas, Thomas and Br\'egier, Romain and Cabon, Yohann and Arora, Vaibhav and Antsfeld, Leonid and Chidlovskii, Boris and Csurka, Gabriela and Revaud J\'er\^ome}},
13
+ booktitle={{NeurIPS}},
14
+ year={2022}
15
+ }
16
+
17
+ @inproceedings{croco_v2,
18
+ title={{CroCo v2: Improved Cross-view Completion Pre-training for Stereo Matching and Optical Flow}},
19
+ author={Weinzaepfel, Philippe and Lucas, Thomas and Leroy, Vincent and Cabon, Yohann and Arora, Vaibhav and Br{\'e}gier, Romain and Csurka, Gabriela and Antsfeld, Leonid and Chidlovskii, Boris and Revaud, J{\'e}r{\^o}me},
20
+ booktitle={ICCV},
21
+ year={2023}
22
+ }
23
+ ```
24
+
25
+ ## License
26
+
27
+ The code is distributed under the CC BY-NC-SA 4.0 License. See [LICENSE](LICENSE) for more information.
28
+ Some components are based on code from [MAE](https://github.com/facebookresearch/mae) released under the CC BY-NC-SA 4.0 License and [timm](https://github.com/rwightman/pytorch-image-models) released under the Apache 2.0 License.
29
+ Some components for stereo matching and optical flow are based on code from [unimatch](https://github.com/autonomousvision/unimatch) released under the MIT license.
30
+
31
+ ## Preparation
32
+
33
+ 1. Install dependencies on a machine with a NVidia GPU using e.g. conda. Note that `habitat-sim` is required only for the interactive demo and the synthetic pre-training data generation. If you don't plan to use it, you can ignore the line installing it and use a more recent python version.
34
+
35
+ ```bash
36
+ conda create -n croco python=3.7 cmake=3.14.0
37
+ conda activate croco
38
+ conda install habitat-sim headless -c conda-forge -c aihabitat
39
+ conda install pytorch torchvision -c pytorch
40
+ conda install notebook ipykernel matplotlib
41
+ conda install ipywidgets widgetsnbextension
42
+ conda install scikit-learn tqdm quaternion opencv # only for pretraining / habitat data generation
43
+
44
+ ```
45
+
46
+ 2. Compile cuda kernels for RoPE
47
+
48
+ CroCo v2 relies on RoPE positional embeddings for which you need to compile some cuda kernels.
49
+ ```bash
50
+ cd models/curope/
51
+ python setup.py build_ext --inplace
52
+ cd ../../
53
+ ```
54
+
55
+ This can be a bit long as we compile for all cuda architectures, feel free to update L9 of `models/curope/setup.py` to compile for specific architectures only.
56
+ You might also need to set the environment `CUDA_HOME` in case you use a custom cuda installation.
57
+
58
+ In case you cannot provide, we also provide a slow pytorch version, which will be automatically loaded.
59
+
60
+ 3. Download pre-trained model
61
+
62
+ We provide several pre-trained models:
63
+
64
+ | modelname | pre-training data | pos. embed. | Encoder | Decoder |
65
+ |------------------------------------------------------------------------------------------------------------------------------------|-------------------|-------------|---------|---------|
66
+ | [`CroCo.pth`](https://download.europe.naverlabs.com/ComputerVision/CroCo/CroCo.pth) | Habitat | cosine | ViT-B | Small |
67
+ | [`CroCo_V2_ViTBase_SmallDecoder.pth`](https://download.europe.naverlabs.com/ComputerVision/CroCo/CroCo_V2_ViTBase_SmallDecoder.pth) | Habitat + real | RoPE | ViT-B | Small |
68
+ | [`CroCo_V2_ViTBase_BaseDecoder.pth`](https://download.europe.naverlabs.com/ComputerVision/CroCo/CroCo_V2_ViTBase_BaseDecoder.pth) | Habitat + real | RoPE | ViT-B | Base |
69
+ | [`CroCo_V2_ViTLarge_BaseDecoder.pth`](https://download.europe.naverlabs.com/ComputerVision/CroCo/CroCo_V2_ViTLarge_BaseDecoder.pth) | Habitat + real | RoPE | ViT-L | Base |
70
+
71
+ To download a specific model, i.e., the first one (`CroCo.pth`)
72
+ ```bash
73
+ mkdir -p pretrained_models/
74
+ wget https://download.europe.naverlabs.com/ComputerVision/CroCo/CroCo.pth -P pretrained_models/
75
+ ```
76
+
77
+ ## Reconstruction example
78
+
79
+ Simply run after downloading the `CroCo_V2_ViTLarge_BaseDecoder` pretrained model (or update the corresponding line in `demo.py`)
80
+ ```bash
81
+ python demo.py
82
+ ```
83
+
84
+ ## Interactive demonstration of cross-view completion reconstruction on the Habitat simulator
85
+
86
+ First download the test scene from Habitat:
87
+ ```bash
88
+ python -m habitat_sim.utils.datasets_download --uids habitat_test_scenes --data-path habitat-sim-data/
89
+ ```
90
+
91
+ Then, run the Notebook demo `interactive_demo.ipynb`.
92
+
93
+ In this demo, you should be able to sample a random reference viewpoint from an [Habitat](https://github.com/facebookresearch/habitat-sim) test scene. Use the sliders to change viewpoint and select a masked target view to reconstruct using CroCo.
94
+ ![croco_interactive_demo](https://user-images.githubusercontent.com/1822210/200516576-7937bc6a-55f8-49ed-8618-3ddf89433ea4.jpg)
95
+
96
+ ## Pre-training
97
+
98
+ ### CroCo
99
+
100
+ To pre-train CroCo, please first generate the pre-training data from the Habitat simulator, following the instructions in [datasets/habitat_sim/README.MD](datasets/habitat_sim/README.MD) and then run the following command:
101
+ ```
102
+ torchrun --nproc_per_node=4 pretrain.py --output_dir ./output/pretraining/
103
+ ```
104
+
105
+ Our CroCo pre-training was launched on a single server with 4 GPUs.
106
+ It should take around 10 days with A100 or 15 days with V100 to do the 400 pre-training epochs, but decent performances are obtained earlier in training.
107
+ Note that, while the code contains the same scaling rule of the learning rate as MAE when changing the effective batch size, we did not experimented if it is valid in our case.
108
+ The first run can take a few minutes to start, to parse all available pre-training pairs.
109
+
110
+ ### CroCo v2
111
+
112
+ For CroCo v2 pre-training, in addition to the generation of the pre-training data from the Habitat simulator above, please pre-extract the crops from the real datasets following the instructions in [datasets/crops/README.MD](datasets/crops/README.MD).
113
+ Then, run the following command for the largest model (ViT-L encoder, Base decoder):
114
+ ```
115
+ torchrun --nproc_per_node=8 pretrain.py --model "CroCoNet(enc_embed_dim=1024, enc_depth=24, enc_num_heads=16, dec_embed_dim=768, dec_num_heads=12, dec_depth=12, pos_embed='RoPE100')" --dataset "habitat_release+ARKitScenes+MegaDepth+3DStreetView+IndoorVL" --warmup_epochs 12 --max_epoch 125 --epochs 250 --amp 0 --keep_freq 5 --output_dir ./output/pretraining_crocov2/
116
+ ```
117
+
118
+ Our CroCo v2 pre-training was launched on a single server with 8 GPUs for the largest model, and on a single server with 4 GPUs for the smaller ones, keeping a batch size of 64 per gpu in all cases.
119
+ The largest model should take around 12 days on A100.
120
+ Note that, while the code contains the same scaling rule of the learning rate as MAE when changing the effective batch size, we did not experimented if it is valid in our case.
121
+
122
+ ## Stereo matching and Optical flow downstream tasks
123
+
124
+ For CroCo-Stereo and CroCo-Flow, please refer to [stereoflow/README.MD](stereoflow/README.MD).
longstream/utils/vendor/croco/assets/arch.jpg ADDED
longstream/utils/vendor/croco/croco-stereo-flow-demo.ipynb ADDED
@@ -0,0 +1,182 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "id": "9bca0f41",
6
+ "metadata": {},
7
+ "source": []
8
+ },
9
+ {
10
+ "cell_type": "code",
11
+ "execution_count": null,
12
+ "id": "80653ef7",
13
+ "metadata": {},
14
+ "outputs": [],
15
+ "source": []
16
+ },
17
+ {
18
+ "cell_type": "markdown",
19
+ "id": "4f033862",
20
+ "metadata": {},
21
+ "source": [
22
+ "First download the model(s) of your choice by running\n",
23
+ "```\n",
24
+ "bash stereoflow/download_model.sh crocostereo.pth\n",
25
+ "bash stereoflow/download_model.sh crocoflow.pth\n",
26
+ "```"
27
+ ]
28
+ },
29
+ {
30
+ "cell_type": "code",
31
+ "execution_count": null,
32
+ "id": "1fb2e392",
33
+ "metadata": {},
34
+ "outputs": [],
35
+ "source": [
36
+ "import torch\n",
37
+ "use_gpu = torch.cuda.is_available() and torch.cuda.device_count()>0\n",
38
+ "device = torch.device('cuda:0' if use_gpu else 'cpu')\n",
39
+ "import matplotlib.pylab as plt"
40
+ ]
41
+ },
42
+ {
43
+ "cell_type": "code",
44
+ "execution_count": null,
45
+ "id": "e0e25d77",
46
+ "metadata": {},
47
+ "outputs": [],
48
+ "source": [
49
+ "from stereoflow.test import _load_model_and_criterion\n",
50
+ "from stereoflow.engine import tiled_pred\n",
51
+ "from stereoflow.datasets_stereo import img_to_tensor, vis_disparity\n",
52
+ "from stereoflow.datasets_flow import flowToColor\n",
53
+ "tile_overlap=0.7 # recommended value, higher value can be slightly better but slower"
54
+ ]
55
+ },
56
+ {
57
+ "cell_type": "markdown",
58
+ "id": "86a921f5",
59
+ "metadata": {},
60
+ "source": []
61
+ },
62
+ {
63
+ "cell_type": "code",
64
+ "execution_count": null,
65
+ "id": "64e483cb",
66
+ "metadata": {},
67
+ "outputs": [],
68
+ "source": [
69
+ "image1 = np.asarray(Image.open('<path_to_left_image>'))\n",
70
+ "image2 = np.asarray(Image.open('<path_to_right_image>'))"
71
+ ]
72
+ },
73
+ {
74
+ "cell_type": "code",
75
+ "execution_count": null,
76
+ "id": "f0d04303",
77
+ "metadata": {},
78
+ "outputs": [],
79
+ "source": [
80
+ "model, _, cropsize, with_conf, task, tile_conf_mode = _load_model_and_criterion('stereoflow_models/crocostereo.pth', None, device)\n"
81
+ ]
82
+ },
83
+ {
84
+ "cell_type": "code",
85
+ "execution_count": null,
86
+ "id": "47dc14b5",
87
+ "metadata": {},
88
+ "outputs": [],
89
+ "source": [
90
+ "im1 = img_to_tensor(image1).to(device).unsqueeze(0)\n",
91
+ "im2 = img_to_tensor(image2).to(device).unsqueeze(0)\n",
92
+ "with torch.inference_mode():\n",
93
+ " pred, _, _ = tiled_pred(model, None, im1, im2, None, conf_mode=tile_conf_mode, overlap=tile_overlap, crop=cropsize, with_conf=with_conf, return_time=False)\n",
94
+ "pred = pred.squeeze(0).squeeze(0).cpu().numpy()"
95
+ ]
96
+ },
97
+ {
98
+ "cell_type": "code",
99
+ "execution_count": null,
100
+ "id": "583b9f16",
101
+ "metadata": {},
102
+ "outputs": [],
103
+ "source": [
104
+ "plt.imshow(vis_disparity(pred))\n",
105
+ "plt.axis('off')"
106
+ ]
107
+ },
108
+ {
109
+ "cell_type": "markdown",
110
+ "id": "d2df5d70",
111
+ "metadata": {},
112
+ "source": []
113
+ },
114
+ {
115
+ "cell_type": "code",
116
+ "execution_count": null,
117
+ "id": "9ee257a7",
118
+ "metadata": {},
119
+ "outputs": [],
120
+ "source": [
121
+ "image1 = np.asarray(Image.open('<path_to_first_image>'))\n",
122
+ "image2 = np.asarray(Image.open('<path_to_second_image>'))"
123
+ ]
124
+ },
125
+ {
126
+ "cell_type": "code",
127
+ "execution_count": null,
128
+ "id": "d5edccf0",
129
+ "metadata": {},
130
+ "outputs": [],
131
+ "source": [
132
+ "model, _, cropsize, with_conf, task, tile_conf_mode = _load_model_and_criterion('stereoflow_models/crocoflow.pth', None, device)\n"
133
+ ]
134
+ },
135
+ {
136
+ "cell_type": "code",
137
+ "execution_count": null,
138
+ "id": "b19692c3",
139
+ "metadata": {},
140
+ "outputs": [],
141
+ "source": [
142
+ "im1 = img_to_tensor(image1).to(device).unsqueeze(0)\n",
143
+ "im2 = img_to_tensor(image2).to(device).unsqueeze(0)\n",
144
+ "with torch.inference_mode():\n",
145
+ " pred, _, _ = tiled_pred(model, None, im1, im2, None, conf_mode=tile_conf_mode, overlap=tile_overlap, crop=cropsize, with_conf=with_conf, return_time=False)\n",
146
+ "pred = pred.squeeze(0).permute(1,2,0).cpu().numpy()"
147
+ ]
148
+ },
149
+ {
150
+ "cell_type": "code",
151
+ "execution_count": null,
152
+ "id": "26f79db3",
153
+ "metadata": {},
154
+ "outputs": [],
155
+ "source": [
156
+ "plt.imshow(flowToColor(pred))\n",
157
+ "plt.axis('off')"
158
+ ]
159
+ }
160
+ ],
161
+ "metadata": {
162
+ "kernelspec": {
163
+ "display_name": "Python 3 (ipykernel)",
164
+ "language": "python",
165
+ "name": "python3"
166
+ },
167
+ "language_info": {
168
+ "codemirror_mode": {
169
+ "name": "ipython",
170
+ "version": 3
171
+ },
172
+ "file_extension": ".py",
173
+ "mimetype": "text/x-python",
174
+ "name": "python",
175
+ "nbconvert_exporter": "python",
176
+ "pygments_lexer": "ipython3",
177
+ "version": "3.9.7"
178
+ }
179
+ },
180
+ "nbformat": 4,
181
+ "nbformat_minor": 5
182
+ }
longstream/utils/vendor/croco/datasets/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+
2
+
longstream/utils/vendor/croco/datasets/crops/README.MD ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## Generation of crops from the real datasets
2
+
3
+ The instructions below allow to generate the crops used for pre-training CroCo v2 from the following real-world datasets: ARKitScenes, MegaDepth, 3DStreetView and IndoorVL.
4
+
5
+ ### Download the metadata of the crops to generate
6
+
7
+ First, download the metadata and put them in `./data/`:
8
+ ```
9
+ mkdir -p data
10
+ cd data/
11
+ wget https://download.europe.naverlabs.com/ComputerVision/CroCo/data/crop_metadata.zip
12
+ unzip crop_metadata.zip
13
+ rm crop_metadata.zip
14
+ cd ..
15
+ ```
16
+
17
+ ### Prepare the original datasets
18
+
19
+ Second, download the original datasets in `./data/original_datasets/`.
20
+ ```
21
+ mkdir -p data/original_datasets
22
+ ```
23
+
24
+ ##### ARKitScenes
25
+
26
+ Download the `raw` dataset from https://github.com/apple/ARKitScenes/blob/main/DATA.md and put it in `./data/original_datasets/ARKitScenes/`.
27
+ The resulting file structure should be like:
28
+ ```
29
+ ./data/original_datasets/ARKitScenes/
30
+ └───Training
31
+ └───40753679
32
+ │ │ ultrawide
33
+ │ │ ...
34
+ └───40753686
35
+
36
+ ...
37
+ ```
38
+
39
+ ##### MegaDepth
40
+
41
+ Download `MegaDepth v1 Dataset` from https://www.cs.cornell.edu/projects/megadepth/ and put it in `./data/original_datasets/MegaDepth/`.
42
+ The resulting file structure should be like:
43
+
44
+ ```
45
+ ./data/original_datasets/MegaDepth/
46
+ └───0000
47
+ │ └───images
48
+ │ │ │ 1000557903_87fa96b8a4_o.jpg
49
+ │ │ └ ...
50
+ │ └─── ...
51
+ └───0001
52
+ │ │
53
+ │ └ ...
54
+ └─── ...
55
+ ```
56
+
57
+ ##### 3DStreetView
58
+
59
+ Download `3D_Street_View` dataset from https://github.com/amir32002/3D_Street_View and put it in `./data/original_datasets/3DStreetView/`.
60
+ The resulting file structure should be like:
61
+
62
+ ```
63
+ ./data/original_datasets/3DStreetView/
64
+ └───dataset_aligned
65
+ │ └───0002
66
+ │ │ │ 0000002_0000001_0000002_0000001.jpg
67
+ │ │ └ ...
68
+ │ └─── ...
69
+ └───dataset_unaligned
70
+ │ └───0003
71
+ │ │ │ 0000003_0000001_0000002_0000001.jpg
72
+ │ │ └ ...
73
+ │ └─── ...
74
+ ```
75
+
76
+ ##### IndoorVL
77
+
78
+ Download the `IndoorVL` datasets using [Kapture](https://github.com/naver/kapture).
79
+
80
+ ```
81
+ pip install kapture
82
+ mkdir -p ./data/original_datasets/IndoorVL
83
+ cd ./data/original_datasets/IndoorVL
84
+ kapture_download_dataset.py update
85
+ kapture_download_dataset.py install "HyundaiDepartmentStore_*"
86
+ kapture_download_dataset.py install "GangnamStation_*"
87
+ cd -
88
+ ```
89
+
90
+ ### Extract the crops
91
+
92
+ Now, extract the crops for each of the dataset:
93
+ ```
94
+ for dataset in ARKitScenes MegaDepth 3DStreetView IndoorVL;
95
+ do
96
+ python3 datasets/crops/extract_crops_from_images.py --crops ./data/crop_metadata/${dataset}/crops_release.txt --root-dir ./data/original_datasets/${dataset}/ --output-dir ./data/${dataset}_crops/ --imsize 256 --nthread 8 --max-subdir-levels 5 --ideal-number-pairs-in-dir 500;
97
+ done
98
+ ```
99
+
100
+ ##### Note for IndoorVL
101
+
102
+ Due to some legal issues, we can only release 144,228 pairs out of the 1,593,689 pairs used in the paper.
103
+ To account for it in terms of number of pre-training iterations, the pre-training command in this repository uses 125 training epochs including 12 warm-up epochs and learning rate cosine schedule of 250, instead of 100, 10 and 200 respectively.
104
+ The impact on the performance is negligible.
longstream/utils/vendor/croco/datasets/crops/extract_crops_from_images.py ADDED
@@ -0,0 +1,175 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import functools
3
+ import math
4
+ import os
5
+ from multiprocessing import Pool
6
+
7
+ from PIL import Image
8
+ from tqdm import tqdm
9
+
10
+
11
+ def arg_parser():
12
+ parser = argparse.ArgumentParser(
13
+ "Generate cropped image pairs from image crop list"
14
+ )
15
+
16
+ parser.add_argument("--crops", type=str, required=True, help="crop file")
17
+ parser.add_argument("--root-dir", type=str, required=True, help="root directory")
18
+ parser.add_argument(
19
+ "--output-dir", type=str, required=True, help="output directory"
20
+ )
21
+ parser.add_argument("--imsize", type=int, default=256, help="size of the crops")
22
+ parser.add_argument(
23
+ "--nthread", type=int, required=True, help="number of simultaneous threads"
24
+ )
25
+ parser.add_argument(
26
+ "--max-subdir-levels",
27
+ type=int,
28
+ default=5,
29
+ help="maximum number of subdirectories",
30
+ )
31
+ parser.add_argument(
32
+ "--ideal-number-pairs-in-dir",
33
+ type=int,
34
+ default=500,
35
+ help="number of pairs stored in a dir",
36
+ )
37
+ return parser
38
+
39
+
40
+ def main(args):
41
+ listing_path = os.path.join(args.output_dir, "listing.txt")
42
+
43
+ print(f"Loading list of crops ... ({args.nthread} threads)")
44
+ crops, num_crops_to_generate = load_crop_file(args.crops)
45
+
46
+ print(f"Preparing jobs ({len(crops)} candidate image pairs)...")
47
+ num_levels = min(
48
+ math.ceil(math.log(num_crops_to_generate, args.ideal_number_pairs_in_dir)),
49
+ args.max_subdir_levels,
50
+ )
51
+ num_pairs_in_dir = math.ceil(num_crops_to_generate ** (1 / num_levels))
52
+
53
+ jobs = prepare_jobs(crops, num_levels, num_pairs_in_dir)
54
+ del crops
55
+
56
+ os.makedirs(args.output_dir, exist_ok=True)
57
+ mmap = Pool(args.nthread).imap_unordered if args.nthread > 1 else map
58
+ call = functools.partial(save_image_crops, args)
59
+
60
+ print(f"Generating cropped images to {args.output_dir} ...")
61
+ with open(listing_path, "w") as listing:
62
+ listing.write("# pair_path\n")
63
+ for results in tqdm(mmap(call, jobs), total=len(jobs)):
64
+ for path in results:
65
+ listing.write(f"{path}\n")
66
+ print("Finished writing listing to", listing_path)
67
+
68
+
69
+ def load_crop_file(path):
70
+ data = open(path).read().splitlines()
71
+ pairs = []
72
+ num_crops_to_generate = 0
73
+ for line in tqdm(data):
74
+ if line.startswith("#"):
75
+ continue
76
+ line = line.split(", ")
77
+ if len(line) < 8:
78
+ img1, img2, rotation = line
79
+ pairs.append((img1, img2, int(rotation), []))
80
+ else:
81
+ l1, r1, t1, b1, l2, r2, t2, b2 = map(int, line)
82
+ rect1, rect2 = (l1, t1, r1, b1), (l2, t2, r2, b2)
83
+ pairs[-1][-1].append((rect1, rect2))
84
+ num_crops_to_generate += 1
85
+ return pairs, num_crops_to_generate
86
+
87
+
88
+ def prepare_jobs(pairs, num_levels, num_pairs_in_dir):
89
+ jobs = []
90
+ powers = [num_pairs_in_dir ** level for level in reversed(range(num_levels))]
91
+
92
+ def get_path(idx):
93
+ idx_array = []
94
+ d = idx
95
+ for level in range(num_levels - 1):
96
+ idx_array.append(idx // powers[level])
97
+ idx = idx % powers[level]
98
+ idx_array.append(d)
99
+ return "/".join(map(lambda x: hex(x)[2:], idx_array))
100
+
101
+ idx = 0
102
+ for pair_data in tqdm(pairs):
103
+ img1, img2, rotation, crops = pair_data
104
+ if -60 <= rotation and rotation <= 60:
105
+ rotation = 0
106
+ paths = [get_path(idx + k) for k in range(len(crops))]
107
+ idx += len(crops)
108
+ jobs.append(((img1, img2), rotation, crops, paths))
109
+ return jobs
110
+
111
+
112
+ def load_image(path):
113
+ try:
114
+ return Image.open(path).convert("RGB")
115
+ except Exception as e:
116
+ print("skipping", path, e)
117
+ raise OSError()
118
+
119
+
120
+ def save_image_crops(args, data):
121
+
122
+ img_pair, rot, crops, paths = data
123
+ try:
124
+ img1, img2 = [
125
+ load_image(os.path.join(args.root_dir, impath)) for impath in img_pair
126
+ ]
127
+ except OSError as e:
128
+ return []
129
+
130
+ def area(sz):
131
+ return sz[0] * sz[1]
132
+
133
+ tgt_size = (args.imsize, args.imsize)
134
+
135
+ def prepare_crop(img, rect, rot=0):
136
+
137
+ img = img.crop(rect)
138
+
139
+ interp = (
140
+ Image.Resampling.LANCZOS
141
+ if area(img.size) > 4 * area(tgt_size)
142
+ else Image.Resampling.BICUBIC
143
+ )
144
+ img = img.resize(tgt_size, resample=interp)
145
+
146
+ rot90 = (round(rot / 90) % 4) * 90
147
+ if rot90 == 90:
148
+ img = img.transpose(Image.Transpose.ROTATE_90)
149
+ elif rot90 == 180:
150
+ img = img.transpose(Image.Transpose.ROTATE_180)
151
+ elif rot90 == 270:
152
+ img = img.transpose(Image.Transpose.ROTATE_270)
153
+ return img
154
+
155
+ results = []
156
+ for (rect1, rect2), path in zip(crops, paths):
157
+ crop1 = prepare_crop(img1, rect1)
158
+ crop2 = prepare_crop(img2, rect2, rot)
159
+
160
+ fullpath1 = os.path.join(args.output_dir, path + "_1.jpg")
161
+ fullpath2 = os.path.join(args.output_dir, path + "_2.jpg")
162
+ os.makedirs(os.path.dirname(fullpath1), exist_ok=True)
163
+
164
+ assert not os.path.isfile(fullpath1), fullpath1
165
+ assert not os.path.isfile(fullpath2), fullpath2
166
+ crop1.save(fullpath1)
167
+ crop2.save(fullpath2)
168
+ results.append(path)
169
+
170
+ return results
171
+
172
+
173
+ if __name__ == "__main__":
174
+ args = arg_parser().parse_args()
175
+ main(args)
longstream/utils/vendor/croco/datasets/habitat_sim/README.MD ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## Generation of synthetic image pairs using Habitat-Sim
2
+
3
+ These instructions allow to generate pre-training pairs from the Habitat simulator.
4
+ As we did not save metadata of the pairs used in the original paper, they are not strictly the same, but these data use the same setting and are equivalent.
5
+
6
+ ### Download Habitat-Sim scenes
7
+ Download Habitat-Sim scenes:
8
+ - Download links can be found here: https://github.com/facebookresearch/habitat-sim/blob/main/DATASETS.md
9
+ - We used scenes from the HM3D, habitat-test-scenes, Replica, ReplicaCad and ScanNet datasets.
10
+ - Please put the scenes under `./data/habitat-sim-data/scene_datasets/` following the structure below, or update manually paths in `paths.py`.
11
+ ```
12
+ ./data/
13
+ └──habitat-sim-data/
14
+ └──scene_datasets/
15
+ ├──hm3d/
16
+ ├──gibson/
17
+ ├──habitat-test-scenes/
18
+ ├──replica_cad_baked_lighting/
19
+ ├──replica_cad/
20
+ ├──ReplicaDataset/
21
+ └──scannet/
22
+ ```
23
+
24
+ ### Image pairs generation
25
+ We provide metadata to generate reproducible images pairs for pretraining and validation.
26
+ Experiments described in the paper used similar data, but whose generation was not reproducible at the time.
27
+
28
+ Specifications:
29
+ - 256x256 resolution images, with 60 degrees field of view .
30
+ - Up to 1000 image pairs per scene.
31
+ - Number of scenes considered/number of images pairs per dataset:
32
+ - Scannet: 1097 scenes / 985 209 pairs
33
+ - HM3D:
34
+ - hm3d/train: 800 / 800k pairs
35
+ - hm3d/val: 100 scenes / 100k pairs
36
+ - hm3d/minival: 10 scenes / 10k pairs
37
+ - habitat-test-scenes: 3 scenes / 3k pairs
38
+ - replica_cad_baked_lighting: 13 scenes / 13k pairs
39
+
40
+ - Scenes from hm3d/val and hm3d/minival pairs were not used for the pre-training but kept for validation purposes.
41
+
42
+ Download metadata and extract it:
43
+ ```bash
44
+ mkdir -p data/habitat_release_metadata/
45
+ cd data/habitat_release_metadata/
46
+ wget https://download.europe.naverlabs.com/ComputerVision/CroCo/data/habitat_release_metadata/multiview_habitat_metadata.tar.gz
47
+ tar -xvf multiview_habitat_metadata.tar.gz
48
+ cd ../..
49
+ # Location of the metadata
50
+ METADATA_DIR="./data/habitat_release_metadata/multiview_habitat_metadata"
51
+ ```
52
+
53
+ Generate image pairs from metadata:
54
+ - The following command will print a list of commandlines to generate image pairs for each scene:
55
+ ```bash
56
+ # Target output directory
57
+ PAIRS_DATASET_DIR="./data/habitat_release/"
58
+ python datasets/habitat_sim/generate_from_metadata_files.py --input_dir=$METADATA_DIR --output_dir=$PAIRS_DATASET_DIR
59
+ ```
60
+ - One can launch multiple of such commands in parallel e.g. using GNU Parallel:
61
+ ```bash
62
+ python datasets/habitat_sim/generate_from_metadata_files.py --input_dir=$METADATA_DIR --output_dir=$PAIRS_DATASET_DIR | parallel -j 16
63
+ ```
64
+
65
+ ## Metadata generation
66
+
67
+ Image pairs were randomly sampled using the following commands, whose outputs contain randomness and are thus not exactly reproducible:
68
+ ```bash
69
+ # Print commandlines to generate image pairs from the different scenes available.
70
+ PAIRS_DATASET_DIR=MY_CUSTOM_PATH
71
+ python datasets/habitat_sim/generate_multiview_images.py --list_commands --output_dir=$PAIRS_DATASET_DIR
72
+
73
+ # Once a dataset is generated, pack metadata files for reproducibility.
74
+ METADATA_DIR=MY_CUSTON_PATH
75
+ python datasets/habitat_sim/pack_metadata_files.py $PAIRS_DATASET_DIR $METADATA_DIR
76
+ ```
longstream/utils/vendor/croco/datasets/habitat_sim/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+
2
+
longstream/utils/vendor/croco/datasets/habitat_sim/generate_from_metadata.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Script to generate image pairs for a given scene reproducing poses provided in a metadata file.
3
+ """
4
+ import argparse
5
+ import json
6
+ import os
7
+
8
+ import cv2
9
+ import PIL.Image
10
+ import quaternion
11
+ from datasets.habitat_sim.multiview_habitat_sim_generator import (
12
+ MultiviewHabitatSimGenerator,
13
+ )
14
+ from datasets.habitat_sim.paths import SCENES_DATASET
15
+ from tqdm import tqdm
16
+
17
+
18
+ def generate_multiview_images_from_metadata(
19
+ metadata_filename,
20
+ output_dir,
21
+ overload_params=dict(),
22
+ scene_datasets_paths=None,
23
+ exist_ok=False,
24
+ ):
25
+ """
26
+ Generate images from a metadata file for reproducibility purposes.
27
+ """
28
+
29
+ if scene_datasets_paths is not None:
30
+ scene_datasets_paths = dict(
31
+ sorted(scene_datasets_paths.items(), key=lambda x: len(x[0]), reverse=True)
32
+ )
33
+
34
+ with open(metadata_filename, "r") as f:
35
+ input_metadata = json.load(f)
36
+ metadata = dict()
37
+ for key, value in input_metadata.items():
38
+
39
+ if key in ("scene_dataset_config_file", "scene", "navmesh") and value != "":
40
+ if scene_datasets_paths is not None:
41
+ for dataset_label, dataset_path in scene_datasets_paths.items():
42
+ if value.startswith(dataset_label):
43
+ value = os.path.normpath(
44
+ os.path.join(
45
+ dataset_path, os.path.relpath(value, dataset_label)
46
+ )
47
+ )
48
+ break
49
+ metadata[key] = value
50
+
51
+ for key, value in overload_params.items():
52
+ metadata[key] = value
53
+
54
+ generation_entries = dict(
55
+ [
56
+ (key, value)
57
+ for key, value in metadata.items()
58
+ if not (key in ("multiviews", "output_dir", "generate_depth"))
59
+ ]
60
+ )
61
+ generate_depth = metadata["generate_depth"]
62
+
63
+ os.makedirs(output_dir, exist_ok=exist_ok)
64
+
65
+ generator = MultiviewHabitatSimGenerator(**generation_entries)
66
+
67
+ for idx_label, data in tqdm(metadata["multiviews"].items()):
68
+ positions = data["positions"]
69
+ orientations = data["orientations"]
70
+ n = len(positions)
71
+ for oidx in range(n):
72
+ observation = generator.render_viewpoint(
73
+ positions[oidx], quaternion.from_float_array(orientations[oidx])
74
+ )
75
+ observation_label = f"{oidx + 1}"
76
+
77
+ img = PIL.Image.fromarray(observation["color"][:, :, :3])
78
+ filename = os.path.join(output_dir, f"{idx_label}_{observation_label}.jpeg")
79
+ img.save(filename)
80
+ if generate_depth:
81
+
82
+ filename = os.path.join(
83
+ output_dir, f"{idx_label}_{observation_label}_depth.exr"
84
+ )
85
+ cv2.imwrite(
86
+ filename,
87
+ observation["depth"],
88
+ [cv2.IMWRITE_EXR_TYPE, cv2.IMWRITE_EXR_TYPE_HALF],
89
+ )
90
+
91
+ camera_params = dict(
92
+ [
93
+ (key, observation[key].tolist())
94
+ for key in ("camera_intrinsics", "R_cam2world", "t_cam2world")
95
+ ]
96
+ )
97
+ filename = os.path.join(
98
+ output_dir, f"{idx_label}_{observation_label}_camera_params.json"
99
+ )
100
+ with open(filename, "w") as f:
101
+ json.dump(camera_params, f)
102
+
103
+ with open(os.path.join(output_dir, "metadata.json"), "w") as f:
104
+ json.dump(metadata, f)
105
+
106
+ generator.close()
107
+
108
+
109
+ if __name__ == "__main__":
110
+ parser = argparse.ArgumentParser()
111
+ parser.add_argument("--metadata_filename", required=True)
112
+ parser.add_argument("--output_dir", required=True)
113
+ args = parser.parse_args()
114
+
115
+ generate_multiview_images_from_metadata(
116
+ metadata_filename=args.metadata_filename,
117
+ output_dir=args.output_dir,
118
+ scene_datasets_paths=SCENES_DATASET,
119
+ overload_params=dict(),
120
+ exist_ok=True,
121
+ )
longstream/utils/vendor/croco/datasets/habitat_sim/generate_from_metadata_files.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Script generating commandlines to generate image pairs from metadata files.
3
+ """
4
+ import argparse
5
+ import glob
6
+ import os
7
+
8
+ from tqdm import tqdm
9
+
10
+ if __name__ == "__main__":
11
+ parser = argparse.ArgumentParser()
12
+ parser.add_argument("--input_dir", required=True)
13
+ parser.add_argument("--output_dir", required=True)
14
+ parser.add_argument(
15
+ "--prefix",
16
+ default="",
17
+ help="Commanline prefix, useful e.g. to setup environment.",
18
+ )
19
+ args = parser.parse_args()
20
+
21
+ input_metadata_filenames = glob.iglob(
22
+ f"{args.input_dir}/**/metadata.json", recursive=True
23
+ )
24
+
25
+ for metadata_filename in tqdm(input_metadata_filenames):
26
+ output_dir = os.path.join(
27
+ args.output_dir,
28
+ os.path.relpath(os.path.dirname(metadata_filename), args.input_dir),
29
+ )
30
+
31
+ if os.path.exists(os.path.join(output_dir, "metadata.json")):
32
+ continue
33
+ commandline = f"{args.prefix}python datasets/habitat_sim/generate_from_metadata.py --metadata_filename={metadata_filename} --output_dir={output_dir}"
34
+ print(commandline)