Spaces:
Running on Zero
Running on Zero
Cc commited on
Commit ·
e340a84
1
Parent(s): 4d8122b
init
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- README.md +35 -9
- app.py +5 -0
- configs/longstream_infer.yaml +84 -0
- demo_gradio.py +332 -0
- longstream/.DS_Store +0 -0
- longstream/__init__.py +1 -0
- longstream/core/__init__.py +0 -0
- longstream/core/cli.py +213 -0
- longstream/core/infer.py +451 -0
- longstream/core/model.py +69 -0
- longstream/data/__init__.py +3 -0
- longstream/data/dataloader.py +422 -0
- longstream/demo/__init__.py +11 -0
- longstream/demo/backend.py +495 -0
- longstream/demo/common.py +84 -0
- longstream/demo/export.py +85 -0
- longstream/demo/geometry.py +211 -0
- longstream/demo/viewer.py +134 -0
- longstream/eval/__init__.py +3 -0
- longstream/eval/evaluate.py +551 -0
- longstream/eval/io.py +156 -0
- longstream/eval/metrics.py +116 -0
- longstream/io/__init__.py +0 -0
- longstream/io/save_images.py +38 -0
- longstream/io/save_points.py +71 -0
- longstream/io/save_poses_txt.py +43 -0
- longstream/models/__init__.py +3 -0
- longstream/models/longstream.py +370 -0
- longstream/streaming/__init__.py +0 -0
- longstream/streaming/keyframe_selector.py +80 -0
- longstream/streaming/refresh.py +217 -0
- longstream/streaming/stream_session.py +294 -0
- longstream/utils/__init__.py +0 -0
- longstream/utils/camera.py +50 -0
- longstream/utils/depth.py +36 -0
- longstream/utils/hub.py +42 -0
- longstream/utils/sky_mask.py +100 -0
- longstream/utils/vendor/__init__.py +2 -0
- longstream/utils/vendor/croco/LICENSE +52 -0
- longstream/utils/vendor/croco/NOTICE +21 -0
- longstream/utils/vendor/croco/README.MD +124 -0
- longstream/utils/vendor/croco/assets/arch.jpg +0 -0
- longstream/utils/vendor/croco/croco-stereo-flow-demo.ipynb +182 -0
- longstream/utils/vendor/croco/datasets/__init__.py +2 -0
- longstream/utils/vendor/croco/datasets/crops/README.MD +104 -0
- longstream/utils/vendor/croco/datasets/crops/extract_crops_from_images.py +175 -0
- longstream/utils/vendor/croco/datasets/habitat_sim/README.MD +76 -0
- longstream/utils/vendor/croco/datasets/habitat_sim/__init__.py +2 -0
- longstream/utils/vendor/croco/datasets/habitat_sim/generate_from_metadata.py +121 -0
- longstream/utils/vendor/croco/datasets/habitat_sim/generate_from_metadata_files.py +34 -0
README.md
CHANGED
|
@@ -1,14 +1,40 @@
|
|
| 1 |
---
|
| 2 |
-
title: LongStream
|
| 3 |
-
emoji: 📊
|
| 4 |
-
colorFrom: red
|
| 5 |
-
colorTo: purple
|
| 6 |
sdk: gradio
|
| 7 |
-
sdk_version:
|
| 8 |
app_file: app.py
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
short_description: Demo of LongStream
|
| 12 |
---
|
| 13 |
|
| 14 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
---
|
| 2 |
+
title: LongStream Demo
|
|
|
|
|
|
|
|
|
|
| 3 |
sdk: gradio
|
| 4 |
+
sdk_version: 5.44.0
|
| 5 |
app_file: app.py
|
| 6 |
+
python_version: "3.10"
|
| 7 |
+
startup_duration_timeout: 1h
|
|
|
|
| 8 |
---
|
| 9 |
|
| 10 |
+
# LongStream Demo
|
| 11 |
+
|
| 12 |
+
This repository is the Hugging Face Space package for LongStream.
|
| 13 |
+
|
| 14 |
+
Project page: `https://3dagentworld.github.io/longstream/`
|
| 15 |
+
|
| 16 |
+
## Space Settings
|
| 17 |
+
|
| 18 |
+
Set these variables in the Space settings before the first run:
|
| 19 |
+
|
| 20 |
+
- `LONGSTREAM_HF_REPO=NicolasCC/LongStream`
|
| 21 |
+
- `LONGSTREAM_HF_FILE=50_longstream.pt`
|
| 22 |
+
- `LONGSTREAM_HF_LOCAL_DIR=checkpoints`
|
| 23 |
+
|
| 24 |
+
Optional:
|
| 25 |
+
|
| 26 |
+
- `LONGSTREAM_HF_REVISION=v0.1.0`
|
| 27 |
+
- `HF_TOKEN=<token>` if the model repo is private
|
| 28 |
+
|
| 29 |
+
## Entrypoints
|
| 30 |
+
|
| 31 |
+
- `app.py`: stable demo
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
## Included Files
|
| 35 |
+
|
| 36 |
+
- `demo_gradio.py`
|
| 37 |
+
- `demo_gradio_interactive.py`
|
| 38 |
+
- `longstream/`
|
| 39 |
+
- `configs/longstream_infer.yaml`
|
| 40 |
+
|
app.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from demo_gradio import main
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
if __name__ == "__main__":
|
| 5 |
+
main()
|
configs/longstream_infer.yaml
ADDED
|
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
device: cuda
|
| 2 |
+
|
| 3 |
+
model:
|
| 4 |
+
checkpoint: checkpoints/50_longstream.pt
|
| 5 |
+
strict_load: false
|
| 6 |
+
hf:
|
| 7 |
+
repo_id: null
|
| 8 |
+
filename: null
|
| 9 |
+
revision: null
|
| 10 |
+
local_dir: checkpoints
|
| 11 |
+
longstream_cfg:
|
| 12 |
+
img_size: 518
|
| 13 |
+
patch_size: 14
|
| 14 |
+
embed_dim: 1024
|
| 15 |
+
window_size: 48
|
| 16 |
+
use_role_embedding: false
|
| 17 |
+
enable_scale_token: true
|
| 18 |
+
disable_keyframe_distinction: true
|
| 19 |
+
use_segment_mask: false
|
| 20 |
+
enable_camera_head: false
|
| 21 |
+
freeze: none
|
| 22 |
+
use_rel_pose_head: true
|
| 23 |
+
rel_pose_head_cfg:
|
| 24 |
+
enabled: true
|
| 25 |
+
keyframe_mode: fixed
|
| 26 |
+
keyframe_stride: 8
|
| 27 |
+
reference_source: pred
|
| 28 |
+
detach_reference: false
|
| 29 |
+
trunk_depth: 4
|
| 30 |
+
pose_mode: SE3
|
| 31 |
+
num_heads: 16
|
| 32 |
+
mlp_ratio: 4
|
| 33 |
+
init_values: 0.01
|
| 34 |
+
trans_act: linear
|
| 35 |
+
quat_act: linear
|
| 36 |
+
use_pair_cross_attn: false
|
| 37 |
+
xattn_temperature: 1.0
|
| 38 |
+
use_precat: false
|
| 39 |
+
use_kf_role_embed: false
|
| 40 |
+
kf_role_embed_init_std: 0.02
|
| 41 |
+
fl_act: relu
|
| 42 |
+
use_global_scale: false
|
| 43 |
+
reinit_camera_head: false
|
| 44 |
+
|
| 45 |
+
inference:
|
| 46 |
+
mode: batch_refresh
|
| 47 |
+
streaming_mode: causal
|
| 48 |
+
window_size: 48
|
| 49 |
+
keyframe_mode: fixed
|
| 50 |
+
keyframe_stride: 8
|
| 51 |
+
refresh: 4
|
| 52 |
+
rel_pose_head_cfg:
|
| 53 |
+
num_iterations: 4
|
| 54 |
+
|
| 55 |
+
data:
|
| 56 |
+
format: generalizable
|
| 57 |
+
data_roots_file: data_roots.txt
|
| 58 |
+
camera: null
|
| 59 |
+
img_path: "path/to/your/image/directory"
|
| 60 |
+
stride: 1
|
| 61 |
+
max_frames: null
|
| 62 |
+
size: 518
|
| 63 |
+
crop: false
|
| 64 |
+
patch_size: 14
|
| 65 |
+
|
| 66 |
+
output:
|
| 67 |
+
root: outputs
|
| 68 |
+
save_videos: true
|
| 69 |
+
save_points: true
|
| 70 |
+
save_frame_points: true
|
| 71 |
+
save_depth: true
|
| 72 |
+
save_images: true
|
| 73 |
+
mask_sky: true
|
| 74 |
+
max_full_pointcloud_points: 2000000
|
| 75 |
+
max_frame_pointcloud_points: 200000
|
| 76 |
+
skyseg_path: skyseg.onnx
|
| 77 |
+
|
| 78 |
+
evaluation:
|
| 79 |
+
align_scale: true
|
| 80 |
+
depth_rel_delta_threshold: 1.25
|
| 81 |
+
point_f1_threshold: 0.25
|
| 82 |
+
point_eval_max_points: 100000
|
| 83 |
+
point_eval_voxel_size: null
|
| 84 |
+
point_eval_oversample_factor: 4
|
demo_gradio.py
ADDED
|
@@ -0,0 +1,332 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
|
| 3 |
+
import gradio as gr
|
| 4 |
+
|
| 5 |
+
from longstream.demo import BRANCH_OPTIONS, create_demo_session, load_metadata
|
| 6 |
+
from longstream.demo.backend import load_frame_previews
|
| 7 |
+
from longstream.demo.export import export_glb
|
| 8 |
+
from longstream.demo.viewer import build_interactive_figure
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
DEFAULT_KEYFRAME_STRIDE = 8
|
| 12 |
+
DEFAULT_REFRESH = 3
|
| 13 |
+
DEFAULT_WINDOW_SIZE = 48
|
| 14 |
+
DEFAULT_CHECKPOINT = os.getenv("LONGSTREAM_CHECKPOINT", "checkpoints/50_longstream.pt")
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def _run_stable_demo(
|
| 18 |
+
image_dir,
|
| 19 |
+
uploaded_files,
|
| 20 |
+
uploaded_video,
|
| 21 |
+
checkpoint,
|
| 22 |
+
device,
|
| 23 |
+
mode,
|
| 24 |
+
streaming_mode,
|
| 25 |
+
refresh,
|
| 26 |
+
window_size,
|
| 27 |
+
compute_sky,
|
| 28 |
+
branch_label,
|
| 29 |
+
show_cameras,
|
| 30 |
+
mask_sky,
|
| 31 |
+
camera_scale,
|
| 32 |
+
point_size,
|
| 33 |
+
opacity,
|
| 34 |
+
preview_max_points,
|
| 35 |
+
glb_max_points,
|
| 36 |
+
):
|
| 37 |
+
if not image_dir and not uploaded_files and not uploaded_video:
|
| 38 |
+
raise gr.Error("Provide an image folder, upload images, or upload a video.")
|
| 39 |
+
session_dir = create_demo_session(
|
| 40 |
+
image_dir=image_dir or "",
|
| 41 |
+
uploaded_files=uploaded_files,
|
| 42 |
+
uploaded_video=uploaded_video,
|
| 43 |
+
checkpoint=checkpoint,
|
| 44 |
+
device=device,
|
| 45 |
+
mode=mode,
|
| 46 |
+
streaming_mode=streaming_mode,
|
| 47 |
+
keyframe_stride=DEFAULT_KEYFRAME_STRIDE,
|
| 48 |
+
refresh=int(refresh),
|
| 49 |
+
window_size=int(window_size),
|
| 50 |
+
compute_sky=bool(compute_sky),
|
| 51 |
+
)
|
| 52 |
+
fig = build_interactive_figure(
|
| 53 |
+
session_dir=session_dir,
|
| 54 |
+
branch=branch_label,
|
| 55 |
+
display_mode="All Frames",
|
| 56 |
+
frame_index=0,
|
| 57 |
+
point_size=float(point_size),
|
| 58 |
+
opacity=float(opacity),
|
| 59 |
+
preview_max_points=int(preview_max_points),
|
| 60 |
+
show_cameras=bool(show_cameras),
|
| 61 |
+
camera_scale=float(camera_scale),
|
| 62 |
+
mask_sky=bool(mask_sky),
|
| 63 |
+
)
|
| 64 |
+
glb_path = export_glb(
|
| 65 |
+
session_dir=session_dir,
|
| 66 |
+
branch=branch_label,
|
| 67 |
+
display_mode="All Frames",
|
| 68 |
+
frame_index=0,
|
| 69 |
+
mask_sky=bool(mask_sky),
|
| 70 |
+
show_cameras=bool(show_cameras),
|
| 71 |
+
camera_scale=float(camera_scale),
|
| 72 |
+
max_points=int(glb_max_points),
|
| 73 |
+
)
|
| 74 |
+
rgb, depth, frame_label = load_frame_previews(session_dir, 0)
|
| 75 |
+
meta = load_metadata(session_dir)
|
| 76 |
+
slider = gr.update(
|
| 77 |
+
minimum=0,
|
| 78 |
+
maximum=max(meta["num_frames"] - 1, 0),
|
| 79 |
+
value=0,
|
| 80 |
+
step=1,
|
| 81 |
+
interactive=meta["num_frames"] > 1,
|
| 82 |
+
)
|
| 83 |
+
sky_msg = ""
|
| 84 |
+
if meta.get("has_sky_masks"):
|
| 85 |
+
removed = float(meta.get("sky_removed_ratio") or 0.0) * 100.0
|
| 86 |
+
sky_msg = f" | sky_removed={removed:.1f}%"
|
| 87 |
+
status = f"Ready: {meta['num_frames']} frames | branch={branch_label}{sky_msg}"
|
| 88 |
+
return (
|
| 89 |
+
fig,
|
| 90 |
+
glb_path,
|
| 91 |
+
session_dir,
|
| 92 |
+
rgb,
|
| 93 |
+
depth,
|
| 94 |
+
frame_label,
|
| 95 |
+
slider,
|
| 96 |
+
status,
|
| 97 |
+
)
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
def _update_stable_scene(
|
| 101 |
+
session_dir,
|
| 102 |
+
branch_label,
|
| 103 |
+
show_cameras,
|
| 104 |
+
mask_sky,
|
| 105 |
+
camera_scale,
|
| 106 |
+
point_size,
|
| 107 |
+
opacity,
|
| 108 |
+
preview_max_points,
|
| 109 |
+
glb_max_points,
|
| 110 |
+
):
|
| 111 |
+
if not session_dir or not os.path.isdir(session_dir):
|
| 112 |
+
return None, None, "Run reconstruction first."
|
| 113 |
+
fig = build_interactive_figure(
|
| 114 |
+
session_dir=session_dir,
|
| 115 |
+
branch=branch_label,
|
| 116 |
+
display_mode="All Frames",
|
| 117 |
+
frame_index=0,
|
| 118 |
+
point_size=float(point_size),
|
| 119 |
+
opacity=float(opacity),
|
| 120 |
+
preview_max_points=int(preview_max_points),
|
| 121 |
+
show_cameras=bool(show_cameras),
|
| 122 |
+
camera_scale=float(camera_scale),
|
| 123 |
+
mask_sky=bool(mask_sky),
|
| 124 |
+
)
|
| 125 |
+
glb_path = export_glb(
|
| 126 |
+
session_dir=session_dir,
|
| 127 |
+
branch=branch_label,
|
| 128 |
+
display_mode="All Frames",
|
| 129 |
+
frame_index=0,
|
| 130 |
+
mask_sky=bool(mask_sky),
|
| 131 |
+
show_cameras=bool(show_cameras),
|
| 132 |
+
camera_scale=float(camera_scale),
|
| 133 |
+
max_points=int(glb_max_points),
|
| 134 |
+
)
|
| 135 |
+
meta = load_metadata(session_dir)
|
| 136 |
+
sky_msg = ""
|
| 137 |
+
if meta.get("has_sky_masks"):
|
| 138 |
+
removed = float(meta.get("sky_removed_ratio") or 0.0) * 100.0
|
| 139 |
+
sky_msg = f" | sky_removed={removed:.1f}%"
|
| 140 |
+
return fig, glb_path, f"Updated preview: {branch_label}{sky_msg}"
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
def _update_frame_preview(session_dir, frame_index):
|
| 144 |
+
if not session_dir or not os.path.isdir(session_dir):
|
| 145 |
+
return None, None, ""
|
| 146 |
+
rgb, depth, label = load_frame_previews(session_dir, int(frame_index))
|
| 147 |
+
return rgb, depth, label
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
def main():
|
| 151 |
+
with gr.Blocks(title="LongStream Demo") as demo:
|
| 152 |
+
session_dir = gr.Textbox(visible=False)
|
| 153 |
+
|
| 154 |
+
gr.Markdown("# LongStream Demo")
|
| 155 |
+
|
| 156 |
+
with gr.Row():
|
| 157 |
+
image_dir = gr.Textbox(
|
| 158 |
+
label="Image Folder", placeholder="/path/to/sequence"
|
| 159 |
+
)
|
| 160 |
+
uploaded_files = gr.File(
|
| 161 |
+
label="Upload Images", file_count="multiple", file_types=["image"]
|
| 162 |
+
)
|
| 163 |
+
uploaded_video = gr.File(
|
| 164 |
+
label="Upload Video", file_count="single", file_types=["video"]
|
| 165 |
+
)
|
| 166 |
+
|
| 167 |
+
with gr.Row():
|
| 168 |
+
checkpoint = gr.Textbox(label="Checkpoint", value=DEFAULT_CHECKPOINT)
|
| 169 |
+
device = gr.Dropdown(label="Device", choices=["cuda", "cpu"], value="cuda")
|
| 170 |
+
|
| 171 |
+
with gr.Accordion("Inference", open=False):
|
| 172 |
+
with gr.Row():
|
| 173 |
+
mode = gr.Dropdown(
|
| 174 |
+
label="Mode",
|
| 175 |
+
choices=["streaming_refresh", "batch_refresh"],
|
| 176 |
+
value="batch_refresh",
|
| 177 |
+
)
|
| 178 |
+
streaming_mode = gr.Dropdown(
|
| 179 |
+
label="Streaming Mode", choices=["causal", "window"], value="causal"
|
| 180 |
+
)
|
| 181 |
+
with gr.Row():
|
| 182 |
+
refresh = gr.Slider(
|
| 183 |
+
label="Refresh", minimum=2, maximum=9, step=1, value=DEFAULT_REFRESH
|
| 184 |
+
)
|
| 185 |
+
window_size = gr.Slider(
|
| 186 |
+
label="Window Size",
|
| 187 |
+
minimum=1,
|
| 188 |
+
maximum=64,
|
| 189 |
+
step=1,
|
| 190 |
+
value=DEFAULT_WINDOW_SIZE,
|
| 191 |
+
)
|
| 192 |
+
compute_sky = gr.Checkbox(label="Compute Sky Masks", value=True)
|
| 193 |
+
|
| 194 |
+
with gr.Accordion("GLB Settings", open=True):
|
| 195 |
+
with gr.Row():
|
| 196 |
+
branch_label = gr.Dropdown(
|
| 197 |
+
label="Point Cloud Branch",
|
| 198 |
+
choices=BRANCH_OPTIONS,
|
| 199 |
+
value="Point Head + Pose",
|
| 200 |
+
)
|
| 201 |
+
show_cameras = gr.Checkbox(label="Show Cameras", value=True)
|
| 202 |
+
mask_sky = gr.Checkbox(label="Mask Sky", value=True)
|
| 203 |
+
with gr.Row():
|
| 204 |
+
point_size = gr.Slider(
|
| 205 |
+
label="Point Size",
|
| 206 |
+
minimum=0.05,
|
| 207 |
+
maximum=2.0,
|
| 208 |
+
step=0.05,
|
| 209 |
+
value=0.3,
|
| 210 |
+
)
|
| 211 |
+
opacity = gr.Slider(
|
| 212 |
+
label="Opacity",
|
| 213 |
+
minimum=0.1,
|
| 214 |
+
maximum=1.0,
|
| 215 |
+
step=0.05,
|
| 216 |
+
value=0.75,
|
| 217 |
+
)
|
| 218 |
+
preview_max_points = gr.Slider(
|
| 219 |
+
label="Preview Max Points",
|
| 220 |
+
minimum=5000,
|
| 221 |
+
maximum=1000000,
|
| 222 |
+
step=10000,
|
| 223 |
+
value=100000,
|
| 224 |
+
)
|
| 225 |
+
with gr.Row():
|
| 226 |
+
camera_scale = gr.Slider(
|
| 227 |
+
label="Camera Scale",
|
| 228 |
+
minimum=0.001,
|
| 229 |
+
maximum=0.05,
|
| 230 |
+
step=0.001,
|
| 231 |
+
value=0.01,
|
| 232 |
+
)
|
| 233 |
+
glb_max_points = gr.Slider(
|
| 234 |
+
label="GLB Max Points",
|
| 235 |
+
minimum=20000,
|
| 236 |
+
maximum=1000000,
|
| 237 |
+
step=10000,
|
| 238 |
+
value=400000,
|
| 239 |
+
)
|
| 240 |
+
|
| 241 |
+
run_btn = gr.Button("Run Stable Demo", variant="primary")
|
| 242 |
+
status = gr.Markdown("Provide input images, then run reconstruction.")
|
| 243 |
+
|
| 244 |
+
plot = gr.Plot(label="Scene Preview")
|
| 245 |
+
|
| 246 |
+
glb_file = gr.File(label="Download GLB")
|
| 247 |
+
|
| 248 |
+
with gr.Row():
|
| 249 |
+
frame_slider = gr.Slider(
|
| 250 |
+
label="Preview Frame",
|
| 251 |
+
minimum=0,
|
| 252 |
+
maximum=0,
|
| 253 |
+
step=1,
|
| 254 |
+
value=0,
|
| 255 |
+
interactive=False,
|
| 256 |
+
)
|
| 257 |
+
frame_label = gr.Textbox(label="Frame")
|
| 258 |
+
with gr.Row():
|
| 259 |
+
rgb_preview = gr.Image(label="RGB", type="numpy")
|
| 260 |
+
depth_preview = gr.Image(label="Depth Plasma", type="numpy")
|
| 261 |
+
|
| 262 |
+
run_btn.click(
|
| 263 |
+
_run_stable_demo,
|
| 264 |
+
inputs=[
|
| 265 |
+
image_dir,
|
| 266 |
+
uploaded_files,
|
| 267 |
+
uploaded_video,
|
| 268 |
+
checkpoint,
|
| 269 |
+
device,
|
| 270 |
+
mode,
|
| 271 |
+
streaming_mode,
|
| 272 |
+
refresh,
|
| 273 |
+
window_size,
|
| 274 |
+
compute_sky,
|
| 275 |
+
branch_label,
|
| 276 |
+
show_cameras,
|
| 277 |
+
mask_sky,
|
| 278 |
+
camera_scale,
|
| 279 |
+
point_size,
|
| 280 |
+
opacity,
|
| 281 |
+
preview_max_points,
|
| 282 |
+
glb_max_points,
|
| 283 |
+
],
|
| 284 |
+
outputs=[
|
| 285 |
+
plot,
|
| 286 |
+
glb_file,
|
| 287 |
+
session_dir,
|
| 288 |
+
rgb_preview,
|
| 289 |
+
depth_preview,
|
| 290 |
+
frame_label,
|
| 291 |
+
frame_slider,
|
| 292 |
+
status,
|
| 293 |
+
],
|
| 294 |
+
)
|
| 295 |
+
|
| 296 |
+
for component in [
|
| 297 |
+
branch_label,
|
| 298 |
+
show_cameras,
|
| 299 |
+
mask_sky,
|
| 300 |
+
camera_scale,
|
| 301 |
+
point_size,
|
| 302 |
+
opacity,
|
| 303 |
+
preview_max_points,
|
| 304 |
+
glb_max_points,
|
| 305 |
+
]:
|
| 306 |
+
component.change(
|
| 307 |
+
_update_stable_scene,
|
| 308 |
+
inputs=[
|
| 309 |
+
session_dir,
|
| 310 |
+
branch_label,
|
| 311 |
+
show_cameras,
|
| 312 |
+
mask_sky,
|
| 313 |
+
camera_scale,
|
| 314 |
+
point_size,
|
| 315 |
+
opacity,
|
| 316 |
+
preview_max_points,
|
| 317 |
+
glb_max_points,
|
| 318 |
+
],
|
| 319 |
+
outputs=[plot, glb_file, status],
|
| 320 |
+
)
|
| 321 |
+
|
| 322 |
+
frame_slider.change(
|
| 323 |
+
_update_frame_preview,
|
| 324 |
+
inputs=[session_dir, frame_slider],
|
| 325 |
+
outputs=[rgb_preview, depth_preview, frame_label],
|
| 326 |
+
)
|
| 327 |
+
|
| 328 |
+
demo.launch()
|
| 329 |
+
|
| 330 |
+
|
| 331 |
+
if __name__ == "__main__":
|
| 332 |
+
main()
|
longstream/.DS_Store
ADDED
|
Binary file (6.15 kB). View file
|
|
|
longstream/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
__all__ = []
|
longstream/core/__init__.py
ADDED
|
File without changes
|
longstream/core/cli.py
ADDED
|
@@ -0,0 +1,213 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import os
|
| 3 |
+
import sys
|
| 4 |
+
|
| 5 |
+
import yaml
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def default_config_path() -> str:
|
| 9 |
+
return os.path.join(
|
| 10 |
+
os.path.dirname(os.path.dirname(os.path.dirname(__file__))),
|
| 11 |
+
"configs",
|
| 12 |
+
"longstream_infer.yaml",
|
| 13 |
+
)
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def add_runtime_arguments(parser):
|
| 17 |
+
parser.add_argument(
|
| 18 |
+
"--config",
|
| 19 |
+
default=default_config_path(),
|
| 20 |
+
help="Path to longstream config yaml.",
|
| 21 |
+
)
|
| 22 |
+
parser.add_argument(
|
| 23 |
+
"--dataset",
|
| 24 |
+
default=None,
|
| 25 |
+
help="Optional dataset hint. Generic format works without it.",
|
| 26 |
+
)
|
| 27 |
+
parser.add_argument("--img-path", default=None)
|
| 28 |
+
parser.add_argument(
|
| 29 |
+
"--seq-list",
|
| 30 |
+
default=None,
|
| 31 |
+
help="Comma-separated sequence names. Default: auto-detect all sequences.",
|
| 32 |
+
)
|
| 33 |
+
parser.add_argument("--format", default=None, help="generalizable")
|
| 34 |
+
parser.add_argument("--data-roots-file", default=None)
|
| 35 |
+
parser.add_argument("--camera", default=None)
|
| 36 |
+
parser.add_argument("--output-root", default=None)
|
| 37 |
+
parser.add_argument("--device", default=None)
|
| 38 |
+
parser.add_argument("--checkpoint", default=None)
|
| 39 |
+
parser.add_argument("--hf-repo", default=None)
|
| 40 |
+
parser.add_argument("--hf-file", default=None)
|
| 41 |
+
parser.add_argument(
|
| 42 |
+
"--mode", default=None, help="batch_refresh | streaming_refresh"
|
| 43 |
+
)
|
| 44 |
+
parser.add_argument("--streaming-mode", default=None, help="causal | window")
|
| 45 |
+
parser.add_argument("--window-size", type=int, default=None)
|
| 46 |
+
parser.add_argument("--keyframe-stride", type=int, default=None)
|
| 47 |
+
parser.add_argument(
|
| 48 |
+
"--refresh",
|
| 49 |
+
type=int,
|
| 50 |
+
default=None,
|
| 51 |
+
help="Number of keyframes per refresh span, inclusive of both ends and including the segment start keyframe.",
|
| 52 |
+
)
|
| 53 |
+
parser.add_argument(
|
| 54 |
+
"--keyframes-per-batch",
|
| 55 |
+
dest="keyframes_per_batch_legacy",
|
| 56 |
+
type=int,
|
| 57 |
+
default=None,
|
| 58 |
+
help=argparse.SUPPRESS,
|
| 59 |
+
)
|
| 60 |
+
parser.add_argument("--max-frames", type=int, default=None)
|
| 61 |
+
parser.add_argument("--depth-rel-delta-threshold", type=float, default=None)
|
| 62 |
+
parser.add_argument("--point-f1-threshold", type=float, default=None)
|
| 63 |
+
parser.add_argument("--eval-max-points", type=int, default=None)
|
| 64 |
+
parser.add_argument("--eval-voxel-size", type=float, default=None)
|
| 65 |
+
parser.add_argument("--max-full-pointcloud-points", type=int, default=None)
|
| 66 |
+
parser.add_argument("--max-frame-pointcloud-points", type=int, default=None)
|
| 67 |
+
parser.add_argument("--save-frame-points", action="store_true")
|
| 68 |
+
parser.add_argument("--no-save-frame-points", action="store_true")
|
| 69 |
+
parser.add_argument("--no-align-scale", action="store_true")
|
| 70 |
+
parser.add_argument("--mask-sky", action="store_true")
|
| 71 |
+
parser.add_argument("--no-mask-sky", action="store_true")
|
| 72 |
+
return parser
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
def parse_runtime_args(parser):
|
| 76 |
+
argv = [arg for arg in sys.argv[1:] if arg.strip()]
|
| 77 |
+
return parser.parse_args(argv)
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def load_config_with_overrides(args):
|
| 81 |
+
with open(args.config, "r") as f:
|
| 82 |
+
cfg = yaml.safe_load(f) or {}
|
| 83 |
+
cfg.setdefault("model", {})
|
| 84 |
+
|
| 85 |
+
if args.device is not None:
|
| 86 |
+
cfg["device"] = args.device
|
| 87 |
+
|
| 88 |
+
if args.output_root is not None:
|
| 89 |
+
cfg.setdefault("output", {})
|
| 90 |
+
cfg["output"]["root"] = args.output_root
|
| 91 |
+
|
| 92 |
+
if args.dataset is not None:
|
| 93 |
+
cfg.setdefault("data", {})
|
| 94 |
+
cfg["data"]["dataset"] = args.dataset
|
| 95 |
+
|
| 96 |
+
if args.img_path is not None:
|
| 97 |
+
cfg.setdefault("data", {})
|
| 98 |
+
cfg["data"]["img_path"] = args.img_path
|
| 99 |
+
|
| 100 |
+
if args.seq_list is not None:
|
| 101 |
+
seqs = [s.strip() for s in args.seq_list.split(",") if s.strip()]
|
| 102 |
+
cfg.setdefault("data", {})
|
| 103 |
+
cfg["data"]["seq_list"] = seqs
|
| 104 |
+
|
| 105 |
+
if args.format is not None:
|
| 106 |
+
cfg.setdefault("data", {})
|
| 107 |
+
cfg["data"]["format"] = args.format
|
| 108 |
+
|
| 109 |
+
if args.data_roots_file is not None:
|
| 110 |
+
cfg.setdefault("data", {})
|
| 111 |
+
cfg["data"]["data_roots_file"] = args.data_roots_file
|
| 112 |
+
|
| 113 |
+
if args.camera is not None:
|
| 114 |
+
cfg.setdefault("data", {})
|
| 115 |
+
cfg["data"]["camera"] = args.camera
|
| 116 |
+
|
| 117 |
+
if args.max_frames is not None:
|
| 118 |
+
cfg.setdefault("data", {})
|
| 119 |
+
cfg["data"]["max_frames"] = args.max_frames
|
| 120 |
+
|
| 121 |
+
if args.checkpoint is not None:
|
| 122 |
+
cfg.setdefault("model", {})
|
| 123 |
+
cfg["model"]["checkpoint"] = args.checkpoint
|
| 124 |
+
|
| 125 |
+
if args.hf_repo is not None or args.hf_file is not None:
|
| 126 |
+
cfg.setdefault("model", {})
|
| 127 |
+
cfg["model"].setdefault("hf", {})
|
| 128 |
+
if args.hf_repo is not None:
|
| 129 |
+
cfg["model"]["hf"]["repo_id"] = args.hf_repo
|
| 130 |
+
if args.hf_file is not None:
|
| 131 |
+
cfg["model"]["hf"]["filename"] = args.hf_file
|
| 132 |
+
if cfg["model"].get("checkpoint") is None:
|
| 133 |
+
cfg["model"]["checkpoint"] = None
|
| 134 |
+
|
| 135 |
+
if args.mode is not None:
|
| 136 |
+
cfg.setdefault("inference", {})
|
| 137 |
+
cfg["inference"]["mode"] = args.mode
|
| 138 |
+
|
| 139 |
+
if args.streaming_mode is not None:
|
| 140 |
+
cfg.setdefault("inference", {})
|
| 141 |
+
cfg["inference"]["streaming_mode"] = args.streaming_mode
|
| 142 |
+
|
| 143 |
+
if args.window_size is not None:
|
| 144 |
+
cfg.setdefault("inference", {})
|
| 145 |
+
cfg["inference"]["window_size"] = args.window_size
|
| 146 |
+
cfg["model"].setdefault("longstream_cfg", {})
|
| 147 |
+
cfg["model"]["longstream_cfg"]["window_size"] = args.window_size
|
| 148 |
+
|
| 149 |
+
if args.keyframe_stride is not None:
|
| 150 |
+
cfg.setdefault("inference", {})
|
| 151 |
+
cfg["inference"]["keyframe_stride"] = args.keyframe_stride
|
| 152 |
+
cfg["model"].setdefault("longstream_cfg", {})
|
| 153 |
+
cfg["model"]["longstream_cfg"].setdefault("rel_pose_head_cfg", {})
|
| 154 |
+
cfg["model"]["longstream_cfg"]["rel_pose_head_cfg"][
|
| 155 |
+
"keyframe_stride"
|
| 156 |
+
] = args.keyframe_stride
|
| 157 |
+
|
| 158 |
+
refresh = args.refresh
|
| 159 |
+
if refresh is None and args.keyframes_per_batch_legacy is not None:
|
| 160 |
+
refresh = args.keyframes_per_batch_legacy + 1
|
| 161 |
+
if refresh is not None:
|
| 162 |
+
cfg.setdefault("inference", {})
|
| 163 |
+
cfg["inference"]["refresh"] = refresh
|
| 164 |
+
|
| 165 |
+
if args.depth_rel_delta_threshold is not None:
|
| 166 |
+
cfg.setdefault("evaluation", {})
|
| 167 |
+
cfg["evaluation"]["depth_rel_delta_threshold"] = args.depth_rel_delta_threshold
|
| 168 |
+
|
| 169 |
+
if args.point_f1_threshold is not None:
|
| 170 |
+
cfg.setdefault("evaluation", {})
|
| 171 |
+
cfg["evaluation"]["point_f1_threshold"] = args.point_f1_threshold
|
| 172 |
+
|
| 173 |
+
if args.eval_max_points is not None:
|
| 174 |
+
cfg.setdefault("evaluation", {})
|
| 175 |
+
cfg["evaluation"]["point_eval_max_points"] = args.eval_max_points
|
| 176 |
+
|
| 177 |
+
if args.eval_voxel_size is not None:
|
| 178 |
+
cfg.setdefault("evaluation", {})
|
| 179 |
+
cfg["evaluation"]["point_eval_voxel_size"] = args.eval_voxel_size
|
| 180 |
+
|
| 181 |
+
if args.max_full_pointcloud_points is not None:
|
| 182 |
+
cfg.setdefault("output", {})
|
| 183 |
+
cfg["output"]["max_full_pointcloud_points"] = args.max_full_pointcloud_points
|
| 184 |
+
|
| 185 |
+
if args.max_frame_pointcloud_points is not None:
|
| 186 |
+
cfg.setdefault("output", {})
|
| 187 |
+
cfg["output"]["max_frame_pointcloud_points"] = args.max_frame_pointcloud_points
|
| 188 |
+
|
| 189 |
+
if args.save_frame_points:
|
| 190 |
+
cfg.setdefault("output", {})
|
| 191 |
+
cfg["output"]["save_frame_points"] = True
|
| 192 |
+
if args.no_save_frame_points:
|
| 193 |
+
cfg.setdefault("output", {})
|
| 194 |
+
cfg["output"]["save_frame_points"] = False
|
| 195 |
+
|
| 196 |
+
if args.no_align_scale:
|
| 197 |
+
cfg.setdefault("evaluation", {})
|
| 198 |
+
cfg["evaluation"]["align_scale"] = False
|
| 199 |
+
|
| 200 |
+
if args.mask_sky:
|
| 201 |
+
cfg.setdefault("output", {})
|
| 202 |
+
cfg["output"]["mask_sky"] = True
|
| 203 |
+
if args.no_mask_sky:
|
| 204 |
+
cfg.setdefault("output", {})
|
| 205 |
+
cfg["output"]["mask_sky"] = False
|
| 206 |
+
|
| 207 |
+
infer_cfg = cfg.setdefault("inference", {})
|
| 208 |
+
if "refresh" not in infer_cfg and "keyframes_per_batch" in infer_cfg:
|
| 209 |
+
infer_cfg["refresh"] = int(infer_cfg["keyframes_per_batch"]) + 1
|
| 210 |
+
|
| 211 |
+
cfg.setdefault("data", {})
|
| 212 |
+
cfg["data"]["format"] = "generalizable"
|
| 213 |
+
return cfg
|
longstream/core/infer.py
ADDED
|
@@ -0,0 +1,451 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import os
|
| 3 |
+
import yaml
|
| 4 |
+
import cv2
|
| 5 |
+
import numpy as np
|
| 6 |
+
import torch
|
| 7 |
+
from PIL import Image
|
| 8 |
+
|
| 9 |
+
from longstream.core.model import LongStreamModel
|
| 10 |
+
from longstream.data.dataloader import LongStreamDataLoader
|
| 11 |
+
from longstream.streaming.keyframe_selector import KeyframeSelector
|
| 12 |
+
from longstream.streaming.refresh import run_batch_refresh, run_streaming_refresh
|
| 13 |
+
from longstream.utils.vendor.models.components.utils.pose_enc import (
|
| 14 |
+
pose_encoding_to_extri_intri,
|
| 15 |
+
)
|
| 16 |
+
from longstream.utils.camera import compose_abs_from_rel
|
| 17 |
+
from longstream.utils.depth import colorize_depth, unproject_depth_to_points
|
| 18 |
+
from longstream.utils.sky_mask import compute_sky_mask
|
| 19 |
+
from longstream.io.save_points import save_pointcloud
|
| 20 |
+
from longstream.io.save_poses_txt import save_w2c_txt, save_intri_txt, save_rel_pose_txt
|
| 21 |
+
from longstream.io.save_images import save_image_sequence, save_video
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def _to_uint8_rgb(images):
|
| 25 |
+
imgs = images.detach().cpu().numpy()
|
| 26 |
+
imgs = np.clip(imgs, 0.0, 1.0)
|
| 27 |
+
imgs = (imgs * 255.0).astype(np.uint8)
|
| 28 |
+
return imgs
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def _ensure_dir(path):
|
| 32 |
+
os.makedirs(path, exist_ok=True)
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def _apply_sky_mask(depth, mask):
|
| 36 |
+
if mask is None:
|
| 37 |
+
return depth
|
| 38 |
+
m = (mask > 0).astype(np.float32)
|
| 39 |
+
return depth * m
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def _camera_points_to_world(points, extri):
|
| 43 |
+
pts = np.asarray(points, dtype=np.float64).reshape(-1, 3)
|
| 44 |
+
R = np.asarray(extri[:3, :3], dtype=np.float64)
|
| 45 |
+
t = np.asarray(extri[:3, 3], dtype=np.float64)
|
| 46 |
+
world = (R.T @ (pts.T - t[:, None])).T
|
| 47 |
+
return world.astype(np.float32, copy=False)
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def _mask_points_and_colors(points, colors, mask):
|
| 51 |
+
pts = points.reshape(-1, 3)
|
| 52 |
+
cols = None if colors is None else colors.reshape(-1, 3)
|
| 53 |
+
if mask is None:
|
| 54 |
+
return pts, cols
|
| 55 |
+
valid = mask.reshape(-1) > 0
|
| 56 |
+
pts = pts[valid]
|
| 57 |
+
if cols is not None:
|
| 58 |
+
cols = cols[valid]
|
| 59 |
+
return pts, cols
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def _resize_long_edge(arr, long_edge_size, interpolation):
|
| 63 |
+
h, w = arr.shape[:2]
|
| 64 |
+
scale = float(long_edge_size) / float(max(h, w))
|
| 65 |
+
new_w = int(round(w * scale))
|
| 66 |
+
new_h = int(round(h * scale))
|
| 67 |
+
return cv2.resize(arr, (new_w, new_h), interpolation=interpolation)
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def _prepare_mask_for_model(
|
| 71 |
+
mask, size, crop, patch_size, target_shape, square_ok=False
|
| 72 |
+
):
|
| 73 |
+
if mask is None:
|
| 74 |
+
return None
|
| 75 |
+
long_edge = (
|
| 76 |
+
round(size * max(mask.shape[1] / mask.shape[0], mask.shape[0] / mask.shape[1]))
|
| 77 |
+
if size == 224
|
| 78 |
+
else size
|
| 79 |
+
)
|
| 80 |
+
mask = _resize_long_edge(mask, long_edge, cv2.INTER_NEAREST)
|
| 81 |
+
|
| 82 |
+
h, w = mask.shape[:2]
|
| 83 |
+
cx, cy = w // 2, h // 2
|
| 84 |
+
if size == 224:
|
| 85 |
+
half = min(cx, cy)
|
| 86 |
+
target_w = 2 * half
|
| 87 |
+
target_h = 2 * half
|
| 88 |
+
if crop:
|
| 89 |
+
mask = mask[cy - half : cy + half, cx - half : cx + half]
|
| 90 |
+
else:
|
| 91 |
+
mask = cv2.resize(
|
| 92 |
+
mask, (target_w, target_h), interpolation=cv2.INTER_NEAREST
|
| 93 |
+
)
|
| 94 |
+
else:
|
| 95 |
+
halfw = ((2 * cx) // patch_size) * (patch_size // 2)
|
| 96 |
+
halfh = ((2 * cy) // patch_size) * (patch_size // 2)
|
| 97 |
+
if not square_ok and w == h:
|
| 98 |
+
halfh = int(3 * halfw / 4)
|
| 99 |
+
target_w = 2 * halfw
|
| 100 |
+
target_h = 2 * halfh
|
| 101 |
+
if crop:
|
| 102 |
+
mask = mask[cy - halfh : cy + halfh, cx - halfw : cx + halfw]
|
| 103 |
+
else:
|
| 104 |
+
mask = cv2.resize(
|
| 105 |
+
mask, (target_w, target_h), interpolation=cv2.INTER_NEAREST
|
| 106 |
+
)
|
| 107 |
+
|
| 108 |
+
if mask.shape[:2] != tuple(target_shape):
|
| 109 |
+
mask = cv2.resize(
|
| 110 |
+
mask, (target_shape[1], target_shape[0]), interpolation=cv2.INTER_NEAREST
|
| 111 |
+
)
|
| 112 |
+
return mask
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
def _save_full_pointcloud(path, point_chunks, color_chunks, max_points=None, seed=0):
|
| 116 |
+
if not point_chunks:
|
| 117 |
+
return
|
| 118 |
+
points = np.concatenate(point_chunks, axis=0)
|
| 119 |
+
colors = None
|
| 120 |
+
if color_chunks and len(color_chunks) == len(point_chunks):
|
| 121 |
+
colors = np.concatenate(color_chunks, axis=0)
|
| 122 |
+
if max_points is not None and len(points) > max_points:
|
| 123 |
+
rng = np.random.default_rng(seed)
|
| 124 |
+
keep = rng.choice(len(points), size=max_points, replace=False)
|
| 125 |
+
points = points[keep]
|
| 126 |
+
if colors is not None:
|
| 127 |
+
colors = colors[keep]
|
| 128 |
+
np.save(os.path.splitext(path)[0] + ".npy", points.astype(np.float32, copy=False))
|
| 129 |
+
save_pointcloud(path, points, colors=colors, max_points=None, seed=seed)
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
def run_inference_cfg(cfg: dict):
|
| 133 |
+
device = cfg.get("device", "cuda" if torch.cuda.is_available() else "cpu")
|
| 134 |
+
device_type = torch.device(device).type
|
| 135 |
+
model_cfg = cfg.get("model", {})
|
| 136 |
+
data_cfg = cfg.get("data", {})
|
| 137 |
+
infer_cfg = cfg.get("inference", {})
|
| 138 |
+
output_cfg = cfg.get("output", {})
|
| 139 |
+
|
| 140 |
+
print(f"[longstream] device={device}", flush=True)
|
| 141 |
+
model = LongStreamModel(model_cfg).to(device)
|
| 142 |
+
model.eval()
|
| 143 |
+
print("[longstream] model ready", flush=True)
|
| 144 |
+
|
| 145 |
+
loader = LongStreamDataLoader(data_cfg)
|
| 146 |
+
|
| 147 |
+
keyframe_stride = int(infer_cfg.get("keyframe_stride", 8))
|
| 148 |
+
keyframe_mode = infer_cfg.get("keyframe_mode", "fixed")
|
| 149 |
+
refresh = int(
|
| 150 |
+
infer_cfg.get("refresh", int(infer_cfg.get("keyframes_per_batch", 3)) + 1)
|
| 151 |
+
)
|
| 152 |
+
if refresh < 2:
|
| 153 |
+
raise ValueError(
|
| 154 |
+
"refresh must be >= 2 because it counts both keyframe endpoints"
|
| 155 |
+
)
|
| 156 |
+
mode = infer_cfg.get("mode", "streaming_refresh")
|
| 157 |
+
if mode == "streaming":
|
| 158 |
+
mode = "streaming_refresh"
|
| 159 |
+
streaming_mode = infer_cfg.get("streaming_mode", "causal")
|
| 160 |
+
window_size = int(infer_cfg.get("window_size", 5))
|
| 161 |
+
|
| 162 |
+
selector = KeyframeSelector(
|
| 163 |
+
min_interval=keyframe_stride,
|
| 164 |
+
max_interval=keyframe_stride,
|
| 165 |
+
force_first=True,
|
| 166 |
+
mode="random" if keyframe_mode == "random" else "fixed",
|
| 167 |
+
)
|
| 168 |
+
|
| 169 |
+
out_root = output_cfg.get("root", "outputs")
|
| 170 |
+
_ensure_dir(out_root)
|
| 171 |
+
save_videos = bool(output_cfg.get("save_videos", True))
|
| 172 |
+
save_points = bool(output_cfg.get("save_points", True))
|
| 173 |
+
save_frame_points = bool(output_cfg.get("save_frame_points", True))
|
| 174 |
+
save_depth = bool(output_cfg.get("save_depth", True))
|
| 175 |
+
save_images = bool(output_cfg.get("save_images", True))
|
| 176 |
+
mask_sky = bool(output_cfg.get("mask_sky", True))
|
| 177 |
+
max_full_pointcloud_points = output_cfg.get("max_full_pointcloud_points", None)
|
| 178 |
+
if max_full_pointcloud_points is not None:
|
| 179 |
+
max_full_pointcloud_points = int(max_full_pointcloud_points)
|
| 180 |
+
max_frame_pointcloud_points = output_cfg.get("max_frame_pointcloud_points", None)
|
| 181 |
+
if max_frame_pointcloud_points is not None:
|
| 182 |
+
max_frame_pointcloud_points = int(max_frame_pointcloud_points)
|
| 183 |
+
skyseg_path = output_cfg.get(
|
| 184 |
+
"skyseg_path",
|
| 185 |
+
os.path.join(os.path.dirname(__file__), "..", "..", "skyseg.onnx"),
|
| 186 |
+
)
|
| 187 |
+
|
| 188 |
+
with torch.no_grad():
|
| 189 |
+
for seq in loader:
|
| 190 |
+
images = seq.images
|
| 191 |
+
B, S, C, H, W = images.shape
|
| 192 |
+
print(
|
| 193 |
+
f"[longstream] sequence {seq.name}: inference start ({S} frames)",
|
| 194 |
+
flush=True,
|
| 195 |
+
)
|
| 196 |
+
|
| 197 |
+
is_keyframe, keyframe_indices = selector.select_keyframes(
|
| 198 |
+
S, B, images.device
|
| 199 |
+
)
|
| 200 |
+
|
| 201 |
+
rel_pose_cfg = infer_cfg.get("rel_pose_head_cfg", {"num_iterations": 4})
|
| 202 |
+
|
| 203 |
+
if mode == "batch_refresh":
|
| 204 |
+
outputs = run_batch_refresh(
|
| 205 |
+
model,
|
| 206 |
+
images,
|
| 207 |
+
is_keyframe,
|
| 208 |
+
keyframe_indices,
|
| 209 |
+
streaming_mode,
|
| 210 |
+
keyframe_stride,
|
| 211 |
+
refresh,
|
| 212 |
+
rel_pose_cfg,
|
| 213 |
+
)
|
| 214 |
+
elif mode == "streaming_refresh":
|
| 215 |
+
outputs = run_streaming_refresh(
|
| 216 |
+
model,
|
| 217 |
+
images,
|
| 218 |
+
is_keyframe,
|
| 219 |
+
keyframe_indices,
|
| 220 |
+
streaming_mode,
|
| 221 |
+
window_size,
|
| 222 |
+
refresh,
|
| 223 |
+
rel_pose_cfg,
|
| 224 |
+
)
|
| 225 |
+
else:
|
| 226 |
+
raise ValueError(f"Unsupported inference mode: {mode}")
|
| 227 |
+
print(f"[longstream] sequence {seq.name}: inference done", flush=True)
|
| 228 |
+
if device_type == "cuda":
|
| 229 |
+
torch.cuda.empty_cache()
|
| 230 |
+
|
| 231 |
+
seq_dir = os.path.join(out_root, seq.name)
|
| 232 |
+
_ensure_dir(seq_dir)
|
| 233 |
+
|
| 234 |
+
frame_ids = list(range(S))
|
| 235 |
+
rgb = _to_uint8_rgb(images[0].permute(0, 2, 3, 1))
|
| 236 |
+
|
| 237 |
+
if "rel_pose_enc" in outputs:
|
| 238 |
+
rel_pose_enc = outputs["rel_pose_enc"][0]
|
| 239 |
+
abs_pose_enc = compose_abs_from_rel(rel_pose_enc, keyframe_indices[0])
|
| 240 |
+
extri, intri = pose_encoding_to_extri_intri(
|
| 241 |
+
abs_pose_enc[None], image_size_hw=(H, W)
|
| 242 |
+
)
|
| 243 |
+
extri_np = extri[0].detach().cpu().numpy()
|
| 244 |
+
intri_np = intri[0].detach().cpu().numpy()
|
| 245 |
+
|
| 246 |
+
pose_dir = os.path.join(seq_dir, "poses")
|
| 247 |
+
_ensure_dir(pose_dir)
|
| 248 |
+
save_w2c_txt(
|
| 249 |
+
os.path.join(pose_dir, "abs_pose.txt"), extri_np, frame_ids
|
| 250 |
+
)
|
| 251 |
+
save_intri_txt(os.path.join(pose_dir, "intri.txt"), intri_np, frame_ids)
|
| 252 |
+
save_rel_pose_txt(
|
| 253 |
+
os.path.join(pose_dir, "rel_pose.txt"), rel_pose_enc, frame_ids
|
| 254 |
+
)
|
| 255 |
+
elif "pose_enc" in outputs:
|
| 256 |
+
pose_enc = outputs["pose_enc"][0]
|
| 257 |
+
extri, intri = pose_encoding_to_extri_intri(
|
| 258 |
+
pose_enc[None], image_size_hw=(H, W)
|
| 259 |
+
)
|
| 260 |
+
extri_np = extri[0].detach().cpu().numpy()
|
| 261 |
+
intri_np = intri[0].detach().cpu().numpy()
|
| 262 |
+
|
| 263 |
+
pose_dir = os.path.join(seq_dir, "poses")
|
| 264 |
+
_ensure_dir(pose_dir)
|
| 265 |
+
save_w2c_txt(
|
| 266 |
+
os.path.join(pose_dir, "abs_pose.txt"), extri_np, frame_ids
|
| 267 |
+
)
|
| 268 |
+
save_intri_txt(os.path.join(pose_dir, "intri.txt"), intri_np, frame_ids)
|
| 269 |
+
|
| 270 |
+
if save_images:
|
| 271 |
+
print(f"[longstream] sequence {seq.name}: saving rgb", flush=True)
|
| 272 |
+
rgb_dir = os.path.join(seq_dir, "images", "rgb")
|
| 273 |
+
save_image_sequence(rgb_dir, list(rgb))
|
| 274 |
+
if save_videos:
|
| 275 |
+
save_video(
|
| 276 |
+
os.path.join(seq_dir, "images", "rgb.mp4"),
|
| 277 |
+
os.path.join(rgb_dir, "frame_*.png"),
|
| 278 |
+
)
|
| 279 |
+
|
| 280 |
+
sky_masks = None
|
| 281 |
+
if mask_sky:
|
| 282 |
+
raw_sky_masks = compute_sky_mask(
|
| 283 |
+
seq.image_paths, skyseg_path, os.path.join(seq_dir, "sky_masks")
|
| 284 |
+
)
|
| 285 |
+
if raw_sky_masks is not None:
|
| 286 |
+
sky_masks = [
|
| 287 |
+
_prepare_mask_for_model(
|
| 288 |
+
mask,
|
| 289 |
+
size=int(data_cfg.get("size", 518)),
|
| 290 |
+
crop=bool(data_cfg.get("crop", False)),
|
| 291 |
+
patch_size=int(data_cfg.get("patch_size", 14)),
|
| 292 |
+
target_shape=(H, W),
|
| 293 |
+
)
|
| 294 |
+
for mask in raw_sky_masks
|
| 295 |
+
]
|
| 296 |
+
|
| 297 |
+
if save_depth and "depth" in outputs:
|
| 298 |
+
print(f"[longstream] sequence {seq.name}: saving depth", flush=True)
|
| 299 |
+
depth = outputs["depth"][0, :, :, :, 0].detach().cpu().numpy()
|
| 300 |
+
depth_dir = os.path.join(seq_dir, "depth", "dpt")
|
| 301 |
+
_ensure_dir(depth_dir)
|
| 302 |
+
color_dir = os.path.join(seq_dir, "depth", "dpt_plasma")
|
| 303 |
+
_ensure_dir(color_dir)
|
| 304 |
+
|
| 305 |
+
color_frames = []
|
| 306 |
+
for i in range(S):
|
| 307 |
+
d = depth[i]
|
| 308 |
+
if sky_masks is not None and sky_masks[i] is not None:
|
| 309 |
+
d = _apply_sky_mask(d, sky_masks[i])
|
| 310 |
+
np.save(os.path.join(depth_dir, f"frame_{i:06d}.npy"), d)
|
| 311 |
+
colored = colorize_depth(d, cmap="plasma")
|
| 312 |
+
Image.fromarray(colored).save(
|
| 313 |
+
os.path.join(color_dir, f"frame_{i:06d}.png")
|
| 314 |
+
)
|
| 315 |
+
color_frames.append(colored)
|
| 316 |
+
if save_videos:
|
| 317 |
+
save_video(
|
| 318 |
+
os.path.join(seq_dir, "depth", "dpt_plasma.mp4"),
|
| 319 |
+
os.path.join(color_dir, "frame_*.png"),
|
| 320 |
+
)
|
| 321 |
+
|
| 322 |
+
if save_points:
|
| 323 |
+
print(
|
| 324 |
+
f"[longstream] sequence {seq.name}: saving point clouds", flush=True
|
| 325 |
+
)
|
| 326 |
+
if "world_points" in outputs:
|
| 327 |
+
if "rel_pose_enc" in outputs:
|
| 328 |
+
abs_pose_enc = compose_abs_from_rel(
|
| 329 |
+
outputs["rel_pose_enc"][0], keyframe_indices[0]
|
| 330 |
+
)
|
| 331 |
+
extri, intri = pose_encoding_to_extri_intri(
|
| 332 |
+
abs_pose_enc[None], image_size_hw=(H, W)
|
| 333 |
+
)
|
| 334 |
+
else:
|
| 335 |
+
extri, intri = pose_encoding_to_extri_intri(
|
| 336 |
+
outputs["pose_enc"][0][None], image_size_hw=(H, W)
|
| 337 |
+
)
|
| 338 |
+
extri = extri[0]
|
| 339 |
+
intri = intri[0]
|
| 340 |
+
|
| 341 |
+
pts_dir = os.path.join(seq_dir, "points", "point_head")
|
| 342 |
+
_ensure_dir(pts_dir)
|
| 343 |
+
pts = outputs["world_points"][0].detach().cpu().numpy()
|
| 344 |
+
full_pts = []
|
| 345 |
+
full_cols = []
|
| 346 |
+
for i in range(S):
|
| 347 |
+
pts_world = _camera_points_to_world(
|
| 348 |
+
pts[i], extri[i].detach().cpu().numpy()
|
| 349 |
+
)
|
| 350 |
+
pts_world = pts_world.reshape(pts[i].shape)
|
| 351 |
+
pts_i, cols_i = _mask_points_and_colors(
|
| 352 |
+
pts_world,
|
| 353 |
+
rgb[i],
|
| 354 |
+
None if sky_masks is None else sky_masks[i],
|
| 355 |
+
)
|
| 356 |
+
if save_frame_points:
|
| 357 |
+
save_pointcloud(
|
| 358 |
+
os.path.join(pts_dir, f"frame_{i:06d}.ply"),
|
| 359 |
+
pts_i,
|
| 360 |
+
colors=cols_i,
|
| 361 |
+
max_points=max_frame_pointcloud_points,
|
| 362 |
+
seed=i,
|
| 363 |
+
)
|
| 364 |
+
if len(pts_i):
|
| 365 |
+
full_pts.append(pts_i)
|
| 366 |
+
full_cols.append(cols_i)
|
| 367 |
+
_save_full_pointcloud(
|
| 368 |
+
os.path.join(seq_dir, "points", "point_head_full.ply"),
|
| 369 |
+
full_pts,
|
| 370 |
+
full_cols,
|
| 371 |
+
max_points=max_full_pointcloud_points,
|
| 372 |
+
seed=0,
|
| 373 |
+
)
|
| 374 |
+
|
| 375 |
+
if "depth" in outputs and (
|
| 376 |
+
"rel_pose_enc" in outputs or "pose_enc" in outputs
|
| 377 |
+
):
|
| 378 |
+
depth = outputs["depth"][0, :, :, :, 0]
|
| 379 |
+
if "rel_pose_enc" in outputs:
|
| 380 |
+
abs_pose_enc = compose_abs_from_rel(
|
| 381 |
+
outputs["rel_pose_enc"][0], keyframe_indices[0]
|
| 382 |
+
)
|
| 383 |
+
extri, intri = pose_encoding_to_extri_intri(
|
| 384 |
+
abs_pose_enc[None], image_size_hw=(H, W)
|
| 385 |
+
)
|
| 386 |
+
else:
|
| 387 |
+
extri, intri = pose_encoding_to_extri_intri(
|
| 388 |
+
outputs["pose_enc"][0][None], image_size_hw=(H, W)
|
| 389 |
+
)
|
| 390 |
+
|
| 391 |
+
extri = extri[0]
|
| 392 |
+
intri = intri[0]
|
| 393 |
+
dpt_pts_dir = os.path.join(seq_dir, "points", "dpt_unproj")
|
| 394 |
+
_ensure_dir(dpt_pts_dir)
|
| 395 |
+
full_pts = []
|
| 396 |
+
full_cols = []
|
| 397 |
+
|
| 398 |
+
for i in range(S):
|
| 399 |
+
d = depth[i]
|
| 400 |
+
pts_cam = unproject_depth_to_points(d[None], intri[i : i + 1])[
|
| 401 |
+
0
|
| 402 |
+
]
|
| 403 |
+
R = extri[i, :3, :3]
|
| 404 |
+
t = extri[i, :3, 3]
|
| 405 |
+
pts_world = (
|
| 406 |
+
R.t() @ (pts_cam.reshape(-1, 3).t() - t[:, None])
|
| 407 |
+
).t()
|
| 408 |
+
pts_world = pts_world.cpu().numpy().reshape(-1, 3)
|
| 409 |
+
pts_i, cols_i = _mask_points_and_colors(
|
| 410 |
+
pts_world,
|
| 411 |
+
rgb[i],
|
| 412 |
+
None if sky_masks is None else sky_masks[i],
|
| 413 |
+
)
|
| 414 |
+
if save_frame_points:
|
| 415 |
+
save_pointcloud(
|
| 416 |
+
os.path.join(dpt_pts_dir, f"frame_{i:06d}.ply"),
|
| 417 |
+
pts_i,
|
| 418 |
+
colors=cols_i,
|
| 419 |
+
max_points=max_frame_pointcloud_points,
|
| 420 |
+
seed=i,
|
| 421 |
+
)
|
| 422 |
+
if len(pts_i):
|
| 423 |
+
full_pts.append(pts_i)
|
| 424 |
+
full_cols.append(cols_i)
|
| 425 |
+
_save_full_pointcloud(
|
| 426 |
+
os.path.join(seq_dir, "points", "dpt_unproj_full.ply"),
|
| 427 |
+
full_pts,
|
| 428 |
+
full_cols,
|
| 429 |
+
max_points=max_full_pointcloud_points,
|
| 430 |
+
seed=1,
|
| 431 |
+
)
|
| 432 |
+
del outputs
|
| 433 |
+
if device_type == "cuda":
|
| 434 |
+
torch.cuda.empty_cache()
|
| 435 |
+
|
| 436 |
+
|
| 437 |
+
def run_inference(config_path: str):
|
| 438 |
+
with open(config_path, "r") as f:
|
| 439 |
+
cfg = yaml.safe_load(f)
|
| 440 |
+
run_inference_cfg(cfg)
|
| 441 |
+
|
| 442 |
+
|
| 443 |
+
def main():
|
| 444 |
+
parser = argparse.ArgumentParser()
|
| 445 |
+
parser.add_argument("--config", required=True)
|
| 446 |
+
args = parser.parse_args()
|
| 447 |
+
run_inference(args.config)
|
| 448 |
+
|
| 449 |
+
|
| 450 |
+
if __name__ == "__main__":
|
| 451 |
+
main()
|
longstream/core/model.py
ADDED
|
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import torch
|
| 3 |
+
from typing import Dict, Any
|
| 4 |
+
|
| 5 |
+
from longstream.models.longstream import LongStream
|
| 6 |
+
from longstream.utils.hub import resolve_checkpoint_path
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class LongStreamModel(torch.nn.Module):
|
| 10 |
+
def __init__(self, cfg: Dict[str, Any] | None):
|
| 11 |
+
super().__init__()
|
| 12 |
+
cfg = cfg or {}
|
| 13 |
+
|
| 14 |
+
ckpt_path = resolve_checkpoint_path(
|
| 15 |
+
cfg.get("checkpoint", None), cfg.get("hf", None)
|
| 16 |
+
)
|
| 17 |
+
|
| 18 |
+
stream_cfg = dict(cfg.get("longstream_cfg", {}) or {})
|
| 19 |
+
rel_pose_cfg = stream_cfg.pop(
|
| 20 |
+
"rel_pose_head_cfg", cfg.get("rel_pose_head_cfg", None)
|
| 21 |
+
)
|
| 22 |
+
use_rel_pose_head = bool(stream_cfg.pop("use_rel_pose_head", False))
|
| 23 |
+
if use_rel_pose_head and rel_pose_cfg is not None:
|
| 24 |
+
stream_cfg["rel_pose_head_cfg"] = rel_pose_cfg
|
| 25 |
+
self.longstream = LongStream(**stream_cfg)
|
| 26 |
+
|
| 27 |
+
if ckpt_path:
|
| 28 |
+
self.load_checkpoint(ckpt_path, strict=bool(cfg.get("strict_load", True)))
|
| 29 |
+
|
| 30 |
+
def load_checkpoint(self, ckpt_path: str, strict: bool = True):
|
| 31 |
+
if not os.path.exists(ckpt_path):
|
| 32 |
+
raise FileNotFoundError(ckpt_path)
|
| 33 |
+
ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=False)
|
| 34 |
+
if isinstance(ckpt, dict):
|
| 35 |
+
if "model" in ckpt and isinstance(ckpt["model"], dict):
|
| 36 |
+
state = ckpt["model"]
|
| 37 |
+
elif "state_dict" in ckpt and isinstance(ckpt["state_dict"], dict):
|
| 38 |
+
state = ckpt["state_dict"]
|
| 39 |
+
else:
|
| 40 |
+
state = ckpt
|
| 41 |
+
else:
|
| 42 |
+
raise TypeError("Unsupported checkpoint format")
|
| 43 |
+
|
| 44 |
+
if state:
|
| 45 |
+
first_key = next(iter(state.keys()))
|
| 46 |
+
if first_key.startswith("sampler.longstream."):
|
| 47 |
+
state = {k.replace("sampler.", "", 1): v for k, v in state.items()}
|
| 48 |
+
|
| 49 |
+
missing, unexpected = self.load_state_dict(state, strict=False)
|
| 50 |
+
if missing or unexpected:
|
| 51 |
+
msg = f"checkpoint mismatch: missing={len(missing)} unexpected={len(unexpected)}"
|
| 52 |
+
if strict:
|
| 53 |
+
raise RuntimeError(msg)
|
| 54 |
+
print(msg)
|
| 55 |
+
|
| 56 |
+
def forward(self, *args, **kwargs):
|
| 57 |
+
return self.longstream(*args, **kwargs)
|
| 58 |
+
|
| 59 |
+
@property
|
| 60 |
+
def aggregator(self):
|
| 61 |
+
return self.longstream.aggregator
|
| 62 |
+
|
| 63 |
+
@property
|
| 64 |
+
def camera_head(self):
|
| 65 |
+
return getattr(self.longstream, "camera_head", None)
|
| 66 |
+
|
| 67 |
+
@property
|
| 68 |
+
def rel_pose_head(self):
|
| 69 |
+
return getattr(self.longstream, "rel_pose_head", None)
|
longstream/data/__init__.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .dataloader import LongStreamDataLoader, LongStreamSequence, LongStreamSequenceInfo
|
| 2 |
+
|
| 3 |
+
__all__ = ["LongStreamDataLoader", "LongStreamSequence", "LongStreamSequenceInfo"]
|
longstream/data/dataloader.py
ADDED
|
@@ -0,0 +1,422 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import glob
|
| 3 |
+
from dataclasses import dataclass
|
| 4 |
+
from typing import List, Dict, Any, Iterator, Optional, Tuple
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
|
| 8 |
+
from longstream.utils.vendor.dust3r.utils.image import load_images_for_eval
|
| 9 |
+
|
| 10 |
+
dataset_metadata: Dict[str, Dict[str, Any]] = {
|
| 11 |
+
"davis": {
|
| 12 |
+
"img_path": "data/davis/DAVIS/JPEGImages/480p",
|
| 13 |
+
"mask_path": "data/davis/DAVIS/masked_images/480p",
|
| 14 |
+
"dir_path_func": lambda img_path, seq: os.path.join(img_path, seq),
|
| 15 |
+
"gt_traj_func": lambda img_path, anno_path, seq: None,
|
| 16 |
+
"traj_format": None,
|
| 17 |
+
"seq_list": None,
|
| 18 |
+
"full_seq": True,
|
| 19 |
+
"mask_path_seq_func": lambda mask_path, seq: os.path.join(mask_path, seq),
|
| 20 |
+
"skip_condition": None,
|
| 21 |
+
"process_func": None,
|
| 22 |
+
},
|
| 23 |
+
"kitti": {
|
| 24 |
+
"img_path": "data/kitti/sequences",
|
| 25 |
+
"anno_path": "data/kitti/poses",
|
| 26 |
+
"mask_path": None,
|
| 27 |
+
"dir_path_func": lambda img_path, seq: os.path.join(img_path, seq, "image_2"),
|
| 28 |
+
"gt_traj_func": lambda img_path, anno_path, seq: os.path.join(
|
| 29 |
+
anno_path, f"{seq}.txt"
|
| 30 |
+
)
|
| 31 |
+
if os.path.exists(os.path.join(anno_path, f"{seq}.txt"))
|
| 32 |
+
else None,
|
| 33 |
+
"traj_format": "kitti",
|
| 34 |
+
"seq_list": ["00", "01", "02", "03", "04", "05", "06", "07", "08", "09", "10"],
|
| 35 |
+
"full_seq": True,
|
| 36 |
+
"mask_path_seq_func": lambda mask_path, seq: None,
|
| 37 |
+
"skip_condition": None,
|
| 38 |
+
"process_func": None,
|
| 39 |
+
},
|
| 40 |
+
"bonn": {
|
| 41 |
+
"img_path": "data/bonn/rgbd_bonn_dataset",
|
| 42 |
+
"mask_path": None,
|
| 43 |
+
"dir_path_func": lambda img_path, seq: os.path.join(
|
| 44 |
+
img_path, f"rgbd_bonn_{seq}", "rgb_110"
|
| 45 |
+
),
|
| 46 |
+
"gt_traj_func": lambda img_path, anno_path, seq: os.path.join(
|
| 47 |
+
img_path, f"rgbd_bonn_{seq}", "groundtruth_110.txt"
|
| 48 |
+
),
|
| 49 |
+
"traj_format": "tum",
|
| 50 |
+
"seq_list": ["balloon2", "crowd2", "crowd3", "person_tracking2", "synchronous"],
|
| 51 |
+
"full_seq": False,
|
| 52 |
+
"mask_path_seq_func": lambda mask_path, seq: None,
|
| 53 |
+
"skip_condition": None,
|
| 54 |
+
"process_func": None,
|
| 55 |
+
},
|
| 56 |
+
"nyu": {
|
| 57 |
+
"img_path": "data/nyu-v2/val/nyu_images",
|
| 58 |
+
"mask_path": None,
|
| 59 |
+
"process_func": None,
|
| 60 |
+
},
|
| 61 |
+
"scannet": {
|
| 62 |
+
"img_path": "data/scannetv2",
|
| 63 |
+
"mask_path": None,
|
| 64 |
+
"dir_path_func": lambda img_path, seq: os.path.join(img_path, seq, "color_90"),
|
| 65 |
+
"gt_traj_func": lambda img_path, anno_path, seq: os.path.join(
|
| 66 |
+
img_path, seq, "pose_90.txt"
|
| 67 |
+
),
|
| 68 |
+
"traj_format": "replica",
|
| 69 |
+
"seq_list": None,
|
| 70 |
+
"full_seq": True,
|
| 71 |
+
"mask_path_seq_func": lambda mask_path, seq: None,
|
| 72 |
+
"skip_condition": None,
|
| 73 |
+
"process_func": None,
|
| 74 |
+
},
|
| 75 |
+
"tum": {
|
| 76 |
+
"img_path": "data/tum",
|
| 77 |
+
"mask_path": None,
|
| 78 |
+
"dir_path_func": lambda img_path, seq: os.path.join(img_path, seq, "rgb_90"),
|
| 79 |
+
"gt_traj_func": lambda img_path, anno_path, seq: os.path.join(
|
| 80 |
+
img_path, seq, "groundtruth_90.txt"
|
| 81 |
+
),
|
| 82 |
+
"traj_format": "tum",
|
| 83 |
+
"seq_list": None,
|
| 84 |
+
"full_seq": True,
|
| 85 |
+
"mask_path_seq_func": lambda mask_path, seq: None,
|
| 86 |
+
"skip_condition": None,
|
| 87 |
+
"process_func": None,
|
| 88 |
+
},
|
| 89 |
+
"sintel": {
|
| 90 |
+
"img_path": "data/sintel/training/final",
|
| 91 |
+
"anno_path": "data/sintel/training/camdata_left",
|
| 92 |
+
"mask_path": None,
|
| 93 |
+
"dir_path_func": lambda img_path, seq: os.path.join(img_path, seq),
|
| 94 |
+
"gt_traj_func": lambda img_path, anno_path, seq: os.path.join(anno_path, seq),
|
| 95 |
+
"traj_format": None,
|
| 96 |
+
"seq_list": [
|
| 97 |
+
"alley_2",
|
| 98 |
+
"ambush_4",
|
| 99 |
+
"ambush_5",
|
| 100 |
+
"ambush_6",
|
| 101 |
+
"cave_2",
|
| 102 |
+
"cave_4",
|
| 103 |
+
"market_2",
|
| 104 |
+
"market_5",
|
| 105 |
+
"market_6",
|
| 106 |
+
"shaman_3",
|
| 107 |
+
"sleeping_1",
|
| 108 |
+
"sleeping_2",
|
| 109 |
+
"temple_2",
|
| 110 |
+
"temple_3",
|
| 111 |
+
],
|
| 112 |
+
"full_seq": False,
|
| 113 |
+
"mask_path_seq_func": lambda mask_path, seq: None,
|
| 114 |
+
"skip_condition": None,
|
| 115 |
+
"process_func": None,
|
| 116 |
+
},
|
| 117 |
+
"waymo": {
|
| 118 |
+
"img_path": "/horizon-bucket/saturn_v_4dlabel/004_vision/01_users/tao02.xie/datasets/scatt3r_evaluation/waymo_open_dataset_v1_4_3",
|
| 119 |
+
"anno_path": None,
|
| 120 |
+
"mask_path": None,
|
| 121 |
+
"dir_path_func": lambda img_path, seq: os.path.join(
|
| 122 |
+
img_path,
|
| 123 |
+
seq.split("_cam")[0] if "_cam" in seq else seq,
|
| 124 |
+
"images",
|
| 125 |
+
seq.split("_cam")[1] if "_cam" in seq else "00",
|
| 126 |
+
),
|
| 127 |
+
"gt_traj_func": lambda img_path, anno_path, seq: os.path.join(
|
| 128 |
+
img_path,
|
| 129 |
+
seq.split("_cam")[0] if "_cam" in seq else seq,
|
| 130 |
+
"cameras",
|
| 131 |
+
seq.split("_cam")[1] if "_cam" in seq else "00",
|
| 132 |
+
"extri.yml",
|
| 133 |
+
),
|
| 134 |
+
"traj_format": "waymo",
|
| 135 |
+
"seq_list": None,
|
| 136 |
+
"full_seq": True,
|
| 137 |
+
"mask_path_seq_func": lambda mask_path, seq: None,
|
| 138 |
+
"skip_condition": None,
|
| 139 |
+
"process_func": None,
|
| 140 |
+
},
|
| 141 |
+
}
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
@dataclass
|
| 145 |
+
class LongStreamSequenceInfo:
|
| 146 |
+
name: str
|
| 147 |
+
scene_root: str
|
| 148 |
+
image_dir: str
|
| 149 |
+
image_paths: List[str]
|
| 150 |
+
camera: Optional[str]
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
class LongStreamSequence:
|
| 154 |
+
def __init__(
|
| 155 |
+
self,
|
| 156 |
+
name: str,
|
| 157 |
+
images: torch.Tensor,
|
| 158 |
+
image_paths: List[str],
|
| 159 |
+
scene_root: Optional[str] = None,
|
| 160 |
+
image_dir: Optional[str] = None,
|
| 161 |
+
camera: Optional[str] = None,
|
| 162 |
+
):
|
| 163 |
+
self.name = name
|
| 164 |
+
self.images = images
|
| 165 |
+
self.image_paths = image_paths
|
| 166 |
+
self.scene_root = scene_root
|
| 167 |
+
self.image_dir = image_dir
|
| 168 |
+
self.camera = camera
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
def _read_list_file(path: str) -> List[str]:
|
| 172 |
+
with open(path, "r") as f:
|
| 173 |
+
lines = []
|
| 174 |
+
for line in f.readlines():
|
| 175 |
+
line = line.strip()
|
| 176 |
+
if not line:
|
| 177 |
+
continue
|
| 178 |
+
if line.startswith("#"):
|
| 179 |
+
continue
|
| 180 |
+
lines.append(line)
|
| 181 |
+
return lines
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
def _is_generalizable_scene_root(path: str) -> bool:
|
| 185 |
+
return os.path.isdir(os.path.join(path, "images"))
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
def _direct_image_files(dir_path: str) -> List[str]:
|
| 189 |
+
filelist = sorted(glob.glob(os.path.join(dir_path, "*.png")))
|
| 190 |
+
if not filelist:
|
| 191 |
+
filelist = sorted(glob.glob(os.path.join(dir_path, "*.jpg")))
|
| 192 |
+
if not filelist:
|
| 193 |
+
filelist = sorted(glob.glob(os.path.join(dir_path, "*.jpeg")))
|
| 194 |
+
return filelist
|
| 195 |
+
|
| 196 |
+
|
| 197 |
+
class LongStreamDataLoader:
|
| 198 |
+
def __init__(self, cfg: Dict[str, Any]):
|
| 199 |
+
self.cfg = cfg
|
| 200 |
+
self.dataset = cfg.get("dataset", None)
|
| 201 |
+
meta = dataset_metadata.get(self.dataset, {})
|
| 202 |
+
self.img_path = cfg.get("img_path", meta.get("img_path"))
|
| 203 |
+
self.mask_path = cfg.get("mask_path", meta.get("mask_path"))
|
| 204 |
+
self.dir_path_func = meta.get("dir_path_func", lambda p, s: os.path.join(p, s))
|
| 205 |
+
self.mask_path_seq_func = meta.get("mask_path_seq_func", lambda p, s: None)
|
| 206 |
+
self.full_seq = bool(cfg.get("full_seq", meta.get("full_seq", True)))
|
| 207 |
+
self.seq_list = cfg.get("seq_list", None)
|
| 208 |
+
self.stride = int(cfg.get("stride", 1))
|
| 209 |
+
self.max_frames = cfg.get("max_frames", None)
|
| 210 |
+
self.size = int(cfg.get("size", 518))
|
| 211 |
+
self.crop = bool(cfg.get("crop", False))
|
| 212 |
+
self.patch_size = int(cfg.get("patch_size", 14))
|
| 213 |
+
self.format = cfg.get("format", "auto")
|
| 214 |
+
self.data_roots_file = cfg.get("data_roots_file", None)
|
| 215 |
+
self.split = cfg.get("split", None)
|
| 216 |
+
self.camera = cfg.get("camera", None)
|
| 217 |
+
|
| 218 |
+
def _infer_format(self) -> str:
|
| 219 |
+
if self.format in ["relpose", "generalizable"]:
|
| 220 |
+
return self.format
|
| 221 |
+
if self.img_path is None:
|
| 222 |
+
return "relpose"
|
| 223 |
+
if _is_generalizable_scene_root(self.img_path):
|
| 224 |
+
return "generalizable"
|
| 225 |
+
default_list = self.data_roots_file or "data_roots.txt"
|
| 226 |
+
if os.path.exists(os.path.join(self.img_path, default_list)):
|
| 227 |
+
return "generalizable"
|
| 228 |
+
return "relpose"
|
| 229 |
+
|
| 230 |
+
def _resolve_seq_list_generalizable(self) -> List[str]:
|
| 231 |
+
if self.seq_list is not None:
|
| 232 |
+
return list(self.seq_list)
|
| 233 |
+
if self.img_path is None or not os.path.isdir(self.img_path):
|
| 234 |
+
return []
|
| 235 |
+
|
| 236 |
+
if _is_generalizable_scene_root(self.img_path):
|
| 237 |
+
return [self.img_path]
|
| 238 |
+
|
| 239 |
+
candidates = []
|
| 240 |
+
if isinstance(self.data_roots_file, str) and self.data_roots_file:
|
| 241 |
+
candidates.append(self.data_roots_file)
|
| 242 |
+
if isinstance(self.split, str) and self.split:
|
| 243 |
+
split_name = self.split.lower()
|
| 244 |
+
if split_name in ["val", "valid", "validate"]:
|
| 245 |
+
split_name = "validate"
|
| 246 |
+
candidates.append(f"{split_name}_data_roots.txt")
|
| 247 |
+
candidates.append("data_roots.txt")
|
| 248 |
+
candidates.append("train_data_roots.txt")
|
| 249 |
+
candidates.append("validate_data_roots.txt")
|
| 250 |
+
|
| 251 |
+
for fname in candidates:
|
| 252 |
+
path = os.path.join(self.img_path, fname)
|
| 253 |
+
if os.path.exists(path):
|
| 254 |
+
return _read_list_file(path)
|
| 255 |
+
|
| 256 |
+
img_dirs = sorted(
|
| 257 |
+
glob.glob(os.path.join(self.img_path, "**", "images"), recursive=True)
|
| 258 |
+
)
|
| 259 |
+
scene_roots = [os.path.dirname(p) for p in img_dirs]
|
| 260 |
+
|
| 261 |
+
rels = []
|
| 262 |
+
for p in scene_roots:
|
| 263 |
+
try:
|
| 264 |
+
rels.append(os.path.relpath(p, self.img_path))
|
| 265 |
+
except ValueError:
|
| 266 |
+
rels.append(p)
|
| 267 |
+
return sorted(set(rels))
|
| 268 |
+
|
| 269 |
+
def _resolve_seq_list_relpose(self) -> List[str]:
|
| 270 |
+
if self.seq_list is not None:
|
| 271 |
+
return list(self.seq_list)
|
| 272 |
+
meta = dataset_metadata.get(self.dataset, {})
|
| 273 |
+
if self.full_seq:
|
| 274 |
+
if self.img_path is None or not os.path.isdir(self.img_path):
|
| 275 |
+
return []
|
| 276 |
+
seqs = [
|
| 277 |
+
s
|
| 278 |
+
for s in os.listdir(self.img_path)
|
| 279 |
+
if os.path.isdir(os.path.join(self.img_path, s))
|
| 280 |
+
]
|
| 281 |
+
return sorted(seqs)
|
| 282 |
+
seqs = meta.get("seq_list", []) or []
|
| 283 |
+
return list(seqs)
|
| 284 |
+
|
| 285 |
+
def _resolve_seq_list(self) -> List[str]:
|
| 286 |
+
fmt = self._infer_format()
|
| 287 |
+
if fmt == "generalizable":
|
| 288 |
+
return self._resolve_seq_list_generalizable()
|
| 289 |
+
return self._resolve_seq_list_relpose()
|
| 290 |
+
|
| 291 |
+
def _resolve_scene_root(self, seq_entry: str) -> Tuple[str, str]:
|
| 292 |
+
if os.path.isabs(seq_entry) or os.path.sep in seq_entry:
|
| 293 |
+
scene_root = seq_entry
|
| 294 |
+
name = os.path.basename(os.path.normpath(seq_entry))
|
| 295 |
+
else:
|
| 296 |
+
scene_root = os.path.join(self.img_path, seq_entry)
|
| 297 |
+
name = seq_entry
|
| 298 |
+
return name, scene_root
|
| 299 |
+
|
| 300 |
+
def _resolve_image_dir_generalizable(self, scene_root: str) -> Optional[str]:
|
| 301 |
+
images_root = os.path.join(scene_root, "images")
|
| 302 |
+
if not os.path.isdir(images_root):
|
| 303 |
+
return None
|
| 304 |
+
|
| 305 |
+
if isinstance(self.camera, str) and self.camera:
|
| 306 |
+
cam_dir = os.path.join(images_root, self.camera)
|
| 307 |
+
if os.path.isdir(cam_dir):
|
| 308 |
+
return cam_dir
|
| 309 |
+
|
| 310 |
+
if _direct_image_files(images_root):
|
| 311 |
+
return images_root
|
| 312 |
+
|
| 313 |
+
cams = [
|
| 314 |
+
d
|
| 315 |
+
for d in os.listdir(images_root)
|
| 316 |
+
if os.path.isdir(os.path.join(images_root, d))
|
| 317 |
+
]
|
| 318 |
+
if not cams:
|
| 319 |
+
return None
|
| 320 |
+
cams = sorted(cams)
|
| 321 |
+
|
| 322 |
+
frame_dirs = []
|
| 323 |
+
for name in cams:
|
| 324 |
+
child_dir = os.path.join(images_root, name)
|
| 325 |
+
child_images = _direct_image_files(child_dir)
|
| 326 |
+
if child_images:
|
| 327 |
+
frame_dirs.append((name, len(child_images)))
|
| 328 |
+
|
| 329 |
+
if (
|
| 330 |
+
len(cams) > 10
|
| 331 |
+
and len(frame_dirs) == len(cams)
|
| 332 |
+
and max(count for _, count in frame_dirs) == 1
|
| 333 |
+
):
|
| 334 |
+
return images_root
|
| 335 |
+
|
| 336 |
+
return os.path.join(images_root, cams[0])
|
| 337 |
+
|
| 338 |
+
def _camera_from_image_dir(self, image_dir: str) -> Optional[str]:
|
| 339 |
+
parent = os.path.basename(os.path.dirname(image_dir))
|
| 340 |
+
if parent != "images":
|
| 341 |
+
return None
|
| 342 |
+
return os.path.basename(image_dir)
|
| 343 |
+
|
| 344 |
+
def _collect_filelist(self, dir_path: str) -> List[str]:
|
| 345 |
+
filelist = _direct_image_files(dir_path)
|
| 346 |
+
if not filelist:
|
| 347 |
+
nested = []
|
| 348 |
+
child_dirs = sorted(
|
| 349 |
+
d for d in glob.glob(os.path.join(dir_path, "*")) if os.path.isdir(d)
|
| 350 |
+
)
|
| 351 |
+
for child_dir in child_dirs:
|
| 352 |
+
child_images = _direct_image_files(child_dir)
|
| 353 |
+
if child_images:
|
| 354 |
+
nested.append(child_images[0])
|
| 355 |
+
filelist = nested
|
| 356 |
+
if self.stride > 1:
|
| 357 |
+
filelist = filelist[:: self.stride]
|
| 358 |
+
if self.max_frames is not None:
|
| 359 |
+
filelist = filelist[: self.max_frames]
|
| 360 |
+
return filelist
|
| 361 |
+
|
| 362 |
+
def _load_images(self, filelist: List[str]) -> torch.Tensor:
|
| 363 |
+
views = load_images_for_eval(
|
| 364 |
+
filelist,
|
| 365 |
+
size=self.size,
|
| 366 |
+
verbose=False,
|
| 367 |
+
crop=self.crop,
|
| 368 |
+
patch_size=self.patch_size,
|
| 369 |
+
)
|
| 370 |
+
imgs = torch.cat([view["img"] for view in views], dim=0)
|
| 371 |
+
images = imgs.unsqueeze(0)
|
| 372 |
+
images = (images + 1.0) / 2.0
|
| 373 |
+
return images
|
| 374 |
+
|
| 375 |
+
def iter_sequence_infos(self) -> Iterator[LongStreamSequenceInfo]:
|
| 376 |
+
fmt = self._infer_format()
|
| 377 |
+
seqs = self._resolve_seq_list()
|
| 378 |
+
for seq_entry in seqs:
|
| 379 |
+
if fmt == "generalizable":
|
| 380 |
+
seq, scene_root = self._resolve_scene_root(seq_entry)
|
| 381 |
+
dir_path = self._resolve_image_dir_generalizable(scene_root)
|
| 382 |
+
if dir_path is None or not os.path.isdir(dir_path):
|
| 383 |
+
continue
|
| 384 |
+
camera = self._camera_from_image_dir(dir_path)
|
| 385 |
+
else:
|
| 386 |
+
seq = seq_entry
|
| 387 |
+
scene_root = os.path.join(self.img_path, seq)
|
| 388 |
+
dir_path = self.dir_path_func(self.img_path, seq)
|
| 389 |
+
if not os.path.isdir(dir_path):
|
| 390 |
+
continue
|
| 391 |
+
camera = None
|
| 392 |
+
|
| 393 |
+
filelist = self._collect_filelist(dir_path)
|
| 394 |
+
if not filelist:
|
| 395 |
+
continue
|
| 396 |
+
yield LongStreamSequenceInfo(
|
| 397 |
+
name=seq,
|
| 398 |
+
scene_root=scene_root,
|
| 399 |
+
image_dir=dir_path,
|
| 400 |
+
image_paths=filelist,
|
| 401 |
+
camera=camera,
|
| 402 |
+
)
|
| 403 |
+
|
| 404 |
+
def __iter__(self) -> Iterator[LongStreamSequence]:
|
| 405 |
+
for info in self.iter_sequence_infos():
|
| 406 |
+
print(
|
| 407 |
+
f"[longstream] loading sequence {info.name}: {len(info.image_paths)} frames",
|
| 408 |
+
flush=True,
|
| 409 |
+
)
|
| 410 |
+
images = self._load_images(info.image_paths)
|
| 411 |
+
print(
|
| 412 |
+
f"[longstream] loaded sequence {info.name}: {tuple(images.shape)}",
|
| 413 |
+
flush=True,
|
| 414 |
+
)
|
| 415 |
+
yield LongStreamSequence(
|
| 416 |
+
info.name,
|
| 417 |
+
images,
|
| 418 |
+
info.image_paths,
|
| 419 |
+
scene_root=info.scene_root,
|
| 420 |
+
image_dir=info.image_dir,
|
| 421 |
+
camera=info.camera,
|
| 422 |
+
)
|
longstream/demo/__init__.py
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .backend import create_demo_session, load_frame_previews
|
| 2 |
+
from .common import BRANCH_OPTIONS, DISPLAY_MODE_OPTIONS, branch_key, load_metadata
|
| 3 |
+
|
| 4 |
+
__all__ = [
|
| 5 |
+
"BRANCH_OPTIONS",
|
| 6 |
+
"DISPLAY_MODE_OPTIONS",
|
| 7 |
+
"branch_key",
|
| 8 |
+
"create_demo_session",
|
| 9 |
+
"load_frame_previews",
|
| 10 |
+
"load_metadata",
|
| 11 |
+
]
|
longstream/demo/backend.py
ADDED
|
@@ -0,0 +1,495 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import os
|
| 3 |
+
import re
|
| 4 |
+
import shutil
|
| 5 |
+
import tempfile
|
| 6 |
+
from datetime import datetime
|
| 7 |
+
from typing import Iterable, List, Optional, Tuple
|
| 8 |
+
|
| 9 |
+
import cv2
|
| 10 |
+
import numpy as np
|
| 11 |
+
import torch
|
| 12 |
+
import yaml
|
| 13 |
+
|
| 14 |
+
from longstream.core.cli import default_config_path
|
| 15 |
+
from longstream.core.model import LongStreamModel
|
| 16 |
+
from longstream.streaming.keyframe_selector import KeyframeSelector
|
| 17 |
+
from longstream.streaming.refresh import run_batch_refresh, run_streaming_refresh
|
| 18 |
+
from longstream.utils.camera import compose_abs_from_rel
|
| 19 |
+
from longstream.utils.depth import colorize_depth
|
| 20 |
+
from longstream.utils.hub import resolve_checkpoint_path
|
| 21 |
+
from longstream.utils.sky_mask import compute_sky_mask
|
| 22 |
+
from longstream.utils.vendor.dust3r.utils.image import load_images_for_eval
|
| 23 |
+
from longstream.utils.vendor.models.components.utils.pose_enc import (
|
| 24 |
+
pose_encoding_to_extri_intri,
|
| 25 |
+
)
|
| 26 |
+
|
| 27 |
+
from .common import load_metadata, session_file
|
| 28 |
+
|
| 29 |
+
_IMAGE_EXTS = (".png", ".jpg", ".jpeg", ".bmp", ".webp")
|
| 30 |
+
_MODEL_CACHE = {}
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def _resolve_file_path(item) -> str:
|
| 34 |
+
if item is None:
|
| 35 |
+
return ""
|
| 36 |
+
if isinstance(item, str):
|
| 37 |
+
return item
|
| 38 |
+
if isinstance(item, dict) and "name" in item:
|
| 39 |
+
return item["name"]
|
| 40 |
+
if hasattr(item, "name"):
|
| 41 |
+
return item.name
|
| 42 |
+
return str(item)
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def _natural_sort_key(path: str):
|
| 46 |
+
name = os.path.basename(path)
|
| 47 |
+
stem, _ = os.path.splitext(name)
|
| 48 |
+
parts = re.split(r"(\d+)", stem)
|
| 49 |
+
key = []
|
| 50 |
+
for part in parts:
|
| 51 |
+
if not part:
|
| 52 |
+
continue
|
| 53 |
+
if part.isdigit():
|
| 54 |
+
key.append((0, int(part)))
|
| 55 |
+
else:
|
| 56 |
+
key.append((1, part.lower()))
|
| 57 |
+
return key, name.lower()
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def _sorted_image_paths(image_dir: str) -> List[str]:
|
| 61 |
+
files = []
|
| 62 |
+
for name in os.listdir(image_dir):
|
| 63 |
+
if name.lower().endswith(_IMAGE_EXTS):
|
| 64 |
+
files.append(os.path.join(image_dir, name))
|
| 65 |
+
return sorted(files, key=_natural_sort_key)
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def _session_root() -> str:
|
| 69 |
+
root = os.path.join(tempfile.gettempdir(), "longstream_demo_sessions")
|
| 70 |
+
os.makedirs(root, exist_ok=True)
|
| 71 |
+
return root
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def _new_session_dir() -> str:
|
| 75 |
+
stamp = datetime.utcnow().strftime("%Y%m%d_%H%M%S_%f")
|
| 76 |
+
return tempfile.mkdtemp(prefix=f"longstream_{stamp}_", dir=_session_root())
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def _copy_uploaded_images(uploaded_files: Iterable, session_dir: str) -> List[str]:
|
| 80 |
+
input_dir = os.path.join(session_dir, "input_images")
|
| 81 |
+
os.makedirs(input_dir, exist_ok=True)
|
| 82 |
+
copied = []
|
| 83 |
+
sources = sorted(
|
| 84 |
+
(_resolve_file_path(x) for x in uploaded_files if x),
|
| 85 |
+
key=_natural_sort_key,
|
| 86 |
+
)
|
| 87 |
+
for src in sources:
|
| 88 |
+
if not src or not os.path.isfile(src):
|
| 89 |
+
continue
|
| 90 |
+
dst = os.path.join(input_dir, os.path.basename(src))
|
| 91 |
+
shutil.copy2(src, dst)
|
| 92 |
+
copied.append(dst)
|
| 93 |
+
return copied
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
def _extract_uploaded_video(uploaded_video, session_dir: str) -> List[str]:
|
| 97 |
+
src = _resolve_file_path(uploaded_video)
|
| 98 |
+
if not src:
|
| 99 |
+
return []
|
| 100 |
+
if not os.path.isfile(src):
|
| 101 |
+
raise FileNotFoundError(src)
|
| 102 |
+
|
| 103 |
+
input_dir = os.path.join(session_dir, "input_images")
|
| 104 |
+
os.makedirs(input_dir, exist_ok=True)
|
| 105 |
+
cap = cv2.VideoCapture(src)
|
| 106 |
+
if not cap.isOpened():
|
| 107 |
+
raise ValueError(f"unable to open video: {src}")
|
| 108 |
+
|
| 109 |
+
image_paths = []
|
| 110 |
+
frame_id = 0
|
| 111 |
+
while True:
|
| 112 |
+
ok, frame = cap.read()
|
| 113 |
+
if not ok:
|
| 114 |
+
break
|
| 115 |
+
dst = os.path.join(input_dir, f"{frame_id:06d}.png")
|
| 116 |
+
if not cv2.imwrite(dst, frame):
|
| 117 |
+
cap.release()
|
| 118 |
+
raise ValueError(f"failed to write extracted frame: {dst}")
|
| 119 |
+
image_paths.append(dst)
|
| 120 |
+
frame_id += 1
|
| 121 |
+
cap.release()
|
| 122 |
+
|
| 123 |
+
if not image_paths:
|
| 124 |
+
raise ValueError(f"no frames extracted from video: {src}")
|
| 125 |
+
return image_paths
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
def _resize_long_edge(arr, long_edge_size, interpolation):
|
| 129 |
+
h, w = arr.shape[:2]
|
| 130 |
+
scale = float(long_edge_size) / float(max(h, w))
|
| 131 |
+
new_w = int(round(w * scale))
|
| 132 |
+
new_h = int(round(h * scale))
|
| 133 |
+
return cv2.resize(arr, (new_w, new_h), interpolation=interpolation)
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
def _prepare_mask_for_model(
|
| 137 |
+
mask, size, crop, patch_size, target_shape, square_ok=False
|
| 138 |
+
):
|
| 139 |
+
if mask is None:
|
| 140 |
+
return None
|
| 141 |
+
h0, w0 = mask.shape[:2]
|
| 142 |
+
long_edge = round(size * max(w0 / h0, h0 / w0)) if size == 224 else size
|
| 143 |
+
mask = _resize_long_edge(mask, long_edge, cv2.INTER_NEAREST)
|
| 144 |
+
|
| 145 |
+
h, w = mask.shape[:2]
|
| 146 |
+
cx, cy = w // 2, h // 2
|
| 147 |
+
if size == 224:
|
| 148 |
+
half = min(cx, cy)
|
| 149 |
+
if crop:
|
| 150 |
+
mask = mask[cy - half : cy + half, cx - half : cx + half]
|
| 151 |
+
else:
|
| 152 |
+
mask = cv2.resize(
|
| 153 |
+
mask, (2 * half, 2 * half), interpolation=cv2.INTER_NEAREST
|
| 154 |
+
)
|
| 155 |
+
else:
|
| 156 |
+
halfw = ((2 * cx) // patch_size) * (patch_size // 2)
|
| 157 |
+
halfh = ((2 * cy) // patch_size) * (patch_size // 2)
|
| 158 |
+
if not square_ok and w == h:
|
| 159 |
+
halfh = int(3 * halfw / 4)
|
| 160 |
+
if crop:
|
| 161 |
+
mask = mask[cy - halfh : cy + halfh, cx - halfw : cx + halfw]
|
| 162 |
+
else:
|
| 163 |
+
mask = cv2.resize(
|
| 164 |
+
mask, (2 * halfw, 2 * halfh), interpolation=cv2.INTER_NEAREST
|
| 165 |
+
)
|
| 166 |
+
|
| 167 |
+
if mask.shape[:2] != tuple(target_shape):
|
| 168 |
+
mask = cv2.resize(
|
| 169 |
+
mask, (target_shape[1], target_shape[0]), interpolation=cv2.INTER_NEAREST
|
| 170 |
+
)
|
| 171 |
+
return mask.astype(np.uint8, copy=False)
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
def _load_base_config(config_path: Optional[str] = None) -> dict:
|
| 175 |
+
path = config_path or default_config_path()
|
| 176 |
+
with open(path, "r") as f:
|
| 177 |
+
return yaml.safe_load(f) or {}
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
def _resolve_demo_checkpoint(checkpoint: str) -> str:
|
| 181 |
+
local_candidates = []
|
| 182 |
+
for candidate in [checkpoint, os.getenv("LONGSTREAM_CHECKPOINT", "")]:
|
| 183 |
+
if isinstance(candidate, str) and candidate:
|
| 184 |
+
local_candidates.append(candidate)
|
| 185 |
+
|
| 186 |
+
for candidate in local_candidates:
|
| 187 |
+
if os.path.exists(candidate):
|
| 188 |
+
return os.path.abspath(candidate)
|
| 189 |
+
|
| 190 |
+
hf_cfg = {
|
| 191 |
+
"repo_id": os.getenv("LONGSTREAM_HF_REPO"),
|
| 192 |
+
"filename": os.getenv("LONGSTREAM_HF_FILE"),
|
| 193 |
+
"revision": os.getenv("LONGSTREAM_HF_REVISION"),
|
| 194 |
+
"local_dir": os.getenv("LONGSTREAM_HF_LOCAL_DIR", "checkpoints"),
|
| 195 |
+
}
|
| 196 |
+
resolved = resolve_checkpoint_path(None, hf_cfg)
|
| 197 |
+
if resolved and os.path.exists(resolved):
|
| 198 |
+
return os.path.abspath(resolved)
|
| 199 |
+
|
| 200 |
+
if hf_cfg["repo_id"] and hf_cfg["filename"]:
|
| 201 |
+
raise FileNotFoundError(
|
| 202 |
+
"checkpoint not found locally and Hugging Face resolution failed: "
|
| 203 |
+
f"repo_id={hf_cfg['repo_id']} filename={hf_cfg['filename']}"
|
| 204 |
+
)
|
| 205 |
+
|
| 206 |
+
searched = ", ".join(local_candidates) if local_candidates else "<none>"
|
| 207 |
+
raise FileNotFoundError(
|
| 208 |
+
"checkpoint not found. "
|
| 209 |
+
f"searched local paths: {searched}. "
|
| 210 |
+
"You can also set LONGSTREAM_HF_REPO and LONGSTREAM_HF_FILE."
|
| 211 |
+
)
|
| 212 |
+
|
| 213 |
+
|
| 214 |
+
def _model_device(device: str) -> str:
|
| 215 |
+
if device == "cuda" and not torch.cuda.is_available():
|
| 216 |
+
return "cpu"
|
| 217 |
+
return device
|
| 218 |
+
|
| 219 |
+
|
| 220 |
+
def _cache_key(checkpoint: str, device: str, model_cfg: dict) -> Tuple[str, str, str]:
|
| 221 |
+
rel_cfg = json.dumps(model_cfg.get("longstream_cfg", {}), sort_keys=True)
|
| 222 |
+
return checkpoint, device, rel_cfg
|
| 223 |
+
|
| 224 |
+
|
| 225 |
+
def get_or_load_model(checkpoint: str, device: str, model_cfg: dict) -> LongStreamModel:
|
| 226 |
+
device = _model_device(device)
|
| 227 |
+
cfg = json.loads(json.dumps(model_cfg))
|
| 228 |
+
cfg["checkpoint"] = checkpoint
|
| 229 |
+
key = _cache_key(checkpoint, device, cfg)
|
| 230 |
+
model = _MODEL_CACHE.get(key)
|
| 231 |
+
if model is None:
|
| 232 |
+
model = LongStreamModel(cfg).to(device)
|
| 233 |
+
model.eval()
|
| 234 |
+
_MODEL_CACHE.clear()
|
| 235 |
+
_MODEL_CACHE[key] = model
|
| 236 |
+
return model
|
| 237 |
+
|
| 238 |
+
|
| 239 |
+
def _load_images(
|
| 240 |
+
image_paths: List[str], size: int, crop: bool, patch_size: int
|
| 241 |
+
) -> torch.Tensor:
|
| 242 |
+
views = load_images_for_eval(
|
| 243 |
+
image_paths, size=size, verbose=False, crop=crop, patch_size=patch_size
|
| 244 |
+
)
|
| 245 |
+
imgs = torch.cat([view["img"] for view in views], dim=0)
|
| 246 |
+
images = (imgs.unsqueeze(0) + 1.0) / 2.0
|
| 247 |
+
return images
|
| 248 |
+
|
| 249 |
+
|
| 250 |
+
def _select_keyframes(images: torch.Tensor, keyframe_stride: int, keyframe_mode: str):
|
| 251 |
+
selector = KeyframeSelector(
|
| 252 |
+
min_interval=keyframe_stride,
|
| 253 |
+
max_interval=keyframe_stride,
|
| 254 |
+
force_first=True,
|
| 255 |
+
mode="random" if keyframe_mode == "random" else "fixed",
|
| 256 |
+
)
|
| 257 |
+
return selector.select_keyframes(images.shape[1], images.shape[0], images.device)
|
| 258 |
+
|
| 259 |
+
|
| 260 |
+
def _run_model(images: torch.Tensor, model: LongStreamModel, infer_cfg: dict):
|
| 261 |
+
keyframe_stride = int(infer_cfg.get("keyframe_stride", 8))
|
| 262 |
+
keyframe_mode = infer_cfg.get("keyframe_mode", "fixed")
|
| 263 |
+
refresh = int(infer_cfg.get("refresh", 4))
|
| 264 |
+
mode = infer_cfg.get("mode", "streaming_refresh")
|
| 265 |
+
streaming_mode = infer_cfg.get("streaming_mode", "causal")
|
| 266 |
+
window_size = int(infer_cfg.get("window_size", 48))
|
| 267 |
+
rel_pose_cfg = infer_cfg.get("rel_pose_head_cfg", {"num_iterations": 4})
|
| 268 |
+
|
| 269 |
+
is_keyframe, keyframe_indices = _select_keyframes(
|
| 270 |
+
images, keyframe_stride, keyframe_mode
|
| 271 |
+
)
|
| 272 |
+
if mode == "batch_refresh":
|
| 273 |
+
outputs = run_batch_refresh(
|
| 274 |
+
model,
|
| 275 |
+
images,
|
| 276 |
+
is_keyframe,
|
| 277 |
+
keyframe_indices,
|
| 278 |
+
streaming_mode,
|
| 279 |
+
keyframe_stride,
|
| 280 |
+
refresh,
|
| 281 |
+
rel_pose_cfg,
|
| 282 |
+
)
|
| 283 |
+
elif mode == "streaming_refresh":
|
| 284 |
+
outputs = run_streaming_refresh(
|
| 285 |
+
model,
|
| 286 |
+
images,
|
| 287 |
+
is_keyframe,
|
| 288 |
+
keyframe_indices,
|
| 289 |
+
streaming_mode,
|
| 290 |
+
window_size,
|
| 291 |
+
refresh,
|
| 292 |
+
rel_pose_cfg,
|
| 293 |
+
)
|
| 294 |
+
else:
|
| 295 |
+
raise ValueError(f"Unsupported demo inference mode: {mode}")
|
| 296 |
+
return outputs, keyframe_indices
|
| 297 |
+
|
| 298 |
+
|
| 299 |
+
def _compute_pose_outputs(
|
| 300 |
+
outputs: dict, keyframe_indices: torch.Tensor, image_hw: Tuple[int, int]
|
| 301 |
+
):
|
| 302 |
+
if "rel_pose_enc" in outputs:
|
| 303 |
+
rel_pose_enc = outputs["rel_pose_enc"][0]
|
| 304 |
+
abs_pose_enc = compose_abs_from_rel(rel_pose_enc, keyframe_indices[0])
|
| 305 |
+
extri, intri = pose_encoding_to_extri_intri(
|
| 306 |
+
abs_pose_enc[None], image_size_hw=image_hw
|
| 307 |
+
)
|
| 308 |
+
return (
|
| 309 |
+
rel_pose_enc.detach().cpu().numpy(),
|
| 310 |
+
extri[0].detach().cpu().numpy(),
|
| 311 |
+
intri[0].detach().cpu().numpy(),
|
| 312 |
+
)
|
| 313 |
+
if "pose_enc" in outputs:
|
| 314 |
+
pose_enc = outputs["pose_enc"][0]
|
| 315 |
+
extri, intri = pose_encoding_to_extri_intri(
|
| 316 |
+
pose_enc[None], image_size_hw=image_hw
|
| 317 |
+
)
|
| 318 |
+
return None, extri[0].detach().cpu().numpy(), intri[0].detach().cpu().numpy()
|
| 319 |
+
raise RuntimeError("Model outputs contain neither rel_pose_enc nor pose_enc")
|
| 320 |
+
|
| 321 |
+
|
| 322 |
+
def _compute_sky_masks(
|
| 323 |
+
image_paths: List[str],
|
| 324 |
+
target_shape: Tuple[int, int],
|
| 325 |
+
data_cfg: dict,
|
| 326 |
+
skyseg_path: str,
|
| 327 |
+
session_dir: str,
|
| 328 |
+
):
|
| 329 |
+
raw_masks = compute_sky_mask(
|
| 330 |
+
image_paths, skyseg_path, os.path.join(session_dir, "sky_masks_raw")
|
| 331 |
+
)
|
| 332 |
+
if raw_masks is None:
|
| 333 |
+
return None
|
| 334 |
+
masks = []
|
| 335 |
+
for mask in raw_masks:
|
| 336 |
+
masks.append(
|
| 337 |
+
_prepare_mask_for_model(
|
| 338 |
+
mask,
|
| 339 |
+
size=int(data_cfg.get("size", 518)),
|
| 340 |
+
crop=bool(data_cfg.get("crop", False)),
|
| 341 |
+
patch_size=int(data_cfg.get("patch_size", 14)),
|
| 342 |
+
target_shape=target_shape,
|
| 343 |
+
)
|
| 344 |
+
)
|
| 345 |
+
return np.stack(masks, axis=0)
|
| 346 |
+
|
| 347 |
+
|
| 348 |
+
def create_demo_session(
|
| 349 |
+
image_dir: str,
|
| 350 |
+
uploaded_files,
|
| 351 |
+
uploaded_video,
|
| 352 |
+
checkpoint: str,
|
| 353 |
+
device: str,
|
| 354 |
+
mode: str,
|
| 355 |
+
streaming_mode: str,
|
| 356 |
+
keyframe_stride: int,
|
| 357 |
+
refresh: int,
|
| 358 |
+
window_size: int,
|
| 359 |
+
compute_sky: bool,
|
| 360 |
+
config_path: Optional[str] = None,
|
| 361 |
+
) -> str:
|
| 362 |
+
checkpoint = _resolve_demo_checkpoint(checkpoint)
|
| 363 |
+
|
| 364 |
+
session_dir = _new_session_dir()
|
| 365 |
+
base_cfg = _load_base_config(config_path)
|
| 366 |
+
data_cfg = dict(base_cfg.get("data", {}))
|
| 367 |
+
model_cfg = dict(base_cfg.get("model", {}))
|
| 368 |
+
infer_cfg = dict(base_cfg.get("inference", {}))
|
| 369 |
+
|
| 370 |
+
if image_dir:
|
| 371 |
+
image_dir = os.path.abspath(image_dir)
|
| 372 |
+
if not os.path.isdir(image_dir):
|
| 373 |
+
raise FileNotFoundError(f"image_dir not found: {image_dir}")
|
| 374 |
+
image_paths = _sorted_image_paths(image_dir)
|
| 375 |
+
input_root = image_dir
|
| 376 |
+
elif uploaded_video:
|
| 377 |
+
image_paths = _extract_uploaded_video(uploaded_video, session_dir)
|
| 378 |
+
input_root = _resolve_file_path(uploaded_video)
|
| 379 |
+
else:
|
| 380 |
+
image_paths = _copy_uploaded_images(uploaded_files or [], session_dir)
|
| 381 |
+
input_root = os.path.dirname(image_paths[0]) if image_paths else ""
|
| 382 |
+
|
| 383 |
+
if not image_paths:
|
| 384 |
+
raise ValueError("No input images found")
|
| 385 |
+
|
| 386 |
+
data_cfg["size"] = int(data_cfg.get("size", 518))
|
| 387 |
+
data_cfg["crop"] = bool(data_cfg.get("crop", False))
|
| 388 |
+
data_cfg["patch_size"] = int(data_cfg.get("patch_size", 14))
|
| 389 |
+
|
| 390 |
+
device = _model_device(device)
|
| 391 |
+
model = get_or_load_model(checkpoint, device, model_cfg)
|
| 392 |
+
|
| 393 |
+
images = _load_images(
|
| 394 |
+
image_paths, data_cfg["size"], data_cfg["crop"], data_cfg["patch_size"]
|
| 395 |
+
)
|
| 396 |
+
infer_cfg.update(
|
| 397 |
+
{
|
| 398 |
+
"mode": mode,
|
| 399 |
+
"streaming_mode": streaming_mode,
|
| 400 |
+
"keyframe_stride": int(keyframe_stride),
|
| 401 |
+
"refresh": int(refresh),
|
| 402 |
+
"window_size": int(window_size),
|
| 403 |
+
}
|
| 404 |
+
)
|
| 405 |
+
|
| 406 |
+
with torch.no_grad():
|
| 407 |
+
outputs, keyframe_indices = _run_model(images, model, infer_cfg)
|
| 408 |
+
h, w = images.shape[-2:]
|
| 409 |
+
rel_pose_enc, extri, intri = _compute_pose_outputs(
|
| 410 |
+
outputs, keyframe_indices, (h, w)
|
| 411 |
+
)
|
| 412 |
+
point_head = (
|
| 413 |
+
outputs["world_points"][0]
|
| 414 |
+
.detach()
|
| 415 |
+
.cpu()
|
| 416 |
+
.numpy()
|
| 417 |
+
.astype(np.float32, copy=False)
|
| 418 |
+
)
|
| 419 |
+
depth = (
|
| 420 |
+
outputs["depth"][0, :, :, :, 0]
|
| 421 |
+
.detach()
|
| 422 |
+
.cpu()
|
| 423 |
+
.numpy()
|
| 424 |
+
.astype(np.float32, copy=False)
|
| 425 |
+
)
|
| 426 |
+
|
| 427 |
+
if device == "cuda":
|
| 428 |
+
torch.cuda.empty_cache()
|
| 429 |
+
|
| 430 |
+
images_uint8 = np.clip(
|
| 431 |
+
images[0].permute(0, 2, 3, 1).cpu().numpy() * 255.0, 0, 255
|
| 432 |
+
).astype(np.uint8)
|
| 433 |
+
sky_masks = None
|
| 434 |
+
if compute_sky:
|
| 435 |
+
skyseg_path = os.path.join(
|
| 436 |
+
os.path.dirname(os.path.dirname(os.path.dirname(__file__))), "skyseg.onnx"
|
| 437 |
+
)
|
| 438 |
+
sky_masks = _compute_sky_masks(
|
| 439 |
+
image_paths, (h, w), data_cfg, skyseg_path, session_dir
|
| 440 |
+
)
|
| 441 |
+
|
| 442 |
+
np.save(session_file(session_dir, "images.npy"), images_uint8)
|
| 443 |
+
np.save(session_file(session_dir, "depth.npy"), depth)
|
| 444 |
+
np.save(session_file(session_dir, "point_head.npy"), point_head)
|
| 445 |
+
np.save(session_file(session_dir, "w2c.npy"), extri)
|
| 446 |
+
np.save(session_file(session_dir, "intri.npy"), intri)
|
| 447 |
+
if rel_pose_enc is not None:
|
| 448 |
+
np.save(
|
| 449 |
+
session_file(session_dir, "rel_pose_enc.npy"),
|
| 450 |
+
rel_pose_enc.astype(np.float32, copy=False),
|
| 451 |
+
)
|
| 452 |
+
if sky_masks is not None:
|
| 453 |
+
np.save(
|
| 454 |
+
session_file(session_dir, "sky_masks.npy"),
|
| 455 |
+
sky_masks.astype(np.uint8, copy=False),
|
| 456 |
+
)
|
| 457 |
+
|
| 458 |
+
sky_removed_ratio = None
|
| 459 |
+
if sky_masks is not None:
|
| 460 |
+
sky_removed_ratio = float(1.0 - (sky_masks > 0).mean())
|
| 461 |
+
|
| 462 |
+
metadata = {
|
| 463 |
+
"session_dir": session_dir,
|
| 464 |
+
"created_at": datetime.utcnow().isoformat() + "Z",
|
| 465 |
+
"checkpoint": os.path.abspath(checkpoint),
|
| 466 |
+
"device": device,
|
| 467 |
+
"mode": mode,
|
| 468 |
+
"streaming_mode": streaming_mode,
|
| 469 |
+
"keyframe_stride": int(keyframe_stride),
|
| 470 |
+
"refresh": int(refresh),
|
| 471 |
+
"window_size": int(window_size),
|
| 472 |
+
"num_frames": int(images_uint8.shape[0]),
|
| 473 |
+
"height": int(images_uint8.shape[1]),
|
| 474 |
+
"width": int(images_uint8.shape[2]),
|
| 475 |
+
"input_root": input_root,
|
| 476 |
+
"image_paths": image_paths,
|
| 477 |
+
"has_sky_masks": bool(sky_masks is not None),
|
| 478 |
+
"sky_removed_ratio": sky_removed_ratio,
|
| 479 |
+
}
|
| 480 |
+
with open(session_file(session_dir, "metadata.json"), "w") as f:
|
| 481 |
+
json.dump(metadata, f, indent=2)
|
| 482 |
+
|
| 483 |
+
del outputs
|
| 484 |
+
return session_dir
|
| 485 |
+
|
| 486 |
+
|
| 487 |
+
def load_frame_previews(session_dir: str, frame_index: int):
|
| 488 |
+
meta = load_metadata(session_dir)
|
| 489 |
+
frame_index = int(np.clip(frame_index, 0, meta["num_frames"] - 1))
|
| 490 |
+
images = np.load(session_file(session_dir, "images.npy"), mmap_mode="r")
|
| 491 |
+
depth = np.load(session_file(session_dir, "depth.npy"), mmap_mode="r")
|
| 492 |
+
rgb = np.array(images[frame_index])
|
| 493 |
+
depth_color = colorize_depth(np.array(depth[frame_index]), cmap="plasma")
|
| 494 |
+
label = f"Frame {frame_index + 1}/{meta['num_frames']}"
|
| 495 |
+
return rgb, depth_color, label
|
longstream/demo/common.py
ADDED
|
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import os
|
| 3 |
+
from typing import List
|
| 4 |
+
|
| 5 |
+
import numpy as np
|
| 6 |
+
|
| 7 |
+
BRANCH_OPTIONS = [
|
| 8 |
+
"Point Head + Pose",
|
| 9 |
+
"Depth Projection + Pose",
|
| 10 |
+
]
|
| 11 |
+
BRANCH_TO_KEY = {
|
| 12 |
+
"Point Head + Pose": "point_head",
|
| 13 |
+
"Depth Projection + Pose": "depth_projection",
|
| 14 |
+
}
|
| 15 |
+
DISPLAY_MODE_OPTIONS = [
|
| 16 |
+
"Current Frame",
|
| 17 |
+
"Accumulate to Frame",
|
| 18 |
+
"All Frames",
|
| 19 |
+
]
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def branch_key(label: str) -> str:
|
| 23 |
+
return BRANCH_TO_KEY.get(label, "point_head")
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def session_file(session_dir: str, name: str) -> str:
|
| 27 |
+
return os.path.join(session_dir, name)
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def load_metadata(session_dir: str) -> dict:
|
| 31 |
+
with open(session_file(session_dir, "metadata.json"), "r") as f:
|
| 32 |
+
return json.load(f)
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def selected_frame_indices(
|
| 36 |
+
num_frames: int, frame_index: int, display_mode: str
|
| 37 |
+
) -> List[int]:
|
| 38 |
+
if num_frames <= 0:
|
| 39 |
+
return []
|
| 40 |
+
frame_index = int(np.clip(frame_index, 0, num_frames - 1))
|
| 41 |
+
if display_mode == "Current Frame":
|
| 42 |
+
return [frame_index]
|
| 43 |
+
if display_mode == "Accumulate to Frame":
|
| 44 |
+
return list(range(frame_index + 1))
|
| 45 |
+
return list(range(num_frames))
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def as_4x4(w2c):
|
| 49 |
+
w2c = np.asarray(w2c, dtype=np.float64)
|
| 50 |
+
if w2c.shape == (4, 4):
|
| 51 |
+
return w2c
|
| 52 |
+
out = np.eye(4, dtype=np.float64)
|
| 53 |
+
out[:3, :4] = w2c
|
| 54 |
+
return out
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
_VIEW_ROT = np.array(
|
| 58 |
+
[
|
| 59 |
+
[1.0, 0.0, 0.0],
|
| 60 |
+
[0.0, 0.0, 1.0],
|
| 61 |
+
[0.0, -1.0, 0.0],
|
| 62 |
+
],
|
| 63 |
+
dtype=np.float64,
|
| 64 |
+
)
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def world_to_view(points):
|
| 68 |
+
points = np.asarray(points, dtype=np.float64)
|
| 69 |
+
return points @ _VIEW_ROT.T
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def camera_center_from_w2c(w2c):
|
| 73 |
+
c2w = np.linalg.inv(as_4x4(w2c))
|
| 74 |
+
return c2w[:3, 3]
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
def c2w_in_view_space(w2c, origin_shift=None):
|
| 78 |
+
c2w = np.linalg.inv(as_4x4(w2c))
|
| 79 |
+
out = np.eye(4, dtype=np.float64)
|
| 80 |
+
out[:3, :3] = _VIEW_ROT @ c2w[:3, :3]
|
| 81 |
+
out[:3, 3] = world_to_view(c2w[:3, 3][None])[0]
|
| 82 |
+
if origin_shift is not None:
|
| 83 |
+
out[:3, 3] -= np.asarray(origin_shift, dtype=np.float64)
|
| 84 |
+
return out
|
longstream/demo/export.py
ADDED
|
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
|
| 3 |
+
import numpy as np
|
| 4 |
+
|
| 5 |
+
from .geometry import camera_geometry, collect_points
|
| 6 |
+
|
| 7 |
+
_CAMERA_COLORS = np.array(
|
| 8 |
+
[
|
| 9 |
+
[239, 68, 68, 255],
|
| 10 |
+
[14, 165, 233, 255],
|
| 11 |
+
[34, 197, 94, 255],
|
| 12 |
+
[245, 158, 11, 255],
|
| 13 |
+
],
|
| 14 |
+
dtype=np.uint8,
|
| 15 |
+
)
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def _camera_mesh(center, corners, color):
|
| 19 |
+
import trimesh
|
| 20 |
+
|
| 21 |
+
vertices = np.vstack([center[None], corners]).astype(np.float32)
|
| 22 |
+
faces = np.array(
|
| 23 |
+
[
|
| 24 |
+
[0, 1, 2],
|
| 25 |
+
[0, 2, 3],
|
| 26 |
+
[0, 3, 4],
|
| 27 |
+
[0, 4, 1],
|
| 28 |
+
[1, 2, 3],
|
| 29 |
+
[1, 3, 4],
|
| 30 |
+
],
|
| 31 |
+
dtype=np.int64,
|
| 32 |
+
)
|
| 33 |
+
mesh = trimesh.Trimesh(vertices=vertices, faces=faces, process=False)
|
| 34 |
+
mesh.visual.face_colors = np.tile(color[None], (faces.shape[0], 1))
|
| 35 |
+
return mesh
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def export_glb(
|
| 39 |
+
session_dir: str,
|
| 40 |
+
branch: str,
|
| 41 |
+
display_mode: str,
|
| 42 |
+
frame_index: int,
|
| 43 |
+
mask_sky: bool,
|
| 44 |
+
show_cameras: bool,
|
| 45 |
+
camera_scale: float,
|
| 46 |
+
max_points: int,
|
| 47 |
+
) -> str:
|
| 48 |
+
import trimesh
|
| 49 |
+
|
| 50 |
+
points, colors, _ = collect_points(
|
| 51 |
+
session_dir=session_dir,
|
| 52 |
+
branch=branch,
|
| 53 |
+
display_mode=display_mode,
|
| 54 |
+
frame_index=frame_index,
|
| 55 |
+
mask_sky=mask_sky,
|
| 56 |
+
max_points=max_points,
|
| 57 |
+
seed=13,
|
| 58 |
+
)
|
| 59 |
+
if len(points) == 0:
|
| 60 |
+
raise ValueError("No valid points to export")
|
| 61 |
+
|
| 62 |
+
scene = trimesh.Scene()
|
| 63 |
+
scene.add_geometry(trimesh.PointCloud(vertices=points, colors=colors))
|
| 64 |
+
|
| 65 |
+
if show_cameras:
|
| 66 |
+
_, frustums, _ = camera_geometry(
|
| 67 |
+
session_dir=session_dir,
|
| 68 |
+
display_mode=display_mode,
|
| 69 |
+
frame_index=frame_index,
|
| 70 |
+
camera_scale_ratio=camera_scale,
|
| 71 |
+
points_hint=points,
|
| 72 |
+
)
|
| 73 |
+
for idx, (center, corners) in enumerate(frustums):
|
| 74 |
+
scene.add_geometry(
|
| 75 |
+
_camera_mesh(center, corners, _CAMERA_COLORS[idx % len(_CAMERA_COLORS)])
|
| 76 |
+
)
|
| 77 |
+
|
| 78 |
+
export_dir = os.path.join(session_dir, "exports")
|
| 79 |
+
os.makedirs(export_dir, exist_ok=True)
|
| 80 |
+
branch_slug = branch.lower().replace(" + ", "_").replace(" ", "_")
|
| 81 |
+
mode_slug = display_mode.replace(" ", "_").lower()
|
| 82 |
+
filename = f"{branch_slug}_{mode_slug}_{frame_index:04d}_sky{int(mask_sky)}_cam{int(show_cameras)}.glb"
|
| 83 |
+
path = os.path.join(export_dir, filename)
|
| 84 |
+
scene.export(path)
|
| 85 |
+
return path
|
longstream/demo/geometry.py
ADDED
|
@@ -0,0 +1,211 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from typing import List, Optional, Tuple
|
| 3 |
+
|
| 4 |
+
import numpy as np
|
| 5 |
+
|
| 6 |
+
from .common import (
|
| 7 |
+
branch_key,
|
| 8 |
+
c2w_in_view_space,
|
| 9 |
+
load_metadata,
|
| 10 |
+
selected_frame_indices,
|
| 11 |
+
session_file,
|
| 12 |
+
world_to_view,
|
| 13 |
+
)
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def _origin_shift(w2c_all) -> np.ndarray:
|
| 17 |
+
first = c2w_in_view_space(w2c_all[0])
|
| 18 |
+
return first[:3, 3].copy()
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def _sample_flat_indices(
|
| 22 |
+
valid_indices: np.ndarray, budget: Optional[int], rng: np.random.Generator
|
| 23 |
+
) -> np.ndarray:
|
| 24 |
+
if budget is None or budget <= 0 or valid_indices.size <= budget:
|
| 25 |
+
return valid_indices
|
| 26 |
+
keep = rng.choice(valid_indices.size, size=int(budget), replace=False)
|
| 27 |
+
return valid_indices[keep]
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def _depth_points_from_flat(depth, intri, w2c, flat_indices):
|
| 31 |
+
h, w = depth.shape
|
| 32 |
+
ys = flat_indices // w
|
| 33 |
+
xs = flat_indices % w
|
| 34 |
+
z = depth.reshape(-1)[flat_indices].astype(np.float64)
|
| 35 |
+
fx = float(intri[0, 0])
|
| 36 |
+
fy = float(intri[1, 1])
|
| 37 |
+
cx = float(intri[0, 2])
|
| 38 |
+
cy = float(intri[1, 2])
|
| 39 |
+
x = (xs.astype(np.float64) - cx) * z / max(fx, 1e-12)
|
| 40 |
+
y = (ys.astype(np.float64) - cy) * z / max(fy, 1e-12)
|
| 41 |
+
pts_cam = np.stack([x, y, z], axis=1)
|
| 42 |
+
R = w2c[:3, :3].astype(np.float64)
|
| 43 |
+
t = w2c[:3, 3].astype(np.float64)
|
| 44 |
+
return (R.T @ (pts_cam.T - t[:, None])).T.astype(np.float32, copy=False)
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def _camera_points_to_world(points, w2c):
|
| 48 |
+
pts = np.asarray(points, dtype=np.float64).reshape(-1, 3)
|
| 49 |
+
R = w2c[:3, :3].astype(np.float64)
|
| 50 |
+
t = w2c[:3, 3].astype(np.float64)
|
| 51 |
+
return (R.T @ (pts.T - t[:, None])).T.astype(np.float32, copy=False)
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def collect_points(
|
| 55 |
+
session_dir: str,
|
| 56 |
+
branch: str,
|
| 57 |
+
display_mode: str,
|
| 58 |
+
frame_index: int,
|
| 59 |
+
mask_sky: bool,
|
| 60 |
+
max_points: Optional[int],
|
| 61 |
+
seed: int = 0,
|
| 62 |
+
):
|
| 63 |
+
branch = branch_key(branch)
|
| 64 |
+
meta = load_metadata(session_dir)
|
| 65 |
+
frame_ids = selected_frame_indices(meta["num_frames"], frame_index, display_mode)
|
| 66 |
+
if not frame_ids:
|
| 67 |
+
return (
|
| 68 |
+
np.empty((0, 3), dtype=np.float32),
|
| 69 |
+
np.empty((0, 3), dtype=np.uint8),
|
| 70 |
+
np.zeros(3, dtype=np.float64),
|
| 71 |
+
)
|
| 72 |
+
|
| 73 |
+
images = np.load(session_file(session_dir, "images.npy"), mmap_mode="r")
|
| 74 |
+
w2c = np.load(session_file(session_dir, "w2c.npy"), mmap_mode="r")
|
| 75 |
+
origin_shift = _origin_shift(w2c)
|
| 76 |
+
sky = None
|
| 77 |
+
if mask_sky and os.path.exists(session_file(session_dir, "sky_masks.npy")):
|
| 78 |
+
sky = np.load(session_file(session_dir, "sky_masks.npy"), mmap_mode="r")
|
| 79 |
+
|
| 80 |
+
if branch == "point_head":
|
| 81 |
+
point_head = np.load(session_file(session_dir, "point_head.npy"), mmap_mode="r")
|
| 82 |
+
source = point_head
|
| 83 |
+
depth = None
|
| 84 |
+
intri = None
|
| 85 |
+
else:
|
| 86 |
+
source = None
|
| 87 |
+
depth = np.load(session_file(session_dir, "depth.npy"), mmap_mode="r")
|
| 88 |
+
intri = np.load(session_file(session_dir, "intri.npy"), mmap_mode="r")
|
| 89 |
+
|
| 90 |
+
per_frame_budget = None
|
| 91 |
+
if max_points is not None and max_points > 0:
|
| 92 |
+
per_frame_budget = max(int(max_points) // max(len(frame_ids), 1), 1)
|
| 93 |
+
|
| 94 |
+
rng = np.random.default_rng(seed)
|
| 95 |
+
points = []
|
| 96 |
+
colors = []
|
| 97 |
+
for idx in frame_ids:
|
| 98 |
+
rgb_flat = images[idx].reshape(-1, 3)
|
| 99 |
+
if branch == "point_head":
|
| 100 |
+
pts_map = source[idx]
|
| 101 |
+
valid = np.isfinite(pts_map).all(axis=-1).reshape(-1)
|
| 102 |
+
if sky is not None:
|
| 103 |
+
valid &= sky[idx].reshape(-1) > 0
|
| 104 |
+
flat = np.flatnonzero(valid)
|
| 105 |
+
if flat.size == 0:
|
| 106 |
+
continue
|
| 107 |
+
flat = _sample_flat_indices(flat, per_frame_budget, rng)
|
| 108 |
+
pts_cam = pts_map.reshape(-1, 3)[flat]
|
| 109 |
+
pts_world = _camera_points_to_world(pts_cam, w2c[idx])
|
| 110 |
+
else:
|
| 111 |
+
depth_i = depth[idx]
|
| 112 |
+
valid = (np.isfinite(depth_i) & (depth_i > 0)).reshape(-1)
|
| 113 |
+
if sky is not None:
|
| 114 |
+
valid &= sky[idx].reshape(-1) > 0
|
| 115 |
+
flat = np.flatnonzero(valid)
|
| 116 |
+
if flat.size == 0:
|
| 117 |
+
continue
|
| 118 |
+
flat = _sample_flat_indices(flat, per_frame_budget, rng)
|
| 119 |
+
pts_world = _depth_points_from_flat(depth_i, intri[idx], w2c[idx], flat)
|
| 120 |
+
|
| 121 |
+
pts_view = world_to_view(pts_world) - origin_shift[None]
|
| 122 |
+
points.append(pts_view.astype(np.float32, copy=False))
|
| 123 |
+
colors.append(rgb_flat[flat].astype(np.uint8, copy=False))
|
| 124 |
+
|
| 125 |
+
if not points:
|
| 126 |
+
return (
|
| 127 |
+
np.empty((0, 3), dtype=np.float32),
|
| 128 |
+
np.empty((0, 3), dtype=np.uint8),
|
| 129 |
+
origin_shift,
|
| 130 |
+
)
|
| 131 |
+
return np.concatenate(points, axis=0), np.concatenate(colors, axis=0), origin_shift
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
def _frustum_corners_camera(intri, image_hw, depth_scale):
|
| 135 |
+
h, w = image_hw
|
| 136 |
+
fx = float(intri[0, 0])
|
| 137 |
+
fy = float(intri[1, 1])
|
| 138 |
+
cx = float(intri[0, 2])
|
| 139 |
+
cy = float(intri[1, 2])
|
| 140 |
+
corners = np.array(
|
| 141 |
+
[
|
| 142 |
+
[
|
| 143 |
+
(0.0 - cx) * depth_scale / max(fx, 1e-12),
|
| 144 |
+
(0.0 - cy) * depth_scale / max(fy, 1e-12),
|
| 145 |
+
depth_scale,
|
| 146 |
+
],
|
| 147 |
+
[
|
| 148 |
+
((w - 1.0) - cx) * depth_scale / max(fx, 1e-12),
|
| 149 |
+
(0.0 - cy) * depth_scale / max(fy, 1e-12),
|
| 150 |
+
depth_scale,
|
| 151 |
+
],
|
| 152 |
+
[
|
| 153 |
+
((w - 1.0) - cx) * depth_scale / max(fx, 1e-12),
|
| 154 |
+
((h - 1.0) - cy) * depth_scale / max(fy, 1e-12),
|
| 155 |
+
depth_scale,
|
| 156 |
+
],
|
| 157 |
+
[
|
| 158 |
+
(0.0 - cx) * depth_scale / max(fx, 1e-12),
|
| 159 |
+
((h - 1.0) - cy) * depth_scale / max(fy, 1e-12),
|
| 160 |
+
depth_scale,
|
| 161 |
+
],
|
| 162 |
+
],
|
| 163 |
+
dtype=np.float64,
|
| 164 |
+
)
|
| 165 |
+
return corners
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
def camera_geometry(
|
| 169 |
+
session_dir: str,
|
| 170 |
+
display_mode: str,
|
| 171 |
+
frame_index: int,
|
| 172 |
+
camera_scale_ratio: float,
|
| 173 |
+
points_hint=None,
|
| 174 |
+
):
|
| 175 |
+
meta = load_metadata(session_dir)
|
| 176 |
+
frame_ids = selected_frame_indices(meta["num_frames"], frame_index, display_mode)
|
| 177 |
+
w2c = np.load(session_file(session_dir, "w2c.npy"), mmap_mode="r")
|
| 178 |
+
intri = np.load(session_file(session_dir, "intri.npy"), mmap_mode="r")
|
| 179 |
+
origin_shift = _origin_shift(w2c)
|
| 180 |
+
|
| 181 |
+
center_points = np.array(
|
| 182 |
+
[c2w_in_view_space(w2c[idx], origin_shift)[:3, 3] for idx in frame_ids],
|
| 183 |
+
dtype=np.float64,
|
| 184 |
+
)
|
| 185 |
+
center_extent = 1.0
|
| 186 |
+
if len(center_points) > 1:
|
| 187 |
+
center_extent = float(
|
| 188 |
+
np.linalg.norm(center_points.max(axis=0) - center_points.min(axis=0))
|
| 189 |
+
)
|
| 190 |
+
|
| 191 |
+
point_extent = 0.0
|
| 192 |
+
if points_hint is not None and len(points_hint) > 0:
|
| 193 |
+
lo = np.percentile(points_hint, 5, axis=0)
|
| 194 |
+
hi = np.percentile(points_hint, 95, axis=0)
|
| 195 |
+
point_extent = float(np.linalg.norm(hi - lo))
|
| 196 |
+
|
| 197 |
+
extent = max(center_extent, point_extent, 1.0)
|
| 198 |
+
depth_scale = extent * float(camera_scale_ratio)
|
| 199 |
+
|
| 200 |
+
centers = []
|
| 201 |
+
frustums = []
|
| 202 |
+
for idx in frame_ids:
|
| 203 |
+
c2w_view = c2w_in_view_space(w2c[idx], origin_shift)
|
| 204 |
+
center = c2w_view[:3, 3]
|
| 205 |
+
corners_cam = _frustum_corners_camera(
|
| 206 |
+
intri[idx], (meta["height"], meta["width"]), depth_scale
|
| 207 |
+
)
|
| 208 |
+
corners_world = (c2w_view[:3, :3] @ corners_cam.T).T + center[None]
|
| 209 |
+
centers.append(center)
|
| 210 |
+
frustums.append((center, corners_world))
|
| 211 |
+
return np.asarray(centers, dtype=np.float64), frustums, origin_shift
|
longstream/demo/viewer.py
ADDED
|
@@ -0,0 +1,134 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import plotly.graph_objects as go
|
| 3 |
+
|
| 4 |
+
from longstream.demo.backend import load_frame_previews
|
| 5 |
+
|
| 6 |
+
from .common import load_metadata
|
| 7 |
+
from .geometry import camera_geometry, collect_points
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def _empty_figure(message: str):
|
| 11 |
+
fig = go.Figure()
|
| 12 |
+
fig.add_annotation(
|
| 13 |
+
text=message, x=0.5, y=0.5, xref="paper", yref="paper", showarrow=False
|
| 14 |
+
)
|
| 15 |
+
fig.update_layout(
|
| 16 |
+
template="plotly_white",
|
| 17 |
+
margin=dict(l=0, r=0, t=40, b=0),
|
| 18 |
+
scene=dict(aspectmode="data"),
|
| 19 |
+
)
|
| 20 |
+
return fig
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def _camera_lines(frustums):
|
| 24 |
+
xs, ys, zs = [], [], []
|
| 25 |
+
for center, corners in frustums:
|
| 26 |
+
order = [(0, 1), (1, 2), (2, 3), (3, 0)]
|
| 27 |
+
for a, b in order:
|
| 28 |
+
xs.extend([corners[a, 0], corners[b, 0], None])
|
| 29 |
+
ys.extend([corners[a, 1], corners[b, 1], None])
|
| 30 |
+
zs.extend([corners[a, 2], corners[b, 2], None])
|
| 31 |
+
for corner in corners:
|
| 32 |
+
xs.extend([center[0], corner[0], None])
|
| 33 |
+
ys.extend([center[1], corner[1], None])
|
| 34 |
+
zs.extend([center[2], corner[2], None])
|
| 35 |
+
return xs, ys, zs
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def build_interactive_figure(
|
| 39 |
+
session_dir: str,
|
| 40 |
+
branch: str,
|
| 41 |
+
display_mode: str,
|
| 42 |
+
frame_index: int,
|
| 43 |
+
point_size: float,
|
| 44 |
+
opacity: float,
|
| 45 |
+
preview_max_points: int,
|
| 46 |
+
show_cameras: bool,
|
| 47 |
+
camera_scale: float,
|
| 48 |
+
mask_sky: bool,
|
| 49 |
+
):
|
| 50 |
+
meta = load_metadata(session_dir)
|
| 51 |
+
points, colors, _ = collect_points(
|
| 52 |
+
session_dir=session_dir,
|
| 53 |
+
branch=branch,
|
| 54 |
+
display_mode=display_mode,
|
| 55 |
+
frame_index=frame_index,
|
| 56 |
+
mask_sky=mask_sky,
|
| 57 |
+
max_points=preview_max_points,
|
| 58 |
+
seed=frame_index,
|
| 59 |
+
)
|
| 60 |
+
if len(points) == 0:
|
| 61 |
+
return _empty_figure("No valid points for the current selection")
|
| 62 |
+
|
| 63 |
+
fig = go.Figure()
|
| 64 |
+
fig.add_trace(
|
| 65 |
+
go.Scatter3d(
|
| 66 |
+
x=points[:, 0],
|
| 67 |
+
y=points[:, 1],
|
| 68 |
+
z=points[:, 2],
|
| 69 |
+
mode="markers",
|
| 70 |
+
marker=dict(
|
| 71 |
+
size=float(point_size),
|
| 72 |
+
color=[f"rgb({r},{g},{b})" for r, g, b in colors],
|
| 73 |
+
opacity=float(opacity),
|
| 74 |
+
),
|
| 75 |
+
hoverinfo="skip",
|
| 76 |
+
name="points",
|
| 77 |
+
)
|
| 78 |
+
)
|
| 79 |
+
|
| 80 |
+
if show_cameras:
|
| 81 |
+
centers, frustums, _ = camera_geometry(
|
| 82 |
+
session_dir=session_dir,
|
| 83 |
+
display_mode=display_mode,
|
| 84 |
+
frame_index=frame_index,
|
| 85 |
+
camera_scale_ratio=camera_scale,
|
| 86 |
+
points_hint=points,
|
| 87 |
+
)
|
| 88 |
+
if len(centers) > 0:
|
| 89 |
+
fig.add_trace(
|
| 90 |
+
go.Scatter3d(
|
| 91 |
+
x=centers[:, 0],
|
| 92 |
+
y=centers[:, 1],
|
| 93 |
+
z=centers[:, 2],
|
| 94 |
+
mode="lines",
|
| 95 |
+
line=dict(color="#16a34a", width=2),
|
| 96 |
+
name="trajectory",
|
| 97 |
+
hoverinfo="skip",
|
| 98 |
+
)
|
| 99 |
+
)
|
| 100 |
+
xs, ys, zs = _camera_lines(frustums)
|
| 101 |
+
fig.add_trace(
|
| 102 |
+
go.Scatter3d(
|
| 103 |
+
x=xs,
|
| 104 |
+
y=ys,
|
| 105 |
+
z=zs,
|
| 106 |
+
mode="lines",
|
| 107 |
+
line=dict(color="#22c55e", width=1.5),
|
| 108 |
+
name="cameras",
|
| 109 |
+
hoverinfo="skip",
|
| 110 |
+
)
|
| 111 |
+
)
|
| 112 |
+
|
| 113 |
+
fig.update_layout(
|
| 114 |
+
template="plotly_white",
|
| 115 |
+
margin=dict(l=0, r=0, t=40, b=0),
|
| 116 |
+
scene=dict(
|
| 117 |
+
aspectmode="data",
|
| 118 |
+
xaxis_title="x_right",
|
| 119 |
+
yaxis_title="z_forward",
|
| 120 |
+
zaxis_title="y_up",
|
| 121 |
+
bgcolor="#f8fafc",
|
| 122 |
+
camera=dict(
|
| 123 |
+
up=dict(x=0.0, y=0.0, z=1.0),
|
| 124 |
+
eye=dict(x=-1.0, y=-1.8, z=0.9),
|
| 125 |
+
),
|
| 126 |
+
),
|
| 127 |
+
legend=dict(orientation="h", yanchor="bottom", y=1.02, xanchor="left", x=0.0),
|
| 128 |
+
)
|
| 129 |
+
return fig
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
def build_frame_outputs(session_dir: str, frame_index: int):
|
| 133 |
+
rgb, depth, label = load_frame_previews(session_dir, frame_index)
|
| 134 |
+
return rgb, depth, label
|
longstream/eval/__init__.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .evaluate import evaluate_predictions_cfg
|
| 2 |
+
|
| 3 |
+
__all__ = ["evaluate_predictions_cfg"]
|
longstream/eval/evaluate.py
ADDED
|
@@ -0,0 +1,551 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import os
|
| 3 |
+
|
| 4 |
+
import cv2
|
| 5 |
+
import numpy as np
|
| 6 |
+
import matplotlib
|
| 7 |
+
|
| 8 |
+
matplotlib.use("Agg")
|
| 9 |
+
import matplotlib.pyplot as plt
|
| 10 |
+
|
| 11 |
+
from longstream.data import LongStreamDataLoader
|
| 12 |
+
from longstream.eval.io import (
|
| 13 |
+
frame_stems,
|
| 14 |
+
read_depth,
|
| 15 |
+
read_opencv_camera_yml,
|
| 16 |
+
read_pointcloud_xyz,
|
| 17 |
+
read_pred_w2c_txt,
|
| 18 |
+
)
|
| 19 |
+
from longstream.eval.metrics import ate_rmse, chamfer_and_f1, transform_points
|
| 20 |
+
from longstream.utils.sky_mask import sky_mask_filename
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def _ensure_dir(path):
|
| 24 |
+
os.makedirs(path, exist_ok=True)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def _sequence_output_dir(output_root, seq_name):
|
| 28 |
+
return os.path.join(output_root, seq_name)
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def _sequence_metrics_path(output_root, seq_name):
|
| 32 |
+
return os.path.join(output_root, "metrics", f"{seq_name}.json")
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def _sequence_plot_path(output_root, seq_name):
|
| 36 |
+
return os.path.join(output_root, "plots", f"{seq_name}_traj_3d.png")
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def _world_xyz_to_plot_xyz(xyz):
|
| 40 |
+
xyz = np.asarray(xyz, dtype=np.float64)
|
| 41 |
+
return np.stack([xyz[:, 0], xyz[:, 2], -xyz[:, 1]], axis=-1)
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def _set_equal_3d_axes(ax, xyz):
|
| 45 |
+
mins = xyz.min(axis=0)
|
| 46 |
+
maxs = xyz.max(axis=0)
|
| 47 |
+
center = 0.5 * (mins + maxs)
|
| 48 |
+
radius = 0.5 * np.max(np.maximum(maxs - mins, 1e-6))
|
| 49 |
+
ax.set_xlim(center[0] - radius, center[0] + radius)
|
| 50 |
+
ax.set_ylim(center[1] - radius, center[1] + radius)
|
| 51 |
+
ax.set_zlim(center[2] - radius, center[2] + radius)
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def _load_gt_pose_data(seq_info):
|
| 55 |
+
if seq_info.camera is not None:
|
| 56 |
+
cam_dir = os.path.join(seq_info.scene_root, "cameras", seq_info.camera)
|
| 57 |
+
extri_path = os.path.join(cam_dir, "extri.yml")
|
| 58 |
+
intri_path = os.path.join(cam_dir, "intri.yml")
|
| 59 |
+
if os.path.exists(extri_path):
|
| 60 |
+
extri, intri, image_sizes = read_opencv_camera_yml(extri_path, intri_path)
|
| 61 |
+
return extri, intri, image_sizes
|
| 62 |
+
|
| 63 |
+
extri_path = os.path.join(seq_info.scene_root, "extri.yml")
|
| 64 |
+
intri_path = os.path.join(seq_info.scene_root, "intri.yml")
|
| 65 |
+
if not os.path.exists(extri_path):
|
| 66 |
+
return None, None, None
|
| 67 |
+
extri, intri, image_sizes = read_opencv_camera_yml(extri_path, intri_path)
|
| 68 |
+
return extri, intri, image_sizes
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def _resolve_gt_depth_root(seq_info):
|
| 72 |
+
if seq_info.camera is not None:
|
| 73 |
+
camera_depth_root = os.path.join(seq_info.scene_root, "depths", seq_info.camera)
|
| 74 |
+
if os.path.isdir(camera_depth_root):
|
| 75 |
+
return camera_depth_root
|
| 76 |
+
depth_root = os.path.join(seq_info.scene_root, "depths")
|
| 77 |
+
if os.path.isdir(depth_root):
|
| 78 |
+
return depth_root
|
| 79 |
+
return None
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def _resolve_gt_depth_path(seq_info, depth_root, image_path, stem):
|
| 83 |
+
rel_path = os.path.relpath(image_path, seq_info.image_dir)
|
| 84 |
+
rel_stem = os.path.splitext(rel_path)[0]
|
| 85 |
+
file_stem = os.path.splitext(os.path.basename(image_path))[0]
|
| 86 |
+
candidates = [
|
| 87 |
+
os.path.join(depth_root, f"{stem}.exr"),
|
| 88 |
+
os.path.join(depth_root, rel_stem + ".exr"),
|
| 89 |
+
os.path.join(depth_root, stem, f"{file_stem}.exr"),
|
| 90 |
+
]
|
| 91 |
+
for candidate in candidates:
|
| 92 |
+
if os.path.exists(candidate):
|
| 93 |
+
return candidate
|
| 94 |
+
return None
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
def _resize_long_edge(arr, long_edge_size, interpolation):
|
| 98 |
+
h, w = arr.shape[:2]
|
| 99 |
+
scale = float(long_edge_size) / float(max(h, w))
|
| 100 |
+
new_w = int(round(w * scale))
|
| 101 |
+
new_h = int(round(h * scale))
|
| 102 |
+
return cv2.resize(arr, (new_w, new_h), interpolation=interpolation)
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
def _prepare_map_for_eval(
|
| 106 |
+
arr, size, crop, patch_size, target_shape, interpolation, square_ok=False
|
| 107 |
+
):
|
| 108 |
+
h0, w0 = arr.shape[:2]
|
| 109 |
+
long_edge = round(size * max(w0 / h0, h0 / w0)) if size == 224 else size
|
| 110 |
+
arr = _resize_long_edge(arr, long_edge, interpolation)
|
| 111 |
+
|
| 112 |
+
h, w = arr.shape[:2]
|
| 113 |
+
cx, cy = w // 2, h // 2
|
| 114 |
+
|
| 115 |
+
if size == 224:
|
| 116 |
+
half = min(cx, cy)
|
| 117 |
+
target_w = 2 * half
|
| 118 |
+
target_h = 2 * half
|
| 119 |
+
if crop:
|
| 120 |
+
arr = arr[cy - half : cy + half, cx - half : cx + half]
|
| 121 |
+
else:
|
| 122 |
+
arr = cv2.resize(arr, (target_w, target_h), interpolation=interpolation)
|
| 123 |
+
else:
|
| 124 |
+
halfw = ((2 * cx) // patch_size) * (patch_size // 2)
|
| 125 |
+
halfh = ((2 * cy) // patch_size) * (patch_size // 2)
|
| 126 |
+
if not square_ok and w == h:
|
| 127 |
+
halfh = int(3 * halfw / 4)
|
| 128 |
+
target_w = 2 * halfw
|
| 129 |
+
target_h = 2 * halfh
|
| 130 |
+
if crop:
|
| 131 |
+
arr = arr[cy - halfh : cy + halfh, cx - halfw : cx + halfw]
|
| 132 |
+
else:
|
| 133 |
+
arr = cv2.resize(arr, (target_w, target_h), interpolation=interpolation)
|
| 134 |
+
|
| 135 |
+
if arr.shape[:2] != tuple(target_shape):
|
| 136 |
+
arr = cv2.resize(
|
| 137 |
+
arr, (target_shape[1], target_shape[0]), interpolation=interpolation
|
| 138 |
+
)
|
| 139 |
+
return arr
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
def _sky_mask_path(seq_dir, image_path):
|
| 143 |
+
return os.path.join(seq_dir, "sky_masks", sky_mask_filename(image_path))
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
def _sample_frame_points(points, max_points, rng):
|
| 147 |
+
if max_points is None or len(points) <= max_points:
|
| 148 |
+
return points
|
| 149 |
+
keep = rng.choice(len(points), size=max_points, replace=False)
|
| 150 |
+
return points[keep]
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
def _depth_to_world_points(depth, intri, extri, valid_mask):
|
| 154 |
+
ys, xs = np.nonzero(valid_mask)
|
| 155 |
+
if ys.size == 0:
|
| 156 |
+
return np.empty((0, 3), dtype=np.float32)
|
| 157 |
+
|
| 158 |
+
z = depth[ys, xs].astype(np.float64)
|
| 159 |
+
fx = float(intri[0, 0])
|
| 160 |
+
fy = float(intri[1, 1])
|
| 161 |
+
cx = float(intri[0, 2])
|
| 162 |
+
cy = float(intri[1, 2])
|
| 163 |
+
|
| 164 |
+
x = (xs.astype(np.float64) - cx) * z / max(fx, 1e-12)
|
| 165 |
+
y = (ys.astype(np.float64) - cy) * z / max(fy, 1e-12)
|
| 166 |
+
pts_cam = np.stack([x, y, z], axis=1)
|
| 167 |
+
|
| 168 |
+
R = extri[:3, :3]
|
| 169 |
+
t = extri[:3, 3]
|
| 170 |
+
pts_world = (R.T @ (pts_cam.T - t[:, None])).T
|
| 171 |
+
return pts_world.astype(np.float32, copy=False)
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
def _load_gt_pointcloud(seq_info, seq_dir, gt_extri, gt_intri, eval_cfg):
|
| 175 |
+
if not gt_extri or not gt_intri:
|
| 176 |
+
return None
|
| 177 |
+
|
| 178 |
+
gt_dir = _resolve_gt_depth_root(seq_info)
|
| 179 |
+
if gt_dir is None:
|
| 180 |
+
return None
|
| 181 |
+
|
| 182 |
+
eval_max_points = int(eval_cfg.get("point_eval_max_points", 100000))
|
| 183 |
+
oversample_factor = int(eval_cfg.get("point_eval_oversample_factor", 4))
|
| 184 |
+
per_frame_budget = max(
|
| 185 |
+
(eval_max_points * oversample_factor) // max(len(seq_info.image_paths), 1), 1
|
| 186 |
+
)
|
| 187 |
+
rng = np.random.default_rng(0)
|
| 188 |
+
chunks = []
|
| 189 |
+
|
| 190 |
+
for image_path, stem in zip(
|
| 191 |
+
seq_info.image_paths, frame_stems(seq_info.image_paths)
|
| 192 |
+
):
|
| 193 |
+
depth_path = _resolve_gt_depth_path(seq_info, gt_dir, image_path, stem)
|
| 194 |
+
if depth_path is None or stem not in gt_extri or stem not in gt_intri:
|
| 195 |
+
continue
|
| 196 |
+
|
| 197 |
+
depth = read_depth(depth_path)
|
| 198 |
+
valid = np.isfinite(depth) & (depth > 0)
|
| 199 |
+
if not np.any(valid):
|
| 200 |
+
continue
|
| 201 |
+
|
| 202 |
+
sky_path = _sky_mask_path(seq_dir, image_path)
|
| 203 |
+
if os.path.exists(sky_path):
|
| 204 |
+
sky_mask = cv2.imread(sky_path, cv2.IMREAD_GRAYSCALE)
|
| 205 |
+
if sky_mask is not None:
|
| 206 |
+
if sky_mask.shape[:2] != depth.shape[:2]:
|
| 207 |
+
sky_mask = cv2.resize(
|
| 208 |
+
sky_mask,
|
| 209 |
+
(depth.shape[1], depth.shape[0]),
|
| 210 |
+
interpolation=cv2.INTER_NEAREST,
|
| 211 |
+
)
|
| 212 |
+
valid &= sky_mask > 0
|
| 213 |
+
if not np.any(valid):
|
| 214 |
+
continue
|
| 215 |
+
|
| 216 |
+
pts_world = _depth_to_world_points(depth, gt_intri[stem], gt_extri[stem], valid)
|
| 217 |
+
if len(pts_world) == 0:
|
| 218 |
+
continue
|
| 219 |
+
chunks.append(_sample_frame_points(pts_world, per_frame_budget, rng))
|
| 220 |
+
|
| 221 |
+
if not chunks:
|
| 222 |
+
return None
|
| 223 |
+
return np.concatenate(chunks, axis=0)
|
| 224 |
+
|
| 225 |
+
|
| 226 |
+
def _evaluate_pointclouds(seq_info, seq_dir, eval_cfg, pose_align, gt_cloud):
|
| 227 |
+
if pose_align is None or gt_cloud is None:
|
| 228 |
+
return None
|
| 229 |
+
|
| 230 |
+
scale, R, t = pose_align
|
| 231 |
+
point_paths = {
|
| 232 |
+
"point_head": [
|
| 233 |
+
os.path.join(seq_dir, "points", "point_head_full.npy"),
|
| 234 |
+
os.path.join(seq_dir, "points", "point_head_full.npz"),
|
| 235 |
+
os.path.join(seq_dir, "points", "point_head_full.ply"),
|
| 236 |
+
],
|
| 237 |
+
"dpt_unproj": [
|
| 238 |
+
os.path.join(seq_dir, "points", "dpt_unproj_full.npy"),
|
| 239 |
+
os.path.join(seq_dir, "points", "dpt_unproj_full.npz"),
|
| 240 |
+
os.path.join(seq_dir, "points", "dpt_unproj_full.ply"),
|
| 241 |
+
],
|
| 242 |
+
}
|
| 243 |
+
threshold = float(eval_cfg.get("point_f1_threshold", 0.25))
|
| 244 |
+
max_points = int(eval_cfg.get("point_eval_max_points", 100000))
|
| 245 |
+
voxel_size = eval_cfg.get("point_eval_voxel_size", None)
|
| 246 |
+
voxel_size = None if voxel_size in (None, "", "null") else float(voxel_size)
|
| 247 |
+
|
| 248 |
+
metrics_by_branch = {}
|
| 249 |
+
for branch, candidates in point_paths.items():
|
| 250 |
+
path = next(
|
| 251 |
+
(candidate for candidate in candidates if os.path.exists(candidate)), None
|
| 252 |
+
)
|
| 253 |
+
if path is None:
|
| 254 |
+
continue
|
| 255 |
+
pred_cloud = read_pointcloud_xyz(path)
|
| 256 |
+
pred_cloud = transform_points(pred_cloud, scale, R, t)
|
| 257 |
+
metrics = chamfer_and_f1(
|
| 258 |
+
pred_cloud,
|
| 259 |
+
gt_cloud,
|
| 260 |
+
threshold=threshold,
|
| 261 |
+
max_points=max_points,
|
| 262 |
+
voxel_size=voxel_size,
|
| 263 |
+
seed=0 if branch == "point_head" else 1,
|
| 264 |
+
)
|
| 265 |
+
if metrics is not None:
|
| 266 |
+
metrics_by_branch[branch] = metrics
|
| 267 |
+
return metrics_by_branch or None
|
| 268 |
+
|
| 269 |
+
|
| 270 |
+
def _evaluate_video_dpt(seq_info, seq_dir, eval_cfg, data_cfg):
|
| 271 |
+
pred_dir = os.path.join(seq_dir, "depth", "dpt")
|
| 272 |
+
gt_dir = _resolve_gt_depth_root(seq_info)
|
| 273 |
+
if not os.path.isdir(pred_dir) or gt_dir is None:
|
| 274 |
+
return None
|
| 275 |
+
|
| 276 |
+
size = int(data_cfg.get("size", 518))
|
| 277 |
+
crop = bool(data_cfg.get("crop", False))
|
| 278 |
+
patch_size = int(data_cfg.get("patch_size", 14))
|
| 279 |
+
rel_delta_threshold = float(eval_cfg.get("depth_rel_delta_threshold", 1.25))
|
| 280 |
+
|
| 281 |
+
abs_rel_sum = 0.0
|
| 282 |
+
rel_delta_hits = 0
|
| 283 |
+
valid_pixels = 0
|
| 284 |
+
evaluated_frames = 0
|
| 285 |
+
|
| 286 |
+
stems = frame_stems(seq_info.image_paths)
|
| 287 |
+
for frame_id, stem in enumerate(stems):
|
| 288 |
+
pred_path = os.path.join(pred_dir, f"frame_{frame_id:06d}.npy")
|
| 289 |
+
gt_path = _resolve_gt_depth_path(
|
| 290 |
+
seq_info, gt_dir, seq_info.image_paths[frame_id], stem
|
| 291 |
+
)
|
| 292 |
+
if not os.path.exists(pred_path) or gt_path is None:
|
| 293 |
+
continue
|
| 294 |
+
|
| 295 |
+
pred = np.load(pred_path).astype(np.float32)
|
| 296 |
+
gt = read_depth(gt_path)
|
| 297 |
+
gt = _prepare_map_for_eval(
|
| 298 |
+
gt,
|
| 299 |
+
size=size,
|
| 300 |
+
crop=crop,
|
| 301 |
+
patch_size=patch_size,
|
| 302 |
+
target_shape=pred.shape,
|
| 303 |
+
interpolation=cv2.INTER_NEAREST,
|
| 304 |
+
)
|
| 305 |
+
|
| 306 |
+
valid = np.isfinite(gt) & (gt > 0)
|
| 307 |
+
if not np.any(valid):
|
| 308 |
+
continue
|
| 309 |
+
|
| 310 |
+
sky_mask_path = _sky_mask_path(seq_dir, seq_info.image_paths[frame_id])
|
| 311 |
+
if os.path.exists(sky_mask_path):
|
| 312 |
+
sky_mask = cv2.imread(sky_mask_path, cv2.IMREAD_GRAYSCALE)
|
| 313 |
+
if sky_mask is not None:
|
| 314 |
+
sky_mask = _prepare_map_for_eval(
|
| 315 |
+
sky_mask,
|
| 316 |
+
size=size,
|
| 317 |
+
crop=crop,
|
| 318 |
+
patch_size=patch_size,
|
| 319 |
+
target_shape=pred.shape,
|
| 320 |
+
interpolation=cv2.INTER_NEAREST,
|
| 321 |
+
)
|
| 322 |
+
valid &= sky_mask > 0
|
| 323 |
+
|
| 324 |
+
valid &= np.isfinite(pred)
|
| 325 |
+
if not np.any(valid):
|
| 326 |
+
continue
|
| 327 |
+
|
| 328 |
+
pred_valid = pred[valid].astype(np.float64)
|
| 329 |
+
gt_valid = gt[valid].astype(np.float64)
|
| 330 |
+
pred_safe = np.clip(pred_valid, 1e-6, None)
|
| 331 |
+
gt_safe = np.clip(gt_valid, 1e-6, None)
|
| 332 |
+
|
| 333 |
+
abs_rel_sum += np.sum(np.abs(pred_valid - gt_valid) / gt_safe)
|
| 334 |
+
rel_ratio = np.maximum(gt_safe / pred_safe, pred_safe / gt_safe)
|
| 335 |
+
rel_delta_hits += int(np.sum(rel_ratio < rel_delta_threshold))
|
| 336 |
+
valid_pixels += int(gt_valid.size)
|
| 337 |
+
evaluated_frames += 1
|
| 338 |
+
|
| 339 |
+
if valid_pixels == 0:
|
| 340 |
+
return None
|
| 341 |
+
|
| 342 |
+
return {
|
| 343 |
+
"abs_rel": float(abs_rel_sum / valid_pixels),
|
| 344 |
+
"rel_delta": float(rel_delta_hits / valid_pixels),
|
| 345 |
+
"rel_delta_threshold": rel_delta_threshold,
|
| 346 |
+
"num_valid_pixels": int(valid_pixels),
|
| 347 |
+
"num_frames": int(evaluated_frames),
|
| 348 |
+
}
|
| 349 |
+
|
| 350 |
+
|
| 351 |
+
def _extract_pose_pairs(seq_info, pred_pose_path, gt_extri):
|
| 352 |
+
frame_ids, pred_w2c = read_pred_w2c_txt(pred_pose_path)
|
| 353 |
+
if not pred_w2c:
|
| 354 |
+
return None
|
| 355 |
+
|
| 356 |
+
stems = frame_stems(seq_info.image_paths)
|
| 357 |
+
pred_xyz = []
|
| 358 |
+
gt_xyz = []
|
| 359 |
+
|
| 360 |
+
for frame_id, pred_mat in zip(frame_ids, pred_w2c):
|
| 361 |
+
if frame_id < 0 or frame_id >= len(stems):
|
| 362 |
+
continue
|
| 363 |
+
stem = stems[frame_id]
|
| 364 |
+
if stem not in gt_extri:
|
| 365 |
+
continue
|
| 366 |
+
pred_c2w = np.linalg.inv(pred_mat)
|
| 367 |
+
gt_c2w = np.linalg.inv(gt_extri[stem])
|
| 368 |
+
pred_xyz.append(pred_c2w[:3, 3])
|
| 369 |
+
gt_xyz.append(gt_c2w[:3, 3])
|
| 370 |
+
|
| 371 |
+
if len(pred_xyz) < 3:
|
| 372 |
+
return None
|
| 373 |
+
return np.asarray(pred_xyz, dtype=np.float64), np.asarray(gt_xyz, dtype=np.float64)
|
| 374 |
+
|
| 375 |
+
|
| 376 |
+
def _save_traj_plot_3d(path, pred_xyz, gt_xyz):
|
| 377 |
+
_ensure_dir(os.path.dirname(path))
|
| 378 |
+
pred_plot = _world_xyz_to_plot_xyz(pred_xyz)
|
| 379 |
+
gt_plot = _world_xyz_to_plot_xyz(gt_xyz)
|
| 380 |
+
origin = gt_plot[:1]
|
| 381 |
+
pred_plot = pred_plot - origin
|
| 382 |
+
gt_plot = gt_plot - origin
|
| 383 |
+
all_plot = np.concatenate([pred_plot, gt_plot], axis=0)
|
| 384 |
+
|
| 385 |
+
fig = plt.figure(figsize=(7, 6))
|
| 386 |
+
ax = fig.add_subplot(111, projection="3d")
|
| 387 |
+
ax.plot(
|
| 388 |
+
gt_plot[:, 0],
|
| 389 |
+
gt_plot[:, 1],
|
| 390 |
+
gt_plot[:, 2],
|
| 391 |
+
label="gt",
|
| 392 |
+
linewidth=2.0,
|
| 393 |
+
color="#1f77b4",
|
| 394 |
+
)
|
| 395 |
+
ax.plot(
|
| 396 |
+
pred_plot[:, 0],
|
| 397 |
+
pred_plot[:, 1],
|
| 398 |
+
pred_plot[:, 2],
|
| 399 |
+
label="pred",
|
| 400 |
+
linewidth=2.0,
|
| 401 |
+
color="#d62728",
|
| 402 |
+
)
|
| 403 |
+
_set_equal_3d_axes(ax, all_plot)
|
| 404 |
+
ax.view_init(elev=24, azim=-118)
|
| 405 |
+
ax.set_xlabel("x_right")
|
| 406 |
+
ax.set_ylabel("z_forward")
|
| 407 |
+
ax.set_zlabel("y_up")
|
| 408 |
+
ax.legend(loc="best")
|
| 409 |
+
ax.set_title("Trajectory 3D (Sim3-aligned view)")
|
| 410 |
+
fig.tight_layout()
|
| 411 |
+
fig.savefig(path, dpi=180)
|
| 412 |
+
plt.close(fig)
|
| 413 |
+
|
| 414 |
+
|
| 415 |
+
def evaluate_sequence(seq_info, output_root, eval_cfg, data_cfg):
|
| 416 |
+
seq_dir = _sequence_output_dir(output_root, seq_info.name)
|
| 417 |
+
result = {
|
| 418 |
+
"sequence": seq_info.name,
|
| 419 |
+
"output_dir": seq_dir,
|
| 420 |
+
"has_gt": False,
|
| 421 |
+
"has_gt_pose": False,
|
| 422 |
+
"has_gt_depth": False,
|
| 423 |
+
}
|
| 424 |
+
|
| 425 |
+
gt_extri, gt_intri, _ = _load_gt_pose_data(seq_info)
|
| 426 |
+
pose_align = None
|
| 427 |
+
if gt_extri:
|
| 428 |
+
result["has_gt"] = True
|
| 429 |
+
result["has_gt_pose"] = True
|
| 430 |
+
|
| 431 |
+
pred_pose_path = os.path.join(seq_dir, "poses", "abs_pose.txt")
|
| 432 |
+
pairs = _extract_pose_pairs(seq_info, pred_pose_path, gt_extri)
|
| 433 |
+
if pairs is not None:
|
| 434 |
+
pred_xyz, gt_xyz = pairs
|
| 435 |
+
pose_metrics = ate_rmse(
|
| 436 |
+
pred_xyz, gt_xyz, align_scale=bool(eval_cfg.get("align_scale", True))
|
| 437 |
+
)
|
| 438 |
+
sim3_scale = float(pose_metrics.get("sim3_scale", 1.0))
|
| 439 |
+
pred_xyz_aligned = transform_points(
|
| 440 |
+
pred_xyz,
|
| 441 |
+
sim3_scale,
|
| 442 |
+
np.asarray(pose_metrics["sim3_rotation"], dtype=np.float64),
|
| 443 |
+
np.asarray(pose_metrics["sim3_translation"], dtype=np.float64),
|
| 444 |
+
)
|
| 445 |
+
pose_align = (
|
| 446 |
+
sim3_scale,
|
| 447 |
+
np.asarray(pose_metrics["sim3_rotation"], dtype=np.float64),
|
| 448 |
+
np.asarray(pose_metrics["sim3_translation"], dtype=np.float64),
|
| 449 |
+
)
|
| 450 |
+
plot_path = _sequence_plot_path(output_root, seq_info.name)
|
| 451 |
+
_save_traj_plot_3d(plot_path, pred_xyz_aligned, gt_xyz)
|
| 452 |
+
pose_metrics.pop("sim3_scale", None)
|
| 453 |
+
pose_metrics["traj_3d_plot"] = plot_path
|
| 454 |
+
result["pose"] = pose_metrics
|
| 455 |
+
|
| 456 |
+
video_dpt_metrics = _evaluate_video_dpt(seq_info, seq_dir, eval_cfg, data_cfg)
|
| 457 |
+
if video_dpt_metrics is not None:
|
| 458 |
+
result["has_gt"] = True
|
| 459 |
+
result["has_gt_depth"] = True
|
| 460 |
+
result["video_dpt"] = video_dpt_metrics
|
| 461 |
+
|
| 462 |
+
gt_cloud = _load_gt_pointcloud(seq_info, seq_dir, gt_extri, gt_intri, eval_cfg)
|
| 463 |
+
pointcloud_metrics = _evaluate_pointclouds(
|
| 464 |
+
seq_info, seq_dir, eval_cfg, pose_align, gt_cloud
|
| 465 |
+
)
|
| 466 |
+
if pointcloud_metrics is not None:
|
| 467 |
+
result["has_gt"] = True
|
| 468 |
+
result["has_gt_depth"] = True
|
| 469 |
+
result["pointcloud"] = pointcloud_metrics
|
| 470 |
+
|
| 471 |
+
if not result["has_gt"]:
|
| 472 |
+
result["skipped"] = "missing_gt"
|
| 473 |
+
|
| 474 |
+
return result
|
| 475 |
+
|
| 476 |
+
|
| 477 |
+
def _mean_metric(sequence_results, group_name, metric_name):
|
| 478 |
+
values = []
|
| 479 |
+
for item in sequence_results:
|
| 480 |
+
group = item
|
| 481 |
+
for key in group_name.split("."):
|
| 482 |
+
if not isinstance(group, dict):
|
| 483 |
+
group = None
|
| 484 |
+
break
|
| 485 |
+
group = group.get(key)
|
| 486 |
+
if not isinstance(group, dict):
|
| 487 |
+
continue
|
| 488 |
+
if metric_name in group:
|
| 489 |
+
values.append(float(group[metric_name]))
|
| 490 |
+
if not values:
|
| 491 |
+
return None
|
| 492 |
+
return float(np.mean(values))
|
| 493 |
+
|
| 494 |
+
|
| 495 |
+
def evaluate_predictions_cfg(cfg):
|
| 496 |
+
data_cfg = dict(cfg.get("data", {}))
|
| 497 |
+
data_cfg["format"] = "generalizable"
|
| 498 |
+
output_cfg = cfg.get("output", {})
|
| 499 |
+
eval_cfg = cfg.get("evaluation", {})
|
| 500 |
+
output_root = output_cfg.get("root", "outputs")
|
| 501 |
+
_ensure_dir(output_root)
|
| 502 |
+
|
| 503 |
+
loader = LongStreamDataLoader(data_cfg)
|
| 504 |
+
sequence_results = []
|
| 505 |
+
for seq_info in loader.iter_sequence_infos():
|
| 506 |
+
print(f"[longstream] eval {seq_info.name}: start", flush=True)
|
| 507 |
+
metrics = evaluate_sequence(seq_info, output_root, eval_cfg, data_cfg)
|
| 508 |
+
sequence_results.append(metrics)
|
| 509 |
+
metrics_path = _sequence_metrics_path(output_root, seq_info.name)
|
| 510 |
+
_ensure_dir(os.path.dirname(metrics_path))
|
| 511 |
+
with open(metrics_path, "w") as f:
|
| 512 |
+
json.dump(metrics, f, indent=2)
|
| 513 |
+
print(f"[longstream] eval {seq_info.name}: wrote {metrics_path}", flush=True)
|
| 514 |
+
|
| 515 |
+
summary = {
|
| 516 |
+
"num_sequences": len(sequence_results),
|
| 517 |
+
"num_sequences_with_gt": sum(1 for x in sequence_results if x.get("has_gt")),
|
| 518 |
+
"num_sequences_with_pose_gt": sum(
|
| 519 |
+
1 for x in sequence_results if x.get("has_gt_pose")
|
| 520 |
+
),
|
| 521 |
+
"num_sequences_with_depth_gt": sum(
|
| 522 |
+
1 for x in sequence_results if x.get("has_gt_depth")
|
| 523 |
+
),
|
| 524 |
+
"ate_mean": _mean_metric(sequence_results, "pose", "ate_mean"),
|
| 525 |
+
"ate_rmse_mean": _mean_metric(sequence_results, "pose", "ate_rmse"),
|
| 526 |
+
"video_dpt_abs_rel_mean": _mean_metric(
|
| 527 |
+
sequence_results, "video_dpt", "abs_rel"
|
| 528 |
+
),
|
| 529 |
+
"video_dpt_rel_delta_mean": _mean_metric(
|
| 530 |
+
sequence_results, "video_dpt", "rel_delta"
|
| 531 |
+
),
|
| 532 |
+
"point_head_cd_mean": _mean_metric(
|
| 533 |
+
sequence_results, "pointcloud.point_head", "cd"
|
| 534 |
+
),
|
| 535 |
+
"point_head_f1_mean": _mean_metric(
|
| 536 |
+
sequence_results, "pointcloud.point_head", "f1"
|
| 537 |
+
),
|
| 538 |
+
"dpt_unproj_cd_mean": _mean_metric(
|
| 539 |
+
sequence_results, "pointcloud.dpt_unproj", "cd"
|
| 540 |
+
),
|
| 541 |
+
"dpt_unproj_f1_mean": _mean_metric(
|
| 542 |
+
sequence_results, "pointcloud.dpt_unproj", "f1"
|
| 543 |
+
),
|
| 544 |
+
"sequences": sequence_results,
|
| 545 |
+
}
|
| 546 |
+
|
| 547 |
+
summary_path = os.path.join(output_root, "summary.json")
|
| 548 |
+
with open(summary_path, "w") as f:
|
| 549 |
+
json.dump(summary, f, indent=2)
|
| 550 |
+
print(f"[longstream] eval: wrote {summary_path}", flush=True)
|
| 551 |
+
return summary
|
longstream/eval/io.py
ADDED
|
@@ -0,0 +1,156 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
|
| 3 |
+
os.environ.setdefault("OPENCV_IO_ENABLE_OPENEXR", "1")
|
| 4 |
+
|
| 5 |
+
import cv2
|
| 6 |
+
import numpy as np
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def frame_stems(image_paths):
|
| 10 |
+
stems = [os.path.splitext(os.path.basename(p))[0] for p in image_paths]
|
| 11 |
+
if len(set(stems)) == len(stems):
|
| 12 |
+
return stems
|
| 13 |
+
parents = [os.path.basename(os.path.dirname(p)) for p in image_paths]
|
| 14 |
+
if len(set(parents)) == len(parents):
|
| 15 |
+
return parents
|
| 16 |
+
return stems
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def read_pred_w2c_txt(path):
|
| 20 |
+
frames = []
|
| 21 |
+
poses = []
|
| 22 |
+
if not os.path.exists(path):
|
| 23 |
+
return frames, poses
|
| 24 |
+
with open(path, "r") as f:
|
| 25 |
+
for line in f:
|
| 26 |
+
line = line.strip()
|
| 27 |
+
if not line or line.startswith("#"):
|
| 28 |
+
continue
|
| 29 |
+
vals = [float(x) for x in line.split()]
|
| 30 |
+
if len(vals) != 13:
|
| 31 |
+
continue
|
| 32 |
+
frame = int(vals[0])
|
| 33 |
+
mat = np.eye(4, dtype=np.float64)
|
| 34 |
+
mat[:3, :3] = np.asarray(vals[1:10], dtype=np.float64).reshape(3, 3)
|
| 35 |
+
mat[:3, 3] = np.asarray(vals[10:13], dtype=np.float64)
|
| 36 |
+
frames.append(frame)
|
| 37 |
+
poses.append(mat)
|
| 38 |
+
return frames, poses
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def read_opencv_camera_yml(extri_path, intri_path=None):
|
| 42 |
+
if not os.path.exists(extri_path):
|
| 43 |
+
return {}, {}, {}
|
| 44 |
+
|
| 45 |
+
fs_extri = cv2.FileStorage(extri_path, cv2.FILE_STORAGE_READ)
|
| 46 |
+
names_node = fs_extri.getNode("names")
|
| 47 |
+
names = []
|
| 48 |
+
for i in range(names_node.size()):
|
| 49 |
+
names.append(names_node.at(i).string())
|
| 50 |
+
|
| 51 |
+
extri = {}
|
| 52 |
+
for name in names:
|
| 53 |
+
rot = fs_extri.getNode(f"Rot_{name}").mat()
|
| 54 |
+
t = fs_extri.getNode(f"T_{name}").mat()
|
| 55 |
+
if rot is None or t is None:
|
| 56 |
+
continue
|
| 57 |
+
mat = np.eye(4, dtype=np.float64)
|
| 58 |
+
mat[:3, :3] = np.asarray(rot, dtype=np.float64)
|
| 59 |
+
mat[:3, 3] = np.asarray(t, dtype=np.float64).reshape(3)
|
| 60 |
+
extri[name] = mat
|
| 61 |
+
fs_extri.release()
|
| 62 |
+
|
| 63 |
+
intri = {}
|
| 64 |
+
image_sizes = {}
|
| 65 |
+
if intri_path is not None and os.path.exists(intri_path):
|
| 66 |
+
fs_intri = cv2.FileStorage(intri_path, cv2.FILE_STORAGE_READ)
|
| 67 |
+
for name in names:
|
| 68 |
+
K = fs_intri.getNode(f"K_{name}").mat()
|
| 69 |
+
if K is None:
|
| 70 |
+
continue
|
| 71 |
+
intri[name] = np.asarray(K, dtype=np.float64)
|
| 72 |
+
h_node = fs_intri.getNode(f"H_{name}")
|
| 73 |
+
w_node = fs_intri.getNode(f"W_{name}")
|
| 74 |
+
if not h_node.empty() and not w_node.empty():
|
| 75 |
+
image_sizes[name] = (int(h_node.real()), int(w_node.real()))
|
| 76 |
+
fs_intri.release()
|
| 77 |
+
|
| 78 |
+
return extri, intri, image_sizes
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
def read_depth(path):
|
| 82 |
+
depth = cv2.imread(path, cv2.IMREAD_ANYDEPTH)
|
| 83 |
+
if depth is None:
|
| 84 |
+
raise FileNotFoundError(path)
|
| 85 |
+
return depth.astype(np.float32)
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
def read_ply_xyz(path):
|
| 89 |
+
if not os.path.exists(path):
|
| 90 |
+
raise FileNotFoundError(path)
|
| 91 |
+
|
| 92 |
+
header = []
|
| 93 |
+
with open(path, "rb") as f:
|
| 94 |
+
while True:
|
| 95 |
+
line = f.readline()
|
| 96 |
+
if not line:
|
| 97 |
+
raise ValueError(f"Invalid PLY header: {path}")
|
| 98 |
+
text = line.decode("ascii").strip()
|
| 99 |
+
header.append(text)
|
| 100 |
+
if text == "end_header":
|
| 101 |
+
break
|
| 102 |
+
|
| 103 |
+
if "format binary_little_endian 1.0" not in header:
|
| 104 |
+
raise ValueError(f"Unsupported PLY format: {path}")
|
| 105 |
+
|
| 106 |
+
vertex_count = None
|
| 107 |
+
property_specs = []
|
| 108 |
+
in_vertex_block = False
|
| 109 |
+
for line in header:
|
| 110 |
+
if line.startswith("element vertex "):
|
| 111 |
+
vertex_count = int(line.split()[-1])
|
| 112 |
+
in_vertex_block = True
|
| 113 |
+
continue
|
| 114 |
+
if line.startswith("element ") and not line.startswith("element vertex "):
|
| 115 |
+
in_vertex_block = False
|
| 116 |
+
if in_vertex_block and line.startswith("property "):
|
| 117 |
+
_, dtype_name, prop_name = line.split()
|
| 118 |
+
property_specs.append((dtype_name, prop_name))
|
| 119 |
+
|
| 120 |
+
if vertex_count is None:
|
| 121 |
+
raise ValueError(f"Missing vertex count in PLY: {path}")
|
| 122 |
+
|
| 123 |
+
dtype_map = {
|
| 124 |
+
"float": "<f4",
|
| 125 |
+
"float32": "<f4",
|
| 126 |
+
"uchar": "u1",
|
| 127 |
+
"uint8": "u1",
|
| 128 |
+
}
|
| 129 |
+
vertex_dtype = []
|
| 130 |
+
for dtype_name, prop_name in property_specs:
|
| 131 |
+
if dtype_name not in dtype_map:
|
| 132 |
+
raise ValueError(
|
| 133 |
+
f"Unsupported PLY property type {dtype_name} in {path}"
|
| 134 |
+
)
|
| 135 |
+
vertex_dtype.append((prop_name, dtype_map[dtype_name]))
|
| 136 |
+
|
| 137 |
+
data = np.fromfile(f, dtype=np.dtype(vertex_dtype), count=vertex_count)
|
| 138 |
+
return np.stack([data["x"], data["y"], data["z"]], axis=1).astype(
|
| 139 |
+
np.float32, copy=False
|
| 140 |
+
)
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
def read_pointcloud_xyz(path):
|
| 144 |
+
ext = os.path.splitext(path)[1].lower()
|
| 145 |
+
if ext == ".npy":
|
| 146 |
+
data = np.load(path)
|
| 147 |
+
return np.asarray(data, dtype=np.float32).reshape(-1, 3)
|
| 148 |
+
if ext == ".npz":
|
| 149 |
+
data = np.load(path)
|
| 150 |
+
if "points" in data:
|
| 151 |
+
points = data["points"]
|
| 152 |
+
else:
|
| 153 |
+
first_key = next(iter(data.files))
|
| 154 |
+
points = data[first_key]
|
| 155 |
+
return np.asarray(points, dtype=np.float32).reshape(-1, 3)
|
| 156 |
+
return read_ply_xyz(path)
|
longstream/eval/metrics.py
ADDED
|
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
from scipy.spatial import cKDTree
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
def similarity_align(src, dst, with_scale=True):
|
| 6 |
+
src = np.asarray(src, dtype=np.float64)
|
| 7 |
+
dst = np.asarray(dst, dtype=np.float64)
|
| 8 |
+
if src.shape != dst.shape or src.ndim != 2 or src.shape[1] != 3:
|
| 9 |
+
raise ValueError("Expected Nx3 source and target point sets")
|
| 10 |
+
if len(src) < 3:
|
| 11 |
+
return 1.0, np.eye(3), np.zeros(3)
|
| 12 |
+
|
| 13 |
+
src_mean = src.mean(axis=0)
|
| 14 |
+
dst_mean = dst.mean(axis=0)
|
| 15 |
+
src_centered = src - src_mean
|
| 16 |
+
dst_centered = dst - dst_mean
|
| 17 |
+
|
| 18 |
+
cov = (dst_centered.T @ src_centered) / len(src)
|
| 19 |
+
U, D, Vt = np.linalg.svd(cov)
|
| 20 |
+
S = np.eye(3)
|
| 21 |
+
if np.linalg.det(U @ Vt) < 0:
|
| 22 |
+
S[-1, -1] = -1.0
|
| 23 |
+
R = U @ S @ Vt
|
| 24 |
+
|
| 25 |
+
if with_scale:
|
| 26 |
+
var = np.mean(np.sum(src_centered ** 2, axis=1))
|
| 27 |
+
scale = float(np.trace(np.diag(D) @ S) / max(var, 1e-12))
|
| 28 |
+
else:
|
| 29 |
+
scale = 1.0
|
| 30 |
+
t = dst_mean - scale * (R @ src_mean)
|
| 31 |
+
return scale, R, t
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def transform_points(points, scale, R, t):
|
| 35 |
+
return (scale * (R @ points.T)).T + t[None]
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def ate_rmse(pred_xyz, gt_xyz, align_scale=True):
|
| 39 |
+
scale, R, t = similarity_align(pred_xyz, gt_xyz, with_scale=align_scale)
|
| 40 |
+
pred_aligned = transform_points(pred_xyz, scale, R, t)
|
| 41 |
+
err = np.linalg.norm(pred_aligned - gt_xyz, axis=1)
|
| 42 |
+
return {
|
| 43 |
+
"ate_rmse": float(np.sqrt(np.mean(err ** 2))),
|
| 44 |
+
"ate_mean": float(np.mean(err)),
|
| 45 |
+
"ate_median": float(np.median(err)),
|
| 46 |
+
"num_pose_pairs": int(len(err)),
|
| 47 |
+
"align_scale": bool(align_scale),
|
| 48 |
+
"sim3_scale": float(scale),
|
| 49 |
+
"sim3_rotation": R.tolist(),
|
| 50 |
+
"sim3_translation": t.tolist(),
|
| 51 |
+
}
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def _voxel_downsample(points, voxel_size):
|
| 55 |
+
if voxel_size is None:
|
| 56 |
+
return points
|
| 57 |
+
voxel_size = float(voxel_size)
|
| 58 |
+
if voxel_size <= 0 or len(points) == 0:
|
| 59 |
+
return points
|
| 60 |
+
coords = np.floor(points / voxel_size).astype(np.int64)
|
| 61 |
+
_, keep = np.unique(coords, axis=0, return_index=True)
|
| 62 |
+
keep.sort()
|
| 63 |
+
return points[keep]
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def _sample_points(points, max_points, seed):
|
| 67 |
+
if max_points is None or len(points) <= int(max_points):
|
| 68 |
+
return points
|
| 69 |
+
rng = np.random.default_rng(seed)
|
| 70 |
+
keep = rng.choice(len(points), size=int(max_points), replace=False)
|
| 71 |
+
return points[keep]
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def prepare_pointcloud(points, max_points=None, voxel_size=None, seed=0):
|
| 75 |
+
points = np.asarray(points, dtype=np.float64).reshape(-1, 3)
|
| 76 |
+
if len(points) == 0:
|
| 77 |
+
return points
|
| 78 |
+
valid = np.isfinite(points).all(axis=1)
|
| 79 |
+
points = points[valid]
|
| 80 |
+
points = _voxel_downsample(points, voxel_size)
|
| 81 |
+
points = _sample_points(points, max_points, seed)
|
| 82 |
+
return points
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
def chamfer_and_f1(
|
| 86 |
+
pred_points, gt_points, threshold=0.25, max_points=None, voxel_size=None, seed=0
|
| 87 |
+
):
|
| 88 |
+
pred = prepare_pointcloud(
|
| 89 |
+
pred_points, max_points=max_points, voxel_size=voxel_size, seed=seed
|
| 90 |
+
)
|
| 91 |
+
gt = prepare_pointcloud(
|
| 92 |
+
gt_points, max_points=max_points, voxel_size=voxel_size, seed=seed + 1
|
| 93 |
+
)
|
| 94 |
+
if len(pred) == 0 or len(gt) == 0:
|
| 95 |
+
return None
|
| 96 |
+
|
| 97 |
+
pred_tree = cKDTree(pred)
|
| 98 |
+
gt_tree = cKDTree(gt)
|
| 99 |
+
dist_pred_to_gt, _ = gt_tree.query(pred, k=1)
|
| 100 |
+
dist_gt_to_pred, _ = pred_tree.query(gt, k=1)
|
| 101 |
+
|
| 102 |
+
acc = float(np.mean(dist_pred_to_gt))
|
| 103 |
+
comp = float(np.mean(dist_gt_to_pred))
|
| 104 |
+
precision = float(np.mean(dist_pred_to_gt < threshold))
|
| 105 |
+
recall = float(np.mean(dist_gt_to_pred < threshold))
|
| 106 |
+
denom = precision + recall
|
| 107 |
+
f1 = 0.0 if denom <= 0 else float(2.0 * precision * recall / denom)
|
| 108 |
+
return {
|
| 109 |
+
"cd": float(acc + comp),
|
| 110 |
+
"acc": acc,
|
| 111 |
+
"comp": comp,
|
| 112 |
+
"f1": f1,
|
| 113 |
+
"f1_threshold": float(threshold),
|
| 114 |
+
"num_pred_points": int(len(pred)),
|
| 115 |
+
"num_gt_points": int(len(gt)),
|
| 116 |
+
}
|
longstream/io/__init__.py
ADDED
|
File without changes
|
longstream/io/save_images.py
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import subprocess
|
| 3 |
+
from typing import List
|
| 4 |
+
|
| 5 |
+
import numpy as np
|
| 6 |
+
from PIL import Image
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def save_image_sequence(
|
| 10 |
+
path, images: List[np.ndarray], prefix: str = "frame", ext: str = "png"
|
| 11 |
+
):
|
| 12 |
+
os.makedirs(path, exist_ok=True)
|
| 13 |
+
for i, img in enumerate(images):
|
| 14 |
+
out_path = os.path.join(path, f"{prefix}_{i:06d}.{ext}")
|
| 15 |
+
Image.fromarray(img).save(out_path)
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def save_video(output_path, pattern, fps=30):
|
| 19 |
+
os.makedirs(os.path.dirname(output_path), exist_ok=True)
|
| 20 |
+
cmd = [
|
| 21 |
+
"ffmpeg",
|
| 22 |
+
"-hide_banner",
|
| 23 |
+
"-loglevel",
|
| 24 |
+
"error",
|
| 25 |
+
"-y",
|
| 26 |
+
"-framerate",
|
| 27 |
+
str(fps),
|
| 28 |
+
"-pattern_type",
|
| 29 |
+
"glob",
|
| 30 |
+
"-i",
|
| 31 |
+
pattern,
|
| 32 |
+
"-c:v",
|
| 33 |
+
"libx264",
|
| 34 |
+
"-pix_fmt",
|
| 35 |
+
"yuv420p",
|
| 36 |
+
output_path,
|
| 37 |
+
]
|
| 38 |
+
subprocess.run(cmd, check=True)
|
longstream/io/save_points.py
ADDED
|
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import numpy as np
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
def _maybe_downsample(points, colors=None, max_points=None, seed=0):
|
| 6 |
+
pts = np.asarray(points).reshape(-1, 3)
|
| 7 |
+
cols = None if colors is None else np.asarray(colors).reshape(-1, 3)
|
| 8 |
+
if max_points is None or pts.shape[0] <= int(max_points):
|
| 9 |
+
return pts, cols
|
| 10 |
+
rng = np.random.default_rng(seed)
|
| 11 |
+
keep = rng.choice(pts.shape[0], size=int(max_points), replace=False)
|
| 12 |
+
pts = pts[keep]
|
| 13 |
+
if cols is not None:
|
| 14 |
+
cols = cols[keep]
|
| 15 |
+
return pts, cols
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def save_pointcloud(path, points, colors=None, max_points=None, seed=0):
|
| 19 |
+
os.makedirs(os.path.dirname(path), exist_ok=True)
|
| 20 |
+
pts, cols = _maybe_downsample(
|
| 21 |
+
points, colors=colors, max_points=max_points, seed=seed
|
| 22 |
+
)
|
| 23 |
+
pts = pts.astype(np.float32, copy=False)
|
| 24 |
+
if colors is not None:
|
| 25 |
+
if cols.max() <= 1.0:
|
| 26 |
+
cols = (cols * 255.0).astype(np.uint8)
|
| 27 |
+
else:
|
| 28 |
+
cols = cols.astype(np.uint8)
|
| 29 |
+
has_color = True
|
| 30 |
+
else:
|
| 31 |
+
cols = None
|
| 32 |
+
has_color = False
|
| 33 |
+
|
| 34 |
+
with open(path, "wb") as f:
|
| 35 |
+
f.write(b"ply\n")
|
| 36 |
+
f.write(b"format binary_little_endian 1.0\n")
|
| 37 |
+
f.write(f"element vertex {pts.shape[0]}\n".encode("ascii"))
|
| 38 |
+
f.write(b"property float x\n")
|
| 39 |
+
f.write(b"property float y\n")
|
| 40 |
+
f.write(b"property float z\n")
|
| 41 |
+
if has_color:
|
| 42 |
+
f.write(b"property uchar red\n")
|
| 43 |
+
f.write(b"property uchar green\n")
|
| 44 |
+
f.write(b"property uchar blue\n")
|
| 45 |
+
f.write(b"end_header\n")
|
| 46 |
+
if has_color:
|
| 47 |
+
vertex_dtype = np.dtype(
|
| 48 |
+
[
|
| 49 |
+
("x", "<f4"),
|
| 50 |
+
("y", "<f4"),
|
| 51 |
+
("z", "<f4"),
|
| 52 |
+
("red", "u1"),
|
| 53 |
+
("green", "u1"),
|
| 54 |
+
("blue", "u1"),
|
| 55 |
+
]
|
| 56 |
+
)
|
| 57 |
+
vertex_data = np.empty(pts.shape[0], dtype=vertex_dtype)
|
| 58 |
+
vertex_data["x"] = pts[:, 0]
|
| 59 |
+
vertex_data["y"] = pts[:, 1]
|
| 60 |
+
vertex_data["z"] = pts[:, 2]
|
| 61 |
+
vertex_data["red"] = cols[:, 0]
|
| 62 |
+
vertex_data["green"] = cols[:, 1]
|
| 63 |
+
vertex_data["blue"] = cols[:, 2]
|
| 64 |
+
vertex_data.tofile(f)
|
| 65 |
+
else:
|
| 66 |
+
vertex_dtype = np.dtype([("x", "<f4"), ("y", "<f4"), ("z", "<f4")])
|
| 67 |
+
vertex_data = np.empty(pts.shape[0], dtype=vertex_dtype)
|
| 68 |
+
vertex_data["x"] = pts[:, 0]
|
| 69 |
+
vertex_data["y"] = pts[:, 1]
|
| 70 |
+
vertex_data["z"] = pts[:, 2]
|
| 71 |
+
vertex_data.tofile(f)
|
longstream/io/save_poses_txt.py
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import numpy as np
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
def _ensure_dir(path):
|
| 6 |
+
os.makedirs(os.path.dirname(path), exist_ok=True)
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def save_w2c_txt(path, extri, frames):
|
| 10 |
+
_ensure_dir(path)
|
| 11 |
+
with open(path, "w") as f:
|
| 12 |
+
f.write("# w2c\n")
|
| 13 |
+
for i, frame in enumerate(frames):
|
| 14 |
+
mat = extri[i]
|
| 15 |
+
r = mat[:3, :3].reshape(-1)
|
| 16 |
+
t = mat[:3, 3].reshape(-1)
|
| 17 |
+
vals = [frame] + r.tolist() + t.tolist()
|
| 18 |
+
f.write(" ".join([str(v) for v in vals]) + "\n")
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def save_intri_txt(path, intri, frames):
|
| 22 |
+
_ensure_dir(path)
|
| 23 |
+
with open(path, "w") as f:
|
| 24 |
+
f.write("# fx fy cx cy\n")
|
| 25 |
+
for i, frame in enumerate(frames):
|
| 26 |
+
k = intri[i]
|
| 27 |
+
fx = float(k[0, 0])
|
| 28 |
+
fy = float(k[1, 1])
|
| 29 |
+
cx = float(k[0, 2])
|
| 30 |
+
cy = float(k[1, 2])
|
| 31 |
+
f.write(f"{frame} {fx} {fy} {cx} {cy}\n")
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def save_rel_pose_txt(path, rel_pose_enc, frames):
|
| 35 |
+
_ensure_dir(path)
|
| 36 |
+
arr = rel_pose_enc
|
| 37 |
+
if hasattr(arr, "detach"):
|
| 38 |
+
arr = arr.detach().cpu().numpy()
|
| 39 |
+
with open(path, "w") as f:
|
| 40 |
+
f.write("# tx ty tz qx qy qz qw fov_h fov_w\n")
|
| 41 |
+
for i, frame in enumerate(frames):
|
| 42 |
+
vals = [frame] + arr[i].tolist()
|
| 43 |
+
f.write(" ".join([str(v) for v in vals]) + "\n")
|
longstream/models/__init__.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from longstream.models.longstream import LongStream
|
| 2 |
+
|
| 3 |
+
__all__ = ["LongStream"]
|
longstream/models/longstream.py
ADDED
|
@@ -0,0 +1,370 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Tuple, List, Optional, Dict
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
from huggingface_hub import PyTorchModelHubMixin
|
| 5 |
+
|
| 6 |
+
from longstream.utils.vendor.dust3r.utils.misc import freeze_all_params
|
| 7 |
+
from longstream.utils.vendor.models.components.aggregator.streamaggregator import (
|
| 8 |
+
STreamAggregator,
|
| 9 |
+
)
|
| 10 |
+
from longstream.utils.vendor.models.components.heads.camera_head import (
|
| 11 |
+
CameraHead,
|
| 12 |
+
RelPoseHead,
|
| 13 |
+
)
|
| 14 |
+
from longstream.utils.vendor.models.components.heads.dpt_head import DPTHead
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class LongStream(nn.Module, PyTorchModelHubMixin):
|
| 18 |
+
def __init__(
|
| 19 |
+
self,
|
| 20 |
+
img_size=518,
|
| 21 |
+
patch_size=14,
|
| 22 |
+
embed_dim=1024,
|
| 23 |
+
freeze="none",
|
| 24 |
+
rel_pose_head_cfg=None,
|
| 25 |
+
use_role_embedding=True,
|
| 26 |
+
enable_scale_token=False,
|
| 27 |
+
scale_token_config=None,
|
| 28 |
+
disable_keyframe_distinction=False,
|
| 29 |
+
enable_camera_head=True,
|
| 30 |
+
use_segment_mask=False,
|
| 31 |
+
use_3d_rope=False,
|
| 32 |
+
rope_freq=100,
|
| 33 |
+
window_size=5000,
|
| 34 |
+
):
|
| 35 |
+
super().__init__()
|
| 36 |
+
|
| 37 |
+
self.img_size = img_size
|
| 38 |
+
self.patch_size = patch_size
|
| 39 |
+
self.embed_dim = embed_dim
|
| 40 |
+
self.enable_scale_token = enable_scale_token
|
| 41 |
+
self.enable_camera_head = enable_camera_head
|
| 42 |
+
self.window_size = window_size
|
| 43 |
+
|
| 44 |
+
self.aggregator = STreamAggregator(
|
| 45 |
+
img_size=img_size,
|
| 46 |
+
patch_size=patch_size,
|
| 47 |
+
embed_dim=embed_dim,
|
| 48 |
+
use_role_embedding=use_role_embedding,
|
| 49 |
+
disable_keyframe_distinction=disable_keyframe_distinction,
|
| 50 |
+
use_segment_mask=use_segment_mask,
|
| 51 |
+
use_3d_rope=use_3d_rope,
|
| 52 |
+
rope_freq=rope_freq,
|
| 53 |
+
window_size=window_size,
|
| 54 |
+
)
|
| 55 |
+
|
| 56 |
+
if self.enable_camera_head:
|
| 57 |
+
self.camera_head = CameraHead(dim_in=2 * embed_dim, window_size=window_size)
|
| 58 |
+
else:
|
| 59 |
+
self.camera_head = None
|
| 60 |
+
self.point_head = DPTHead(
|
| 61 |
+
dim_in=2 * embed_dim,
|
| 62 |
+
output_dim=4,
|
| 63 |
+
activation="inv_log",
|
| 64 |
+
conf_activation="expp1",
|
| 65 |
+
)
|
| 66 |
+
self.depth_head = DPTHead(
|
| 67 |
+
dim_in=2 * embed_dim,
|
| 68 |
+
output_dim=2,
|
| 69 |
+
activation="exp",
|
| 70 |
+
conf_activation="expp1",
|
| 71 |
+
)
|
| 72 |
+
|
| 73 |
+
self.rel_pose_head = None
|
| 74 |
+
self.reinit_camera_head_when_rel_enabled = False
|
| 75 |
+
|
| 76 |
+
if rel_pose_head_cfg is not None:
|
| 77 |
+
enable = rel_pose_head_cfg.get("enabled", True)
|
| 78 |
+
if enable:
|
| 79 |
+
|
| 80 |
+
head_cfg = {
|
| 81 |
+
"dim_in": 2 * embed_dim,
|
| 82 |
+
"trunk_depth": rel_pose_head_cfg.get("trunk_depth", 4),
|
| 83 |
+
"pose_mode": rel_pose_head_cfg.get("pose_mode", "SE3"),
|
| 84 |
+
"num_heads": rel_pose_head_cfg.get("num_heads", 16),
|
| 85 |
+
"mlp_ratio": rel_pose_head_cfg.get("mlp_ratio", 4),
|
| 86 |
+
"init_values": rel_pose_head_cfg.get("init_values", 0.01),
|
| 87 |
+
"trans_act": rel_pose_head_cfg.get("trans_act", "linear"),
|
| 88 |
+
"quat_act": rel_pose_head_cfg.get("quat_act", "linear"),
|
| 89 |
+
"fl_act": rel_pose_head_cfg.get("fl_act", "relu"),
|
| 90 |
+
"use_global_scale": rel_pose_head_cfg.get(
|
| 91 |
+
"use_global_scale", False
|
| 92 |
+
),
|
| 93 |
+
"use_pair_cross_attn": rel_pose_head_cfg.get(
|
| 94 |
+
"use_pair_cross_attn", False
|
| 95 |
+
),
|
| 96 |
+
"detach_reference": rel_pose_head_cfg.get(
|
| 97 |
+
"detach_reference", False
|
| 98 |
+
),
|
| 99 |
+
"xattn_temperature": rel_pose_head_cfg.get(
|
| 100 |
+
"xattn_temperature", 1.0
|
| 101 |
+
),
|
| 102 |
+
"use_precat": rel_pose_head_cfg.get("use_precat", False),
|
| 103 |
+
"use_kf_role_embed": rel_pose_head_cfg.get(
|
| 104 |
+
"use_kf_role_embed", True
|
| 105 |
+
),
|
| 106 |
+
"kf_role_embed_init_std": rel_pose_head_cfg.get(
|
| 107 |
+
"kf_role_embed_init_std", 0.02
|
| 108 |
+
),
|
| 109 |
+
"window_size": window_size,
|
| 110 |
+
}
|
| 111 |
+
self.rel_pose_head = RelPoseHead(**head_cfg)
|
| 112 |
+
|
| 113 |
+
self.reinit_camera_head_when_rel_enabled = rel_pose_head_cfg.get(
|
| 114 |
+
"reinit_camera_head", False
|
| 115 |
+
)
|
| 116 |
+
|
| 117 |
+
if self.reinit_camera_head_when_rel_enabled:
|
| 118 |
+
pass
|
| 119 |
+
|
| 120 |
+
if self.enable_scale_token:
|
| 121 |
+
self._init_scale_components(scale_token_config or {})
|
| 122 |
+
|
| 123 |
+
self.set_freeze(freeze)
|
| 124 |
+
|
| 125 |
+
def reinitialize_camera_head(self):
|
| 126 |
+
"""
|
| 127 |
+
Reinitialize camera_head with fresh weights.
|
| 128 |
+
|
| 129 |
+
This is useful when:
|
| 130 |
+
1. Loading a pretrained checkpoint that has camera_head weights
|
| 131 |
+
2. But we want to train camera_head from scratch with new settings (e.g., quaternion normalization)
|
| 132 |
+
|
| 133 |
+
This method should be called AFTER checkpoint loading.
|
| 134 |
+
"""
|
| 135 |
+
|
| 136 |
+
old_camera_head = self.camera_head
|
| 137 |
+
dim_in = old_camera_head.token_norm.normalized_shape[0]
|
| 138 |
+
|
| 139 |
+
self.camera_head = CameraHead(dim_in=dim_in)
|
| 140 |
+
|
| 141 |
+
device = next(old_camera_head.parameters()).device
|
| 142 |
+
self.camera_head = self.camera_head.to(device)
|
| 143 |
+
|
| 144 |
+
def _init_scale_components(self, config):
|
| 145 |
+
self.scale_token = nn.Parameter(torch.zeros(self.embed_dim))
|
| 146 |
+
torch.nn.init.trunc_normal_(self.scale_token, std=0.02)
|
| 147 |
+
|
| 148 |
+
self.scale_head = nn.Sequential(
|
| 149 |
+
nn.Linear(2 * self.embed_dim, 256),
|
| 150 |
+
nn.ReLU(),
|
| 151 |
+
nn.Linear(256, 128),
|
| 152 |
+
nn.ReLU(),
|
| 153 |
+
nn.Linear(128, 1),
|
| 154 |
+
)
|
| 155 |
+
|
| 156 |
+
for m in self.scale_head.modules():
|
| 157 |
+
if isinstance(m, nn.Linear):
|
| 158 |
+
nn.init.xavier_uniform_(m.weight, gain=1.0)
|
| 159 |
+
if m.bias is not None:
|
| 160 |
+
nn.init.constant_(m.bias, 0.0)
|
| 161 |
+
|
| 162 |
+
import math
|
| 163 |
+
|
| 164 |
+
nn.init.constant_(self.scale_head[-1].bias, math.log(30.0))
|
| 165 |
+
|
| 166 |
+
def set_freeze(self, freeze):
|
| 167 |
+
self.freeze = freeze
|
| 168 |
+
|
| 169 |
+
to_be_frozen = {
|
| 170 |
+
"none": [],
|
| 171 |
+
"encoder": [self.aggregator.patch_embed],
|
| 172 |
+
}
|
| 173 |
+
freeze_all_params(to_be_frozen[freeze])
|
| 174 |
+
|
| 175 |
+
def forward(
|
| 176 |
+
self,
|
| 177 |
+
images: torch.Tensor,
|
| 178 |
+
mode: str = "causal",
|
| 179 |
+
aggregator_kv_cache_list: Optional[List[List[torch.Tensor]]] = None,
|
| 180 |
+
camera_head_kv_cache_list: Optional[List[List[List[torch.Tensor]]]] = None,
|
| 181 |
+
rel_pose_inputs: Optional[Dict] = None,
|
| 182 |
+
is_keyframe: Optional[torch.Tensor] = None,
|
| 183 |
+
):
|
| 184 |
+
|
| 185 |
+
if len(images.shape) == 4:
|
| 186 |
+
images = images.unsqueeze(0)
|
| 187 |
+
|
| 188 |
+
batch_size = images.shape[0]
|
| 189 |
+
|
| 190 |
+
additional_tokens = None
|
| 191 |
+
if self.enable_scale_token:
|
| 192 |
+
|
| 193 |
+
scale_token_base = self.scale_token.unsqueeze(0).repeat(batch_size, 1)
|
| 194 |
+
additional_tokens = scale_token_base.unsqueeze(-1)
|
| 195 |
+
|
| 196 |
+
keyframe_indices = None
|
| 197 |
+
if rel_pose_inputs is not None and "keyframe_indices" in rel_pose_inputs:
|
| 198 |
+
keyframe_indices = rel_pose_inputs["keyframe_indices"]
|
| 199 |
+
|
| 200 |
+
if aggregator_kv_cache_list is not None:
|
| 201 |
+
(
|
| 202 |
+
aggregated_tokens_list,
|
| 203 |
+
patch_start_idx,
|
| 204 |
+
aggregator_kv_cache_list,
|
| 205 |
+
_,
|
| 206 |
+
) = self.aggregator(
|
| 207 |
+
images,
|
| 208 |
+
mode=mode,
|
| 209 |
+
kv_cache_list=aggregator_kv_cache_list,
|
| 210 |
+
is_keyframe=is_keyframe,
|
| 211 |
+
keyframe_indices=keyframe_indices,
|
| 212 |
+
additional_tokens=additional_tokens,
|
| 213 |
+
reorder_keyframes_first=False,
|
| 214 |
+
)
|
| 215 |
+
else:
|
| 216 |
+
aggregated_tokens_list, patch_start_idx, _ = self.aggregator(
|
| 217 |
+
images,
|
| 218 |
+
mode=mode,
|
| 219 |
+
is_keyframe=is_keyframe,
|
| 220 |
+
keyframe_indices=keyframe_indices,
|
| 221 |
+
additional_tokens=additional_tokens,
|
| 222 |
+
reorder_keyframes_first=False,
|
| 223 |
+
)
|
| 224 |
+
|
| 225 |
+
predictions = {}
|
| 226 |
+
|
| 227 |
+
predicted_scale_factor = None
|
| 228 |
+
if self.enable_scale_token and additional_tokens is not None:
|
| 229 |
+
|
| 230 |
+
if len(aggregated_tokens_list) > 0:
|
| 231 |
+
last_layer_features = aggregated_tokens_list[-1]
|
| 232 |
+
|
| 233 |
+
scale_token_idx = patch_start_idx - 1
|
| 234 |
+
scale_token_output_features = last_layer_features[
|
| 235 |
+
:, :, scale_token_idx, :
|
| 236 |
+
]
|
| 237 |
+
|
| 238 |
+
scale_token_output_features = scale_token_output_features.mean(dim=1)
|
| 239 |
+
|
| 240 |
+
scale_logits = self.scale_head(scale_token_output_features).squeeze(-1)
|
| 241 |
+
|
| 242 |
+
predicted_scale_factor = torch.exp(scale_logits)
|
| 243 |
+
|
| 244 |
+
predictions["predicted_scale_factor"] = predicted_scale_factor
|
| 245 |
+
predictions["scale_token_features"] = scale_token_output_features
|
| 246 |
+
|
| 247 |
+
if self.enable_camera_head and self.camera_head is not None:
|
| 248 |
+
if camera_head_kv_cache_list is not None:
|
| 249 |
+
pose_enc_list, camera_head_kv_cache_list = self.camera_head(
|
| 250 |
+
aggregated_tokens_list,
|
| 251 |
+
mode=mode,
|
| 252 |
+
kv_cache_list=camera_head_kv_cache_list,
|
| 253 |
+
)
|
| 254 |
+
else:
|
| 255 |
+
pose_enc_list = self.camera_head(aggregated_tokens_list, mode=mode)
|
| 256 |
+
|
| 257 |
+
final_pose_enc = pose_enc_list[-1]
|
| 258 |
+
if self.enable_scale_token and predicted_scale_factor is not None:
|
| 259 |
+
scale = predicted_scale_factor.view(-1, 1, 1)
|
| 260 |
+
|
| 261 |
+
scaled_t = final_pose_enc[..., :3] * scale
|
| 262 |
+
scaled_pose_enc = torch.cat([scaled_t, final_pose_enc[..., 3:]], dim=-1)
|
| 263 |
+
predictions["pose_enc"] = scaled_pose_enc
|
| 264 |
+
else:
|
| 265 |
+
predictions["pose_enc"] = final_pose_enc
|
| 266 |
+
|
| 267 |
+
if self.training:
|
| 268 |
+
|
| 269 |
+
if self.enable_scale_token and predicted_scale_factor is not None:
|
| 270 |
+
scale = predicted_scale_factor.view(-1, 1, 1)
|
| 271 |
+
scaled_pose_enc_list = []
|
| 272 |
+
for pose_enc in pose_enc_list:
|
| 273 |
+
|
| 274 |
+
scaled_t = pose_enc[..., :3] * scale
|
| 275 |
+
scaled_pose_enc = torch.cat(
|
| 276 |
+
[scaled_t, pose_enc[..., 3:]], dim=-1
|
| 277 |
+
)
|
| 278 |
+
scaled_pose_enc_list.append(scaled_pose_enc)
|
| 279 |
+
predictions["pose_enc_list"] = scaled_pose_enc_list
|
| 280 |
+
else:
|
| 281 |
+
predictions["pose_enc_list"] = pose_enc_list
|
| 282 |
+
|
| 283 |
+
if self.rel_pose_head is not None and rel_pose_inputs is not None:
|
| 284 |
+
|
| 285 |
+
rel_kwargs = dict(
|
| 286 |
+
aggregated_tokens_list=aggregated_tokens_list,
|
| 287 |
+
keyframe_indices=rel_pose_inputs.get("keyframe_indices"),
|
| 288 |
+
is_keyframe=rel_pose_inputs.get("is_keyframe", is_keyframe),
|
| 289 |
+
num_iterations=rel_pose_inputs.get("num_iterations", 4),
|
| 290 |
+
mode=mode,
|
| 291 |
+
kv_cache_list=rel_pose_inputs.get("kv_cache_list"),
|
| 292 |
+
)
|
| 293 |
+
|
| 294 |
+
rel_kwargs = {k: v for k, v in rel_kwargs.items() if v is not None}
|
| 295 |
+
|
| 296 |
+
rel_result = self.rel_pose_head(**rel_kwargs)
|
| 297 |
+
|
| 298 |
+
if isinstance(rel_result, dict):
|
| 299 |
+
|
| 300 |
+
pose_enc = rel_result["pose_enc"]
|
| 301 |
+
if pose_enc.dtype != torch.float32:
|
| 302 |
+
pose_enc = pose_enc.float()
|
| 303 |
+
|
| 304 |
+
if self.enable_scale_token and predicted_scale_factor is not None:
|
| 305 |
+
scale = predicted_scale_factor.view(-1, 1, 1)
|
| 306 |
+
|
| 307 |
+
scaled_t = pose_enc[..., :3] * scale
|
| 308 |
+
scaled_rel_pose_enc = torch.cat(
|
| 309 |
+
[scaled_t, pose_enc[..., 3:]], dim=-1
|
| 310 |
+
)
|
| 311 |
+
predictions["rel_pose_enc"] = scaled_rel_pose_enc
|
| 312 |
+
|
| 313 |
+
if "pose_enc_list" in rel_result:
|
| 314 |
+
scaled_pose_enc_list = []
|
| 315 |
+
for iter_pose in rel_result["pose_enc_list"]:
|
| 316 |
+
scaled_t = iter_pose[..., :3] * scale
|
| 317 |
+
scaled_iter_pose = torch.cat(
|
| 318 |
+
[scaled_t, iter_pose[..., 3:]], dim=-1
|
| 319 |
+
)
|
| 320 |
+
scaled_pose_enc_list.append(scaled_iter_pose)
|
| 321 |
+
predictions["rel_pose_enc_list"] = scaled_pose_enc_list
|
| 322 |
+
else:
|
| 323 |
+
predictions["rel_pose_enc"] = pose_enc
|
| 324 |
+
|
| 325 |
+
if "pose_enc_list" in rel_result:
|
| 326 |
+
predictions["rel_pose_enc_list"] = rel_result["pose_enc_list"]
|
| 327 |
+
|
| 328 |
+
predictions["is_keyframe"] = rel_result.get("is_keyframe")
|
| 329 |
+
predictions["keyframe_indices"] = rel_result.get("keyframe_indices")
|
| 330 |
+
|
| 331 |
+
if "global_scale" in rel_result:
|
| 332 |
+
predictions["global_scale"] = rel_result["global_scale"]
|
| 333 |
+
|
| 334 |
+
if "kv_cache_list" in rel_result:
|
| 335 |
+
predictions["rel_pose_kv_cache_list"] = rel_result["kv_cache_list"]
|
| 336 |
+
|
| 337 |
+
if self.point_head is not None:
|
| 338 |
+
pts3d, pts3d_conf = self.point_head(
|
| 339 |
+
aggregated_tokens_list, images=images, patch_start_idx=patch_start_idx
|
| 340 |
+
)
|
| 341 |
+
|
| 342 |
+
if self.enable_scale_token and predicted_scale_factor is not None:
|
| 343 |
+
scale = predicted_scale_factor.view(-1, 1, 1, 1, 1)
|
| 344 |
+
predictions["world_points"] = pts3d * scale
|
| 345 |
+
else:
|
| 346 |
+
predictions["world_points"] = pts3d
|
| 347 |
+
predictions["world_points_conf"] = pts3d_conf
|
| 348 |
+
|
| 349 |
+
if self.depth_head is not None:
|
| 350 |
+
depth, depth_conf = self.depth_head(
|
| 351 |
+
aggregated_tokens_list, images=images, patch_start_idx=patch_start_idx
|
| 352 |
+
)
|
| 353 |
+
|
| 354 |
+
if self.enable_scale_token and predicted_scale_factor is not None:
|
| 355 |
+
scale = predicted_scale_factor.view(-1, 1, 1, 1, 1)
|
| 356 |
+
predictions["depth"] = depth * scale
|
| 357 |
+
else:
|
| 358 |
+
predictions["depth"] = depth
|
| 359 |
+
predictions["depth_conf"] = depth_conf
|
| 360 |
+
|
| 361 |
+
if aggregator_kv_cache_list is not None:
|
| 362 |
+
predictions["aggregator_kv_cache_list"] = aggregator_kv_cache_list
|
| 363 |
+
|
| 364 |
+
if camera_head_kv_cache_list is not None:
|
| 365 |
+
predictions["camera_head_kv_cache_list"] = camera_head_kv_cache_list
|
| 366 |
+
|
| 367 |
+
if not self.training:
|
| 368 |
+
predictions["images"] = images
|
| 369 |
+
|
| 370 |
+
return predictions
|
longstream/streaming/__init__.py
ADDED
|
File without changes
|
longstream/streaming/keyframe_selector.py
ADDED
|
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import random
|
| 2 |
+
import torch
|
| 3 |
+
from typing import Optional, Tuple
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class KeyframeSelector:
|
| 7 |
+
def __init__(
|
| 8 |
+
self,
|
| 9 |
+
min_interval: int = 8,
|
| 10 |
+
max_interval: int = 8,
|
| 11 |
+
force_first: bool = True,
|
| 12 |
+
motion_threshold: Optional[float] = None,
|
| 13 |
+
mode: str = "fixed",
|
| 14 |
+
):
|
| 15 |
+
self.min_interval = int(min_interval)
|
| 16 |
+
self.max_interval = int(max_interval)
|
| 17 |
+
self.force_first = bool(force_first)
|
| 18 |
+
self.motion_threshold = motion_threshold
|
| 19 |
+
self.mode = mode
|
| 20 |
+
|
| 21 |
+
def select_keyframes(
|
| 22 |
+
self,
|
| 23 |
+
sequence_length: int,
|
| 24 |
+
batch_size: int = 1,
|
| 25 |
+
device: Optional[torch.device] = None,
|
| 26 |
+
poses: Optional[torch.Tensor] = None,
|
| 27 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 28 |
+
device = device or torch.device("cpu")
|
| 29 |
+
is_keyframe = torch.zeros(
|
| 30 |
+
batch_size, sequence_length, dtype=torch.bool, device=device
|
| 31 |
+
)
|
| 32 |
+
keyframe_indices = torch.zeros(
|
| 33 |
+
batch_size, sequence_length, dtype=torch.long, device=device
|
| 34 |
+
)
|
| 35 |
+
|
| 36 |
+
for b in range(batch_size):
|
| 37 |
+
last_keyframe_idx = 0
|
| 38 |
+
next_keyframe_target = None
|
| 39 |
+
|
| 40 |
+
if self.force_first or sequence_length == 1:
|
| 41 |
+
is_keyframe[b, 0] = True
|
| 42 |
+
keyframe_indices[b, 0] = 0
|
| 43 |
+
if self.mode == "random":
|
| 44 |
+
interval = random.randint(self.min_interval, self.max_interval)
|
| 45 |
+
next_keyframe_target = interval
|
| 46 |
+
|
| 47 |
+
for s in range(1, sequence_length):
|
| 48 |
+
keyframe_indices[b, s] = last_keyframe_idx
|
| 49 |
+
frames_since_last = s - last_keyframe_idx
|
| 50 |
+
|
| 51 |
+
if self.mode == "random" and next_keyframe_target is not None:
|
| 52 |
+
if s >= next_keyframe_target:
|
| 53 |
+
is_keyframe[b, s] = True
|
| 54 |
+
last_keyframe_idx = s
|
| 55 |
+
interval = random.randint(self.min_interval, self.max_interval)
|
| 56 |
+
next_keyframe_target = s + interval
|
| 57 |
+
elif frames_since_last >= self.max_interval:
|
| 58 |
+
is_keyframe[b, s] = True
|
| 59 |
+
last_keyframe_idx = s
|
| 60 |
+
if self.mode == "random":
|
| 61 |
+
interval = random.randint(self.min_interval, self.max_interval)
|
| 62 |
+
next_keyframe_target = s + interval
|
| 63 |
+
elif (
|
| 64 |
+
frames_since_last >= self.min_interval
|
| 65 |
+
and poses is not None
|
| 66 |
+
and self.motion_threshold is not None
|
| 67 |
+
):
|
| 68 |
+
motion = torch.norm(
|
| 69 |
+
poses[b, s, :3] - poses[b, last_keyframe_idx, :3]
|
| 70 |
+
).item()
|
| 71 |
+
if motion > self.motion_threshold:
|
| 72 |
+
is_keyframe[b, s] = True
|
| 73 |
+
last_keyframe_idx = s
|
| 74 |
+
if self.mode == "random":
|
| 75 |
+
interval = random.randint(
|
| 76 |
+
self.min_interval, self.max_interval
|
| 77 |
+
)
|
| 78 |
+
next_keyframe_target = s + interval
|
| 79 |
+
|
| 80 |
+
return is_keyframe, keyframe_indices
|
longstream/streaming/refresh.py
ADDED
|
@@ -0,0 +1,217 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from typing import Dict, Any, List
|
| 3 |
+
|
| 4 |
+
from longstream.streaming.stream_session import StreamSession
|
| 5 |
+
|
| 6 |
+
_SEQUENCE_OUTPUT_KEYS = {
|
| 7 |
+
"pose_enc",
|
| 8 |
+
"rel_pose_enc",
|
| 9 |
+
"world_points",
|
| 10 |
+
"world_points_conf",
|
| 11 |
+
"depth",
|
| 12 |
+
"depth_conf",
|
| 13 |
+
}
|
| 14 |
+
_SCALAR_OUTPUT_KEYS = {
|
| 15 |
+
"predicted_scale_factor",
|
| 16 |
+
"global_scale",
|
| 17 |
+
}
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def _refresh_intervals(refresh: int) -> int:
|
| 21 |
+
refresh = int(refresh)
|
| 22 |
+
if refresh < 2:
|
| 23 |
+
raise ValueError("refresh must be >= 2")
|
| 24 |
+
return refresh - 1
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def _model_device(model) -> torch.device:
|
| 28 |
+
return next(model.parameters()).device
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def _move_scalar_to_cpu(value: Any) -> Any:
|
| 32 |
+
if isinstance(value, torch.Tensor):
|
| 33 |
+
return value.detach().cpu()
|
| 34 |
+
return value
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def _append_batch_output(
|
| 38 |
+
stitched_tensors: Dict[str, List[torch.Tensor]],
|
| 39 |
+
stitched_scalars: Dict[str, Any],
|
| 40 |
+
output: Dict[str, Any],
|
| 41 |
+
actual_frames: int,
|
| 42 |
+
slice_start: int,
|
| 43 |
+
) -> None:
|
| 44 |
+
for key in _SEQUENCE_OUTPUT_KEYS:
|
| 45 |
+
value = output.get(key)
|
| 46 |
+
if not isinstance(value, torch.Tensor):
|
| 47 |
+
continue
|
| 48 |
+
if value.ndim < 2 or value.shape[1] != actual_frames:
|
| 49 |
+
continue
|
| 50 |
+
stitched_tensors.setdefault(key, []).append(
|
| 51 |
+
value[:, slice_start:].detach().cpu()
|
| 52 |
+
)
|
| 53 |
+
|
| 54 |
+
for key in _SCALAR_OUTPUT_KEYS:
|
| 55 |
+
if key in output:
|
| 56 |
+
stitched_scalars[key] = _move_scalar_to_cpu(output[key])
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def _finalize_stitched_batches(
|
| 60 |
+
stitched_tensors: Dict[str, List[torch.Tensor]],
|
| 61 |
+
stitched_scalars: Dict[str, Any],
|
| 62 |
+
) -> Dict[str, Any]:
|
| 63 |
+
stitched_output: Dict[str, Any] = {}
|
| 64 |
+
for key, chunks in stitched_tensors.items():
|
| 65 |
+
if not chunks:
|
| 66 |
+
continue
|
| 67 |
+
stitched_output[key] = (
|
| 68 |
+
chunks[0] if len(chunks) == 1 else torch.cat(chunks, dim=1)
|
| 69 |
+
)
|
| 70 |
+
stitched_output.update(stitched_scalars)
|
| 71 |
+
return stitched_output
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def run_batch_refresh(
|
| 75 |
+
model,
|
| 76 |
+
images,
|
| 77 |
+
is_keyframe,
|
| 78 |
+
keyframe_indices,
|
| 79 |
+
mode: str,
|
| 80 |
+
keyframe_stride: int,
|
| 81 |
+
refresh: int,
|
| 82 |
+
rel_pose_cfg,
|
| 83 |
+
):
|
| 84 |
+
B, S = images.shape[:2]
|
| 85 |
+
device = _model_device(model)
|
| 86 |
+
refresh_intervals = _refresh_intervals(refresh)
|
| 87 |
+
frames_per_batch = refresh_intervals * keyframe_stride + 1
|
| 88 |
+
step_frames = refresh_intervals * keyframe_stride
|
| 89 |
+
|
| 90 |
+
stitched_tensors: Dict[str, List[torch.Tensor]] = {}
|
| 91 |
+
stitched_scalars: Dict[str, Any] = {}
|
| 92 |
+
num_batches = (S + step_frames - 1) // step_frames
|
| 93 |
+
for batch_idx in range(num_batches):
|
| 94 |
+
start_frame = batch_idx * step_frames
|
| 95 |
+
end_frame = min(start_frame + frames_per_batch, S)
|
| 96 |
+
batch_images = images[:, start_frame:end_frame].to(device, non_blocking=True)
|
| 97 |
+
batch_is_keyframe = (
|
| 98 |
+
is_keyframe[:, start_frame:end_frame].clone()
|
| 99 |
+
if is_keyframe is not None
|
| 100 |
+
else None
|
| 101 |
+
)
|
| 102 |
+
batch_keyframe_indices = (
|
| 103 |
+
keyframe_indices[:, start_frame:end_frame].clone()
|
| 104 |
+
if keyframe_indices is not None
|
| 105 |
+
else None
|
| 106 |
+
)
|
| 107 |
+
|
| 108 |
+
if batch_idx > 0 and batch_is_keyframe is not None:
|
| 109 |
+
batch_is_keyframe[:, 0] = True
|
| 110 |
+
if batch_keyframe_indices is not None:
|
| 111 |
+
batch_keyframe_indices[:, 0] = start_frame
|
| 112 |
+
|
| 113 |
+
if batch_keyframe_indices is not None:
|
| 114 |
+
batch_keyframe_indices = batch_keyframe_indices - start_frame
|
| 115 |
+
batch_keyframe_indices = torch.clamp(
|
| 116 |
+
batch_keyframe_indices, 0, end_frame - start_frame - 1
|
| 117 |
+
)
|
| 118 |
+
|
| 119 |
+
batch_rel_pose_inputs = None
|
| 120 |
+
if rel_pose_cfg is not None and batch_is_keyframe is not None:
|
| 121 |
+
batch_is_keyframe = batch_is_keyframe.to(device, non_blocking=True)
|
| 122 |
+
if batch_keyframe_indices is not None:
|
| 123 |
+
batch_keyframe_indices = batch_keyframe_indices.to(
|
| 124 |
+
device, non_blocking=True
|
| 125 |
+
)
|
| 126 |
+
batch_rel_pose_inputs = {
|
| 127 |
+
"is_keyframe": batch_is_keyframe,
|
| 128 |
+
"keyframe_indices": batch_keyframe_indices,
|
| 129 |
+
"num_iterations": rel_pose_cfg.get("num_iterations", 4),
|
| 130 |
+
}
|
| 131 |
+
elif batch_is_keyframe is not None:
|
| 132 |
+
batch_is_keyframe = batch_is_keyframe.to(device, non_blocking=True)
|
| 133 |
+
|
| 134 |
+
batch_output = model(
|
| 135 |
+
images=batch_images,
|
| 136 |
+
mode=mode,
|
| 137 |
+
rel_pose_inputs=batch_rel_pose_inputs,
|
| 138 |
+
is_keyframe=batch_is_keyframe,
|
| 139 |
+
)
|
| 140 |
+
|
| 141 |
+
_append_batch_output(
|
| 142 |
+
stitched_tensors,
|
| 143 |
+
stitched_scalars,
|
| 144 |
+
batch_output,
|
| 145 |
+
actual_frames=end_frame - start_frame,
|
| 146 |
+
slice_start=0 if batch_idx == 0 else 1,
|
| 147 |
+
)
|
| 148 |
+
del batch_output
|
| 149 |
+
del batch_images
|
| 150 |
+
del batch_is_keyframe
|
| 151 |
+
del batch_keyframe_indices
|
| 152 |
+
|
| 153 |
+
return _finalize_stitched_batches(stitched_tensors, stitched_scalars)
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
def run_streaming_refresh(
|
| 157 |
+
model,
|
| 158 |
+
images,
|
| 159 |
+
is_keyframe,
|
| 160 |
+
keyframe_indices,
|
| 161 |
+
mode: str,
|
| 162 |
+
window_size: int,
|
| 163 |
+
refresh: int,
|
| 164 |
+
rel_pose_cfg,
|
| 165 |
+
):
|
| 166 |
+
B, S = images.shape[:2]
|
| 167 |
+
device = _model_device(model)
|
| 168 |
+
refresh_intervals = _refresh_intervals(refresh)
|
| 169 |
+
session = StreamSession(model, mode=mode, window_size=window_size)
|
| 170 |
+
keyframe_count = 0
|
| 171 |
+
segment_start = 0
|
| 172 |
+
for s in range(S):
|
| 173 |
+
frame_images = images[:, s : s + 1].to(device, non_blocking=True)
|
| 174 |
+
is_keyframe_s = (
|
| 175 |
+
is_keyframe[:, s : s + 1].to(device, non_blocking=True)
|
| 176 |
+
if is_keyframe is not None
|
| 177 |
+
else None
|
| 178 |
+
)
|
| 179 |
+
if keyframe_indices is not None:
|
| 180 |
+
keyframe_indices_s = keyframe_indices[:, s : s + 1].clone() - segment_start
|
| 181 |
+
keyframe_indices_s = torch.clamp(keyframe_indices_s, min=0)
|
| 182 |
+
keyframe_indices_s = keyframe_indices_s.to(device, non_blocking=True)
|
| 183 |
+
else:
|
| 184 |
+
keyframe_indices_s = None
|
| 185 |
+
session.forward_stream(
|
| 186 |
+
frame_images,
|
| 187 |
+
is_keyframe=is_keyframe_s,
|
| 188 |
+
keyframe_indices=keyframe_indices_s,
|
| 189 |
+
record=True,
|
| 190 |
+
)
|
| 191 |
+
if is_keyframe_s is None or not bool(is_keyframe_s.item()) or s <= 0:
|
| 192 |
+
del frame_images
|
| 193 |
+
if is_keyframe_s is not None:
|
| 194 |
+
del is_keyframe_s
|
| 195 |
+
if keyframe_indices_s is not None:
|
| 196 |
+
del keyframe_indices_s
|
| 197 |
+
continue
|
| 198 |
+
keyframe_count += 1
|
| 199 |
+
if keyframe_count % refresh_intervals == 0:
|
| 200 |
+
session.clear_cache_only()
|
| 201 |
+
segment_start = s
|
| 202 |
+
if keyframe_indices_s is not None:
|
| 203 |
+
keyframe_indices_self = torch.zeros_like(keyframe_indices_s)
|
| 204 |
+
else:
|
| 205 |
+
keyframe_indices_self = None
|
| 206 |
+
session.forward_stream(
|
| 207 |
+
frame_images,
|
| 208 |
+
is_keyframe=is_keyframe_s,
|
| 209 |
+
keyframe_indices=keyframe_indices_self,
|
| 210 |
+
record=False,
|
| 211 |
+
)
|
| 212 |
+
del frame_images
|
| 213 |
+
if is_keyframe_s is not None:
|
| 214 |
+
del is_keyframe_s
|
| 215 |
+
if keyframe_indices_s is not None:
|
| 216 |
+
del keyframe_indices_s
|
| 217 |
+
return session.get_all_predictions()
|
longstream/streaming/stream_session.py
ADDED
|
@@ -0,0 +1,294 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
class StreamSession:
|
| 5 |
+
def __init__(
|
| 6 |
+
self,
|
| 7 |
+
model,
|
| 8 |
+
mode: str,
|
| 9 |
+
window_size: int = 5,
|
| 10 |
+
keep_first_frame_anchor: bool = True,
|
| 11 |
+
):
|
| 12 |
+
self.model = model
|
| 13 |
+
self.core_model = getattr(model, "longstream", model)
|
| 14 |
+
self.mode = mode
|
| 15 |
+
self.window_size = window_size
|
| 16 |
+
self.keep_first_frame_anchor = keep_first_frame_anchor
|
| 17 |
+
|
| 18 |
+
if self.mode not in ["causal", "window"]:
|
| 19 |
+
raise ValueError(f"Unsupported attention mode: {self.mode}")
|
| 20 |
+
|
| 21 |
+
self.aggregator_kv_cache_depth = self.core_model.aggregator.depth
|
| 22 |
+
self.use_camera_head = self.core_model.camera_head is not None
|
| 23 |
+
if self.use_camera_head:
|
| 24 |
+
self.camera_head_kv_cache_depth = self.core_model.camera_head.trunk_depth
|
| 25 |
+
self.camera_head_iterations = 4
|
| 26 |
+
else:
|
| 27 |
+
self.camera_head_kv_cache_depth = 0
|
| 28 |
+
self.camera_head_iterations = 0
|
| 29 |
+
|
| 30 |
+
self.use_rel_pose_head = (
|
| 31 |
+
hasattr(self.core_model, "rel_pose_head")
|
| 32 |
+
and self.core_model.rel_pose_head is not None
|
| 33 |
+
)
|
| 34 |
+
if self.use_rel_pose_head:
|
| 35 |
+
self.rel_pose_head_trunk_depth = self.core_model.rel_pose_head.trunk_depth
|
| 36 |
+
self.rel_pose_head_iterations = 4
|
| 37 |
+
|
| 38 |
+
self.clear()
|
| 39 |
+
|
| 40 |
+
def _clear_predictions(self):
|
| 41 |
+
self.sequence_predictions = {}
|
| 42 |
+
self.scalar_predictions = {}
|
| 43 |
+
|
| 44 |
+
def _update_predictions(self, predictions):
|
| 45 |
+
sequence_keys = [
|
| 46 |
+
"pose_enc",
|
| 47 |
+
"rel_pose_enc",
|
| 48 |
+
"world_points",
|
| 49 |
+
"world_points_conf",
|
| 50 |
+
"depth",
|
| 51 |
+
"depth_conf",
|
| 52 |
+
]
|
| 53 |
+
scalar_keys = ["predicted_scale_factor", "global_scale"]
|
| 54 |
+
|
| 55 |
+
for k in sequence_keys:
|
| 56 |
+
if k in predictions:
|
| 57 |
+
self.sequence_predictions.setdefault(k, []).append(
|
| 58 |
+
predictions[k].detach().cpu()
|
| 59 |
+
)
|
| 60 |
+
|
| 61 |
+
for k in scalar_keys:
|
| 62 |
+
if k in predictions:
|
| 63 |
+
value = predictions[k]
|
| 64 |
+
self.scalar_predictions[k] = (
|
| 65 |
+
value.detach().cpu() if isinstance(value, torch.Tensor) else value
|
| 66 |
+
)
|
| 67 |
+
|
| 68 |
+
def _clear_cache(self):
|
| 69 |
+
self.aggregator_kv_cache_list = [
|
| 70 |
+
[None, None] for _ in range(self.aggregator_kv_cache_depth)
|
| 71 |
+
]
|
| 72 |
+
if self.use_camera_head:
|
| 73 |
+
self.camera_head_kv_cache_list = [
|
| 74 |
+
[[None, None] for _ in range(self.camera_head_kv_cache_depth)]
|
| 75 |
+
for _ in range(self.camera_head_iterations)
|
| 76 |
+
]
|
| 77 |
+
else:
|
| 78 |
+
self.camera_head_kv_cache_list = None
|
| 79 |
+
if self.use_rel_pose_head:
|
| 80 |
+
self.rel_pose_kv_cache_list = [
|
| 81 |
+
[[None, None] for _ in range(self.rel_pose_head_trunk_depth)]
|
| 82 |
+
for _ in range(self.rel_pose_head_iterations)
|
| 83 |
+
]
|
| 84 |
+
else:
|
| 85 |
+
self.rel_pose_kv_cache_list = None
|
| 86 |
+
|
| 87 |
+
def _update_cache(
|
| 88 |
+
self, aggregator_kv_cache_list, camera_head_kv_cache_list, frame_hw
|
| 89 |
+
):
|
| 90 |
+
if self.mode == "causal":
|
| 91 |
+
self.aggregator_kv_cache_list = aggregator_kv_cache_list
|
| 92 |
+
if self.use_camera_head:
|
| 93 |
+
self.camera_head_kv_cache_list = camera_head_kv_cache_list
|
| 94 |
+
return
|
| 95 |
+
|
| 96 |
+
if self.mode == "window":
|
| 97 |
+
h, w = frame_hw
|
| 98 |
+
P = (
|
| 99 |
+
h
|
| 100 |
+
* w
|
| 101 |
+
// self.core_model.aggregator.patch_size
|
| 102 |
+
// self.core_model.aggregator.patch_size
|
| 103 |
+
+ self.core_model.aggregator.patch_start_idx
|
| 104 |
+
)
|
| 105 |
+
|
| 106 |
+
for k in range(2):
|
| 107 |
+
for i in range(self.aggregator_kv_cache_depth):
|
| 108 |
+
cache_size = aggregator_kv_cache_list[i][k].size(2)
|
| 109 |
+
if self.keep_first_frame_anchor:
|
| 110 |
+
if cache_size <= P:
|
| 111 |
+
self.aggregator_kv_cache_list[i][
|
| 112 |
+
k
|
| 113 |
+
] = aggregator_kv_cache_list[i][k].contiguous()
|
| 114 |
+
elif cache_size <= self.window_size * P:
|
| 115 |
+
self.aggregator_kv_cache_list[i][
|
| 116 |
+
k
|
| 117 |
+
] = aggregator_kv_cache_list[i][k].contiguous()
|
| 118 |
+
else:
|
| 119 |
+
anchor = aggregator_kv_cache_list[i][k][:, :, :P]
|
| 120 |
+
recent_start = cache_size - (self.window_size - 1) * P
|
| 121 |
+
recent = aggregator_kv_cache_list[i][k][:, :, recent_start:]
|
| 122 |
+
self.aggregator_kv_cache_list[i][k] = torch.cat(
|
| 123 |
+
[anchor, recent], dim=2
|
| 124 |
+
).contiguous()
|
| 125 |
+
else:
|
| 126 |
+
start_idx = max(0, cache_size - self.window_size * P)
|
| 127 |
+
self.aggregator_kv_cache_list[i][k] = aggregator_kv_cache_list[
|
| 128 |
+
i
|
| 129 |
+
][k][:, :, start_idx:].contiguous()
|
| 130 |
+
|
| 131 |
+
if camera_head_kv_cache_list is not None:
|
| 132 |
+
for k in range(2):
|
| 133 |
+
for i in range(self.camera_head_iterations):
|
| 134 |
+
for j in range(self.camera_head_kv_cache_depth):
|
| 135 |
+
cache_size = camera_head_kv_cache_list[i][j][k].size(2)
|
| 136 |
+
if self.keep_first_frame_anchor:
|
| 137 |
+
if cache_size <= 1:
|
| 138 |
+
self.camera_head_kv_cache_list[i][j][
|
| 139 |
+
k
|
| 140 |
+
] = camera_head_kv_cache_list[i][j][k].contiguous()
|
| 141 |
+
elif cache_size <= self.window_size:
|
| 142 |
+
self.camera_head_kv_cache_list[i][j][
|
| 143 |
+
k
|
| 144 |
+
] = camera_head_kv_cache_list[i][j][k].contiguous()
|
| 145 |
+
else:
|
| 146 |
+
anchor = camera_head_kv_cache_list[i][j][k][
|
| 147 |
+
:, :, :1
|
| 148 |
+
]
|
| 149 |
+
recent_start = cache_size - (self.window_size - 1)
|
| 150 |
+
recent = camera_head_kv_cache_list[i][j][k][
|
| 151 |
+
:, :, recent_start:
|
| 152 |
+
]
|
| 153 |
+
self.camera_head_kv_cache_list[i][j][k] = torch.cat(
|
| 154 |
+
[anchor, recent], dim=2
|
| 155 |
+
).contiguous()
|
| 156 |
+
else:
|
| 157 |
+
start_idx = max(0, cache_size - self.window_size)
|
| 158 |
+
self.camera_head_kv_cache_list[i][j][
|
| 159 |
+
k
|
| 160 |
+
] = camera_head_kv_cache_list[i][j][k][
|
| 161 |
+
:, :, start_idx:
|
| 162 |
+
].contiguous()
|
| 163 |
+
return
|
| 164 |
+
|
| 165 |
+
raise ValueError(f"Unsupported attention mode: {self.mode}")
|
| 166 |
+
|
| 167 |
+
def _get_cache(self):
|
| 168 |
+
return self.aggregator_kv_cache_list, self.camera_head_kv_cache_list
|
| 169 |
+
|
| 170 |
+
def get_all_predictions(self):
|
| 171 |
+
predictions = {}
|
| 172 |
+
for key, chunks in self.sequence_predictions.items():
|
| 173 |
+
if not chunks:
|
| 174 |
+
continue
|
| 175 |
+
predictions[key] = (
|
| 176 |
+
chunks[0] if len(chunks) == 1 else torch.cat(chunks, dim=1)
|
| 177 |
+
)
|
| 178 |
+
predictions.update(self.scalar_predictions)
|
| 179 |
+
return predictions
|
| 180 |
+
|
| 181 |
+
def get_last_prediction(self):
|
| 182 |
+
last_predictions = {}
|
| 183 |
+
keys_to_extract = [
|
| 184 |
+
"pose_enc",
|
| 185 |
+
"rel_pose_enc",
|
| 186 |
+
"world_points",
|
| 187 |
+
"world_points_conf",
|
| 188 |
+
"depth",
|
| 189 |
+
"depth_conf",
|
| 190 |
+
"predicted_scale_factor",
|
| 191 |
+
]
|
| 192 |
+
for k in keys_to_extract:
|
| 193 |
+
if k in self.sequence_predictions and self.sequence_predictions[k]:
|
| 194 |
+
last_predictions[k] = self.sequence_predictions[k][-1][:, -1:]
|
| 195 |
+
elif k in self.scalar_predictions:
|
| 196 |
+
last_predictions[k] = self.scalar_predictions[k]
|
| 197 |
+
return last_predictions
|
| 198 |
+
|
| 199 |
+
def clear(self):
|
| 200 |
+
self._clear_predictions()
|
| 201 |
+
self._clear_cache()
|
| 202 |
+
if self.use_rel_pose_head:
|
| 203 |
+
if hasattr(self.core_model.rel_pose_head, "_keyframe_tokens_cache"):
|
| 204 |
+
self.core_model.rel_pose_head._keyframe_tokens_cache = {}
|
| 205 |
+
if hasattr(self.core_model.rel_pose_head, "_current_frame_id"):
|
| 206 |
+
self.core_model.rel_pose_head._current_frame_id = 0
|
| 207 |
+
if hasattr(self.core_model.rel_pose_head, "_frame_info"):
|
| 208 |
+
self.core_model.rel_pose_head._frame_info = []
|
| 209 |
+
|
| 210 |
+
def clear_cache_only(self):
|
| 211 |
+
self._clear_cache()
|
| 212 |
+
if self.use_rel_pose_head:
|
| 213 |
+
if hasattr(self.core_model.rel_pose_head, "_keyframe_tokens_cache"):
|
| 214 |
+
self.core_model.rel_pose_head._keyframe_tokens_cache = {}
|
| 215 |
+
if hasattr(self.core_model.rel_pose_head, "_current_frame_id"):
|
| 216 |
+
self.core_model.rel_pose_head._current_frame_id = 0
|
| 217 |
+
if hasattr(self.core_model.rel_pose_head, "_frame_info"):
|
| 218 |
+
self.core_model.rel_pose_head._frame_info = []
|
| 219 |
+
|
| 220 |
+
def forward_stream(
|
| 221 |
+
self, images, is_keyframe=None, keyframe_indices=None, record: bool = True
|
| 222 |
+
):
|
| 223 |
+
aggregator_kv_cache_list, camera_head_kv_cache_list = self._get_cache()
|
| 224 |
+
|
| 225 |
+
rel_pose_inputs = None
|
| 226 |
+
if (
|
| 227 |
+
self.use_rel_pose_head
|
| 228 |
+
and is_keyframe is not None
|
| 229 |
+
and keyframe_indices is not None
|
| 230 |
+
):
|
| 231 |
+
rel_pose_inputs = {
|
| 232 |
+
"is_keyframe": is_keyframe,
|
| 233 |
+
"keyframe_indices": keyframe_indices,
|
| 234 |
+
"kv_cache_list": self.rel_pose_kv_cache_list,
|
| 235 |
+
}
|
| 236 |
+
|
| 237 |
+
outputs = self.model(
|
| 238 |
+
images=images,
|
| 239 |
+
mode=self.mode,
|
| 240 |
+
aggregator_kv_cache_list=aggregator_kv_cache_list,
|
| 241 |
+
camera_head_kv_cache_list=camera_head_kv_cache_list,
|
| 242 |
+
rel_pose_inputs=rel_pose_inputs,
|
| 243 |
+
is_keyframe=is_keyframe,
|
| 244 |
+
)
|
| 245 |
+
|
| 246 |
+
if record:
|
| 247 |
+
self._update_predictions(outputs)
|
| 248 |
+
|
| 249 |
+
camera_head_kv_cache_list = outputs.get("camera_head_kv_cache_list", None)
|
| 250 |
+
depth_hw = (
|
| 251 |
+
outputs["depth"].shape[2:4] if "depth" in outputs else images.shape[-2:]
|
| 252 |
+
)
|
| 253 |
+
self._update_cache(
|
| 254 |
+
outputs["aggregator_kv_cache_list"], camera_head_kv_cache_list, depth_hw
|
| 255 |
+
)
|
| 256 |
+
|
| 257 |
+
if self.use_rel_pose_head and "rel_pose_kv_cache_list" in outputs:
|
| 258 |
+
rel_pose_kv_cache = outputs["rel_pose_kv_cache_list"]
|
| 259 |
+
if self.mode == "causal":
|
| 260 |
+
self.rel_pose_kv_cache_list = rel_pose_kv_cache
|
| 261 |
+
elif self.mode == "window":
|
| 262 |
+
for k in range(2):
|
| 263 |
+
for i in range(self.rel_pose_head_iterations):
|
| 264 |
+
for j in range(self.rel_pose_head_trunk_depth):
|
| 265 |
+
if rel_pose_kv_cache[i][j][k] is None:
|
| 266 |
+
continue
|
| 267 |
+
cache_len = rel_pose_kv_cache[i][j][k].size(2)
|
| 268 |
+
if self.keep_first_frame_anchor:
|
| 269 |
+
if cache_len <= 1:
|
| 270 |
+
self.rel_pose_kv_cache_list[i][j][
|
| 271 |
+
k
|
| 272 |
+
] = rel_pose_kv_cache[i][j][k].contiguous()
|
| 273 |
+
elif cache_len <= self.window_size:
|
| 274 |
+
self.rel_pose_kv_cache_list[i][j][
|
| 275 |
+
k
|
| 276 |
+
] = rel_pose_kv_cache[i][j][k].contiguous()
|
| 277 |
+
else:
|
| 278 |
+
anchor = rel_pose_kv_cache[i][j][k][:, :, :1]
|
| 279 |
+
recent_start = cache_len - (self.window_size - 1)
|
| 280 |
+
recent = rel_pose_kv_cache[i][j][k][
|
| 281 |
+
:, :, recent_start:
|
| 282 |
+
]
|
| 283 |
+
self.rel_pose_kv_cache_list[i][j][k] = torch.cat(
|
| 284 |
+
[anchor, recent], dim=2
|
| 285 |
+
).contiguous()
|
| 286 |
+
else:
|
| 287 |
+
start_idx = max(0, cache_len - self.window_size)
|
| 288 |
+
self.rel_pose_kv_cache_list[i][j][
|
| 289 |
+
k
|
| 290 |
+
] = rel_pose_kv_cache[i][j][k][
|
| 291 |
+
:, :, start_idx:
|
| 292 |
+
].contiguous()
|
| 293 |
+
|
| 294 |
+
return outputs
|
longstream/utils/__init__.py
ADDED
|
File without changes
|
longstream/utils/camera.py
ADDED
|
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from longstream.utils.vendor.models.components.utils.rotation import (
|
| 3 |
+
quat_to_mat,
|
| 4 |
+
mat_to_quat,
|
| 5 |
+
)
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def compose_abs_from_rel(
|
| 9 |
+
rel_pose_enc: torch.Tensor, keyframe_indices: torch.Tensor
|
| 10 |
+
) -> torch.Tensor:
|
| 11 |
+
squeeze_batch = False
|
| 12 |
+
if rel_pose_enc.ndim == 2:
|
| 13 |
+
rel_pose_enc = rel_pose_enc.unsqueeze(0)
|
| 14 |
+
squeeze_batch = True
|
| 15 |
+
if keyframe_indices.ndim == 1:
|
| 16 |
+
keyframe_indices = keyframe_indices.unsqueeze(0)
|
| 17 |
+
if rel_pose_enc.ndim != 3 or keyframe_indices.ndim != 2:
|
| 18 |
+
raise ValueError(
|
| 19 |
+
f"Expected rel_pose_enc [B,S,D] or [S,D] and keyframe_indices [B,S] or [S], "
|
| 20 |
+
f"got {tuple(rel_pose_enc.shape)} and {tuple(keyframe_indices.shape)}"
|
| 21 |
+
)
|
| 22 |
+
|
| 23 |
+
B, S, _ = rel_pose_enc.shape
|
| 24 |
+
device = rel_pose_enc.device
|
| 25 |
+
dtype = rel_pose_enc.dtype
|
| 26 |
+
|
| 27 |
+
rel_t = rel_pose_enc[..., :3]
|
| 28 |
+
rel_q = rel_pose_enc[..., 3:7]
|
| 29 |
+
rel_f = rel_pose_enc[..., 7:9]
|
| 30 |
+
rel_R = quat_to_mat(rel_q.reshape(-1, 4)).reshape(B, S, 3, 3)
|
| 31 |
+
|
| 32 |
+
abs_R = torch.zeros(B, S, 3, 3, device=device, dtype=dtype)
|
| 33 |
+
abs_t = torch.zeros(B, S, 3, device=device, dtype=dtype)
|
| 34 |
+
abs_f = torch.zeros(B, S, 2, device=device, dtype=dtype)
|
| 35 |
+
|
| 36 |
+
for b in range(B):
|
| 37 |
+
abs_R[b, 0] = rel_R[b, 0]
|
| 38 |
+
abs_t[b, 0] = rel_t[b, 0]
|
| 39 |
+
abs_f[b, 0] = rel_f[b, 0]
|
| 40 |
+
for s in range(1, S):
|
| 41 |
+
ref_idx = int(keyframe_indices[b, s].item())
|
| 42 |
+
abs_R[b, s] = rel_R[b, s] @ abs_R[b, ref_idx]
|
| 43 |
+
abs_t[b, s] = rel_t[b, s] + rel_R[b, s] @ abs_t[b, ref_idx]
|
| 44 |
+
abs_f[b, s] = rel_f[b, s]
|
| 45 |
+
|
| 46 |
+
abs_q = mat_to_quat(abs_R.reshape(-1, 3, 3)).reshape(B, S, 4)
|
| 47 |
+
abs_pose_enc = torch.cat([abs_t, abs_q, abs_f], dim=-1)
|
| 48 |
+
if squeeze_batch:
|
| 49 |
+
return abs_pose_enc[0]
|
| 50 |
+
return abs_pose_enc
|
longstream/utils/depth.py
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import torch
|
| 3 |
+
import matplotlib.cm as cm
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def colorize_depth(depth: torch.Tensor, cmap: str = "plasma") -> np.ndarray:
|
| 7 |
+
if torch.is_tensor(depth):
|
| 8 |
+
depth_np = depth.detach().cpu().numpy()
|
| 9 |
+
else:
|
| 10 |
+
depth_np = depth
|
| 11 |
+
d_min = np.nanmin(depth_np)
|
| 12 |
+
d_max = np.nanmax(depth_np)
|
| 13 |
+
if d_max - d_min < 1e-6:
|
| 14 |
+
d_max = d_min + 1e-6
|
| 15 |
+
norm = (depth_np - d_min) / (d_max - d_min)
|
| 16 |
+
norm = np.clip(norm, 0.0, 1.0)
|
| 17 |
+
mapper = cm.get_cmap(cmap)
|
| 18 |
+
colored = mapper(norm)[..., :3]
|
| 19 |
+
return (colored * 255.0).astype(np.uint8)
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def unproject_depth_to_points(depth: torch.Tensor, intri: torch.Tensor) -> torch.Tensor:
|
| 23 |
+
B, H, W = depth.shape
|
| 24 |
+
fx = intri[:, 0, 0].view(B, 1, 1)
|
| 25 |
+
fy = intri[:, 1, 1].view(B, 1, 1)
|
| 26 |
+
cx = intri[:, 0, 2].view(B, 1, 1)
|
| 27 |
+
cy = intri[:, 1, 2].view(B, 1, 1)
|
| 28 |
+
|
| 29 |
+
ys = torch.arange(H, device=depth.device).view(1, H, 1).float()
|
| 30 |
+
xs = torch.arange(W, device=depth.device).view(1, 1, W).float()
|
| 31 |
+
|
| 32 |
+
x = (xs - cx) * depth / fx
|
| 33 |
+
y = (ys - cy) * depth / fy
|
| 34 |
+
z = depth
|
| 35 |
+
pts = torch.stack([x, y, z], dim=-1)
|
| 36 |
+
return pts
|
longstream/utils/hub.py
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from dataclasses import dataclass
|
| 3 |
+
from typing import Optional
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
@dataclass
|
| 7 |
+
class HFSpec:
|
| 8 |
+
repo_id: str
|
| 9 |
+
filename: str
|
| 10 |
+
revision: Optional[str] = None
|
| 11 |
+
local_dir: str = "checkpoints"
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def _is_nonempty_str(x) -> bool:
|
| 15 |
+
return isinstance(x, str) and len(x) > 0
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def resolve_checkpoint_path(
|
| 19 |
+
checkpoint: Optional[str], hf: Optional[dict]
|
| 20 |
+
) -> Optional[str]:
|
| 21 |
+
if _is_nonempty_str(checkpoint):
|
| 22 |
+
return checkpoint
|
| 23 |
+
if not isinstance(hf, dict):
|
| 24 |
+
return None
|
| 25 |
+
|
| 26 |
+
repo_id = hf.get("repo_id")
|
| 27 |
+
filename = hf.get("filename")
|
| 28 |
+
revision = hf.get("revision", None)
|
| 29 |
+
local_dir = hf.get("local_dir", "checkpoints")
|
| 30 |
+
|
| 31 |
+
if not _is_nonempty_str(repo_id) or not _is_nonempty_str(filename):
|
| 32 |
+
return None
|
| 33 |
+
|
| 34 |
+
try:
|
| 35 |
+
from huggingface_hub import hf_hub_download
|
| 36 |
+
except Exception as e:
|
| 37 |
+
raise RuntimeError("huggingface_hub is required for auto-download") from e
|
| 38 |
+
|
| 39 |
+
os.makedirs(local_dir, exist_ok=True)
|
| 40 |
+
return hf_hub_download(
|
| 41 |
+
repo_id=repo_id, filename=filename, revision=revision, local_dir=local_dir
|
| 42 |
+
)
|
longstream/utils/sky_mask.py
ADDED
|
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import copy
|
| 3 |
+
import cv2
|
| 4 |
+
import numpy as np
|
| 5 |
+
import shutil
|
| 6 |
+
import urllib.request
|
| 7 |
+
|
| 8 |
+
try:
|
| 9 |
+
import onnxruntime
|
| 10 |
+
except Exception:
|
| 11 |
+
onnxruntime = None
|
| 12 |
+
|
| 13 |
+
SKYSEG_URL = "https://huggingface.co/JianyuanWang/skyseg/resolve/main/skyseg.onnx"
|
| 14 |
+
SKYSEG_THRESHOLD = 0.5
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def run_skyseg(session, input_size, image):
|
| 18 |
+
temp_image = copy.deepcopy(image)
|
| 19 |
+
resize_image = cv2.resize(temp_image, dsize=(input_size[0], input_size[1]))
|
| 20 |
+
x = cv2.cvtColor(resize_image, cv2.COLOR_BGR2RGB)
|
| 21 |
+
x = np.array(x, dtype=np.float32)
|
| 22 |
+
mean = [0.485, 0.456, 0.406]
|
| 23 |
+
std = [0.229, 0.224, 0.225]
|
| 24 |
+
x = (x / 255 - mean) / std
|
| 25 |
+
x = x.transpose(2, 0, 1)
|
| 26 |
+
x = x.reshape(-1, 3, input_size[0], input_size[1]).astype("float32")
|
| 27 |
+
input_name = session.get_inputs()[0].name
|
| 28 |
+
result_map = session.run(None, {input_name: x})[0]
|
| 29 |
+
return result_map[0, 0]
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def _normalize_skyseg_output(result_map):
|
| 33 |
+
result_map = np.asarray(result_map, dtype=np.float32)
|
| 34 |
+
if result_map.size == 0:
|
| 35 |
+
return result_map
|
| 36 |
+
finite = np.isfinite(result_map)
|
| 37 |
+
if not np.any(finite):
|
| 38 |
+
return np.zeros_like(result_map, dtype=np.float32)
|
| 39 |
+
result_map = np.nan_to_num(result_map, nan=0.0, posinf=1.0, neginf=0.0)
|
| 40 |
+
max_value = float(result_map.max())
|
| 41 |
+
min_value = float(result_map.min())
|
| 42 |
+
if min_value >= 0.0 and max_value > 1.5:
|
| 43 |
+
result_map = result_map / 255.0
|
| 44 |
+
return np.clip(result_map, 0.0, 1.0)
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def sky_mask_filename(image_path):
|
| 48 |
+
parent = os.path.basename(os.path.dirname(image_path))
|
| 49 |
+
name = os.path.basename(image_path)
|
| 50 |
+
if parent:
|
| 51 |
+
return f"{parent}__{name}"
|
| 52 |
+
return name
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def segment_sky(image_path, session, mask_filename=None):
|
| 56 |
+
image = cv2.imread(image_path)
|
| 57 |
+
if image is None:
|
| 58 |
+
return None
|
| 59 |
+
result_map = run_skyseg(session, [320, 320], image)
|
| 60 |
+
result_map_original = cv2.resize(result_map, (image.shape[1], image.shape[0]))
|
| 61 |
+
result_map_original = _normalize_skyseg_output(result_map_original)
|
| 62 |
+
output_mask = np.zeros(result_map_original.shape, dtype=np.uint8)
|
| 63 |
+
output_mask[result_map_original < SKYSEG_THRESHOLD] = 255
|
| 64 |
+
if mask_filename is not None:
|
| 65 |
+
os.makedirs(os.path.dirname(mask_filename), exist_ok=True)
|
| 66 |
+
cv2.imwrite(mask_filename, output_mask)
|
| 67 |
+
return output_mask
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def compute_sky_mask(image_paths, model_path: str, target_dir: str = None):
|
| 71 |
+
if onnxruntime is None:
|
| 72 |
+
return None
|
| 73 |
+
if not os.path.exists(model_path):
|
| 74 |
+
os.makedirs(os.path.dirname(os.path.abspath(model_path)), exist_ok=True)
|
| 75 |
+
try:
|
| 76 |
+
print(f"[longstream] downloading skyseg.onnx to {model_path}", flush=True)
|
| 77 |
+
with urllib.request.urlopen(SKYSEG_URL) as src, open(
|
| 78 |
+
model_path, "wb"
|
| 79 |
+
) as dst:
|
| 80 |
+
shutil.copyfileobj(src, dst)
|
| 81 |
+
except Exception as exc:
|
| 82 |
+
print(f"[longstream] failed to download skyseg.onnx: {exc}", flush=True)
|
| 83 |
+
return None
|
| 84 |
+
if not os.path.exists(model_path):
|
| 85 |
+
return None
|
| 86 |
+
session = onnxruntime.InferenceSession(model_path)
|
| 87 |
+
masks = []
|
| 88 |
+
for image_path in image_paths:
|
| 89 |
+
mask_filepath = None
|
| 90 |
+
if target_dir is not None:
|
| 91 |
+
name = sky_mask_filename(image_path)
|
| 92 |
+
mask_filepath = os.path.join(target_dir, name)
|
| 93 |
+
if os.path.exists(mask_filepath):
|
| 94 |
+
sky_mask = cv2.imread(mask_filepath, cv2.IMREAD_GRAYSCALE)
|
| 95 |
+
else:
|
| 96 |
+
sky_mask = segment_sky(image_path, session, mask_filepath)
|
| 97 |
+
else:
|
| 98 |
+
sky_mask = segment_sky(image_path, session, None)
|
| 99 |
+
masks.append(sky_mask)
|
| 100 |
+
return masks
|
longstream/utils/vendor/__init__.py
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
|
longstream/utils/vendor/croco/LICENSE
ADDED
|
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
CroCo, Copyright (c) 2022-present Naver Corporation, is licensed under the Creative Commons Attribution-NonCommercial-ShareAlike 4.0 license.
|
| 2 |
+
|
| 3 |
+
A summary of the CC BY-NC-SA 4.0 license is located here:
|
| 4 |
+
https://creativecommons.org/licenses/by-nc-sa/4.0/
|
| 5 |
+
|
| 6 |
+
The CC BY-NC-SA 4.0 license is located here:
|
| 7 |
+
https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
SEE NOTICE BELOW WITH RESPECT TO THE FILE: models/pos_embed.py, models/blocks.py
|
| 11 |
+
|
| 12 |
+
***************************
|
| 13 |
+
|
| 14 |
+
NOTICE WITH RESPECT TO THE FILE: models/pos_embed.py
|
| 15 |
+
|
| 16 |
+
This software is being redistributed in a modifiled form. The original form is available here:
|
| 17 |
+
|
| 18 |
+
https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py
|
| 19 |
+
|
| 20 |
+
This software in this file incorporates parts of the following software available here:
|
| 21 |
+
|
| 22 |
+
Transformer: https://github.com/tensorflow/models/blob/master/official/legacy/transformer/model_utils.py
|
| 23 |
+
available under the following license: https://github.com/tensorflow/models/blob/master/LICENSE
|
| 24 |
+
|
| 25 |
+
MoCo v3: https://github.com/facebookresearch/moco-v3
|
| 26 |
+
available under the following license: https://github.com/facebookresearch/moco-v3/blob/main/LICENSE
|
| 27 |
+
|
| 28 |
+
DeiT: https://github.com/facebookresearch/deit
|
| 29 |
+
available under the following license: https://github.com/facebookresearch/deit/blob/main/LICENSE
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
ORIGINAL COPYRIGHT NOTICE AND PERMISSION NOTICE AVAILABLE HERE IS REPRODUCE BELOW:
|
| 33 |
+
|
| 34 |
+
https://github.com/facebookresearch/mae/blob/main/LICENSE
|
| 35 |
+
|
| 36 |
+
Attribution-NonCommercial 4.0 International
|
| 37 |
+
|
| 38 |
+
***************************
|
| 39 |
+
|
| 40 |
+
NOTICE WITH RESPECT TO THE FILE: models/blocks.py
|
| 41 |
+
|
| 42 |
+
This software is being redistributed in a modifiled form. The original form is available here:
|
| 43 |
+
|
| 44 |
+
https://github.com/rwightman/pytorch-image-models
|
| 45 |
+
|
| 46 |
+
ORIGINAL COPYRIGHT NOTICE AND PERMISSION NOTICE AVAILABLE HERE IS REPRODUCE BELOW:
|
| 47 |
+
|
| 48 |
+
https://github.com/rwightman/pytorch-image-models/blob/master/LICENSE
|
| 49 |
+
|
| 50 |
+
Apache License
|
| 51 |
+
Version 2.0, January 2004
|
| 52 |
+
http://www.apache.org/licenses/
|
longstream/utils/vendor/croco/NOTICE
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
CroCo
|
| 2 |
+
Copyright 2022-present NAVER Corp.
|
| 3 |
+
|
| 4 |
+
This project contains subcomponents with separate copyright notices and license terms.
|
| 5 |
+
Your use of the source code for these subcomponents is subject to the terms and conditions of the following licenses.
|
| 6 |
+
|
| 7 |
+
====
|
| 8 |
+
|
| 9 |
+
facebookresearch/mae
|
| 10 |
+
https://github.com/facebookresearch/mae
|
| 11 |
+
|
| 12 |
+
Attribution-NonCommercial 4.0 International
|
| 13 |
+
|
| 14 |
+
====
|
| 15 |
+
|
| 16 |
+
rwightman/pytorch-image-models
|
| 17 |
+
https://github.com/rwightman/pytorch-image-models
|
| 18 |
+
|
| 19 |
+
Apache License
|
| 20 |
+
Version 2.0, January 2004
|
| 21 |
+
http://www.apache.org/licenses/
|
longstream/utils/vendor/croco/README.MD
ADDED
|
@@ -0,0 +1,124 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# CroCo + CroCo v2 / CroCo-Stereo / CroCo-Flow
|
| 2 |
+
|
| 3 |
+
[[`CroCo arXiv`](https://arxiv.org/abs/2210.10716)] [[`CroCo v2 arXiv`](https://arxiv.org/abs/2211.10408)] [[`project page and demo`](https://croco.europe.naverlabs.com/)]
|
| 4 |
+
|
| 5 |
+
This repository contains the code for our CroCo model presented in our NeurIPS'22 paper [CroCo: Self-Supervised Pre-training for 3D Vision Tasks by Cross-View Completion](https://openreview.net/pdf?id=wZEfHUM5ri) and its follow-up extension published at ICCV'23 [Improved Cross-view Completion Pre-training for Stereo Matching and Optical Flow](https://openaccess.thecvf.com/content/ICCV2023/html/Weinzaepfel_CroCo_v2_Improved_Cross-view_Completion_Pre-training_for_Stereo_Matching_and_ICCV_2023_paper.html), refered to as CroCo v2:
|
| 6 |
+
|
| 7 |
+

|
| 8 |
+
|
| 9 |
+
```bibtex
|
| 10 |
+
@inproceedings{croco,
|
| 11 |
+
title={{CroCo: Self-Supervised Pre-training for 3D Vision Tasks by Cross-View Completion}},
|
| 12 |
+
author={{Weinzaepfel, Philippe and Leroy, Vincent and Lucas, Thomas and Br\'egier, Romain and Cabon, Yohann and Arora, Vaibhav and Antsfeld, Leonid and Chidlovskii, Boris and Csurka, Gabriela and Revaud J\'er\^ome}},
|
| 13 |
+
booktitle={{NeurIPS}},
|
| 14 |
+
year={2022}
|
| 15 |
+
}
|
| 16 |
+
|
| 17 |
+
@inproceedings{croco_v2,
|
| 18 |
+
title={{CroCo v2: Improved Cross-view Completion Pre-training for Stereo Matching and Optical Flow}},
|
| 19 |
+
author={Weinzaepfel, Philippe and Lucas, Thomas and Leroy, Vincent and Cabon, Yohann and Arora, Vaibhav and Br{\'e}gier, Romain and Csurka, Gabriela and Antsfeld, Leonid and Chidlovskii, Boris and Revaud, J{\'e}r{\^o}me},
|
| 20 |
+
booktitle={ICCV},
|
| 21 |
+
year={2023}
|
| 22 |
+
}
|
| 23 |
+
```
|
| 24 |
+
|
| 25 |
+
## License
|
| 26 |
+
|
| 27 |
+
The code is distributed under the CC BY-NC-SA 4.0 License. See [LICENSE](LICENSE) for more information.
|
| 28 |
+
Some components are based on code from [MAE](https://github.com/facebookresearch/mae) released under the CC BY-NC-SA 4.0 License and [timm](https://github.com/rwightman/pytorch-image-models) released under the Apache 2.0 License.
|
| 29 |
+
Some components for stereo matching and optical flow are based on code from [unimatch](https://github.com/autonomousvision/unimatch) released under the MIT license.
|
| 30 |
+
|
| 31 |
+
## Preparation
|
| 32 |
+
|
| 33 |
+
1. Install dependencies on a machine with a NVidia GPU using e.g. conda. Note that `habitat-sim` is required only for the interactive demo and the synthetic pre-training data generation. If you don't plan to use it, you can ignore the line installing it and use a more recent python version.
|
| 34 |
+
|
| 35 |
+
```bash
|
| 36 |
+
conda create -n croco python=3.7 cmake=3.14.0
|
| 37 |
+
conda activate croco
|
| 38 |
+
conda install habitat-sim headless -c conda-forge -c aihabitat
|
| 39 |
+
conda install pytorch torchvision -c pytorch
|
| 40 |
+
conda install notebook ipykernel matplotlib
|
| 41 |
+
conda install ipywidgets widgetsnbextension
|
| 42 |
+
conda install scikit-learn tqdm quaternion opencv # only for pretraining / habitat data generation
|
| 43 |
+
|
| 44 |
+
```
|
| 45 |
+
|
| 46 |
+
2. Compile cuda kernels for RoPE
|
| 47 |
+
|
| 48 |
+
CroCo v2 relies on RoPE positional embeddings for which you need to compile some cuda kernels.
|
| 49 |
+
```bash
|
| 50 |
+
cd models/curope/
|
| 51 |
+
python setup.py build_ext --inplace
|
| 52 |
+
cd ../../
|
| 53 |
+
```
|
| 54 |
+
|
| 55 |
+
This can be a bit long as we compile for all cuda architectures, feel free to update L9 of `models/curope/setup.py` to compile for specific architectures only.
|
| 56 |
+
You might also need to set the environment `CUDA_HOME` in case you use a custom cuda installation.
|
| 57 |
+
|
| 58 |
+
In case you cannot provide, we also provide a slow pytorch version, which will be automatically loaded.
|
| 59 |
+
|
| 60 |
+
3. Download pre-trained model
|
| 61 |
+
|
| 62 |
+
We provide several pre-trained models:
|
| 63 |
+
|
| 64 |
+
| modelname | pre-training data | pos. embed. | Encoder | Decoder |
|
| 65 |
+
|------------------------------------------------------------------------------------------------------------------------------------|-------------------|-------------|---------|---------|
|
| 66 |
+
| [`CroCo.pth`](https://download.europe.naverlabs.com/ComputerVision/CroCo/CroCo.pth) | Habitat | cosine | ViT-B | Small |
|
| 67 |
+
| [`CroCo_V2_ViTBase_SmallDecoder.pth`](https://download.europe.naverlabs.com/ComputerVision/CroCo/CroCo_V2_ViTBase_SmallDecoder.pth) | Habitat + real | RoPE | ViT-B | Small |
|
| 68 |
+
| [`CroCo_V2_ViTBase_BaseDecoder.pth`](https://download.europe.naverlabs.com/ComputerVision/CroCo/CroCo_V2_ViTBase_BaseDecoder.pth) | Habitat + real | RoPE | ViT-B | Base |
|
| 69 |
+
| [`CroCo_V2_ViTLarge_BaseDecoder.pth`](https://download.europe.naverlabs.com/ComputerVision/CroCo/CroCo_V2_ViTLarge_BaseDecoder.pth) | Habitat + real | RoPE | ViT-L | Base |
|
| 70 |
+
|
| 71 |
+
To download a specific model, i.e., the first one (`CroCo.pth`)
|
| 72 |
+
```bash
|
| 73 |
+
mkdir -p pretrained_models/
|
| 74 |
+
wget https://download.europe.naverlabs.com/ComputerVision/CroCo/CroCo.pth -P pretrained_models/
|
| 75 |
+
```
|
| 76 |
+
|
| 77 |
+
## Reconstruction example
|
| 78 |
+
|
| 79 |
+
Simply run after downloading the `CroCo_V2_ViTLarge_BaseDecoder` pretrained model (or update the corresponding line in `demo.py`)
|
| 80 |
+
```bash
|
| 81 |
+
python demo.py
|
| 82 |
+
```
|
| 83 |
+
|
| 84 |
+
## Interactive demonstration of cross-view completion reconstruction on the Habitat simulator
|
| 85 |
+
|
| 86 |
+
First download the test scene from Habitat:
|
| 87 |
+
```bash
|
| 88 |
+
python -m habitat_sim.utils.datasets_download --uids habitat_test_scenes --data-path habitat-sim-data/
|
| 89 |
+
```
|
| 90 |
+
|
| 91 |
+
Then, run the Notebook demo `interactive_demo.ipynb`.
|
| 92 |
+
|
| 93 |
+
In this demo, you should be able to sample a random reference viewpoint from an [Habitat](https://github.com/facebookresearch/habitat-sim) test scene. Use the sliders to change viewpoint and select a masked target view to reconstruct using CroCo.
|
| 94 |
+

|
| 95 |
+
|
| 96 |
+
## Pre-training
|
| 97 |
+
|
| 98 |
+
### CroCo
|
| 99 |
+
|
| 100 |
+
To pre-train CroCo, please first generate the pre-training data from the Habitat simulator, following the instructions in [datasets/habitat_sim/README.MD](datasets/habitat_sim/README.MD) and then run the following command:
|
| 101 |
+
```
|
| 102 |
+
torchrun --nproc_per_node=4 pretrain.py --output_dir ./output/pretraining/
|
| 103 |
+
```
|
| 104 |
+
|
| 105 |
+
Our CroCo pre-training was launched on a single server with 4 GPUs.
|
| 106 |
+
It should take around 10 days with A100 or 15 days with V100 to do the 400 pre-training epochs, but decent performances are obtained earlier in training.
|
| 107 |
+
Note that, while the code contains the same scaling rule of the learning rate as MAE when changing the effective batch size, we did not experimented if it is valid in our case.
|
| 108 |
+
The first run can take a few minutes to start, to parse all available pre-training pairs.
|
| 109 |
+
|
| 110 |
+
### CroCo v2
|
| 111 |
+
|
| 112 |
+
For CroCo v2 pre-training, in addition to the generation of the pre-training data from the Habitat simulator above, please pre-extract the crops from the real datasets following the instructions in [datasets/crops/README.MD](datasets/crops/README.MD).
|
| 113 |
+
Then, run the following command for the largest model (ViT-L encoder, Base decoder):
|
| 114 |
+
```
|
| 115 |
+
torchrun --nproc_per_node=8 pretrain.py --model "CroCoNet(enc_embed_dim=1024, enc_depth=24, enc_num_heads=16, dec_embed_dim=768, dec_num_heads=12, dec_depth=12, pos_embed='RoPE100')" --dataset "habitat_release+ARKitScenes+MegaDepth+3DStreetView+IndoorVL" --warmup_epochs 12 --max_epoch 125 --epochs 250 --amp 0 --keep_freq 5 --output_dir ./output/pretraining_crocov2/
|
| 116 |
+
```
|
| 117 |
+
|
| 118 |
+
Our CroCo v2 pre-training was launched on a single server with 8 GPUs for the largest model, and on a single server with 4 GPUs for the smaller ones, keeping a batch size of 64 per gpu in all cases.
|
| 119 |
+
The largest model should take around 12 days on A100.
|
| 120 |
+
Note that, while the code contains the same scaling rule of the learning rate as MAE when changing the effective batch size, we did not experimented if it is valid in our case.
|
| 121 |
+
|
| 122 |
+
## Stereo matching and Optical flow downstream tasks
|
| 123 |
+
|
| 124 |
+
For CroCo-Stereo and CroCo-Flow, please refer to [stereoflow/README.MD](stereoflow/README.MD).
|
longstream/utils/vendor/croco/assets/arch.jpg
ADDED
|
longstream/utils/vendor/croco/croco-stereo-flow-demo.ipynb
ADDED
|
@@ -0,0 +1,182 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "markdown",
|
| 5 |
+
"id": "9bca0f41",
|
| 6 |
+
"metadata": {},
|
| 7 |
+
"source": []
|
| 8 |
+
},
|
| 9 |
+
{
|
| 10 |
+
"cell_type": "code",
|
| 11 |
+
"execution_count": null,
|
| 12 |
+
"id": "80653ef7",
|
| 13 |
+
"metadata": {},
|
| 14 |
+
"outputs": [],
|
| 15 |
+
"source": []
|
| 16 |
+
},
|
| 17 |
+
{
|
| 18 |
+
"cell_type": "markdown",
|
| 19 |
+
"id": "4f033862",
|
| 20 |
+
"metadata": {},
|
| 21 |
+
"source": [
|
| 22 |
+
"First download the model(s) of your choice by running\n",
|
| 23 |
+
"```\n",
|
| 24 |
+
"bash stereoflow/download_model.sh crocostereo.pth\n",
|
| 25 |
+
"bash stereoflow/download_model.sh crocoflow.pth\n",
|
| 26 |
+
"```"
|
| 27 |
+
]
|
| 28 |
+
},
|
| 29 |
+
{
|
| 30 |
+
"cell_type": "code",
|
| 31 |
+
"execution_count": null,
|
| 32 |
+
"id": "1fb2e392",
|
| 33 |
+
"metadata": {},
|
| 34 |
+
"outputs": [],
|
| 35 |
+
"source": [
|
| 36 |
+
"import torch\n",
|
| 37 |
+
"use_gpu = torch.cuda.is_available() and torch.cuda.device_count()>0\n",
|
| 38 |
+
"device = torch.device('cuda:0' if use_gpu else 'cpu')\n",
|
| 39 |
+
"import matplotlib.pylab as plt"
|
| 40 |
+
]
|
| 41 |
+
},
|
| 42 |
+
{
|
| 43 |
+
"cell_type": "code",
|
| 44 |
+
"execution_count": null,
|
| 45 |
+
"id": "e0e25d77",
|
| 46 |
+
"metadata": {},
|
| 47 |
+
"outputs": [],
|
| 48 |
+
"source": [
|
| 49 |
+
"from stereoflow.test import _load_model_and_criterion\n",
|
| 50 |
+
"from stereoflow.engine import tiled_pred\n",
|
| 51 |
+
"from stereoflow.datasets_stereo import img_to_tensor, vis_disparity\n",
|
| 52 |
+
"from stereoflow.datasets_flow import flowToColor\n",
|
| 53 |
+
"tile_overlap=0.7 # recommended value, higher value can be slightly better but slower"
|
| 54 |
+
]
|
| 55 |
+
},
|
| 56 |
+
{
|
| 57 |
+
"cell_type": "markdown",
|
| 58 |
+
"id": "86a921f5",
|
| 59 |
+
"metadata": {},
|
| 60 |
+
"source": []
|
| 61 |
+
},
|
| 62 |
+
{
|
| 63 |
+
"cell_type": "code",
|
| 64 |
+
"execution_count": null,
|
| 65 |
+
"id": "64e483cb",
|
| 66 |
+
"metadata": {},
|
| 67 |
+
"outputs": [],
|
| 68 |
+
"source": [
|
| 69 |
+
"image1 = np.asarray(Image.open('<path_to_left_image>'))\n",
|
| 70 |
+
"image2 = np.asarray(Image.open('<path_to_right_image>'))"
|
| 71 |
+
]
|
| 72 |
+
},
|
| 73 |
+
{
|
| 74 |
+
"cell_type": "code",
|
| 75 |
+
"execution_count": null,
|
| 76 |
+
"id": "f0d04303",
|
| 77 |
+
"metadata": {},
|
| 78 |
+
"outputs": [],
|
| 79 |
+
"source": [
|
| 80 |
+
"model, _, cropsize, with_conf, task, tile_conf_mode = _load_model_and_criterion('stereoflow_models/crocostereo.pth', None, device)\n"
|
| 81 |
+
]
|
| 82 |
+
},
|
| 83 |
+
{
|
| 84 |
+
"cell_type": "code",
|
| 85 |
+
"execution_count": null,
|
| 86 |
+
"id": "47dc14b5",
|
| 87 |
+
"metadata": {},
|
| 88 |
+
"outputs": [],
|
| 89 |
+
"source": [
|
| 90 |
+
"im1 = img_to_tensor(image1).to(device).unsqueeze(0)\n",
|
| 91 |
+
"im2 = img_to_tensor(image2).to(device).unsqueeze(0)\n",
|
| 92 |
+
"with torch.inference_mode():\n",
|
| 93 |
+
" pred, _, _ = tiled_pred(model, None, im1, im2, None, conf_mode=tile_conf_mode, overlap=tile_overlap, crop=cropsize, with_conf=with_conf, return_time=False)\n",
|
| 94 |
+
"pred = pred.squeeze(0).squeeze(0).cpu().numpy()"
|
| 95 |
+
]
|
| 96 |
+
},
|
| 97 |
+
{
|
| 98 |
+
"cell_type": "code",
|
| 99 |
+
"execution_count": null,
|
| 100 |
+
"id": "583b9f16",
|
| 101 |
+
"metadata": {},
|
| 102 |
+
"outputs": [],
|
| 103 |
+
"source": [
|
| 104 |
+
"plt.imshow(vis_disparity(pred))\n",
|
| 105 |
+
"plt.axis('off')"
|
| 106 |
+
]
|
| 107 |
+
},
|
| 108 |
+
{
|
| 109 |
+
"cell_type": "markdown",
|
| 110 |
+
"id": "d2df5d70",
|
| 111 |
+
"metadata": {},
|
| 112 |
+
"source": []
|
| 113 |
+
},
|
| 114 |
+
{
|
| 115 |
+
"cell_type": "code",
|
| 116 |
+
"execution_count": null,
|
| 117 |
+
"id": "9ee257a7",
|
| 118 |
+
"metadata": {},
|
| 119 |
+
"outputs": [],
|
| 120 |
+
"source": [
|
| 121 |
+
"image1 = np.asarray(Image.open('<path_to_first_image>'))\n",
|
| 122 |
+
"image2 = np.asarray(Image.open('<path_to_second_image>'))"
|
| 123 |
+
]
|
| 124 |
+
},
|
| 125 |
+
{
|
| 126 |
+
"cell_type": "code",
|
| 127 |
+
"execution_count": null,
|
| 128 |
+
"id": "d5edccf0",
|
| 129 |
+
"metadata": {},
|
| 130 |
+
"outputs": [],
|
| 131 |
+
"source": [
|
| 132 |
+
"model, _, cropsize, with_conf, task, tile_conf_mode = _load_model_and_criterion('stereoflow_models/crocoflow.pth', None, device)\n"
|
| 133 |
+
]
|
| 134 |
+
},
|
| 135 |
+
{
|
| 136 |
+
"cell_type": "code",
|
| 137 |
+
"execution_count": null,
|
| 138 |
+
"id": "b19692c3",
|
| 139 |
+
"metadata": {},
|
| 140 |
+
"outputs": [],
|
| 141 |
+
"source": [
|
| 142 |
+
"im1 = img_to_tensor(image1).to(device).unsqueeze(0)\n",
|
| 143 |
+
"im2 = img_to_tensor(image2).to(device).unsqueeze(0)\n",
|
| 144 |
+
"with torch.inference_mode():\n",
|
| 145 |
+
" pred, _, _ = tiled_pred(model, None, im1, im2, None, conf_mode=tile_conf_mode, overlap=tile_overlap, crop=cropsize, with_conf=with_conf, return_time=False)\n",
|
| 146 |
+
"pred = pred.squeeze(0).permute(1,2,0).cpu().numpy()"
|
| 147 |
+
]
|
| 148 |
+
},
|
| 149 |
+
{
|
| 150 |
+
"cell_type": "code",
|
| 151 |
+
"execution_count": null,
|
| 152 |
+
"id": "26f79db3",
|
| 153 |
+
"metadata": {},
|
| 154 |
+
"outputs": [],
|
| 155 |
+
"source": [
|
| 156 |
+
"plt.imshow(flowToColor(pred))\n",
|
| 157 |
+
"plt.axis('off')"
|
| 158 |
+
]
|
| 159 |
+
}
|
| 160 |
+
],
|
| 161 |
+
"metadata": {
|
| 162 |
+
"kernelspec": {
|
| 163 |
+
"display_name": "Python 3 (ipykernel)",
|
| 164 |
+
"language": "python",
|
| 165 |
+
"name": "python3"
|
| 166 |
+
},
|
| 167 |
+
"language_info": {
|
| 168 |
+
"codemirror_mode": {
|
| 169 |
+
"name": "ipython",
|
| 170 |
+
"version": 3
|
| 171 |
+
},
|
| 172 |
+
"file_extension": ".py",
|
| 173 |
+
"mimetype": "text/x-python",
|
| 174 |
+
"name": "python",
|
| 175 |
+
"nbconvert_exporter": "python",
|
| 176 |
+
"pygments_lexer": "ipython3",
|
| 177 |
+
"version": "3.9.7"
|
| 178 |
+
}
|
| 179 |
+
},
|
| 180 |
+
"nbformat": 4,
|
| 181 |
+
"nbformat_minor": 5
|
| 182 |
+
}
|
longstream/utils/vendor/croco/datasets/__init__.py
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
|
longstream/utils/vendor/croco/datasets/crops/README.MD
ADDED
|
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
## Generation of crops from the real datasets
|
| 2 |
+
|
| 3 |
+
The instructions below allow to generate the crops used for pre-training CroCo v2 from the following real-world datasets: ARKitScenes, MegaDepth, 3DStreetView and IndoorVL.
|
| 4 |
+
|
| 5 |
+
### Download the metadata of the crops to generate
|
| 6 |
+
|
| 7 |
+
First, download the metadata and put them in `./data/`:
|
| 8 |
+
```
|
| 9 |
+
mkdir -p data
|
| 10 |
+
cd data/
|
| 11 |
+
wget https://download.europe.naverlabs.com/ComputerVision/CroCo/data/crop_metadata.zip
|
| 12 |
+
unzip crop_metadata.zip
|
| 13 |
+
rm crop_metadata.zip
|
| 14 |
+
cd ..
|
| 15 |
+
```
|
| 16 |
+
|
| 17 |
+
### Prepare the original datasets
|
| 18 |
+
|
| 19 |
+
Second, download the original datasets in `./data/original_datasets/`.
|
| 20 |
+
```
|
| 21 |
+
mkdir -p data/original_datasets
|
| 22 |
+
```
|
| 23 |
+
|
| 24 |
+
##### ARKitScenes
|
| 25 |
+
|
| 26 |
+
Download the `raw` dataset from https://github.com/apple/ARKitScenes/blob/main/DATA.md and put it in `./data/original_datasets/ARKitScenes/`.
|
| 27 |
+
The resulting file structure should be like:
|
| 28 |
+
```
|
| 29 |
+
./data/original_datasets/ARKitScenes/
|
| 30 |
+
└───Training
|
| 31 |
+
└───40753679
|
| 32 |
+
│ │ ultrawide
|
| 33 |
+
│ │ ...
|
| 34 |
+
└───40753686
|
| 35 |
+
│
|
| 36 |
+
...
|
| 37 |
+
```
|
| 38 |
+
|
| 39 |
+
##### MegaDepth
|
| 40 |
+
|
| 41 |
+
Download `MegaDepth v1 Dataset` from https://www.cs.cornell.edu/projects/megadepth/ and put it in `./data/original_datasets/MegaDepth/`.
|
| 42 |
+
The resulting file structure should be like:
|
| 43 |
+
|
| 44 |
+
```
|
| 45 |
+
./data/original_datasets/MegaDepth/
|
| 46 |
+
└───0000
|
| 47 |
+
│ └───images
|
| 48 |
+
│ │ │ 1000557903_87fa96b8a4_o.jpg
|
| 49 |
+
│ │ └ ...
|
| 50 |
+
│ └─── ...
|
| 51 |
+
└───0001
|
| 52 |
+
│ │
|
| 53 |
+
│ └ ...
|
| 54 |
+
└─── ...
|
| 55 |
+
```
|
| 56 |
+
|
| 57 |
+
##### 3DStreetView
|
| 58 |
+
|
| 59 |
+
Download `3D_Street_View` dataset from https://github.com/amir32002/3D_Street_View and put it in `./data/original_datasets/3DStreetView/`.
|
| 60 |
+
The resulting file structure should be like:
|
| 61 |
+
|
| 62 |
+
```
|
| 63 |
+
./data/original_datasets/3DStreetView/
|
| 64 |
+
└───dataset_aligned
|
| 65 |
+
│ └───0002
|
| 66 |
+
│ │ │ 0000002_0000001_0000002_0000001.jpg
|
| 67 |
+
│ │ └ ...
|
| 68 |
+
│ └─── ...
|
| 69 |
+
└───dataset_unaligned
|
| 70 |
+
│ └───0003
|
| 71 |
+
│ │ │ 0000003_0000001_0000002_0000001.jpg
|
| 72 |
+
│ │ └ ...
|
| 73 |
+
│ └─── ...
|
| 74 |
+
```
|
| 75 |
+
|
| 76 |
+
##### IndoorVL
|
| 77 |
+
|
| 78 |
+
Download the `IndoorVL` datasets using [Kapture](https://github.com/naver/kapture).
|
| 79 |
+
|
| 80 |
+
```
|
| 81 |
+
pip install kapture
|
| 82 |
+
mkdir -p ./data/original_datasets/IndoorVL
|
| 83 |
+
cd ./data/original_datasets/IndoorVL
|
| 84 |
+
kapture_download_dataset.py update
|
| 85 |
+
kapture_download_dataset.py install "HyundaiDepartmentStore_*"
|
| 86 |
+
kapture_download_dataset.py install "GangnamStation_*"
|
| 87 |
+
cd -
|
| 88 |
+
```
|
| 89 |
+
|
| 90 |
+
### Extract the crops
|
| 91 |
+
|
| 92 |
+
Now, extract the crops for each of the dataset:
|
| 93 |
+
```
|
| 94 |
+
for dataset in ARKitScenes MegaDepth 3DStreetView IndoorVL;
|
| 95 |
+
do
|
| 96 |
+
python3 datasets/crops/extract_crops_from_images.py --crops ./data/crop_metadata/${dataset}/crops_release.txt --root-dir ./data/original_datasets/${dataset}/ --output-dir ./data/${dataset}_crops/ --imsize 256 --nthread 8 --max-subdir-levels 5 --ideal-number-pairs-in-dir 500;
|
| 97 |
+
done
|
| 98 |
+
```
|
| 99 |
+
|
| 100 |
+
##### Note for IndoorVL
|
| 101 |
+
|
| 102 |
+
Due to some legal issues, we can only release 144,228 pairs out of the 1,593,689 pairs used in the paper.
|
| 103 |
+
To account for it in terms of number of pre-training iterations, the pre-training command in this repository uses 125 training epochs including 12 warm-up epochs and learning rate cosine schedule of 250, instead of 100, 10 and 200 respectively.
|
| 104 |
+
The impact on the performance is negligible.
|
longstream/utils/vendor/croco/datasets/crops/extract_crops_from_images.py
ADDED
|
@@ -0,0 +1,175 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import functools
|
| 3 |
+
import math
|
| 4 |
+
import os
|
| 5 |
+
from multiprocessing import Pool
|
| 6 |
+
|
| 7 |
+
from PIL import Image
|
| 8 |
+
from tqdm import tqdm
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def arg_parser():
|
| 12 |
+
parser = argparse.ArgumentParser(
|
| 13 |
+
"Generate cropped image pairs from image crop list"
|
| 14 |
+
)
|
| 15 |
+
|
| 16 |
+
parser.add_argument("--crops", type=str, required=True, help="crop file")
|
| 17 |
+
parser.add_argument("--root-dir", type=str, required=True, help="root directory")
|
| 18 |
+
parser.add_argument(
|
| 19 |
+
"--output-dir", type=str, required=True, help="output directory"
|
| 20 |
+
)
|
| 21 |
+
parser.add_argument("--imsize", type=int, default=256, help="size of the crops")
|
| 22 |
+
parser.add_argument(
|
| 23 |
+
"--nthread", type=int, required=True, help="number of simultaneous threads"
|
| 24 |
+
)
|
| 25 |
+
parser.add_argument(
|
| 26 |
+
"--max-subdir-levels",
|
| 27 |
+
type=int,
|
| 28 |
+
default=5,
|
| 29 |
+
help="maximum number of subdirectories",
|
| 30 |
+
)
|
| 31 |
+
parser.add_argument(
|
| 32 |
+
"--ideal-number-pairs-in-dir",
|
| 33 |
+
type=int,
|
| 34 |
+
default=500,
|
| 35 |
+
help="number of pairs stored in a dir",
|
| 36 |
+
)
|
| 37 |
+
return parser
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def main(args):
|
| 41 |
+
listing_path = os.path.join(args.output_dir, "listing.txt")
|
| 42 |
+
|
| 43 |
+
print(f"Loading list of crops ... ({args.nthread} threads)")
|
| 44 |
+
crops, num_crops_to_generate = load_crop_file(args.crops)
|
| 45 |
+
|
| 46 |
+
print(f"Preparing jobs ({len(crops)} candidate image pairs)...")
|
| 47 |
+
num_levels = min(
|
| 48 |
+
math.ceil(math.log(num_crops_to_generate, args.ideal_number_pairs_in_dir)),
|
| 49 |
+
args.max_subdir_levels,
|
| 50 |
+
)
|
| 51 |
+
num_pairs_in_dir = math.ceil(num_crops_to_generate ** (1 / num_levels))
|
| 52 |
+
|
| 53 |
+
jobs = prepare_jobs(crops, num_levels, num_pairs_in_dir)
|
| 54 |
+
del crops
|
| 55 |
+
|
| 56 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
| 57 |
+
mmap = Pool(args.nthread).imap_unordered if args.nthread > 1 else map
|
| 58 |
+
call = functools.partial(save_image_crops, args)
|
| 59 |
+
|
| 60 |
+
print(f"Generating cropped images to {args.output_dir} ...")
|
| 61 |
+
with open(listing_path, "w") as listing:
|
| 62 |
+
listing.write("# pair_path\n")
|
| 63 |
+
for results in tqdm(mmap(call, jobs), total=len(jobs)):
|
| 64 |
+
for path in results:
|
| 65 |
+
listing.write(f"{path}\n")
|
| 66 |
+
print("Finished writing listing to", listing_path)
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def load_crop_file(path):
|
| 70 |
+
data = open(path).read().splitlines()
|
| 71 |
+
pairs = []
|
| 72 |
+
num_crops_to_generate = 0
|
| 73 |
+
for line in tqdm(data):
|
| 74 |
+
if line.startswith("#"):
|
| 75 |
+
continue
|
| 76 |
+
line = line.split(", ")
|
| 77 |
+
if len(line) < 8:
|
| 78 |
+
img1, img2, rotation = line
|
| 79 |
+
pairs.append((img1, img2, int(rotation), []))
|
| 80 |
+
else:
|
| 81 |
+
l1, r1, t1, b1, l2, r2, t2, b2 = map(int, line)
|
| 82 |
+
rect1, rect2 = (l1, t1, r1, b1), (l2, t2, r2, b2)
|
| 83 |
+
pairs[-1][-1].append((rect1, rect2))
|
| 84 |
+
num_crops_to_generate += 1
|
| 85 |
+
return pairs, num_crops_to_generate
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
def prepare_jobs(pairs, num_levels, num_pairs_in_dir):
|
| 89 |
+
jobs = []
|
| 90 |
+
powers = [num_pairs_in_dir ** level for level in reversed(range(num_levels))]
|
| 91 |
+
|
| 92 |
+
def get_path(idx):
|
| 93 |
+
idx_array = []
|
| 94 |
+
d = idx
|
| 95 |
+
for level in range(num_levels - 1):
|
| 96 |
+
idx_array.append(idx // powers[level])
|
| 97 |
+
idx = idx % powers[level]
|
| 98 |
+
idx_array.append(d)
|
| 99 |
+
return "/".join(map(lambda x: hex(x)[2:], idx_array))
|
| 100 |
+
|
| 101 |
+
idx = 0
|
| 102 |
+
for pair_data in tqdm(pairs):
|
| 103 |
+
img1, img2, rotation, crops = pair_data
|
| 104 |
+
if -60 <= rotation and rotation <= 60:
|
| 105 |
+
rotation = 0
|
| 106 |
+
paths = [get_path(idx + k) for k in range(len(crops))]
|
| 107 |
+
idx += len(crops)
|
| 108 |
+
jobs.append(((img1, img2), rotation, crops, paths))
|
| 109 |
+
return jobs
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
def load_image(path):
|
| 113 |
+
try:
|
| 114 |
+
return Image.open(path).convert("RGB")
|
| 115 |
+
except Exception as e:
|
| 116 |
+
print("skipping", path, e)
|
| 117 |
+
raise OSError()
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
def save_image_crops(args, data):
|
| 121 |
+
|
| 122 |
+
img_pair, rot, crops, paths = data
|
| 123 |
+
try:
|
| 124 |
+
img1, img2 = [
|
| 125 |
+
load_image(os.path.join(args.root_dir, impath)) for impath in img_pair
|
| 126 |
+
]
|
| 127 |
+
except OSError as e:
|
| 128 |
+
return []
|
| 129 |
+
|
| 130 |
+
def area(sz):
|
| 131 |
+
return sz[0] * sz[1]
|
| 132 |
+
|
| 133 |
+
tgt_size = (args.imsize, args.imsize)
|
| 134 |
+
|
| 135 |
+
def prepare_crop(img, rect, rot=0):
|
| 136 |
+
|
| 137 |
+
img = img.crop(rect)
|
| 138 |
+
|
| 139 |
+
interp = (
|
| 140 |
+
Image.Resampling.LANCZOS
|
| 141 |
+
if area(img.size) > 4 * area(tgt_size)
|
| 142 |
+
else Image.Resampling.BICUBIC
|
| 143 |
+
)
|
| 144 |
+
img = img.resize(tgt_size, resample=interp)
|
| 145 |
+
|
| 146 |
+
rot90 = (round(rot / 90) % 4) * 90
|
| 147 |
+
if rot90 == 90:
|
| 148 |
+
img = img.transpose(Image.Transpose.ROTATE_90)
|
| 149 |
+
elif rot90 == 180:
|
| 150 |
+
img = img.transpose(Image.Transpose.ROTATE_180)
|
| 151 |
+
elif rot90 == 270:
|
| 152 |
+
img = img.transpose(Image.Transpose.ROTATE_270)
|
| 153 |
+
return img
|
| 154 |
+
|
| 155 |
+
results = []
|
| 156 |
+
for (rect1, rect2), path in zip(crops, paths):
|
| 157 |
+
crop1 = prepare_crop(img1, rect1)
|
| 158 |
+
crop2 = prepare_crop(img2, rect2, rot)
|
| 159 |
+
|
| 160 |
+
fullpath1 = os.path.join(args.output_dir, path + "_1.jpg")
|
| 161 |
+
fullpath2 = os.path.join(args.output_dir, path + "_2.jpg")
|
| 162 |
+
os.makedirs(os.path.dirname(fullpath1), exist_ok=True)
|
| 163 |
+
|
| 164 |
+
assert not os.path.isfile(fullpath1), fullpath1
|
| 165 |
+
assert not os.path.isfile(fullpath2), fullpath2
|
| 166 |
+
crop1.save(fullpath1)
|
| 167 |
+
crop2.save(fullpath2)
|
| 168 |
+
results.append(path)
|
| 169 |
+
|
| 170 |
+
return results
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
if __name__ == "__main__":
|
| 174 |
+
args = arg_parser().parse_args()
|
| 175 |
+
main(args)
|
longstream/utils/vendor/croco/datasets/habitat_sim/README.MD
ADDED
|
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
## Generation of synthetic image pairs using Habitat-Sim
|
| 2 |
+
|
| 3 |
+
These instructions allow to generate pre-training pairs from the Habitat simulator.
|
| 4 |
+
As we did not save metadata of the pairs used in the original paper, they are not strictly the same, but these data use the same setting and are equivalent.
|
| 5 |
+
|
| 6 |
+
### Download Habitat-Sim scenes
|
| 7 |
+
Download Habitat-Sim scenes:
|
| 8 |
+
- Download links can be found here: https://github.com/facebookresearch/habitat-sim/blob/main/DATASETS.md
|
| 9 |
+
- We used scenes from the HM3D, habitat-test-scenes, Replica, ReplicaCad and ScanNet datasets.
|
| 10 |
+
- Please put the scenes under `./data/habitat-sim-data/scene_datasets/` following the structure below, or update manually paths in `paths.py`.
|
| 11 |
+
```
|
| 12 |
+
./data/
|
| 13 |
+
└──habitat-sim-data/
|
| 14 |
+
└──scene_datasets/
|
| 15 |
+
├──hm3d/
|
| 16 |
+
├──gibson/
|
| 17 |
+
├──habitat-test-scenes/
|
| 18 |
+
├──replica_cad_baked_lighting/
|
| 19 |
+
├──replica_cad/
|
| 20 |
+
├──ReplicaDataset/
|
| 21 |
+
└──scannet/
|
| 22 |
+
```
|
| 23 |
+
|
| 24 |
+
### Image pairs generation
|
| 25 |
+
We provide metadata to generate reproducible images pairs for pretraining and validation.
|
| 26 |
+
Experiments described in the paper used similar data, but whose generation was not reproducible at the time.
|
| 27 |
+
|
| 28 |
+
Specifications:
|
| 29 |
+
- 256x256 resolution images, with 60 degrees field of view .
|
| 30 |
+
- Up to 1000 image pairs per scene.
|
| 31 |
+
- Number of scenes considered/number of images pairs per dataset:
|
| 32 |
+
- Scannet: 1097 scenes / 985 209 pairs
|
| 33 |
+
- HM3D:
|
| 34 |
+
- hm3d/train: 800 / 800k pairs
|
| 35 |
+
- hm3d/val: 100 scenes / 100k pairs
|
| 36 |
+
- hm3d/minival: 10 scenes / 10k pairs
|
| 37 |
+
- habitat-test-scenes: 3 scenes / 3k pairs
|
| 38 |
+
- replica_cad_baked_lighting: 13 scenes / 13k pairs
|
| 39 |
+
|
| 40 |
+
- Scenes from hm3d/val and hm3d/minival pairs were not used for the pre-training but kept for validation purposes.
|
| 41 |
+
|
| 42 |
+
Download metadata and extract it:
|
| 43 |
+
```bash
|
| 44 |
+
mkdir -p data/habitat_release_metadata/
|
| 45 |
+
cd data/habitat_release_metadata/
|
| 46 |
+
wget https://download.europe.naverlabs.com/ComputerVision/CroCo/data/habitat_release_metadata/multiview_habitat_metadata.tar.gz
|
| 47 |
+
tar -xvf multiview_habitat_metadata.tar.gz
|
| 48 |
+
cd ../..
|
| 49 |
+
# Location of the metadata
|
| 50 |
+
METADATA_DIR="./data/habitat_release_metadata/multiview_habitat_metadata"
|
| 51 |
+
```
|
| 52 |
+
|
| 53 |
+
Generate image pairs from metadata:
|
| 54 |
+
- The following command will print a list of commandlines to generate image pairs for each scene:
|
| 55 |
+
```bash
|
| 56 |
+
# Target output directory
|
| 57 |
+
PAIRS_DATASET_DIR="./data/habitat_release/"
|
| 58 |
+
python datasets/habitat_sim/generate_from_metadata_files.py --input_dir=$METADATA_DIR --output_dir=$PAIRS_DATASET_DIR
|
| 59 |
+
```
|
| 60 |
+
- One can launch multiple of such commands in parallel e.g. using GNU Parallel:
|
| 61 |
+
```bash
|
| 62 |
+
python datasets/habitat_sim/generate_from_metadata_files.py --input_dir=$METADATA_DIR --output_dir=$PAIRS_DATASET_DIR | parallel -j 16
|
| 63 |
+
```
|
| 64 |
+
|
| 65 |
+
## Metadata generation
|
| 66 |
+
|
| 67 |
+
Image pairs were randomly sampled using the following commands, whose outputs contain randomness and are thus not exactly reproducible:
|
| 68 |
+
```bash
|
| 69 |
+
# Print commandlines to generate image pairs from the different scenes available.
|
| 70 |
+
PAIRS_DATASET_DIR=MY_CUSTOM_PATH
|
| 71 |
+
python datasets/habitat_sim/generate_multiview_images.py --list_commands --output_dir=$PAIRS_DATASET_DIR
|
| 72 |
+
|
| 73 |
+
# Once a dataset is generated, pack metadata files for reproducibility.
|
| 74 |
+
METADATA_DIR=MY_CUSTON_PATH
|
| 75 |
+
python datasets/habitat_sim/pack_metadata_files.py $PAIRS_DATASET_DIR $METADATA_DIR
|
| 76 |
+
```
|
longstream/utils/vendor/croco/datasets/habitat_sim/__init__.py
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
|
longstream/utils/vendor/croco/datasets/habitat_sim/generate_from_metadata.py
ADDED
|
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Script to generate image pairs for a given scene reproducing poses provided in a metadata file.
|
| 3 |
+
"""
|
| 4 |
+
import argparse
|
| 5 |
+
import json
|
| 6 |
+
import os
|
| 7 |
+
|
| 8 |
+
import cv2
|
| 9 |
+
import PIL.Image
|
| 10 |
+
import quaternion
|
| 11 |
+
from datasets.habitat_sim.multiview_habitat_sim_generator import (
|
| 12 |
+
MultiviewHabitatSimGenerator,
|
| 13 |
+
)
|
| 14 |
+
from datasets.habitat_sim.paths import SCENES_DATASET
|
| 15 |
+
from tqdm import tqdm
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def generate_multiview_images_from_metadata(
|
| 19 |
+
metadata_filename,
|
| 20 |
+
output_dir,
|
| 21 |
+
overload_params=dict(),
|
| 22 |
+
scene_datasets_paths=None,
|
| 23 |
+
exist_ok=False,
|
| 24 |
+
):
|
| 25 |
+
"""
|
| 26 |
+
Generate images from a metadata file for reproducibility purposes.
|
| 27 |
+
"""
|
| 28 |
+
|
| 29 |
+
if scene_datasets_paths is not None:
|
| 30 |
+
scene_datasets_paths = dict(
|
| 31 |
+
sorted(scene_datasets_paths.items(), key=lambda x: len(x[0]), reverse=True)
|
| 32 |
+
)
|
| 33 |
+
|
| 34 |
+
with open(metadata_filename, "r") as f:
|
| 35 |
+
input_metadata = json.load(f)
|
| 36 |
+
metadata = dict()
|
| 37 |
+
for key, value in input_metadata.items():
|
| 38 |
+
|
| 39 |
+
if key in ("scene_dataset_config_file", "scene", "navmesh") and value != "":
|
| 40 |
+
if scene_datasets_paths is not None:
|
| 41 |
+
for dataset_label, dataset_path in scene_datasets_paths.items():
|
| 42 |
+
if value.startswith(dataset_label):
|
| 43 |
+
value = os.path.normpath(
|
| 44 |
+
os.path.join(
|
| 45 |
+
dataset_path, os.path.relpath(value, dataset_label)
|
| 46 |
+
)
|
| 47 |
+
)
|
| 48 |
+
break
|
| 49 |
+
metadata[key] = value
|
| 50 |
+
|
| 51 |
+
for key, value in overload_params.items():
|
| 52 |
+
metadata[key] = value
|
| 53 |
+
|
| 54 |
+
generation_entries = dict(
|
| 55 |
+
[
|
| 56 |
+
(key, value)
|
| 57 |
+
for key, value in metadata.items()
|
| 58 |
+
if not (key in ("multiviews", "output_dir", "generate_depth"))
|
| 59 |
+
]
|
| 60 |
+
)
|
| 61 |
+
generate_depth = metadata["generate_depth"]
|
| 62 |
+
|
| 63 |
+
os.makedirs(output_dir, exist_ok=exist_ok)
|
| 64 |
+
|
| 65 |
+
generator = MultiviewHabitatSimGenerator(**generation_entries)
|
| 66 |
+
|
| 67 |
+
for idx_label, data in tqdm(metadata["multiviews"].items()):
|
| 68 |
+
positions = data["positions"]
|
| 69 |
+
orientations = data["orientations"]
|
| 70 |
+
n = len(positions)
|
| 71 |
+
for oidx in range(n):
|
| 72 |
+
observation = generator.render_viewpoint(
|
| 73 |
+
positions[oidx], quaternion.from_float_array(orientations[oidx])
|
| 74 |
+
)
|
| 75 |
+
observation_label = f"{oidx + 1}"
|
| 76 |
+
|
| 77 |
+
img = PIL.Image.fromarray(observation["color"][:, :, :3])
|
| 78 |
+
filename = os.path.join(output_dir, f"{idx_label}_{observation_label}.jpeg")
|
| 79 |
+
img.save(filename)
|
| 80 |
+
if generate_depth:
|
| 81 |
+
|
| 82 |
+
filename = os.path.join(
|
| 83 |
+
output_dir, f"{idx_label}_{observation_label}_depth.exr"
|
| 84 |
+
)
|
| 85 |
+
cv2.imwrite(
|
| 86 |
+
filename,
|
| 87 |
+
observation["depth"],
|
| 88 |
+
[cv2.IMWRITE_EXR_TYPE, cv2.IMWRITE_EXR_TYPE_HALF],
|
| 89 |
+
)
|
| 90 |
+
|
| 91 |
+
camera_params = dict(
|
| 92 |
+
[
|
| 93 |
+
(key, observation[key].tolist())
|
| 94 |
+
for key in ("camera_intrinsics", "R_cam2world", "t_cam2world")
|
| 95 |
+
]
|
| 96 |
+
)
|
| 97 |
+
filename = os.path.join(
|
| 98 |
+
output_dir, f"{idx_label}_{observation_label}_camera_params.json"
|
| 99 |
+
)
|
| 100 |
+
with open(filename, "w") as f:
|
| 101 |
+
json.dump(camera_params, f)
|
| 102 |
+
|
| 103 |
+
with open(os.path.join(output_dir, "metadata.json"), "w") as f:
|
| 104 |
+
json.dump(metadata, f)
|
| 105 |
+
|
| 106 |
+
generator.close()
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
if __name__ == "__main__":
|
| 110 |
+
parser = argparse.ArgumentParser()
|
| 111 |
+
parser.add_argument("--metadata_filename", required=True)
|
| 112 |
+
parser.add_argument("--output_dir", required=True)
|
| 113 |
+
args = parser.parse_args()
|
| 114 |
+
|
| 115 |
+
generate_multiview_images_from_metadata(
|
| 116 |
+
metadata_filename=args.metadata_filename,
|
| 117 |
+
output_dir=args.output_dir,
|
| 118 |
+
scene_datasets_paths=SCENES_DATASET,
|
| 119 |
+
overload_params=dict(),
|
| 120 |
+
exist_ok=True,
|
| 121 |
+
)
|
longstream/utils/vendor/croco/datasets/habitat_sim/generate_from_metadata_files.py
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Script generating commandlines to generate image pairs from metadata files.
|
| 3 |
+
"""
|
| 4 |
+
import argparse
|
| 5 |
+
import glob
|
| 6 |
+
import os
|
| 7 |
+
|
| 8 |
+
from tqdm import tqdm
|
| 9 |
+
|
| 10 |
+
if __name__ == "__main__":
|
| 11 |
+
parser = argparse.ArgumentParser()
|
| 12 |
+
parser.add_argument("--input_dir", required=True)
|
| 13 |
+
parser.add_argument("--output_dir", required=True)
|
| 14 |
+
parser.add_argument(
|
| 15 |
+
"--prefix",
|
| 16 |
+
default="",
|
| 17 |
+
help="Commanline prefix, useful e.g. to setup environment.",
|
| 18 |
+
)
|
| 19 |
+
args = parser.parse_args()
|
| 20 |
+
|
| 21 |
+
input_metadata_filenames = glob.iglob(
|
| 22 |
+
f"{args.input_dir}/**/metadata.json", recursive=True
|
| 23 |
+
)
|
| 24 |
+
|
| 25 |
+
for metadata_filename in tqdm(input_metadata_filenames):
|
| 26 |
+
output_dir = os.path.join(
|
| 27 |
+
args.output_dir,
|
| 28 |
+
os.path.relpath(os.path.dirname(metadata_filename), args.input_dir),
|
| 29 |
+
)
|
| 30 |
+
|
| 31 |
+
if os.path.exists(os.path.join(output_dir, "metadata.json")):
|
| 32 |
+
continue
|
| 33 |
+
commandline = f"{args.prefix}python datasets/habitat_sim/generate_from_metadata.py --metadata_filename={metadata_filename} --output_dir={output_dir}"
|
| 34 |
+
print(commandline)
|