diff --git a/README.md b/README.md index 142ae4315ea290ece56856c562f492163cadfc9d..9d8f81f4b591a04b5ccafc8c1702ae8547ec30f4 100644 --- a/README.md +++ b/README.md @@ -1,14 +1,40 @@ --- -title: LongStream -emoji: 📊 -colorFrom: red -colorTo: purple +title: LongStream Demo sdk: gradio -sdk_version: 6.9.0 +sdk_version: 5.44.0 app_file: app.py -pinned: false -license: mit -short_description: Demo of LongStream +python_version: "3.10" +startup_duration_timeout: 1h --- -Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference +# LongStream Demo + +This repository is the Hugging Face Space package for LongStream. + +Project page: `https://3dagentworld.github.io/longstream/` + +## Space Settings + +Set these variables in the Space settings before the first run: + +- `LONGSTREAM_HF_REPO=NicolasCC/LongStream` +- `LONGSTREAM_HF_FILE=50_longstream.pt` +- `LONGSTREAM_HF_LOCAL_DIR=checkpoints` + +Optional: + +- `LONGSTREAM_HF_REVISION=v0.1.0` +- `HF_TOKEN=` if the model repo is private + +## Entrypoints + +- `app.py`: stable demo + + +## Included Files + +- `demo_gradio.py` +- `demo_gradio_interactive.py` +- `longstream/` +- `configs/longstream_infer.yaml` + diff --git a/app.py b/app.py new file mode 100644 index 0000000000000000000000000000000000000000..134a2c2fb5c46a5d9d2e942d1910cb91267cbf88 --- /dev/null +++ b/app.py @@ -0,0 +1,5 @@ +from demo_gradio import main + + +if __name__ == "__main__": + main() diff --git a/configs/longstream_infer.yaml b/configs/longstream_infer.yaml new file mode 100644 index 0000000000000000000000000000000000000000..186790b8db023ea2218e2f1814497cccff90dc6f --- /dev/null +++ b/configs/longstream_infer.yaml @@ -0,0 +1,84 @@ +device: cuda + +model: + checkpoint: checkpoints/50_longstream.pt + strict_load: false + hf: + repo_id: null + filename: null + revision: null + local_dir: checkpoints + longstream_cfg: + img_size: 518 + patch_size: 14 + embed_dim: 1024 + window_size: 48 + use_role_embedding: false + enable_scale_token: true + disable_keyframe_distinction: true + use_segment_mask: false + enable_camera_head: false + freeze: none + use_rel_pose_head: true + rel_pose_head_cfg: + enabled: true + keyframe_mode: fixed + keyframe_stride: 8 + reference_source: pred + detach_reference: false + trunk_depth: 4 + pose_mode: SE3 + num_heads: 16 + mlp_ratio: 4 + init_values: 0.01 + trans_act: linear + quat_act: linear + use_pair_cross_attn: false + xattn_temperature: 1.0 + use_precat: false + use_kf_role_embed: false + kf_role_embed_init_std: 0.02 + fl_act: relu + use_global_scale: false + reinit_camera_head: false + +inference: + mode: batch_refresh + streaming_mode: causal + window_size: 48 + keyframe_mode: fixed + keyframe_stride: 8 + refresh: 4 + rel_pose_head_cfg: + num_iterations: 4 + +data: + format: generalizable + data_roots_file: data_roots.txt + camera: null + img_path: "path/to/your/image/directory" + stride: 1 + max_frames: null + size: 518 + crop: false + patch_size: 14 + +output: + root: outputs + save_videos: true + save_points: true + save_frame_points: true + save_depth: true + save_images: true + mask_sky: true + max_full_pointcloud_points: 2000000 + max_frame_pointcloud_points: 200000 + skyseg_path: skyseg.onnx + +evaluation: + align_scale: true + depth_rel_delta_threshold: 1.25 + point_f1_threshold: 0.25 + point_eval_max_points: 100000 + point_eval_voxel_size: null + point_eval_oversample_factor: 4 diff --git a/demo_gradio.py b/demo_gradio.py new file mode 100644 index 0000000000000000000000000000000000000000..08271f3b7c02f6dd5a8edbb6d0458269c5dff877 --- /dev/null +++ b/demo_gradio.py @@ -0,0 +1,332 @@ +import os + +import gradio as gr + +from longstream.demo import BRANCH_OPTIONS, create_demo_session, load_metadata +from longstream.demo.backend import load_frame_previews +from longstream.demo.export import export_glb +from longstream.demo.viewer import build_interactive_figure + + +DEFAULT_KEYFRAME_STRIDE = 8 +DEFAULT_REFRESH = 3 +DEFAULT_WINDOW_SIZE = 48 +DEFAULT_CHECKPOINT = os.getenv("LONGSTREAM_CHECKPOINT", "checkpoints/50_longstream.pt") + + +def _run_stable_demo( + image_dir, + uploaded_files, + uploaded_video, + checkpoint, + device, + mode, + streaming_mode, + refresh, + window_size, + compute_sky, + branch_label, + show_cameras, + mask_sky, + camera_scale, + point_size, + opacity, + preview_max_points, + glb_max_points, +): + if not image_dir and not uploaded_files and not uploaded_video: + raise gr.Error("Provide an image folder, upload images, or upload a video.") + session_dir = create_demo_session( + image_dir=image_dir or "", + uploaded_files=uploaded_files, + uploaded_video=uploaded_video, + checkpoint=checkpoint, + device=device, + mode=mode, + streaming_mode=streaming_mode, + keyframe_stride=DEFAULT_KEYFRAME_STRIDE, + refresh=int(refresh), + window_size=int(window_size), + compute_sky=bool(compute_sky), + ) + fig = build_interactive_figure( + session_dir=session_dir, + branch=branch_label, + display_mode="All Frames", + frame_index=0, + point_size=float(point_size), + opacity=float(opacity), + preview_max_points=int(preview_max_points), + show_cameras=bool(show_cameras), + camera_scale=float(camera_scale), + mask_sky=bool(mask_sky), + ) + glb_path = export_glb( + session_dir=session_dir, + branch=branch_label, + display_mode="All Frames", + frame_index=0, + mask_sky=bool(mask_sky), + show_cameras=bool(show_cameras), + camera_scale=float(camera_scale), + max_points=int(glb_max_points), + ) + rgb, depth, frame_label = load_frame_previews(session_dir, 0) + meta = load_metadata(session_dir) + slider = gr.update( + minimum=0, + maximum=max(meta["num_frames"] - 1, 0), + value=0, + step=1, + interactive=meta["num_frames"] > 1, + ) + sky_msg = "" + if meta.get("has_sky_masks"): + removed = float(meta.get("sky_removed_ratio") or 0.0) * 100.0 + sky_msg = f" | sky_removed={removed:.1f}%" + status = f"Ready: {meta['num_frames']} frames | branch={branch_label}{sky_msg}" + return ( + fig, + glb_path, + session_dir, + rgb, + depth, + frame_label, + slider, + status, + ) + + +def _update_stable_scene( + session_dir, + branch_label, + show_cameras, + mask_sky, + camera_scale, + point_size, + opacity, + preview_max_points, + glb_max_points, +): + if not session_dir or not os.path.isdir(session_dir): + return None, None, "Run reconstruction first." + fig = build_interactive_figure( + session_dir=session_dir, + branch=branch_label, + display_mode="All Frames", + frame_index=0, + point_size=float(point_size), + opacity=float(opacity), + preview_max_points=int(preview_max_points), + show_cameras=bool(show_cameras), + camera_scale=float(camera_scale), + mask_sky=bool(mask_sky), + ) + glb_path = export_glb( + session_dir=session_dir, + branch=branch_label, + display_mode="All Frames", + frame_index=0, + mask_sky=bool(mask_sky), + show_cameras=bool(show_cameras), + camera_scale=float(camera_scale), + max_points=int(glb_max_points), + ) + meta = load_metadata(session_dir) + sky_msg = "" + if meta.get("has_sky_masks"): + removed = float(meta.get("sky_removed_ratio") or 0.0) * 100.0 + sky_msg = f" | sky_removed={removed:.1f}%" + return fig, glb_path, f"Updated preview: {branch_label}{sky_msg}" + + +def _update_frame_preview(session_dir, frame_index): + if not session_dir or not os.path.isdir(session_dir): + return None, None, "" + rgb, depth, label = load_frame_previews(session_dir, int(frame_index)) + return rgb, depth, label + + +def main(): + with gr.Blocks(title="LongStream Demo") as demo: + session_dir = gr.Textbox(visible=False) + + gr.Markdown("# LongStream Demo") + + with gr.Row(): + image_dir = gr.Textbox( + label="Image Folder", placeholder="/path/to/sequence" + ) + uploaded_files = gr.File( + label="Upload Images", file_count="multiple", file_types=["image"] + ) + uploaded_video = gr.File( + label="Upload Video", file_count="single", file_types=["video"] + ) + + with gr.Row(): + checkpoint = gr.Textbox(label="Checkpoint", value=DEFAULT_CHECKPOINT) + device = gr.Dropdown(label="Device", choices=["cuda", "cpu"], value="cuda") + + with gr.Accordion("Inference", open=False): + with gr.Row(): + mode = gr.Dropdown( + label="Mode", + choices=["streaming_refresh", "batch_refresh"], + value="batch_refresh", + ) + streaming_mode = gr.Dropdown( + label="Streaming Mode", choices=["causal", "window"], value="causal" + ) + with gr.Row(): + refresh = gr.Slider( + label="Refresh", minimum=2, maximum=9, step=1, value=DEFAULT_REFRESH + ) + window_size = gr.Slider( + label="Window Size", + minimum=1, + maximum=64, + step=1, + value=DEFAULT_WINDOW_SIZE, + ) + compute_sky = gr.Checkbox(label="Compute Sky Masks", value=True) + + with gr.Accordion("GLB Settings", open=True): + with gr.Row(): + branch_label = gr.Dropdown( + label="Point Cloud Branch", + choices=BRANCH_OPTIONS, + value="Point Head + Pose", + ) + show_cameras = gr.Checkbox(label="Show Cameras", value=True) + mask_sky = gr.Checkbox(label="Mask Sky", value=True) + with gr.Row(): + point_size = gr.Slider( + label="Point Size", + minimum=0.05, + maximum=2.0, + step=0.05, + value=0.3, + ) + opacity = gr.Slider( + label="Opacity", + minimum=0.1, + maximum=1.0, + step=0.05, + value=0.75, + ) + preview_max_points = gr.Slider( + label="Preview Max Points", + minimum=5000, + maximum=1000000, + step=10000, + value=100000, + ) + with gr.Row(): + camera_scale = gr.Slider( + label="Camera Scale", + minimum=0.001, + maximum=0.05, + step=0.001, + value=0.01, + ) + glb_max_points = gr.Slider( + label="GLB Max Points", + minimum=20000, + maximum=1000000, + step=10000, + value=400000, + ) + + run_btn = gr.Button("Run Stable Demo", variant="primary") + status = gr.Markdown("Provide input images, then run reconstruction.") + + plot = gr.Plot(label="Scene Preview") + + glb_file = gr.File(label="Download GLB") + + with gr.Row(): + frame_slider = gr.Slider( + label="Preview Frame", + minimum=0, + maximum=0, + step=1, + value=0, + interactive=False, + ) + frame_label = gr.Textbox(label="Frame") + with gr.Row(): + rgb_preview = gr.Image(label="RGB", type="numpy") + depth_preview = gr.Image(label="Depth Plasma", type="numpy") + + run_btn.click( + _run_stable_demo, + inputs=[ + image_dir, + uploaded_files, + uploaded_video, + checkpoint, + device, + mode, + streaming_mode, + refresh, + window_size, + compute_sky, + branch_label, + show_cameras, + mask_sky, + camera_scale, + point_size, + opacity, + preview_max_points, + glb_max_points, + ], + outputs=[ + plot, + glb_file, + session_dir, + rgb_preview, + depth_preview, + frame_label, + frame_slider, + status, + ], + ) + + for component in [ + branch_label, + show_cameras, + mask_sky, + camera_scale, + point_size, + opacity, + preview_max_points, + glb_max_points, + ]: + component.change( + _update_stable_scene, + inputs=[ + session_dir, + branch_label, + show_cameras, + mask_sky, + camera_scale, + point_size, + opacity, + preview_max_points, + glb_max_points, + ], + outputs=[plot, glb_file, status], + ) + + frame_slider.change( + _update_frame_preview, + inputs=[session_dir, frame_slider], + outputs=[rgb_preview, depth_preview, frame_label], + ) + + demo.launch() + + +if __name__ == "__main__": + main() diff --git a/longstream/.DS_Store b/longstream/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..16e6e3d50700299484c42a80cfa8578bdaa8e908 Binary files /dev/null and b/longstream/.DS_Store differ diff --git a/longstream/__init__.py b/longstream/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a9a2c5b3bb437bff74e283b62c894075e8c15331 --- /dev/null +++ b/longstream/__init__.py @@ -0,0 +1 @@ +__all__ = [] diff --git a/longstream/core/__init__.py b/longstream/core/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/longstream/core/cli.py b/longstream/core/cli.py new file mode 100644 index 0000000000000000000000000000000000000000..04a82d6e05dad679cd58c2a0637222dca1eba334 --- /dev/null +++ b/longstream/core/cli.py @@ -0,0 +1,213 @@ +import argparse +import os +import sys + +import yaml + + +def default_config_path() -> str: + return os.path.join( + os.path.dirname(os.path.dirname(os.path.dirname(__file__))), + "configs", + "longstream_infer.yaml", + ) + + +def add_runtime_arguments(parser): + parser.add_argument( + "--config", + default=default_config_path(), + help="Path to longstream config yaml.", + ) + parser.add_argument( + "--dataset", + default=None, + help="Optional dataset hint. Generic format works without it.", + ) + parser.add_argument("--img-path", default=None) + parser.add_argument( + "--seq-list", + default=None, + help="Comma-separated sequence names. Default: auto-detect all sequences.", + ) + parser.add_argument("--format", default=None, help="generalizable") + parser.add_argument("--data-roots-file", default=None) + parser.add_argument("--camera", default=None) + parser.add_argument("--output-root", default=None) + parser.add_argument("--device", default=None) + parser.add_argument("--checkpoint", default=None) + parser.add_argument("--hf-repo", default=None) + parser.add_argument("--hf-file", default=None) + parser.add_argument( + "--mode", default=None, help="batch_refresh | streaming_refresh" + ) + parser.add_argument("--streaming-mode", default=None, help="causal | window") + parser.add_argument("--window-size", type=int, default=None) + parser.add_argument("--keyframe-stride", type=int, default=None) + parser.add_argument( + "--refresh", + type=int, + default=None, + help="Number of keyframes per refresh span, inclusive of both ends and including the segment start keyframe.", + ) + parser.add_argument( + "--keyframes-per-batch", + dest="keyframes_per_batch_legacy", + type=int, + default=None, + help=argparse.SUPPRESS, + ) + parser.add_argument("--max-frames", type=int, default=None) + parser.add_argument("--depth-rel-delta-threshold", type=float, default=None) + parser.add_argument("--point-f1-threshold", type=float, default=None) + parser.add_argument("--eval-max-points", type=int, default=None) + parser.add_argument("--eval-voxel-size", type=float, default=None) + parser.add_argument("--max-full-pointcloud-points", type=int, default=None) + parser.add_argument("--max-frame-pointcloud-points", type=int, default=None) + parser.add_argument("--save-frame-points", action="store_true") + parser.add_argument("--no-save-frame-points", action="store_true") + parser.add_argument("--no-align-scale", action="store_true") + parser.add_argument("--mask-sky", action="store_true") + parser.add_argument("--no-mask-sky", action="store_true") + return parser + + +def parse_runtime_args(parser): + argv = [arg for arg in sys.argv[1:] if arg.strip()] + return parser.parse_args(argv) + + +def load_config_with_overrides(args): + with open(args.config, "r") as f: + cfg = yaml.safe_load(f) or {} + cfg.setdefault("model", {}) + + if args.device is not None: + cfg["device"] = args.device + + if args.output_root is not None: + cfg.setdefault("output", {}) + cfg["output"]["root"] = args.output_root + + if args.dataset is not None: + cfg.setdefault("data", {}) + cfg["data"]["dataset"] = args.dataset + + if args.img_path is not None: + cfg.setdefault("data", {}) + cfg["data"]["img_path"] = args.img_path + + if args.seq_list is not None: + seqs = [s.strip() for s in args.seq_list.split(",") if s.strip()] + cfg.setdefault("data", {}) + cfg["data"]["seq_list"] = seqs + + if args.format is not None: + cfg.setdefault("data", {}) + cfg["data"]["format"] = args.format + + if args.data_roots_file is not None: + cfg.setdefault("data", {}) + cfg["data"]["data_roots_file"] = args.data_roots_file + + if args.camera is not None: + cfg.setdefault("data", {}) + cfg["data"]["camera"] = args.camera + + if args.max_frames is not None: + cfg.setdefault("data", {}) + cfg["data"]["max_frames"] = args.max_frames + + if args.checkpoint is not None: + cfg.setdefault("model", {}) + cfg["model"]["checkpoint"] = args.checkpoint + + if args.hf_repo is not None or args.hf_file is not None: + cfg.setdefault("model", {}) + cfg["model"].setdefault("hf", {}) + if args.hf_repo is not None: + cfg["model"]["hf"]["repo_id"] = args.hf_repo + if args.hf_file is not None: + cfg["model"]["hf"]["filename"] = args.hf_file + if cfg["model"].get("checkpoint") is None: + cfg["model"]["checkpoint"] = None + + if args.mode is not None: + cfg.setdefault("inference", {}) + cfg["inference"]["mode"] = args.mode + + if args.streaming_mode is not None: + cfg.setdefault("inference", {}) + cfg["inference"]["streaming_mode"] = args.streaming_mode + + if args.window_size is not None: + cfg.setdefault("inference", {}) + cfg["inference"]["window_size"] = args.window_size + cfg["model"].setdefault("longstream_cfg", {}) + cfg["model"]["longstream_cfg"]["window_size"] = args.window_size + + if args.keyframe_stride is not None: + cfg.setdefault("inference", {}) + cfg["inference"]["keyframe_stride"] = args.keyframe_stride + cfg["model"].setdefault("longstream_cfg", {}) + cfg["model"]["longstream_cfg"].setdefault("rel_pose_head_cfg", {}) + cfg["model"]["longstream_cfg"]["rel_pose_head_cfg"][ + "keyframe_stride" + ] = args.keyframe_stride + + refresh = args.refresh + if refresh is None and args.keyframes_per_batch_legacy is not None: + refresh = args.keyframes_per_batch_legacy + 1 + if refresh is not None: + cfg.setdefault("inference", {}) + cfg["inference"]["refresh"] = refresh + + if args.depth_rel_delta_threshold is not None: + cfg.setdefault("evaluation", {}) + cfg["evaluation"]["depth_rel_delta_threshold"] = args.depth_rel_delta_threshold + + if args.point_f1_threshold is not None: + cfg.setdefault("evaluation", {}) + cfg["evaluation"]["point_f1_threshold"] = args.point_f1_threshold + + if args.eval_max_points is not None: + cfg.setdefault("evaluation", {}) + cfg["evaluation"]["point_eval_max_points"] = args.eval_max_points + + if args.eval_voxel_size is not None: + cfg.setdefault("evaluation", {}) + cfg["evaluation"]["point_eval_voxel_size"] = args.eval_voxel_size + + if args.max_full_pointcloud_points is not None: + cfg.setdefault("output", {}) + cfg["output"]["max_full_pointcloud_points"] = args.max_full_pointcloud_points + + if args.max_frame_pointcloud_points is not None: + cfg.setdefault("output", {}) + cfg["output"]["max_frame_pointcloud_points"] = args.max_frame_pointcloud_points + + if args.save_frame_points: + cfg.setdefault("output", {}) + cfg["output"]["save_frame_points"] = True + if args.no_save_frame_points: + cfg.setdefault("output", {}) + cfg["output"]["save_frame_points"] = False + + if args.no_align_scale: + cfg.setdefault("evaluation", {}) + cfg["evaluation"]["align_scale"] = False + + if args.mask_sky: + cfg.setdefault("output", {}) + cfg["output"]["mask_sky"] = True + if args.no_mask_sky: + cfg.setdefault("output", {}) + cfg["output"]["mask_sky"] = False + + infer_cfg = cfg.setdefault("inference", {}) + if "refresh" not in infer_cfg and "keyframes_per_batch" in infer_cfg: + infer_cfg["refresh"] = int(infer_cfg["keyframes_per_batch"]) + 1 + + cfg.setdefault("data", {}) + cfg["data"]["format"] = "generalizable" + return cfg diff --git a/longstream/core/infer.py b/longstream/core/infer.py new file mode 100644 index 0000000000000000000000000000000000000000..c8a8a3e48fe00a6635437663d3f51b487accda13 --- /dev/null +++ b/longstream/core/infer.py @@ -0,0 +1,451 @@ +import argparse +import os +import yaml +import cv2 +import numpy as np +import torch +from PIL import Image + +from longstream.core.model import LongStreamModel +from longstream.data.dataloader import LongStreamDataLoader +from longstream.streaming.keyframe_selector import KeyframeSelector +from longstream.streaming.refresh import run_batch_refresh, run_streaming_refresh +from longstream.utils.vendor.models.components.utils.pose_enc import ( + pose_encoding_to_extri_intri, +) +from longstream.utils.camera import compose_abs_from_rel +from longstream.utils.depth import colorize_depth, unproject_depth_to_points +from longstream.utils.sky_mask import compute_sky_mask +from longstream.io.save_points import save_pointcloud +from longstream.io.save_poses_txt import save_w2c_txt, save_intri_txt, save_rel_pose_txt +from longstream.io.save_images import save_image_sequence, save_video + + +def _to_uint8_rgb(images): + imgs = images.detach().cpu().numpy() + imgs = np.clip(imgs, 0.0, 1.0) + imgs = (imgs * 255.0).astype(np.uint8) + return imgs + + +def _ensure_dir(path): + os.makedirs(path, exist_ok=True) + + +def _apply_sky_mask(depth, mask): + if mask is None: + return depth + m = (mask > 0).astype(np.float32) + return depth * m + + +def _camera_points_to_world(points, extri): + pts = np.asarray(points, dtype=np.float64).reshape(-1, 3) + R = np.asarray(extri[:3, :3], dtype=np.float64) + t = np.asarray(extri[:3, 3], dtype=np.float64) + world = (R.T @ (pts.T - t[:, None])).T + return world.astype(np.float32, copy=False) + + +def _mask_points_and_colors(points, colors, mask): + pts = points.reshape(-1, 3) + cols = None if colors is None else colors.reshape(-1, 3) + if mask is None: + return pts, cols + valid = mask.reshape(-1) > 0 + pts = pts[valid] + if cols is not None: + cols = cols[valid] + return pts, cols + + +def _resize_long_edge(arr, long_edge_size, interpolation): + h, w = arr.shape[:2] + scale = float(long_edge_size) / float(max(h, w)) + new_w = int(round(w * scale)) + new_h = int(round(h * scale)) + return cv2.resize(arr, (new_w, new_h), interpolation=interpolation) + + +def _prepare_mask_for_model( + mask, size, crop, patch_size, target_shape, square_ok=False +): + if mask is None: + return None + long_edge = ( + round(size * max(mask.shape[1] / mask.shape[0], mask.shape[0] / mask.shape[1])) + if size == 224 + else size + ) + mask = _resize_long_edge(mask, long_edge, cv2.INTER_NEAREST) + + h, w = mask.shape[:2] + cx, cy = w // 2, h // 2 + if size == 224: + half = min(cx, cy) + target_w = 2 * half + target_h = 2 * half + if crop: + mask = mask[cy - half : cy + half, cx - half : cx + half] + else: + mask = cv2.resize( + mask, (target_w, target_h), interpolation=cv2.INTER_NEAREST + ) + else: + halfw = ((2 * cx) // patch_size) * (patch_size // 2) + halfh = ((2 * cy) // patch_size) * (patch_size // 2) + if not square_ok and w == h: + halfh = int(3 * halfw / 4) + target_w = 2 * halfw + target_h = 2 * halfh + if crop: + mask = mask[cy - halfh : cy + halfh, cx - halfw : cx + halfw] + else: + mask = cv2.resize( + mask, (target_w, target_h), interpolation=cv2.INTER_NEAREST + ) + + if mask.shape[:2] != tuple(target_shape): + mask = cv2.resize( + mask, (target_shape[1], target_shape[0]), interpolation=cv2.INTER_NEAREST + ) + return mask + + +def _save_full_pointcloud(path, point_chunks, color_chunks, max_points=None, seed=0): + if not point_chunks: + return + points = np.concatenate(point_chunks, axis=0) + colors = None + if color_chunks and len(color_chunks) == len(point_chunks): + colors = np.concatenate(color_chunks, axis=0) + if max_points is not None and len(points) > max_points: + rng = np.random.default_rng(seed) + keep = rng.choice(len(points), size=max_points, replace=False) + points = points[keep] + if colors is not None: + colors = colors[keep] + np.save(os.path.splitext(path)[0] + ".npy", points.astype(np.float32, copy=False)) + save_pointcloud(path, points, colors=colors, max_points=None, seed=seed) + + +def run_inference_cfg(cfg: dict): + device = cfg.get("device", "cuda" if torch.cuda.is_available() else "cpu") + device_type = torch.device(device).type + model_cfg = cfg.get("model", {}) + data_cfg = cfg.get("data", {}) + infer_cfg = cfg.get("inference", {}) + output_cfg = cfg.get("output", {}) + + print(f"[longstream] device={device}", flush=True) + model = LongStreamModel(model_cfg).to(device) + model.eval() + print("[longstream] model ready", flush=True) + + loader = LongStreamDataLoader(data_cfg) + + keyframe_stride = int(infer_cfg.get("keyframe_stride", 8)) + keyframe_mode = infer_cfg.get("keyframe_mode", "fixed") + refresh = int( + infer_cfg.get("refresh", int(infer_cfg.get("keyframes_per_batch", 3)) + 1) + ) + if refresh < 2: + raise ValueError( + "refresh must be >= 2 because it counts both keyframe endpoints" + ) + mode = infer_cfg.get("mode", "streaming_refresh") + if mode == "streaming": + mode = "streaming_refresh" + streaming_mode = infer_cfg.get("streaming_mode", "causal") + window_size = int(infer_cfg.get("window_size", 5)) + + selector = KeyframeSelector( + min_interval=keyframe_stride, + max_interval=keyframe_stride, + force_first=True, + mode="random" if keyframe_mode == "random" else "fixed", + ) + + out_root = output_cfg.get("root", "outputs") + _ensure_dir(out_root) + save_videos = bool(output_cfg.get("save_videos", True)) + save_points = bool(output_cfg.get("save_points", True)) + save_frame_points = bool(output_cfg.get("save_frame_points", True)) + save_depth = bool(output_cfg.get("save_depth", True)) + save_images = bool(output_cfg.get("save_images", True)) + mask_sky = bool(output_cfg.get("mask_sky", True)) + max_full_pointcloud_points = output_cfg.get("max_full_pointcloud_points", None) + if max_full_pointcloud_points is not None: + max_full_pointcloud_points = int(max_full_pointcloud_points) + max_frame_pointcloud_points = output_cfg.get("max_frame_pointcloud_points", None) + if max_frame_pointcloud_points is not None: + max_frame_pointcloud_points = int(max_frame_pointcloud_points) + skyseg_path = output_cfg.get( + "skyseg_path", + os.path.join(os.path.dirname(__file__), "..", "..", "skyseg.onnx"), + ) + + with torch.no_grad(): + for seq in loader: + images = seq.images + B, S, C, H, W = images.shape + print( + f"[longstream] sequence {seq.name}: inference start ({S} frames)", + flush=True, + ) + + is_keyframe, keyframe_indices = selector.select_keyframes( + S, B, images.device + ) + + rel_pose_cfg = infer_cfg.get("rel_pose_head_cfg", {"num_iterations": 4}) + + if mode == "batch_refresh": + outputs = run_batch_refresh( + model, + images, + is_keyframe, + keyframe_indices, + streaming_mode, + keyframe_stride, + refresh, + rel_pose_cfg, + ) + elif mode == "streaming_refresh": + outputs = run_streaming_refresh( + model, + images, + is_keyframe, + keyframe_indices, + streaming_mode, + window_size, + refresh, + rel_pose_cfg, + ) + else: + raise ValueError(f"Unsupported inference mode: {mode}") + print(f"[longstream] sequence {seq.name}: inference done", flush=True) + if device_type == "cuda": + torch.cuda.empty_cache() + + seq_dir = os.path.join(out_root, seq.name) + _ensure_dir(seq_dir) + + frame_ids = list(range(S)) + rgb = _to_uint8_rgb(images[0].permute(0, 2, 3, 1)) + + if "rel_pose_enc" in outputs: + rel_pose_enc = outputs["rel_pose_enc"][0] + abs_pose_enc = compose_abs_from_rel(rel_pose_enc, keyframe_indices[0]) + extri, intri = pose_encoding_to_extri_intri( + abs_pose_enc[None], image_size_hw=(H, W) + ) + extri_np = extri[0].detach().cpu().numpy() + intri_np = intri[0].detach().cpu().numpy() + + pose_dir = os.path.join(seq_dir, "poses") + _ensure_dir(pose_dir) + save_w2c_txt( + os.path.join(pose_dir, "abs_pose.txt"), extri_np, frame_ids + ) + save_intri_txt(os.path.join(pose_dir, "intri.txt"), intri_np, frame_ids) + save_rel_pose_txt( + os.path.join(pose_dir, "rel_pose.txt"), rel_pose_enc, frame_ids + ) + elif "pose_enc" in outputs: + pose_enc = outputs["pose_enc"][0] + extri, intri = pose_encoding_to_extri_intri( + pose_enc[None], image_size_hw=(H, W) + ) + extri_np = extri[0].detach().cpu().numpy() + intri_np = intri[0].detach().cpu().numpy() + + pose_dir = os.path.join(seq_dir, "poses") + _ensure_dir(pose_dir) + save_w2c_txt( + os.path.join(pose_dir, "abs_pose.txt"), extri_np, frame_ids + ) + save_intri_txt(os.path.join(pose_dir, "intri.txt"), intri_np, frame_ids) + + if save_images: + print(f"[longstream] sequence {seq.name}: saving rgb", flush=True) + rgb_dir = os.path.join(seq_dir, "images", "rgb") + save_image_sequence(rgb_dir, list(rgb)) + if save_videos: + save_video( + os.path.join(seq_dir, "images", "rgb.mp4"), + os.path.join(rgb_dir, "frame_*.png"), + ) + + sky_masks = None + if mask_sky: + raw_sky_masks = compute_sky_mask( + seq.image_paths, skyseg_path, os.path.join(seq_dir, "sky_masks") + ) + if raw_sky_masks is not None: + sky_masks = [ + _prepare_mask_for_model( + mask, + size=int(data_cfg.get("size", 518)), + crop=bool(data_cfg.get("crop", False)), + patch_size=int(data_cfg.get("patch_size", 14)), + target_shape=(H, W), + ) + for mask in raw_sky_masks + ] + + if save_depth and "depth" in outputs: + print(f"[longstream] sequence {seq.name}: saving depth", flush=True) + depth = outputs["depth"][0, :, :, :, 0].detach().cpu().numpy() + depth_dir = os.path.join(seq_dir, "depth", "dpt") + _ensure_dir(depth_dir) + color_dir = os.path.join(seq_dir, "depth", "dpt_plasma") + _ensure_dir(color_dir) + + color_frames = [] + for i in range(S): + d = depth[i] + if sky_masks is not None and sky_masks[i] is not None: + d = _apply_sky_mask(d, sky_masks[i]) + np.save(os.path.join(depth_dir, f"frame_{i:06d}.npy"), d) + colored = colorize_depth(d, cmap="plasma") + Image.fromarray(colored).save( + os.path.join(color_dir, f"frame_{i:06d}.png") + ) + color_frames.append(colored) + if save_videos: + save_video( + os.path.join(seq_dir, "depth", "dpt_plasma.mp4"), + os.path.join(color_dir, "frame_*.png"), + ) + + if save_points: + print( + f"[longstream] sequence {seq.name}: saving point clouds", flush=True + ) + if "world_points" in outputs: + if "rel_pose_enc" in outputs: + abs_pose_enc = compose_abs_from_rel( + outputs["rel_pose_enc"][0], keyframe_indices[0] + ) + extri, intri = pose_encoding_to_extri_intri( + abs_pose_enc[None], image_size_hw=(H, W) + ) + else: + extri, intri = pose_encoding_to_extri_intri( + outputs["pose_enc"][0][None], image_size_hw=(H, W) + ) + extri = extri[0] + intri = intri[0] + + pts_dir = os.path.join(seq_dir, "points", "point_head") + _ensure_dir(pts_dir) + pts = outputs["world_points"][0].detach().cpu().numpy() + full_pts = [] + full_cols = [] + for i in range(S): + pts_world = _camera_points_to_world( + pts[i], extri[i].detach().cpu().numpy() + ) + pts_world = pts_world.reshape(pts[i].shape) + pts_i, cols_i = _mask_points_and_colors( + pts_world, + rgb[i], + None if sky_masks is None else sky_masks[i], + ) + if save_frame_points: + save_pointcloud( + os.path.join(pts_dir, f"frame_{i:06d}.ply"), + pts_i, + colors=cols_i, + max_points=max_frame_pointcloud_points, + seed=i, + ) + if len(pts_i): + full_pts.append(pts_i) + full_cols.append(cols_i) + _save_full_pointcloud( + os.path.join(seq_dir, "points", "point_head_full.ply"), + full_pts, + full_cols, + max_points=max_full_pointcloud_points, + seed=0, + ) + + if "depth" in outputs and ( + "rel_pose_enc" in outputs or "pose_enc" in outputs + ): + depth = outputs["depth"][0, :, :, :, 0] + if "rel_pose_enc" in outputs: + abs_pose_enc = compose_abs_from_rel( + outputs["rel_pose_enc"][0], keyframe_indices[0] + ) + extri, intri = pose_encoding_to_extri_intri( + abs_pose_enc[None], image_size_hw=(H, W) + ) + else: + extri, intri = pose_encoding_to_extri_intri( + outputs["pose_enc"][0][None], image_size_hw=(H, W) + ) + + extri = extri[0] + intri = intri[0] + dpt_pts_dir = os.path.join(seq_dir, "points", "dpt_unproj") + _ensure_dir(dpt_pts_dir) + full_pts = [] + full_cols = [] + + for i in range(S): + d = depth[i] + pts_cam = unproject_depth_to_points(d[None], intri[i : i + 1])[ + 0 + ] + R = extri[i, :3, :3] + t = extri[i, :3, 3] + pts_world = ( + R.t() @ (pts_cam.reshape(-1, 3).t() - t[:, None]) + ).t() + pts_world = pts_world.cpu().numpy().reshape(-1, 3) + pts_i, cols_i = _mask_points_and_colors( + pts_world, + rgb[i], + None if sky_masks is None else sky_masks[i], + ) + if save_frame_points: + save_pointcloud( + os.path.join(dpt_pts_dir, f"frame_{i:06d}.ply"), + pts_i, + colors=cols_i, + max_points=max_frame_pointcloud_points, + seed=i, + ) + if len(pts_i): + full_pts.append(pts_i) + full_cols.append(cols_i) + _save_full_pointcloud( + os.path.join(seq_dir, "points", "dpt_unproj_full.ply"), + full_pts, + full_cols, + max_points=max_full_pointcloud_points, + seed=1, + ) + del outputs + if device_type == "cuda": + torch.cuda.empty_cache() + + +def run_inference(config_path: str): + with open(config_path, "r") as f: + cfg = yaml.safe_load(f) + run_inference_cfg(cfg) + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--config", required=True) + args = parser.parse_args() + run_inference(args.config) + + +if __name__ == "__main__": + main() diff --git a/longstream/core/model.py b/longstream/core/model.py new file mode 100644 index 0000000000000000000000000000000000000000..3732dade46a913db99485f1956690139b82623a5 --- /dev/null +++ b/longstream/core/model.py @@ -0,0 +1,69 @@ +import os +import torch +from typing import Dict, Any + +from longstream.models.longstream import LongStream +from longstream.utils.hub import resolve_checkpoint_path + + +class LongStreamModel(torch.nn.Module): + def __init__(self, cfg: Dict[str, Any] | None): + super().__init__() + cfg = cfg or {} + + ckpt_path = resolve_checkpoint_path( + cfg.get("checkpoint", None), cfg.get("hf", None) + ) + + stream_cfg = dict(cfg.get("longstream_cfg", {}) or {}) + rel_pose_cfg = stream_cfg.pop( + "rel_pose_head_cfg", cfg.get("rel_pose_head_cfg", None) + ) + use_rel_pose_head = bool(stream_cfg.pop("use_rel_pose_head", False)) + if use_rel_pose_head and rel_pose_cfg is not None: + stream_cfg["rel_pose_head_cfg"] = rel_pose_cfg + self.longstream = LongStream(**stream_cfg) + + if ckpt_path: + self.load_checkpoint(ckpt_path, strict=bool(cfg.get("strict_load", True))) + + def load_checkpoint(self, ckpt_path: str, strict: bool = True): + if not os.path.exists(ckpt_path): + raise FileNotFoundError(ckpt_path) + ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=False) + if isinstance(ckpt, dict): + if "model" in ckpt and isinstance(ckpt["model"], dict): + state = ckpt["model"] + elif "state_dict" in ckpt and isinstance(ckpt["state_dict"], dict): + state = ckpt["state_dict"] + else: + state = ckpt + else: + raise TypeError("Unsupported checkpoint format") + + if state: + first_key = next(iter(state.keys())) + if first_key.startswith("sampler.longstream."): + state = {k.replace("sampler.", "", 1): v for k, v in state.items()} + + missing, unexpected = self.load_state_dict(state, strict=False) + if missing or unexpected: + msg = f"checkpoint mismatch: missing={len(missing)} unexpected={len(unexpected)}" + if strict: + raise RuntimeError(msg) + print(msg) + + def forward(self, *args, **kwargs): + return self.longstream(*args, **kwargs) + + @property + def aggregator(self): + return self.longstream.aggregator + + @property + def camera_head(self): + return getattr(self.longstream, "camera_head", None) + + @property + def rel_pose_head(self): + return getattr(self.longstream, "rel_pose_head", None) diff --git a/longstream/data/__init__.py b/longstream/data/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..645d71c48f923b1529621ddcac36ac21c7c4ceb8 --- /dev/null +++ b/longstream/data/__init__.py @@ -0,0 +1,3 @@ +from .dataloader import LongStreamDataLoader, LongStreamSequence, LongStreamSequenceInfo + +__all__ = ["LongStreamDataLoader", "LongStreamSequence", "LongStreamSequenceInfo"] diff --git a/longstream/data/dataloader.py b/longstream/data/dataloader.py new file mode 100644 index 0000000000000000000000000000000000000000..ebe41b75d425b60ae9e93021ae6c624e4dbfdab6 --- /dev/null +++ b/longstream/data/dataloader.py @@ -0,0 +1,422 @@ +import os +import glob +from dataclasses import dataclass +from typing import List, Dict, Any, Iterator, Optional, Tuple + +import torch + +from longstream.utils.vendor.dust3r.utils.image import load_images_for_eval + +dataset_metadata: Dict[str, Dict[str, Any]] = { + "davis": { + "img_path": "data/davis/DAVIS/JPEGImages/480p", + "mask_path": "data/davis/DAVIS/masked_images/480p", + "dir_path_func": lambda img_path, seq: os.path.join(img_path, seq), + "gt_traj_func": lambda img_path, anno_path, seq: None, + "traj_format": None, + "seq_list": None, + "full_seq": True, + "mask_path_seq_func": lambda mask_path, seq: os.path.join(mask_path, seq), + "skip_condition": None, + "process_func": None, + }, + "kitti": { + "img_path": "data/kitti/sequences", + "anno_path": "data/kitti/poses", + "mask_path": None, + "dir_path_func": lambda img_path, seq: os.path.join(img_path, seq, "image_2"), + "gt_traj_func": lambda img_path, anno_path, seq: os.path.join( + anno_path, f"{seq}.txt" + ) + if os.path.exists(os.path.join(anno_path, f"{seq}.txt")) + else None, + "traj_format": "kitti", + "seq_list": ["00", "01", "02", "03", "04", "05", "06", "07", "08", "09", "10"], + "full_seq": True, + "mask_path_seq_func": lambda mask_path, seq: None, + "skip_condition": None, + "process_func": None, + }, + "bonn": { + "img_path": "data/bonn/rgbd_bonn_dataset", + "mask_path": None, + "dir_path_func": lambda img_path, seq: os.path.join( + img_path, f"rgbd_bonn_{seq}", "rgb_110" + ), + "gt_traj_func": lambda img_path, anno_path, seq: os.path.join( + img_path, f"rgbd_bonn_{seq}", "groundtruth_110.txt" + ), + "traj_format": "tum", + "seq_list": ["balloon2", "crowd2", "crowd3", "person_tracking2", "synchronous"], + "full_seq": False, + "mask_path_seq_func": lambda mask_path, seq: None, + "skip_condition": None, + "process_func": None, + }, + "nyu": { + "img_path": "data/nyu-v2/val/nyu_images", + "mask_path": None, + "process_func": None, + }, + "scannet": { + "img_path": "data/scannetv2", + "mask_path": None, + "dir_path_func": lambda img_path, seq: os.path.join(img_path, seq, "color_90"), + "gt_traj_func": lambda img_path, anno_path, seq: os.path.join( + img_path, seq, "pose_90.txt" + ), + "traj_format": "replica", + "seq_list": None, + "full_seq": True, + "mask_path_seq_func": lambda mask_path, seq: None, + "skip_condition": None, + "process_func": None, + }, + "tum": { + "img_path": "data/tum", + "mask_path": None, + "dir_path_func": lambda img_path, seq: os.path.join(img_path, seq, "rgb_90"), + "gt_traj_func": lambda img_path, anno_path, seq: os.path.join( + img_path, seq, "groundtruth_90.txt" + ), + "traj_format": "tum", + "seq_list": None, + "full_seq": True, + "mask_path_seq_func": lambda mask_path, seq: None, + "skip_condition": None, + "process_func": None, + }, + "sintel": { + "img_path": "data/sintel/training/final", + "anno_path": "data/sintel/training/camdata_left", + "mask_path": None, + "dir_path_func": lambda img_path, seq: os.path.join(img_path, seq), + "gt_traj_func": lambda img_path, anno_path, seq: os.path.join(anno_path, seq), + "traj_format": None, + "seq_list": [ + "alley_2", + "ambush_4", + "ambush_5", + "ambush_6", + "cave_2", + "cave_4", + "market_2", + "market_5", + "market_6", + "shaman_3", + "sleeping_1", + "sleeping_2", + "temple_2", + "temple_3", + ], + "full_seq": False, + "mask_path_seq_func": lambda mask_path, seq: None, + "skip_condition": None, + "process_func": None, + }, + "waymo": { + "img_path": "/horizon-bucket/saturn_v_4dlabel/004_vision/01_users/tao02.xie/datasets/scatt3r_evaluation/waymo_open_dataset_v1_4_3", + "anno_path": None, + "mask_path": None, + "dir_path_func": lambda img_path, seq: os.path.join( + img_path, + seq.split("_cam")[0] if "_cam" in seq else seq, + "images", + seq.split("_cam")[1] if "_cam" in seq else "00", + ), + "gt_traj_func": lambda img_path, anno_path, seq: os.path.join( + img_path, + seq.split("_cam")[0] if "_cam" in seq else seq, + "cameras", + seq.split("_cam")[1] if "_cam" in seq else "00", + "extri.yml", + ), + "traj_format": "waymo", + "seq_list": None, + "full_seq": True, + "mask_path_seq_func": lambda mask_path, seq: None, + "skip_condition": None, + "process_func": None, + }, +} + + +@dataclass +class LongStreamSequenceInfo: + name: str + scene_root: str + image_dir: str + image_paths: List[str] + camera: Optional[str] + + +class LongStreamSequence: + def __init__( + self, + name: str, + images: torch.Tensor, + image_paths: List[str], + scene_root: Optional[str] = None, + image_dir: Optional[str] = None, + camera: Optional[str] = None, + ): + self.name = name + self.images = images + self.image_paths = image_paths + self.scene_root = scene_root + self.image_dir = image_dir + self.camera = camera + + +def _read_list_file(path: str) -> List[str]: + with open(path, "r") as f: + lines = [] + for line in f.readlines(): + line = line.strip() + if not line: + continue + if line.startswith("#"): + continue + lines.append(line) + return lines + + +def _is_generalizable_scene_root(path: str) -> bool: + return os.path.isdir(os.path.join(path, "images")) + + +def _direct_image_files(dir_path: str) -> List[str]: + filelist = sorted(glob.glob(os.path.join(dir_path, "*.png"))) + if not filelist: + filelist = sorted(glob.glob(os.path.join(dir_path, "*.jpg"))) + if not filelist: + filelist = sorted(glob.glob(os.path.join(dir_path, "*.jpeg"))) + return filelist + + +class LongStreamDataLoader: + def __init__(self, cfg: Dict[str, Any]): + self.cfg = cfg + self.dataset = cfg.get("dataset", None) + meta = dataset_metadata.get(self.dataset, {}) + self.img_path = cfg.get("img_path", meta.get("img_path")) + self.mask_path = cfg.get("mask_path", meta.get("mask_path")) + self.dir_path_func = meta.get("dir_path_func", lambda p, s: os.path.join(p, s)) + self.mask_path_seq_func = meta.get("mask_path_seq_func", lambda p, s: None) + self.full_seq = bool(cfg.get("full_seq", meta.get("full_seq", True))) + self.seq_list = cfg.get("seq_list", None) + self.stride = int(cfg.get("stride", 1)) + self.max_frames = cfg.get("max_frames", None) + self.size = int(cfg.get("size", 518)) + self.crop = bool(cfg.get("crop", False)) + self.patch_size = int(cfg.get("patch_size", 14)) + self.format = cfg.get("format", "auto") + self.data_roots_file = cfg.get("data_roots_file", None) + self.split = cfg.get("split", None) + self.camera = cfg.get("camera", None) + + def _infer_format(self) -> str: + if self.format in ["relpose", "generalizable"]: + return self.format + if self.img_path is None: + return "relpose" + if _is_generalizable_scene_root(self.img_path): + return "generalizable" + default_list = self.data_roots_file or "data_roots.txt" + if os.path.exists(os.path.join(self.img_path, default_list)): + return "generalizable" + return "relpose" + + def _resolve_seq_list_generalizable(self) -> List[str]: + if self.seq_list is not None: + return list(self.seq_list) + if self.img_path is None or not os.path.isdir(self.img_path): + return [] + + if _is_generalizable_scene_root(self.img_path): + return [self.img_path] + + candidates = [] + if isinstance(self.data_roots_file, str) and self.data_roots_file: + candidates.append(self.data_roots_file) + if isinstance(self.split, str) and self.split: + split_name = self.split.lower() + if split_name in ["val", "valid", "validate"]: + split_name = "validate" + candidates.append(f"{split_name}_data_roots.txt") + candidates.append("data_roots.txt") + candidates.append("train_data_roots.txt") + candidates.append("validate_data_roots.txt") + + for fname in candidates: + path = os.path.join(self.img_path, fname) + if os.path.exists(path): + return _read_list_file(path) + + img_dirs = sorted( + glob.glob(os.path.join(self.img_path, "**", "images"), recursive=True) + ) + scene_roots = [os.path.dirname(p) for p in img_dirs] + + rels = [] + for p in scene_roots: + try: + rels.append(os.path.relpath(p, self.img_path)) + except ValueError: + rels.append(p) + return sorted(set(rels)) + + def _resolve_seq_list_relpose(self) -> List[str]: + if self.seq_list is not None: + return list(self.seq_list) + meta = dataset_metadata.get(self.dataset, {}) + if self.full_seq: + if self.img_path is None or not os.path.isdir(self.img_path): + return [] + seqs = [ + s + for s in os.listdir(self.img_path) + if os.path.isdir(os.path.join(self.img_path, s)) + ] + return sorted(seqs) + seqs = meta.get("seq_list", []) or [] + return list(seqs) + + def _resolve_seq_list(self) -> List[str]: + fmt = self._infer_format() + if fmt == "generalizable": + return self._resolve_seq_list_generalizable() + return self._resolve_seq_list_relpose() + + def _resolve_scene_root(self, seq_entry: str) -> Tuple[str, str]: + if os.path.isabs(seq_entry) or os.path.sep in seq_entry: + scene_root = seq_entry + name = os.path.basename(os.path.normpath(seq_entry)) + else: + scene_root = os.path.join(self.img_path, seq_entry) + name = seq_entry + return name, scene_root + + def _resolve_image_dir_generalizable(self, scene_root: str) -> Optional[str]: + images_root = os.path.join(scene_root, "images") + if not os.path.isdir(images_root): + return None + + if isinstance(self.camera, str) and self.camera: + cam_dir = os.path.join(images_root, self.camera) + if os.path.isdir(cam_dir): + return cam_dir + + if _direct_image_files(images_root): + return images_root + + cams = [ + d + for d in os.listdir(images_root) + if os.path.isdir(os.path.join(images_root, d)) + ] + if not cams: + return None + cams = sorted(cams) + + frame_dirs = [] + for name in cams: + child_dir = os.path.join(images_root, name) + child_images = _direct_image_files(child_dir) + if child_images: + frame_dirs.append((name, len(child_images))) + + if ( + len(cams) > 10 + and len(frame_dirs) == len(cams) + and max(count for _, count in frame_dirs) == 1 + ): + return images_root + + return os.path.join(images_root, cams[0]) + + def _camera_from_image_dir(self, image_dir: str) -> Optional[str]: + parent = os.path.basename(os.path.dirname(image_dir)) + if parent != "images": + return None + return os.path.basename(image_dir) + + def _collect_filelist(self, dir_path: str) -> List[str]: + filelist = _direct_image_files(dir_path) + if not filelist: + nested = [] + child_dirs = sorted( + d for d in glob.glob(os.path.join(dir_path, "*")) if os.path.isdir(d) + ) + for child_dir in child_dirs: + child_images = _direct_image_files(child_dir) + if child_images: + nested.append(child_images[0]) + filelist = nested + if self.stride > 1: + filelist = filelist[:: self.stride] + if self.max_frames is not None: + filelist = filelist[: self.max_frames] + return filelist + + def _load_images(self, filelist: List[str]) -> torch.Tensor: + views = load_images_for_eval( + filelist, + size=self.size, + verbose=False, + crop=self.crop, + patch_size=self.patch_size, + ) + imgs = torch.cat([view["img"] for view in views], dim=0) + images = imgs.unsqueeze(0) + images = (images + 1.0) / 2.0 + return images + + def iter_sequence_infos(self) -> Iterator[LongStreamSequenceInfo]: + fmt = self._infer_format() + seqs = self._resolve_seq_list() + for seq_entry in seqs: + if fmt == "generalizable": + seq, scene_root = self._resolve_scene_root(seq_entry) + dir_path = self._resolve_image_dir_generalizable(scene_root) + if dir_path is None or not os.path.isdir(dir_path): + continue + camera = self._camera_from_image_dir(dir_path) + else: + seq = seq_entry + scene_root = os.path.join(self.img_path, seq) + dir_path = self.dir_path_func(self.img_path, seq) + if not os.path.isdir(dir_path): + continue + camera = None + + filelist = self._collect_filelist(dir_path) + if not filelist: + continue + yield LongStreamSequenceInfo( + name=seq, + scene_root=scene_root, + image_dir=dir_path, + image_paths=filelist, + camera=camera, + ) + + def __iter__(self) -> Iterator[LongStreamSequence]: + for info in self.iter_sequence_infos(): + print( + f"[longstream] loading sequence {info.name}: {len(info.image_paths)} frames", + flush=True, + ) + images = self._load_images(info.image_paths) + print( + f"[longstream] loaded sequence {info.name}: {tuple(images.shape)}", + flush=True, + ) + yield LongStreamSequence( + info.name, + images, + info.image_paths, + scene_root=info.scene_root, + image_dir=info.image_dir, + camera=info.camera, + ) diff --git a/longstream/demo/__init__.py b/longstream/demo/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c086061055c3555c9d028d95352520825ec40b6c --- /dev/null +++ b/longstream/demo/__init__.py @@ -0,0 +1,11 @@ +from .backend import create_demo_session, load_frame_previews +from .common import BRANCH_OPTIONS, DISPLAY_MODE_OPTIONS, branch_key, load_metadata + +__all__ = [ + "BRANCH_OPTIONS", + "DISPLAY_MODE_OPTIONS", + "branch_key", + "create_demo_session", + "load_frame_previews", + "load_metadata", +] diff --git a/longstream/demo/backend.py b/longstream/demo/backend.py new file mode 100644 index 0000000000000000000000000000000000000000..b201cafde08372ec1d2575dc6d986ad3ed804f02 --- /dev/null +++ b/longstream/demo/backend.py @@ -0,0 +1,495 @@ +import json +import os +import re +import shutil +import tempfile +from datetime import datetime +from typing import Iterable, List, Optional, Tuple + +import cv2 +import numpy as np +import torch +import yaml + +from longstream.core.cli import default_config_path +from longstream.core.model import LongStreamModel +from longstream.streaming.keyframe_selector import KeyframeSelector +from longstream.streaming.refresh import run_batch_refresh, run_streaming_refresh +from longstream.utils.camera import compose_abs_from_rel +from longstream.utils.depth import colorize_depth +from longstream.utils.hub import resolve_checkpoint_path +from longstream.utils.sky_mask import compute_sky_mask +from longstream.utils.vendor.dust3r.utils.image import load_images_for_eval +from longstream.utils.vendor.models.components.utils.pose_enc import ( + pose_encoding_to_extri_intri, +) + +from .common import load_metadata, session_file + +_IMAGE_EXTS = (".png", ".jpg", ".jpeg", ".bmp", ".webp") +_MODEL_CACHE = {} + + +def _resolve_file_path(item) -> str: + if item is None: + return "" + if isinstance(item, str): + return item + if isinstance(item, dict) and "name" in item: + return item["name"] + if hasattr(item, "name"): + return item.name + return str(item) + + +def _natural_sort_key(path: str): + name = os.path.basename(path) + stem, _ = os.path.splitext(name) + parts = re.split(r"(\d+)", stem) + key = [] + for part in parts: + if not part: + continue + if part.isdigit(): + key.append((0, int(part))) + else: + key.append((1, part.lower())) + return key, name.lower() + + +def _sorted_image_paths(image_dir: str) -> List[str]: + files = [] + for name in os.listdir(image_dir): + if name.lower().endswith(_IMAGE_EXTS): + files.append(os.path.join(image_dir, name)) + return sorted(files, key=_natural_sort_key) + + +def _session_root() -> str: + root = os.path.join(tempfile.gettempdir(), "longstream_demo_sessions") + os.makedirs(root, exist_ok=True) + return root + + +def _new_session_dir() -> str: + stamp = datetime.utcnow().strftime("%Y%m%d_%H%M%S_%f") + return tempfile.mkdtemp(prefix=f"longstream_{stamp}_", dir=_session_root()) + + +def _copy_uploaded_images(uploaded_files: Iterable, session_dir: str) -> List[str]: + input_dir = os.path.join(session_dir, "input_images") + os.makedirs(input_dir, exist_ok=True) + copied = [] + sources = sorted( + (_resolve_file_path(x) for x in uploaded_files if x), + key=_natural_sort_key, + ) + for src in sources: + if not src or not os.path.isfile(src): + continue + dst = os.path.join(input_dir, os.path.basename(src)) + shutil.copy2(src, dst) + copied.append(dst) + return copied + + +def _extract_uploaded_video(uploaded_video, session_dir: str) -> List[str]: + src = _resolve_file_path(uploaded_video) + if not src: + return [] + if not os.path.isfile(src): + raise FileNotFoundError(src) + + input_dir = os.path.join(session_dir, "input_images") + os.makedirs(input_dir, exist_ok=True) + cap = cv2.VideoCapture(src) + if not cap.isOpened(): + raise ValueError(f"unable to open video: {src}") + + image_paths = [] + frame_id = 0 + while True: + ok, frame = cap.read() + if not ok: + break + dst = os.path.join(input_dir, f"{frame_id:06d}.png") + if not cv2.imwrite(dst, frame): + cap.release() + raise ValueError(f"failed to write extracted frame: {dst}") + image_paths.append(dst) + frame_id += 1 + cap.release() + + if not image_paths: + raise ValueError(f"no frames extracted from video: {src}") + return image_paths + + +def _resize_long_edge(arr, long_edge_size, interpolation): + h, w = arr.shape[:2] + scale = float(long_edge_size) / float(max(h, w)) + new_w = int(round(w * scale)) + new_h = int(round(h * scale)) + return cv2.resize(arr, (new_w, new_h), interpolation=interpolation) + + +def _prepare_mask_for_model( + mask, size, crop, patch_size, target_shape, square_ok=False +): + if mask is None: + return None + h0, w0 = mask.shape[:2] + long_edge = round(size * max(w0 / h0, h0 / w0)) if size == 224 else size + mask = _resize_long_edge(mask, long_edge, cv2.INTER_NEAREST) + + h, w = mask.shape[:2] + cx, cy = w // 2, h // 2 + if size == 224: + half = min(cx, cy) + if crop: + mask = mask[cy - half : cy + half, cx - half : cx + half] + else: + mask = cv2.resize( + mask, (2 * half, 2 * half), interpolation=cv2.INTER_NEAREST + ) + else: + halfw = ((2 * cx) // patch_size) * (patch_size // 2) + halfh = ((2 * cy) // patch_size) * (patch_size // 2) + if not square_ok and w == h: + halfh = int(3 * halfw / 4) + if crop: + mask = mask[cy - halfh : cy + halfh, cx - halfw : cx + halfw] + else: + mask = cv2.resize( + mask, (2 * halfw, 2 * halfh), interpolation=cv2.INTER_NEAREST + ) + + if mask.shape[:2] != tuple(target_shape): + mask = cv2.resize( + mask, (target_shape[1], target_shape[0]), interpolation=cv2.INTER_NEAREST + ) + return mask.astype(np.uint8, copy=False) + + +def _load_base_config(config_path: Optional[str] = None) -> dict: + path = config_path or default_config_path() + with open(path, "r") as f: + return yaml.safe_load(f) or {} + + +def _resolve_demo_checkpoint(checkpoint: str) -> str: + local_candidates = [] + for candidate in [checkpoint, os.getenv("LONGSTREAM_CHECKPOINT", "")]: + if isinstance(candidate, str) and candidate: + local_candidates.append(candidate) + + for candidate in local_candidates: + if os.path.exists(candidate): + return os.path.abspath(candidate) + + hf_cfg = { + "repo_id": os.getenv("LONGSTREAM_HF_REPO"), + "filename": os.getenv("LONGSTREAM_HF_FILE"), + "revision": os.getenv("LONGSTREAM_HF_REVISION"), + "local_dir": os.getenv("LONGSTREAM_HF_LOCAL_DIR", "checkpoints"), + } + resolved = resolve_checkpoint_path(None, hf_cfg) + if resolved and os.path.exists(resolved): + return os.path.abspath(resolved) + + if hf_cfg["repo_id"] and hf_cfg["filename"]: + raise FileNotFoundError( + "checkpoint not found locally and Hugging Face resolution failed: " + f"repo_id={hf_cfg['repo_id']} filename={hf_cfg['filename']}" + ) + + searched = ", ".join(local_candidates) if local_candidates else "" + raise FileNotFoundError( + "checkpoint not found. " + f"searched local paths: {searched}. " + "You can also set LONGSTREAM_HF_REPO and LONGSTREAM_HF_FILE." + ) + + +def _model_device(device: str) -> str: + if device == "cuda" and not torch.cuda.is_available(): + return "cpu" + return device + + +def _cache_key(checkpoint: str, device: str, model_cfg: dict) -> Tuple[str, str, str]: + rel_cfg = json.dumps(model_cfg.get("longstream_cfg", {}), sort_keys=True) + return checkpoint, device, rel_cfg + + +def get_or_load_model(checkpoint: str, device: str, model_cfg: dict) -> LongStreamModel: + device = _model_device(device) + cfg = json.loads(json.dumps(model_cfg)) + cfg["checkpoint"] = checkpoint + key = _cache_key(checkpoint, device, cfg) + model = _MODEL_CACHE.get(key) + if model is None: + model = LongStreamModel(cfg).to(device) + model.eval() + _MODEL_CACHE.clear() + _MODEL_CACHE[key] = model + return model + + +def _load_images( + image_paths: List[str], size: int, crop: bool, patch_size: int +) -> torch.Tensor: + views = load_images_for_eval( + image_paths, size=size, verbose=False, crop=crop, patch_size=patch_size + ) + imgs = torch.cat([view["img"] for view in views], dim=0) + images = (imgs.unsqueeze(0) + 1.0) / 2.0 + return images + + +def _select_keyframes(images: torch.Tensor, keyframe_stride: int, keyframe_mode: str): + selector = KeyframeSelector( + min_interval=keyframe_stride, + max_interval=keyframe_stride, + force_first=True, + mode="random" if keyframe_mode == "random" else "fixed", + ) + return selector.select_keyframes(images.shape[1], images.shape[0], images.device) + + +def _run_model(images: torch.Tensor, model: LongStreamModel, infer_cfg: dict): + keyframe_stride = int(infer_cfg.get("keyframe_stride", 8)) + keyframe_mode = infer_cfg.get("keyframe_mode", "fixed") + refresh = int(infer_cfg.get("refresh", 4)) + mode = infer_cfg.get("mode", "streaming_refresh") + streaming_mode = infer_cfg.get("streaming_mode", "causal") + window_size = int(infer_cfg.get("window_size", 48)) + rel_pose_cfg = infer_cfg.get("rel_pose_head_cfg", {"num_iterations": 4}) + + is_keyframe, keyframe_indices = _select_keyframes( + images, keyframe_stride, keyframe_mode + ) + if mode == "batch_refresh": + outputs = run_batch_refresh( + model, + images, + is_keyframe, + keyframe_indices, + streaming_mode, + keyframe_stride, + refresh, + rel_pose_cfg, + ) + elif mode == "streaming_refresh": + outputs = run_streaming_refresh( + model, + images, + is_keyframe, + keyframe_indices, + streaming_mode, + window_size, + refresh, + rel_pose_cfg, + ) + else: + raise ValueError(f"Unsupported demo inference mode: {mode}") + return outputs, keyframe_indices + + +def _compute_pose_outputs( + outputs: dict, keyframe_indices: torch.Tensor, image_hw: Tuple[int, int] +): + if "rel_pose_enc" in outputs: + rel_pose_enc = outputs["rel_pose_enc"][0] + abs_pose_enc = compose_abs_from_rel(rel_pose_enc, keyframe_indices[0]) + extri, intri = pose_encoding_to_extri_intri( + abs_pose_enc[None], image_size_hw=image_hw + ) + return ( + rel_pose_enc.detach().cpu().numpy(), + extri[0].detach().cpu().numpy(), + intri[0].detach().cpu().numpy(), + ) + if "pose_enc" in outputs: + pose_enc = outputs["pose_enc"][0] + extri, intri = pose_encoding_to_extri_intri( + pose_enc[None], image_size_hw=image_hw + ) + return None, extri[0].detach().cpu().numpy(), intri[0].detach().cpu().numpy() + raise RuntimeError("Model outputs contain neither rel_pose_enc nor pose_enc") + + +def _compute_sky_masks( + image_paths: List[str], + target_shape: Tuple[int, int], + data_cfg: dict, + skyseg_path: str, + session_dir: str, +): + raw_masks = compute_sky_mask( + image_paths, skyseg_path, os.path.join(session_dir, "sky_masks_raw") + ) + if raw_masks is None: + return None + masks = [] + for mask in raw_masks: + masks.append( + _prepare_mask_for_model( + mask, + size=int(data_cfg.get("size", 518)), + crop=bool(data_cfg.get("crop", False)), + patch_size=int(data_cfg.get("patch_size", 14)), + target_shape=target_shape, + ) + ) + return np.stack(masks, axis=0) + + +def create_demo_session( + image_dir: str, + uploaded_files, + uploaded_video, + checkpoint: str, + device: str, + mode: str, + streaming_mode: str, + keyframe_stride: int, + refresh: int, + window_size: int, + compute_sky: bool, + config_path: Optional[str] = None, +) -> str: + checkpoint = _resolve_demo_checkpoint(checkpoint) + + session_dir = _new_session_dir() + base_cfg = _load_base_config(config_path) + data_cfg = dict(base_cfg.get("data", {})) + model_cfg = dict(base_cfg.get("model", {})) + infer_cfg = dict(base_cfg.get("inference", {})) + + if image_dir: + image_dir = os.path.abspath(image_dir) + if not os.path.isdir(image_dir): + raise FileNotFoundError(f"image_dir not found: {image_dir}") + image_paths = _sorted_image_paths(image_dir) + input_root = image_dir + elif uploaded_video: + image_paths = _extract_uploaded_video(uploaded_video, session_dir) + input_root = _resolve_file_path(uploaded_video) + else: + image_paths = _copy_uploaded_images(uploaded_files or [], session_dir) + input_root = os.path.dirname(image_paths[0]) if image_paths else "" + + if not image_paths: + raise ValueError("No input images found") + + data_cfg["size"] = int(data_cfg.get("size", 518)) + data_cfg["crop"] = bool(data_cfg.get("crop", False)) + data_cfg["patch_size"] = int(data_cfg.get("patch_size", 14)) + + device = _model_device(device) + model = get_or_load_model(checkpoint, device, model_cfg) + + images = _load_images( + image_paths, data_cfg["size"], data_cfg["crop"], data_cfg["patch_size"] + ) + infer_cfg.update( + { + "mode": mode, + "streaming_mode": streaming_mode, + "keyframe_stride": int(keyframe_stride), + "refresh": int(refresh), + "window_size": int(window_size), + } + ) + + with torch.no_grad(): + outputs, keyframe_indices = _run_model(images, model, infer_cfg) + h, w = images.shape[-2:] + rel_pose_enc, extri, intri = _compute_pose_outputs( + outputs, keyframe_indices, (h, w) + ) + point_head = ( + outputs["world_points"][0] + .detach() + .cpu() + .numpy() + .astype(np.float32, copy=False) + ) + depth = ( + outputs["depth"][0, :, :, :, 0] + .detach() + .cpu() + .numpy() + .astype(np.float32, copy=False) + ) + + if device == "cuda": + torch.cuda.empty_cache() + + images_uint8 = np.clip( + images[0].permute(0, 2, 3, 1).cpu().numpy() * 255.0, 0, 255 + ).astype(np.uint8) + sky_masks = None + if compute_sky: + skyseg_path = os.path.join( + os.path.dirname(os.path.dirname(os.path.dirname(__file__))), "skyseg.onnx" + ) + sky_masks = _compute_sky_masks( + image_paths, (h, w), data_cfg, skyseg_path, session_dir + ) + + np.save(session_file(session_dir, "images.npy"), images_uint8) + np.save(session_file(session_dir, "depth.npy"), depth) + np.save(session_file(session_dir, "point_head.npy"), point_head) + np.save(session_file(session_dir, "w2c.npy"), extri) + np.save(session_file(session_dir, "intri.npy"), intri) + if rel_pose_enc is not None: + np.save( + session_file(session_dir, "rel_pose_enc.npy"), + rel_pose_enc.astype(np.float32, copy=False), + ) + if sky_masks is not None: + np.save( + session_file(session_dir, "sky_masks.npy"), + sky_masks.astype(np.uint8, copy=False), + ) + + sky_removed_ratio = None + if sky_masks is not None: + sky_removed_ratio = float(1.0 - (sky_masks > 0).mean()) + + metadata = { + "session_dir": session_dir, + "created_at": datetime.utcnow().isoformat() + "Z", + "checkpoint": os.path.abspath(checkpoint), + "device": device, + "mode": mode, + "streaming_mode": streaming_mode, + "keyframe_stride": int(keyframe_stride), + "refresh": int(refresh), + "window_size": int(window_size), + "num_frames": int(images_uint8.shape[0]), + "height": int(images_uint8.shape[1]), + "width": int(images_uint8.shape[2]), + "input_root": input_root, + "image_paths": image_paths, + "has_sky_masks": bool(sky_masks is not None), + "sky_removed_ratio": sky_removed_ratio, + } + with open(session_file(session_dir, "metadata.json"), "w") as f: + json.dump(metadata, f, indent=2) + + del outputs + return session_dir + + +def load_frame_previews(session_dir: str, frame_index: int): + meta = load_metadata(session_dir) + frame_index = int(np.clip(frame_index, 0, meta["num_frames"] - 1)) + images = np.load(session_file(session_dir, "images.npy"), mmap_mode="r") + depth = np.load(session_file(session_dir, "depth.npy"), mmap_mode="r") + rgb = np.array(images[frame_index]) + depth_color = colorize_depth(np.array(depth[frame_index]), cmap="plasma") + label = f"Frame {frame_index + 1}/{meta['num_frames']}" + return rgb, depth_color, label diff --git a/longstream/demo/common.py b/longstream/demo/common.py new file mode 100644 index 0000000000000000000000000000000000000000..e8de93b227e292431b080b0729dfdaba499a37ae --- /dev/null +++ b/longstream/demo/common.py @@ -0,0 +1,84 @@ +import json +import os +from typing import List + +import numpy as np + +BRANCH_OPTIONS = [ + "Point Head + Pose", + "Depth Projection + Pose", +] +BRANCH_TO_KEY = { + "Point Head + Pose": "point_head", + "Depth Projection + Pose": "depth_projection", +} +DISPLAY_MODE_OPTIONS = [ + "Current Frame", + "Accumulate to Frame", + "All Frames", +] + + +def branch_key(label: str) -> str: + return BRANCH_TO_KEY.get(label, "point_head") + + +def session_file(session_dir: str, name: str) -> str: + return os.path.join(session_dir, name) + + +def load_metadata(session_dir: str) -> dict: + with open(session_file(session_dir, "metadata.json"), "r") as f: + return json.load(f) + + +def selected_frame_indices( + num_frames: int, frame_index: int, display_mode: str +) -> List[int]: + if num_frames <= 0: + return [] + frame_index = int(np.clip(frame_index, 0, num_frames - 1)) + if display_mode == "Current Frame": + return [frame_index] + if display_mode == "Accumulate to Frame": + return list(range(frame_index + 1)) + return list(range(num_frames)) + + +def as_4x4(w2c): + w2c = np.asarray(w2c, dtype=np.float64) + if w2c.shape == (4, 4): + return w2c + out = np.eye(4, dtype=np.float64) + out[:3, :4] = w2c + return out + + +_VIEW_ROT = np.array( + [ + [1.0, 0.0, 0.0], + [0.0, 0.0, 1.0], + [0.0, -1.0, 0.0], + ], + dtype=np.float64, +) + + +def world_to_view(points): + points = np.asarray(points, dtype=np.float64) + return points @ _VIEW_ROT.T + + +def camera_center_from_w2c(w2c): + c2w = np.linalg.inv(as_4x4(w2c)) + return c2w[:3, 3] + + +def c2w_in_view_space(w2c, origin_shift=None): + c2w = np.linalg.inv(as_4x4(w2c)) + out = np.eye(4, dtype=np.float64) + out[:3, :3] = _VIEW_ROT @ c2w[:3, :3] + out[:3, 3] = world_to_view(c2w[:3, 3][None])[0] + if origin_shift is not None: + out[:3, 3] -= np.asarray(origin_shift, dtype=np.float64) + return out diff --git a/longstream/demo/export.py b/longstream/demo/export.py new file mode 100644 index 0000000000000000000000000000000000000000..e91213509bd7f99fa1e02afb92138bbc8b5e1c74 --- /dev/null +++ b/longstream/demo/export.py @@ -0,0 +1,85 @@ +import os + +import numpy as np + +from .geometry import camera_geometry, collect_points + +_CAMERA_COLORS = np.array( + [ + [239, 68, 68, 255], + [14, 165, 233, 255], + [34, 197, 94, 255], + [245, 158, 11, 255], + ], + dtype=np.uint8, +) + + +def _camera_mesh(center, corners, color): + import trimesh + + vertices = np.vstack([center[None], corners]).astype(np.float32) + faces = np.array( + [ + [0, 1, 2], + [0, 2, 3], + [0, 3, 4], + [0, 4, 1], + [1, 2, 3], + [1, 3, 4], + ], + dtype=np.int64, + ) + mesh = trimesh.Trimesh(vertices=vertices, faces=faces, process=False) + mesh.visual.face_colors = np.tile(color[None], (faces.shape[0], 1)) + return mesh + + +def export_glb( + session_dir: str, + branch: str, + display_mode: str, + frame_index: int, + mask_sky: bool, + show_cameras: bool, + camera_scale: float, + max_points: int, +) -> str: + import trimesh + + points, colors, _ = collect_points( + session_dir=session_dir, + branch=branch, + display_mode=display_mode, + frame_index=frame_index, + mask_sky=mask_sky, + max_points=max_points, + seed=13, + ) + if len(points) == 0: + raise ValueError("No valid points to export") + + scene = trimesh.Scene() + scene.add_geometry(trimesh.PointCloud(vertices=points, colors=colors)) + + if show_cameras: + _, frustums, _ = camera_geometry( + session_dir=session_dir, + display_mode=display_mode, + frame_index=frame_index, + camera_scale_ratio=camera_scale, + points_hint=points, + ) + for idx, (center, corners) in enumerate(frustums): + scene.add_geometry( + _camera_mesh(center, corners, _CAMERA_COLORS[idx % len(_CAMERA_COLORS)]) + ) + + export_dir = os.path.join(session_dir, "exports") + os.makedirs(export_dir, exist_ok=True) + branch_slug = branch.lower().replace(" + ", "_").replace(" ", "_") + mode_slug = display_mode.replace(" ", "_").lower() + filename = f"{branch_slug}_{mode_slug}_{frame_index:04d}_sky{int(mask_sky)}_cam{int(show_cameras)}.glb" + path = os.path.join(export_dir, filename) + scene.export(path) + return path diff --git a/longstream/demo/geometry.py b/longstream/demo/geometry.py new file mode 100644 index 0000000000000000000000000000000000000000..f9da1184c0323a06109e2fd1a571e140e8275c82 --- /dev/null +++ b/longstream/demo/geometry.py @@ -0,0 +1,211 @@ +import os +from typing import List, Optional, Tuple + +import numpy as np + +from .common import ( + branch_key, + c2w_in_view_space, + load_metadata, + selected_frame_indices, + session_file, + world_to_view, +) + + +def _origin_shift(w2c_all) -> np.ndarray: + first = c2w_in_view_space(w2c_all[0]) + return first[:3, 3].copy() + + +def _sample_flat_indices( + valid_indices: np.ndarray, budget: Optional[int], rng: np.random.Generator +) -> np.ndarray: + if budget is None or budget <= 0 or valid_indices.size <= budget: + return valid_indices + keep = rng.choice(valid_indices.size, size=int(budget), replace=False) + return valid_indices[keep] + + +def _depth_points_from_flat(depth, intri, w2c, flat_indices): + h, w = depth.shape + ys = flat_indices // w + xs = flat_indices % w + z = depth.reshape(-1)[flat_indices].astype(np.float64) + fx = float(intri[0, 0]) + fy = float(intri[1, 1]) + cx = float(intri[0, 2]) + cy = float(intri[1, 2]) + x = (xs.astype(np.float64) - cx) * z / max(fx, 1e-12) + y = (ys.astype(np.float64) - cy) * z / max(fy, 1e-12) + pts_cam = np.stack([x, y, z], axis=1) + R = w2c[:3, :3].astype(np.float64) + t = w2c[:3, 3].astype(np.float64) + return (R.T @ (pts_cam.T - t[:, None])).T.astype(np.float32, copy=False) + + +def _camera_points_to_world(points, w2c): + pts = np.asarray(points, dtype=np.float64).reshape(-1, 3) + R = w2c[:3, :3].astype(np.float64) + t = w2c[:3, 3].astype(np.float64) + return (R.T @ (pts.T - t[:, None])).T.astype(np.float32, copy=False) + + +def collect_points( + session_dir: str, + branch: str, + display_mode: str, + frame_index: int, + mask_sky: bool, + max_points: Optional[int], + seed: int = 0, +): + branch = branch_key(branch) + meta = load_metadata(session_dir) + frame_ids = selected_frame_indices(meta["num_frames"], frame_index, display_mode) + if not frame_ids: + return ( + np.empty((0, 3), dtype=np.float32), + np.empty((0, 3), dtype=np.uint8), + np.zeros(3, dtype=np.float64), + ) + + images = np.load(session_file(session_dir, "images.npy"), mmap_mode="r") + w2c = np.load(session_file(session_dir, "w2c.npy"), mmap_mode="r") + origin_shift = _origin_shift(w2c) + sky = None + if mask_sky and os.path.exists(session_file(session_dir, "sky_masks.npy")): + sky = np.load(session_file(session_dir, "sky_masks.npy"), mmap_mode="r") + + if branch == "point_head": + point_head = np.load(session_file(session_dir, "point_head.npy"), mmap_mode="r") + source = point_head + depth = None + intri = None + else: + source = None + depth = np.load(session_file(session_dir, "depth.npy"), mmap_mode="r") + intri = np.load(session_file(session_dir, "intri.npy"), mmap_mode="r") + + per_frame_budget = None + if max_points is not None and max_points > 0: + per_frame_budget = max(int(max_points) // max(len(frame_ids), 1), 1) + + rng = np.random.default_rng(seed) + points = [] + colors = [] + for idx in frame_ids: + rgb_flat = images[idx].reshape(-1, 3) + if branch == "point_head": + pts_map = source[idx] + valid = np.isfinite(pts_map).all(axis=-1).reshape(-1) + if sky is not None: + valid &= sky[idx].reshape(-1) > 0 + flat = np.flatnonzero(valid) + if flat.size == 0: + continue + flat = _sample_flat_indices(flat, per_frame_budget, rng) + pts_cam = pts_map.reshape(-1, 3)[flat] + pts_world = _camera_points_to_world(pts_cam, w2c[idx]) + else: + depth_i = depth[idx] + valid = (np.isfinite(depth_i) & (depth_i > 0)).reshape(-1) + if sky is not None: + valid &= sky[idx].reshape(-1) > 0 + flat = np.flatnonzero(valid) + if flat.size == 0: + continue + flat = _sample_flat_indices(flat, per_frame_budget, rng) + pts_world = _depth_points_from_flat(depth_i, intri[idx], w2c[idx], flat) + + pts_view = world_to_view(pts_world) - origin_shift[None] + points.append(pts_view.astype(np.float32, copy=False)) + colors.append(rgb_flat[flat].astype(np.uint8, copy=False)) + + if not points: + return ( + np.empty((0, 3), dtype=np.float32), + np.empty((0, 3), dtype=np.uint8), + origin_shift, + ) + return np.concatenate(points, axis=0), np.concatenate(colors, axis=0), origin_shift + + +def _frustum_corners_camera(intri, image_hw, depth_scale): + h, w = image_hw + fx = float(intri[0, 0]) + fy = float(intri[1, 1]) + cx = float(intri[0, 2]) + cy = float(intri[1, 2]) + corners = np.array( + [ + [ + (0.0 - cx) * depth_scale / max(fx, 1e-12), + (0.0 - cy) * depth_scale / max(fy, 1e-12), + depth_scale, + ], + [ + ((w - 1.0) - cx) * depth_scale / max(fx, 1e-12), + (0.0 - cy) * depth_scale / max(fy, 1e-12), + depth_scale, + ], + [ + ((w - 1.0) - cx) * depth_scale / max(fx, 1e-12), + ((h - 1.0) - cy) * depth_scale / max(fy, 1e-12), + depth_scale, + ], + [ + (0.0 - cx) * depth_scale / max(fx, 1e-12), + ((h - 1.0) - cy) * depth_scale / max(fy, 1e-12), + depth_scale, + ], + ], + dtype=np.float64, + ) + return corners + + +def camera_geometry( + session_dir: str, + display_mode: str, + frame_index: int, + camera_scale_ratio: float, + points_hint=None, +): + meta = load_metadata(session_dir) + frame_ids = selected_frame_indices(meta["num_frames"], frame_index, display_mode) + w2c = np.load(session_file(session_dir, "w2c.npy"), mmap_mode="r") + intri = np.load(session_file(session_dir, "intri.npy"), mmap_mode="r") + origin_shift = _origin_shift(w2c) + + center_points = np.array( + [c2w_in_view_space(w2c[idx], origin_shift)[:3, 3] for idx in frame_ids], + dtype=np.float64, + ) + center_extent = 1.0 + if len(center_points) > 1: + center_extent = float( + np.linalg.norm(center_points.max(axis=0) - center_points.min(axis=0)) + ) + + point_extent = 0.0 + if points_hint is not None and len(points_hint) > 0: + lo = np.percentile(points_hint, 5, axis=0) + hi = np.percentile(points_hint, 95, axis=0) + point_extent = float(np.linalg.norm(hi - lo)) + + extent = max(center_extent, point_extent, 1.0) + depth_scale = extent * float(camera_scale_ratio) + + centers = [] + frustums = [] + for idx in frame_ids: + c2w_view = c2w_in_view_space(w2c[idx], origin_shift) + center = c2w_view[:3, 3] + corners_cam = _frustum_corners_camera( + intri[idx], (meta["height"], meta["width"]), depth_scale + ) + corners_world = (c2w_view[:3, :3] @ corners_cam.T).T + center[None] + centers.append(center) + frustums.append((center, corners_world)) + return np.asarray(centers, dtype=np.float64), frustums, origin_shift diff --git a/longstream/demo/viewer.py b/longstream/demo/viewer.py new file mode 100644 index 0000000000000000000000000000000000000000..0c5fe990639ba12552b6e2fcb3f86917262cadfb --- /dev/null +++ b/longstream/demo/viewer.py @@ -0,0 +1,134 @@ +import numpy as np +import plotly.graph_objects as go + +from longstream.demo.backend import load_frame_previews + +from .common import load_metadata +from .geometry import camera_geometry, collect_points + + +def _empty_figure(message: str): + fig = go.Figure() + fig.add_annotation( + text=message, x=0.5, y=0.5, xref="paper", yref="paper", showarrow=False + ) + fig.update_layout( + template="plotly_white", + margin=dict(l=0, r=0, t=40, b=0), + scene=dict(aspectmode="data"), + ) + return fig + + +def _camera_lines(frustums): + xs, ys, zs = [], [], [] + for center, corners in frustums: + order = [(0, 1), (1, 2), (2, 3), (3, 0)] + for a, b in order: + xs.extend([corners[a, 0], corners[b, 0], None]) + ys.extend([corners[a, 1], corners[b, 1], None]) + zs.extend([corners[a, 2], corners[b, 2], None]) + for corner in corners: + xs.extend([center[0], corner[0], None]) + ys.extend([center[1], corner[1], None]) + zs.extend([center[2], corner[2], None]) + return xs, ys, zs + + +def build_interactive_figure( + session_dir: str, + branch: str, + display_mode: str, + frame_index: int, + point_size: float, + opacity: float, + preview_max_points: int, + show_cameras: bool, + camera_scale: float, + mask_sky: bool, +): + meta = load_metadata(session_dir) + points, colors, _ = collect_points( + session_dir=session_dir, + branch=branch, + display_mode=display_mode, + frame_index=frame_index, + mask_sky=mask_sky, + max_points=preview_max_points, + seed=frame_index, + ) + if len(points) == 0: + return _empty_figure("No valid points for the current selection") + + fig = go.Figure() + fig.add_trace( + go.Scatter3d( + x=points[:, 0], + y=points[:, 1], + z=points[:, 2], + mode="markers", + marker=dict( + size=float(point_size), + color=[f"rgb({r},{g},{b})" for r, g, b in colors], + opacity=float(opacity), + ), + hoverinfo="skip", + name="points", + ) + ) + + if show_cameras: + centers, frustums, _ = camera_geometry( + session_dir=session_dir, + display_mode=display_mode, + frame_index=frame_index, + camera_scale_ratio=camera_scale, + points_hint=points, + ) + if len(centers) > 0: + fig.add_trace( + go.Scatter3d( + x=centers[:, 0], + y=centers[:, 1], + z=centers[:, 2], + mode="lines", + line=dict(color="#16a34a", width=2), + name="trajectory", + hoverinfo="skip", + ) + ) + xs, ys, zs = _camera_lines(frustums) + fig.add_trace( + go.Scatter3d( + x=xs, + y=ys, + z=zs, + mode="lines", + line=dict(color="#22c55e", width=1.5), + name="cameras", + hoverinfo="skip", + ) + ) + + fig.update_layout( + template="plotly_white", + margin=dict(l=0, r=0, t=40, b=0), + scene=dict( + aspectmode="data", + xaxis_title="x_right", + yaxis_title="z_forward", + zaxis_title="y_up", + bgcolor="#f8fafc", + camera=dict( + up=dict(x=0.0, y=0.0, z=1.0), + eye=dict(x=-1.0, y=-1.8, z=0.9), + ), + ), + legend=dict(orientation="h", yanchor="bottom", y=1.02, xanchor="left", x=0.0), + ) + return fig + + +def build_frame_outputs(session_dir: str, frame_index: int): + rgb, depth, label = load_frame_previews(session_dir, frame_index) + return rgb, depth, label diff --git a/longstream/eval/__init__.py b/longstream/eval/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..1d696c987dba6be912c83d1c64c39a3c1a329b6c --- /dev/null +++ b/longstream/eval/__init__.py @@ -0,0 +1,3 @@ +from .evaluate import evaluate_predictions_cfg + +__all__ = ["evaluate_predictions_cfg"] diff --git a/longstream/eval/evaluate.py b/longstream/eval/evaluate.py new file mode 100644 index 0000000000000000000000000000000000000000..e807d99d374c3cf66d8049ea4c483394527a9355 --- /dev/null +++ b/longstream/eval/evaluate.py @@ -0,0 +1,551 @@ +import json +import os + +import cv2 +import numpy as np +import matplotlib + +matplotlib.use("Agg") +import matplotlib.pyplot as plt + +from longstream.data import LongStreamDataLoader +from longstream.eval.io import ( + frame_stems, + read_depth, + read_opencv_camera_yml, + read_pointcloud_xyz, + read_pred_w2c_txt, +) +from longstream.eval.metrics import ate_rmse, chamfer_and_f1, transform_points +from longstream.utils.sky_mask import sky_mask_filename + + +def _ensure_dir(path): + os.makedirs(path, exist_ok=True) + + +def _sequence_output_dir(output_root, seq_name): + return os.path.join(output_root, seq_name) + + +def _sequence_metrics_path(output_root, seq_name): + return os.path.join(output_root, "metrics", f"{seq_name}.json") + + +def _sequence_plot_path(output_root, seq_name): + return os.path.join(output_root, "plots", f"{seq_name}_traj_3d.png") + + +def _world_xyz_to_plot_xyz(xyz): + xyz = np.asarray(xyz, dtype=np.float64) + return np.stack([xyz[:, 0], xyz[:, 2], -xyz[:, 1]], axis=-1) + + +def _set_equal_3d_axes(ax, xyz): + mins = xyz.min(axis=0) + maxs = xyz.max(axis=0) + center = 0.5 * (mins + maxs) + radius = 0.5 * np.max(np.maximum(maxs - mins, 1e-6)) + ax.set_xlim(center[0] - radius, center[0] + radius) + ax.set_ylim(center[1] - radius, center[1] + radius) + ax.set_zlim(center[2] - radius, center[2] + radius) + + +def _load_gt_pose_data(seq_info): + if seq_info.camera is not None: + cam_dir = os.path.join(seq_info.scene_root, "cameras", seq_info.camera) + extri_path = os.path.join(cam_dir, "extri.yml") + intri_path = os.path.join(cam_dir, "intri.yml") + if os.path.exists(extri_path): + extri, intri, image_sizes = read_opencv_camera_yml(extri_path, intri_path) + return extri, intri, image_sizes + + extri_path = os.path.join(seq_info.scene_root, "extri.yml") + intri_path = os.path.join(seq_info.scene_root, "intri.yml") + if not os.path.exists(extri_path): + return None, None, None + extri, intri, image_sizes = read_opencv_camera_yml(extri_path, intri_path) + return extri, intri, image_sizes + + +def _resolve_gt_depth_root(seq_info): + if seq_info.camera is not None: + camera_depth_root = os.path.join(seq_info.scene_root, "depths", seq_info.camera) + if os.path.isdir(camera_depth_root): + return camera_depth_root + depth_root = os.path.join(seq_info.scene_root, "depths") + if os.path.isdir(depth_root): + return depth_root + return None + + +def _resolve_gt_depth_path(seq_info, depth_root, image_path, stem): + rel_path = os.path.relpath(image_path, seq_info.image_dir) + rel_stem = os.path.splitext(rel_path)[0] + file_stem = os.path.splitext(os.path.basename(image_path))[0] + candidates = [ + os.path.join(depth_root, f"{stem}.exr"), + os.path.join(depth_root, rel_stem + ".exr"), + os.path.join(depth_root, stem, f"{file_stem}.exr"), + ] + for candidate in candidates: + if os.path.exists(candidate): + return candidate + return None + + +def _resize_long_edge(arr, long_edge_size, interpolation): + h, w = arr.shape[:2] + scale = float(long_edge_size) / float(max(h, w)) + new_w = int(round(w * scale)) + new_h = int(round(h * scale)) + return cv2.resize(arr, (new_w, new_h), interpolation=interpolation) + + +def _prepare_map_for_eval( + arr, size, crop, patch_size, target_shape, interpolation, square_ok=False +): + h0, w0 = arr.shape[:2] + long_edge = round(size * max(w0 / h0, h0 / w0)) if size == 224 else size + arr = _resize_long_edge(arr, long_edge, interpolation) + + h, w = arr.shape[:2] + cx, cy = w // 2, h // 2 + + if size == 224: + half = min(cx, cy) + target_w = 2 * half + target_h = 2 * half + if crop: + arr = arr[cy - half : cy + half, cx - half : cx + half] + else: + arr = cv2.resize(arr, (target_w, target_h), interpolation=interpolation) + else: + halfw = ((2 * cx) // patch_size) * (patch_size // 2) + halfh = ((2 * cy) // patch_size) * (patch_size // 2) + if not square_ok and w == h: + halfh = int(3 * halfw / 4) + target_w = 2 * halfw + target_h = 2 * halfh + if crop: + arr = arr[cy - halfh : cy + halfh, cx - halfw : cx + halfw] + else: + arr = cv2.resize(arr, (target_w, target_h), interpolation=interpolation) + + if arr.shape[:2] != tuple(target_shape): + arr = cv2.resize( + arr, (target_shape[1], target_shape[0]), interpolation=interpolation + ) + return arr + + +def _sky_mask_path(seq_dir, image_path): + return os.path.join(seq_dir, "sky_masks", sky_mask_filename(image_path)) + + +def _sample_frame_points(points, max_points, rng): + if max_points is None or len(points) <= max_points: + return points + keep = rng.choice(len(points), size=max_points, replace=False) + return points[keep] + + +def _depth_to_world_points(depth, intri, extri, valid_mask): + ys, xs = np.nonzero(valid_mask) + if ys.size == 0: + return np.empty((0, 3), dtype=np.float32) + + z = depth[ys, xs].astype(np.float64) + fx = float(intri[0, 0]) + fy = float(intri[1, 1]) + cx = float(intri[0, 2]) + cy = float(intri[1, 2]) + + x = (xs.astype(np.float64) - cx) * z / max(fx, 1e-12) + y = (ys.astype(np.float64) - cy) * z / max(fy, 1e-12) + pts_cam = np.stack([x, y, z], axis=1) + + R = extri[:3, :3] + t = extri[:3, 3] + pts_world = (R.T @ (pts_cam.T - t[:, None])).T + return pts_world.astype(np.float32, copy=False) + + +def _load_gt_pointcloud(seq_info, seq_dir, gt_extri, gt_intri, eval_cfg): + if not gt_extri or not gt_intri: + return None + + gt_dir = _resolve_gt_depth_root(seq_info) + if gt_dir is None: + return None + + eval_max_points = int(eval_cfg.get("point_eval_max_points", 100000)) + oversample_factor = int(eval_cfg.get("point_eval_oversample_factor", 4)) + per_frame_budget = max( + (eval_max_points * oversample_factor) // max(len(seq_info.image_paths), 1), 1 + ) + rng = np.random.default_rng(0) + chunks = [] + + for image_path, stem in zip( + seq_info.image_paths, frame_stems(seq_info.image_paths) + ): + depth_path = _resolve_gt_depth_path(seq_info, gt_dir, image_path, stem) + if depth_path is None or stem not in gt_extri or stem not in gt_intri: + continue + + depth = read_depth(depth_path) + valid = np.isfinite(depth) & (depth > 0) + if not np.any(valid): + continue + + sky_path = _sky_mask_path(seq_dir, image_path) + if os.path.exists(sky_path): + sky_mask = cv2.imread(sky_path, cv2.IMREAD_GRAYSCALE) + if sky_mask is not None: + if sky_mask.shape[:2] != depth.shape[:2]: + sky_mask = cv2.resize( + sky_mask, + (depth.shape[1], depth.shape[0]), + interpolation=cv2.INTER_NEAREST, + ) + valid &= sky_mask > 0 + if not np.any(valid): + continue + + pts_world = _depth_to_world_points(depth, gt_intri[stem], gt_extri[stem], valid) + if len(pts_world) == 0: + continue + chunks.append(_sample_frame_points(pts_world, per_frame_budget, rng)) + + if not chunks: + return None + return np.concatenate(chunks, axis=0) + + +def _evaluate_pointclouds(seq_info, seq_dir, eval_cfg, pose_align, gt_cloud): + if pose_align is None or gt_cloud is None: + return None + + scale, R, t = pose_align + point_paths = { + "point_head": [ + os.path.join(seq_dir, "points", "point_head_full.npy"), + os.path.join(seq_dir, "points", "point_head_full.npz"), + os.path.join(seq_dir, "points", "point_head_full.ply"), + ], + "dpt_unproj": [ + os.path.join(seq_dir, "points", "dpt_unproj_full.npy"), + os.path.join(seq_dir, "points", "dpt_unproj_full.npz"), + os.path.join(seq_dir, "points", "dpt_unproj_full.ply"), + ], + } + threshold = float(eval_cfg.get("point_f1_threshold", 0.25)) + max_points = int(eval_cfg.get("point_eval_max_points", 100000)) + voxel_size = eval_cfg.get("point_eval_voxel_size", None) + voxel_size = None if voxel_size in (None, "", "null") else float(voxel_size) + + metrics_by_branch = {} + for branch, candidates in point_paths.items(): + path = next( + (candidate for candidate in candidates if os.path.exists(candidate)), None + ) + if path is None: + continue + pred_cloud = read_pointcloud_xyz(path) + pred_cloud = transform_points(pred_cloud, scale, R, t) + metrics = chamfer_and_f1( + pred_cloud, + gt_cloud, + threshold=threshold, + max_points=max_points, + voxel_size=voxel_size, + seed=0 if branch == "point_head" else 1, + ) + if metrics is not None: + metrics_by_branch[branch] = metrics + return metrics_by_branch or None + + +def _evaluate_video_dpt(seq_info, seq_dir, eval_cfg, data_cfg): + pred_dir = os.path.join(seq_dir, "depth", "dpt") + gt_dir = _resolve_gt_depth_root(seq_info) + if not os.path.isdir(pred_dir) or gt_dir is None: + return None + + size = int(data_cfg.get("size", 518)) + crop = bool(data_cfg.get("crop", False)) + patch_size = int(data_cfg.get("patch_size", 14)) + rel_delta_threshold = float(eval_cfg.get("depth_rel_delta_threshold", 1.25)) + + abs_rel_sum = 0.0 + rel_delta_hits = 0 + valid_pixels = 0 + evaluated_frames = 0 + + stems = frame_stems(seq_info.image_paths) + for frame_id, stem in enumerate(stems): + pred_path = os.path.join(pred_dir, f"frame_{frame_id:06d}.npy") + gt_path = _resolve_gt_depth_path( + seq_info, gt_dir, seq_info.image_paths[frame_id], stem + ) + if not os.path.exists(pred_path) or gt_path is None: + continue + + pred = np.load(pred_path).astype(np.float32) + gt = read_depth(gt_path) + gt = _prepare_map_for_eval( + gt, + size=size, + crop=crop, + patch_size=patch_size, + target_shape=pred.shape, + interpolation=cv2.INTER_NEAREST, + ) + + valid = np.isfinite(gt) & (gt > 0) + if not np.any(valid): + continue + + sky_mask_path = _sky_mask_path(seq_dir, seq_info.image_paths[frame_id]) + if os.path.exists(sky_mask_path): + sky_mask = cv2.imread(sky_mask_path, cv2.IMREAD_GRAYSCALE) + if sky_mask is not None: + sky_mask = _prepare_map_for_eval( + sky_mask, + size=size, + crop=crop, + patch_size=patch_size, + target_shape=pred.shape, + interpolation=cv2.INTER_NEAREST, + ) + valid &= sky_mask > 0 + + valid &= np.isfinite(pred) + if not np.any(valid): + continue + + pred_valid = pred[valid].astype(np.float64) + gt_valid = gt[valid].astype(np.float64) + pred_safe = np.clip(pred_valid, 1e-6, None) + gt_safe = np.clip(gt_valid, 1e-6, None) + + abs_rel_sum += np.sum(np.abs(pred_valid - gt_valid) / gt_safe) + rel_ratio = np.maximum(gt_safe / pred_safe, pred_safe / gt_safe) + rel_delta_hits += int(np.sum(rel_ratio < rel_delta_threshold)) + valid_pixels += int(gt_valid.size) + evaluated_frames += 1 + + if valid_pixels == 0: + return None + + return { + "abs_rel": float(abs_rel_sum / valid_pixels), + "rel_delta": float(rel_delta_hits / valid_pixels), + "rel_delta_threshold": rel_delta_threshold, + "num_valid_pixels": int(valid_pixels), + "num_frames": int(evaluated_frames), + } + + +def _extract_pose_pairs(seq_info, pred_pose_path, gt_extri): + frame_ids, pred_w2c = read_pred_w2c_txt(pred_pose_path) + if not pred_w2c: + return None + + stems = frame_stems(seq_info.image_paths) + pred_xyz = [] + gt_xyz = [] + + for frame_id, pred_mat in zip(frame_ids, pred_w2c): + if frame_id < 0 or frame_id >= len(stems): + continue + stem = stems[frame_id] + if stem not in gt_extri: + continue + pred_c2w = np.linalg.inv(pred_mat) + gt_c2w = np.linalg.inv(gt_extri[stem]) + pred_xyz.append(pred_c2w[:3, 3]) + gt_xyz.append(gt_c2w[:3, 3]) + + if len(pred_xyz) < 3: + return None + return np.asarray(pred_xyz, dtype=np.float64), np.asarray(gt_xyz, dtype=np.float64) + + +def _save_traj_plot_3d(path, pred_xyz, gt_xyz): + _ensure_dir(os.path.dirname(path)) + pred_plot = _world_xyz_to_plot_xyz(pred_xyz) + gt_plot = _world_xyz_to_plot_xyz(gt_xyz) + origin = gt_plot[:1] + pred_plot = pred_plot - origin + gt_plot = gt_plot - origin + all_plot = np.concatenate([pred_plot, gt_plot], axis=0) + + fig = plt.figure(figsize=(7, 6)) + ax = fig.add_subplot(111, projection="3d") + ax.plot( + gt_plot[:, 0], + gt_plot[:, 1], + gt_plot[:, 2], + label="gt", + linewidth=2.0, + color="#1f77b4", + ) + ax.plot( + pred_plot[:, 0], + pred_plot[:, 1], + pred_plot[:, 2], + label="pred", + linewidth=2.0, + color="#d62728", + ) + _set_equal_3d_axes(ax, all_plot) + ax.view_init(elev=24, azim=-118) + ax.set_xlabel("x_right") + ax.set_ylabel("z_forward") + ax.set_zlabel("y_up") + ax.legend(loc="best") + ax.set_title("Trajectory 3D (Sim3-aligned view)") + fig.tight_layout() + fig.savefig(path, dpi=180) + plt.close(fig) + + +def evaluate_sequence(seq_info, output_root, eval_cfg, data_cfg): + seq_dir = _sequence_output_dir(output_root, seq_info.name) + result = { + "sequence": seq_info.name, + "output_dir": seq_dir, + "has_gt": False, + "has_gt_pose": False, + "has_gt_depth": False, + } + + gt_extri, gt_intri, _ = _load_gt_pose_data(seq_info) + pose_align = None + if gt_extri: + result["has_gt"] = True + result["has_gt_pose"] = True + + pred_pose_path = os.path.join(seq_dir, "poses", "abs_pose.txt") + pairs = _extract_pose_pairs(seq_info, pred_pose_path, gt_extri) + if pairs is not None: + pred_xyz, gt_xyz = pairs + pose_metrics = ate_rmse( + pred_xyz, gt_xyz, align_scale=bool(eval_cfg.get("align_scale", True)) + ) + sim3_scale = float(pose_metrics.get("sim3_scale", 1.0)) + pred_xyz_aligned = transform_points( + pred_xyz, + sim3_scale, + np.asarray(pose_metrics["sim3_rotation"], dtype=np.float64), + np.asarray(pose_metrics["sim3_translation"], dtype=np.float64), + ) + pose_align = ( + sim3_scale, + np.asarray(pose_metrics["sim3_rotation"], dtype=np.float64), + np.asarray(pose_metrics["sim3_translation"], dtype=np.float64), + ) + plot_path = _sequence_plot_path(output_root, seq_info.name) + _save_traj_plot_3d(plot_path, pred_xyz_aligned, gt_xyz) + pose_metrics.pop("sim3_scale", None) + pose_metrics["traj_3d_plot"] = plot_path + result["pose"] = pose_metrics + + video_dpt_metrics = _evaluate_video_dpt(seq_info, seq_dir, eval_cfg, data_cfg) + if video_dpt_metrics is not None: + result["has_gt"] = True + result["has_gt_depth"] = True + result["video_dpt"] = video_dpt_metrics + + gt_cloud = _load_gt_pointcloud(seq_info, seq_dir, gt_extri, gt_intri, eval_cfg) + pointcloud_metrics = _evaluate_pointclouds( + seq_info, seq_dir, eval_cfg, pose_align, gt_cloud + ) + if pointcloud_metrics is not None: + result["has_gt"] = True + result["has_gt_depth"] = True + result["pointcloud"] = pointcloud_metrics + + if not result["has_gt"]: + result["skipped"] = "missing_gt" + + return result + + +def _mean_metric(sequence_results, group_name, metric_name): + values = [] + for item in sequence_results: + group = item + for key in group_name.split("."): + if not isinstance(group, dict): + group = None + break + group = group.get(key) + if not isinstance(group, dict): + continue + if metric_name in group: + values.append(float(group[metric_name])) + if not values: + return None + return float(np.mean(values)) + + +def evaluate_predictions_cfg(cfg): + data_cfg = dict(cfg.get("data", {})) + data_cfg["format"] = "generalizable" + output_cfg = cfg.get("output", {}) + eval_cfg = cfg.get("evaluation", {}) + output_root = output_cfg.get("root", "outputs") + _ensure_dir(output_root) + + loader = LongStreamDataLoader(data_cfg) + sequence_results = [] + for seq_info in loader.iter_sequence_infos(): + print(f"[longstream] eval {seq_info.name}: start", flush=True) + metrics = evaluate_sequence(seq_info, output_root, eval_cfg, data_cfg) + sequence_results.append(metrics) + metrics_path = _sequence_metrics_path(output_root, seq_info.name) + _ensure_dir(os.path.dirname(metrics_path)) + with open(metrics_path, "w") as f: + json.dump(metrics, f, indent=2) + print(f"[longstream] eval {seq_info.name}: wrote {metrics_path}", flush=True) + + summary = { + "num_sequences": len(sequence_results), + "num_sequences_with_gt": sum(1 for x in sequence_results if x.get("has_gt")), + "num_sequences_with_pose_gt": sum( + 1 for x in sequence_results if x.get("has_gt_pose") + ), + "num_sequences_with_depth_gt": sum( + 1 for x in sequence_results if x.get("has_gt_depth") + ), + "ate_mean": _mean_metric(sequence_results, "pose", "ate_mean"), + "ate_rmse_mean": _mean_metric(sequence_results, "pose", "ate_rmse"), + "video_dpt_abs_rel_mean": _mean_metric( + sequence_results, "video_dpt", "abs_rel" + ), + "video_dpt_rel_delta_mean": _mean_metric( + sequence_results, "video_dpt", "rel_delta" + ), + "point_head_cd_mean": _mean_metric( + sequence_results, "pointcloud.point_head", "cd" + ), + "point_head_f1_mean": _mean_metric( + sequence_results, "pointcloud.point_head", "f1" + ), + "dpt_unproj_cd_mean": _mean_metric( + sequence_results, "pointcloud.dpt_unproj", "cd" + ), + "dpt_unproj_f1_mean": _mean_metric( + sequence_results, "pointcloud.dpt_unproj", "f1" + ), + "sequences": sequence_results, + } + + summary_path = os.path.join(output_root, "summary.json") + with open(summary_path, "w") as f: + json.dump(summary, f, indent=2) + print(f"[longstream] eval: wrote {summary_path}", flush=True) + return summary diff --git a/longstream/eval/io.py b/longstream/eval/io.py new file mode 100644 index 0000000000000000000000000000000000000000..be84fb38cc928e26a1f52829261d346f732c2567 --- /dev/null +++ b/longstream/eval/io.py @@ -0,0 +1,156 @@ +import os + +os.environ.setdefault("OPENCV_IO_ENABLE_OPENEXR", "1") + +import cv2 +import numpy as np + + +def frame_stems(image_paths): + stems = [os.path.splitext(os.path.basename(p))[0] for p in image_paths] + if len(set(stems)) == len(stems): + return stems + parents = [os.path.basename(os.path.dirname(p)) for p in image_paths] + if len(set(parents)) == len(parents): + return parents + return stems + + +def read_pred_w2c_txt(path): + frames = [] + poses = [] + if not os.path.exists(path): + return frames, poses + with open(path, "r") as f: + for line in f: + line = line.strip() + if not line or line.startswith("#"): + continue + vals = [float(x) for x in line.split()] + if len(vals) != 13: + continue + frame = int(vals[0]) + mat = np.eye(4, dtype=np.float64) + mat[:3, :3] = np.asarray(vals[1:10], dtype=np.float64).reshape(3, 3) + mat[:3, 3] = np.asarray(vals[10:13], dtype=np.float64) + frames.append(frame) + poses.append(mat) + return frames, poses + + +def read_opencv_camera_yml(extri_path, intri_path=None): + if not os.path.exists(extri_path): + return {}, {}, {} + + fs_extri = cv2.FileStorage(extri_path, cv2.FILE_STORAGE_READ) + names_node = fs_extri.getNode("names") + names = [] + for i in range(names_node.size()): + names.append(names_node.at(i).string()) + + extri = {} + for name in names: + rot = fs_extri.getNode(f"Rot_{name}").mat() + t = fs_extri.getNode(f"T_{name}").mat() + if rot is None or t is None: + continue + mat = np.eye(4, dtype=np.float64) + mat[:3, :3] = np.asarray(rot, dtype=np.float64) + mat[:3, 3] = np.asarray(t, dtype=np.float64).reshape(3) + extri[name] = mat + fs_extri.release() + + intri = {} + image_sizes = {} + if intri_path is not None and os.path.exists(intri_path): + fs_intri = cv2.FileStorage(intri_path, cv2.FILE_STORAGE_READ) + for name in names: + K = fs_intri.getNode(f"K_{name}").mat() + if K is None: + continue + intri[name] = np.asarray(K, dtype=np.float64) + h_node = fs_intri.getNode(f"H_{name}") + w_node = fs_intri.getNode(f"W_{name}") + if not h_node.empty() and not w_node.empty(): + image_sizes[name] = (int(h_node.real()), int(w_node.real())) + fs_intri.release() + + return extri, intri, image_sizes + + +def read_depth(path): + depth = cv2.imread(path, cv2.IMREAD_ANYDEPTH) + if depth is None: + raise FileNotFoundError(path) + return depth.astype(np.float32) + + +def read_ply_xyz(path): + if not os.path.exists(path): + raise FileNotFoundError(path) + + header = [] + with open(path, "rb") as f: + while True: + line = f.readline() + if not line: + raise ValueError(f"Invalid PLY header: {path}") + text = line.decode("ascii").strip() + header.append(text) + if text == "end_header": + break + + if "format binary_little_endian 1.0" not in header: + raise ValueError(f"Unsupported PLY format: {path}") + + vertex_count = None + property_specs = [] + in_vertex_block = False + for line in header: + if line.startswith("element vertex "): + vertex_count = int(line.split()[-1]) + in_vertex_block = True + continue + if line.startswith("element ") and not line.startswith("element vertex "): + in_vertex_block = False + if in_vertex_block and line.startswith("property "): + _, dtype_name, prop_name = line.split() + property_specs.append((dtype_name, prop_name)) + + if vertex_count is None: + raise ValueError(f"Missing vertex count in PLY: {path}") + + dtype_map = { + "float": " 0: + last_layer_features = aggregated_tokens_list[-1] + + scale_token_idx = patch_start_idx - 1 + scale_token_output_features = last_layer_features[ + :, :, scale_token_idx, : + ] + + scale_token_output_features = scale_token_output_features.mean(dim=1) + + scale_logits = self.scale_head(scale_token_output_features).squeeze(-1) + + predicted_scale_factor = torch.exp(scale_logits) + + predictions["predicted_scale_factor"] = predicted_scale_factor + predictions["scale_token_features"] = scale_token_output_features + + if self.enable_camera_head and self.camera_head is not None: + if camera_head_kv_cache_list is not None: + pose_enc_list, camera_head_kv_cache_list = self.camera_head( + aggregated_tokens_list, + mode=mode, + kv_cache_list=camera_head_kv_cache_list, + ) + else: + pose_enc_list = self.camera_head(aggregated_tokens_list, mode=mode) + + final_pose_enc = pose_enc_list[-1] + if self.enable_scale_token and predicted_scale_factor is not None: + scale = predicted_scale_factor.view(-1, 1, 1) + + scaled_t = final_pose_enc[..., :3] * scale + scaled_pose_enc = torch.cat([scaled_t, final_pose_enc[..., 3:]], dim=-1) + predictions["pose_enc"] = scaled_pose_enc + else: + predictions["pose_enc"] = final_pose_enc + + if self.training: + + if self.enable_scale_token and predicted_scale_factor is not None: + scale = predicted_scale_factor.view(-1, 1, 1) + scaled_pose_enc_list = [] + for pose_enc in pose_enc_list: + + scaled_t = pose_enc[..., :3] * scale + scaled_pose_enc = torch.cat( + [scaled_t, pose_enc[..., 3:]], dim=-1 + ) + scaled_pose_enc_list.append(scaled_pose_enc) + predictions["pose_enc_list"] = scaled_pose_enc_list + else: + predictions["pose_enc_list"] = pose_enc_list + + if self.rel_pose_head is not None and rel_pose_inputs is not None: + + rel_kwargs = dict( + aggregated_tokens_list=aggregated_tokens_list, + keyframe_indices=rel_pose_inputs.get("keyframe_indices"), + is_keyframe=rel_pose_inputs.get("is_keyframe", is_keyframe), + num_iterations=rel_pose_inputs.get("num_iterations", 4), + mode=mode, + kv_cache_list=rel_pose_inputs.get("kv_cache_list"), + ) + + rel_kwargs = {k: v for k, v in rel_kwargs.items() if v is not None} + + rel_result = self.rel_pose_head(**rel_kwargs) + + if isinstance(rel_result, dict): + + pose_enc = rel_result["pose_enc"] + if pose_enc.dtype != torch.float32: + pose_enc = pose_enc.float() + + if self.enable_scale_token and predicted_scale_factor is not None: + scale = predicted_scale_factor.view(-1, 1, 1) + + scaled_t = pose_enc[..., :3] * scale + scaled_rel_pose_enc = torch.cat( + [scaled_t, pose_enc[..., 3:]], dim=-1 + ) + predictions["rel_pose_enc"] = scaled_rel_pose_enc + + if "pose_enc_list" in rel_result: + scaled_pose_enc_list = [] + for iter_pose in rel_result["pose_enc_list"]: + scaled_t = iter_pose[..., :3] * scale + scaled_iter_pose = torch.cat( + [scaled_t, iter_pose[..., 3:]], dim=-1 + ) + scaled_pose_enc_list.append(scaled_iter_pose) + predictions["rel_pose_enc_list"] = scaled_pose_enc_list + else: + predictions["rel_pose_enc"] = pose_enc + + if "pose_enc_list" in rel_result: + predictions["rel_pose_enc_list"] = rel_result["pose_enc_list"] + + predictions["is_keyframe"] = rel_result.get("is_keyframe") + predictions["keyframe_indices"] = rel_result.get("keyframe_indices") + + if "global_scale" in rel_result: + predictions["global_scale"] = rel_result["global_scale"] + + if "kv_cache_list" in rel_result: + predictions["rel_pose_kv_cache_list"] = rel_result["kv_cache_list"] + + if self.point_head is not None: + pts3d, pts3d_conf = self.point_head( + aggregated_tokens_list, images=images, patch_start_idx=patch_start_idx + ) + + if self.enable_scale_token and predicted_scale_factor is not None: + scale = predicted_scale_factor.view(-1, 1, 1, 1, 1) + predictions["world_points"] = pts3d * scale + else: + predictions["world_points"] = pts3d + predictions["world_points_conf"] = pts3d_conf + + if self.depth_head is not None: + depth, depth_conf = self.depth_head( + aggregated_tokens_list, images=images, patch_start_idx=patch_start_idx + ) + + if self.enable_scale_token and predicted_scale_factor is not None: + scale = predicted_scale_factor.view(-1, 1, 1, 1, 1) + predictions["depth"] = depth * scale + else: + predictions["depth"] = depth + predictions["depth_conf"] = depth_conf + + if aggregator_kv_cache_list is not None: + predictions["aggregator_kv_cache_list"] = aggregator_kv_cache_list + + if camera_head_kv_cache_list is not None: + predictions["camera_head_kv_cache_list"] = camera_head_kv_cache_list + + if not self.training: + predictions["images"] = images + + return predictions diff --git a/longstream/streaming/__init__.py b/longstream/streaming/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/longstream/streaming/keyframe_selector.py b/longstream/streaming/keyframe_selector.py new file mode 100644 index 0000000000000000000000000000000000000000..a9327306702a48994ef591859122ba1dbef3a8e9 --- /dev/null +++ b/longstream/streaming/keyframe_selector.py @@ -0,0 +1,80 @@ +import random +import torch +from typing import Optional, Tuple + + +class KeyframeSelector: + def __init__( + self, + min_interval: int = 8, + max_interval: int = 8, + force_first: bool = True, + motion_threshold: Optional[float] = None, + mode: str = "fixed", + ): + self.min_interval = int(min_interval) + self.max_interval = int(max_interval) + self.force_first = bool(force_first) + self.motion_threshold = motion_threshold + self.mode = mode + + def select_keyframes( + self, + sequence_length: int, + batch_size: int = 1, + device: Optional[torch.device] = None, + poses: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + device = device or torch.device("cpu") + is_keyframe = torch.zeros( + batch_size, sequence_length, dtype=torch.bool, device=device + ) + keyframe_indices = torch.zeros( + batch_size, sequence_length, dtype=torch.long, device=device + ) + + for b in range(batch_size): + last_keyframe_idx = 0 + next_keyframe_target = None + + if self.force_first or sequence_length == 1: + is_keyframe[b, 0] = True + keyframe_indices[b, 0] = 0 + if self.mode == "random": + interval = random.randint(self.min_interval, self.max_interval) + next_keyframe_target = interval + + for s in range(1, sequence_length): + keyframe_indices[b, s] = last_keyframe_idx + frames_since_last = s - last_keyframe_idx + + if self.mode == "random" and next_keyframe_target is not None: + if s >= next_keyframe_target: + is_keyframe[b, s] = True + last_keyframe_idx = s + interval = random.randint(self.min_interval, self.max_interval) + next_keyframe_target = s + interval + elif frames_since_last >= self.max_interval: + is_keyframe[b, s] = True + last_keyframe_idx = s + if self.mode == "random": + interval = random.randint(self.min_interval, self.max_interval) + next_keyframe_target = s + interval + elif ( + frames_since_last >= self.min_interval + and poses is not None + and self.motion_threshold is not None + ): + motion = torch.norm( + poses[b, s, :3] - poses[b, last_keyframe_idx, :3] + ).item() + if motion > self.motion_threshold: + is_keyframe[b, s] = True + last_keyframe_idx = s + if self.mode == "random": + interval = random.randint( + self.min_interval, self.max_interval + ) + next_keyframe_target = s + interval + + return is_keyframe, keyframe_indices diff --git a/longstream/streaming/refresh.py b/longstream/streaming/refresh.py new file mode 100644 index 0000000000000000000000000000000000000000..a3ef51ed5da2c5ba025ef9238a505cb135592079 --- /dev/null +++ b/longstream/streaming/refresh.py @@ -0,0 +1,217 @@ +import torch +from typing import Dict, Any, List + +from longstream.streaming.stream_session import StreamSession + +_SEQUENCE_OUTPUT_KEYS = { + "pose_enc", + "rel_pose_enc", + "world_points", + "world_points_conf", + "depth", + "depth_conf", +} +_SCALAR_OUTPUT_KEYS = { + "predicted_scale_factor", + "global_scale", +} + + +def _refresh_intervals(refresh: int) -> int: + refresh = int(refresh) + if refresh < 2: + raise ValueError("refresh must be >= 2") + return refresh - 1 + + +def _model_device(model) -> torch.device: + return next(model.parameters()).device + + +def _move_scalar_to_cpu(value: Any) -> Any: + if isinstance(value, torch.Tensor): + return value.detach().cpu() + return value + + +def _append_batch_output( + stitched_tensors: Dict[str, List[torch.Tensor]], + stitched_scalars: Dict[str, Any], + output: Dict[str, Any], + actual_frames: int, + slice_start: int, +) -> None: + for key in _SEQUENCE_OUTPUT_KEYS: + value = output.get(key) + if not isinstance(value, torch.Tensor): + continue + if value.ndim < 2 or value.shape[1] != actual_frames: + continue + stitched_tensors.setdefault(key, []).append( + value[:, slice_start:].detach().cpu() + ) + + for key in _SCALAR_OUTPUT_KEYS: + if key in output: + stitched_scalars[key] = _move_scalar_to_cpu(output[key]) + + +def _finalize_stitched_batches( + stitched_tensors: Dict[str, List[torch.Tensor]], + stitched_scalars: Dict[str, Any], +) -> Dict[str, Any]: + stitched_output: Dict[str, Any] = {} + for key, chunks in stitched_tensors.items(): + if not chunks: + continue + stitched_output[key] = ( + chunks[0] if len(chunks) == 1 else torch.cat(chunks, dim=1) + ) + stitched_output.update(stitched_scalars) + return stitched_output + + +def run_batch_refresh( + model, + images, + is_keyframe, + keyframe_indices, + mode: str, + keyframe_stride: int, + refresh: int, + rel_pose_cfg, +): + B, S = images.shape[:2] + device = _model_device(model) + refresh_intervals = _refresh_intervals(refresh) + frames_per_batch = refresh_intervals * keyframe_stride + 1 + step_frames = refresh_intervals * keyframe_stride + + stitched_tensors: Dict[str, List[torch.Tensor]] = {} + stitched_scalars: Dict[str, Any] = {} + num_batches = (S + step_frames - 1) // step_frames + for batch_idx in range(num_batches): + start_frame = batch_idx * step_frames + end_frame = min(start_frame + frames_per_batch, S) + batch_images = images[:, start_frame:end_frame].to(device, non_blocking=True) + batch_is_keyframe = ( + is_keyframe[:, start_frame:end_frame].clone() + if is_keyframe is not None + else None + ) + batch_keyframe_indices = ( + keyframe_indices[:, start_frame:end_frame].clone() + if keyframe_indices is not None + else None + ) + + if batch_idx > 0 and batch_is_keyframe is not None: + batch_is_keyframe[:, 0] = True + if batch_keyframe_indices is not None: + batch_keyframe_indices[:, 0] = start_frame + + if batch_keyframe_indices is not None: + batch_keyframe_indices = batch_keyframe_indices - start_frame + batch_keyframe_indices = torch.clamp( + batch_keyframe_indices, 0, end_frame - start_frame - 1 + ) + + batch_rel_pose_inputs = None + if rel_pose_cfg is not None and batch_is_keyframe is not None: + batch_is_keyframe = batch_is_keyframe.to(device, non_blocking=True) + if batch_keyframe_indices is not None: + batch_keyframe_indices = batch_keyframe_indices.to( + device, non_blocking=True + ) + batch_rel_pose_inputs = { + "is_keyframe": batch_is_keyframe, + "keyframe_indices": batch_keyframe_indices, + "num_iterations": rel_pose_cfg.get("num_iterations", 4), + } + elif batch_is_keyframe is not None: + batch_is_keyframe = batch_is_keyframe.to(device, non_blocking=True) + + batch_output = model( + images=batch_images, + mode=mode, + rel_pose_inputs=batch_rel_pose_inputs, + is_keyframe=batch_is_keyframe, + ) + + _append_batch_output( + stitched_tensors, + stitched_scalars, + batch_output, + actual_frames=end_frame - start_frame, + slice_start=0 if batch_idx == 0 else 1, + ) + del batch_output + del batch_images + del batch_is_keyframe + del batch_keyframe_indices + + return _finalize_stitched_batches(stitched_tensors, stitched_scalars) + + +def run_streaming_refresh( + model, + images, + is_keyframe, + keyframe_indices, + mode: str, + window_size: int, + refresh: int, + rel_pose_cfg, +): + B, S = images.shape[:2] + device = _model_device(model) + refresh_intervals = _refresh_intervals(refresh) + session = StreamSession(model, mode=mode, window_size=window_size) + keyframe_count = 0 + segment_start = 0 + for s in range(S): + frame_images = images[:, s : s + 1].to(device, non_blocking=True) + is_keyframe_s = ( + is_keyframe[:, s : s + 1].to(device, non_blocking=True) + if is_keyframe is not None + else None + ) + if keyframe_indices is not None: + keyframe_indices_s = keyframe_indices[:, s : s + 1].clone() - segment_start + keyframe_indices_s = torch.clamp(keyframe_indices_s, min=0) + keyframe_indices_s = keyframe_indices_s.to(device, non_blocking=True) + else: + keyframe_indices_s = None + session.forward_stream( + frame_images, + is_keyframe=is_keyframe_s, + keyframe_indices=keyframe_indices_s, + record=True, + ) + if is_keyframe_s is None or not bool(is_keyframe_s.item()) or s <= 0: + del frame_images + if is_keyframe_s is not None: + del is_keyframe_s + if keyframe_indices_s is not None: + del keyframe_indices_s + continue + keyframe_count += 1 + if keyframe_count % refresh_intervals == 0: + session.clear_cache_only() + segment_start = s + if keyframe_indices_s is not None: + keyframe_indices_self = torch.zeros_like(keyframe_indices_s) + else: + keyframe_indices_self = None + session.forward_stream( + frame_images, + is_keyframe=is_keyframe_s, + keyframe_indices=keyframe_indices_self, + record=False, + ) + del frame_images + if is_keyframe_s is not None: + del is_keyframe_s + if keyframe_indices_s is not None: + del keyframe_indices_s + return session.get_all_predictions() diff --git a/longstream/streaming/stream_session.py b/longstream/streaming/stream_session.py new file mode 100644 index 0000000000000000000000000000000000000000..2952518d626f7dfbba2d4184c8b7aa807738c228 --- /dev/null +++ b/longstream/streaming/stream_session.py @@ -0,0 +1,294 @@ +import torch + + +class StreamSession: + def __init__( + self, + model, + mode: str, + window_size: int = 5, + keep_first_frame_anchor: bool = True, + ): + self.model = model + self.core_model = getattr(model, "longstream", model) + self.mode = mode + self.window_size = window_size + self.keep_first_frame_anchor = keep_first_frame_anchor + + if self.mode not in ["causal", "window"]: + raise ValueError(f"Unsupported attention mode: {self.mode}") + + self.aggregator_kv_cache_depth = self.core_model.aggregator.depth + self.use_camera_head = self.core_model.camera_head is not None + if self.use_camera_head: + self.camera_head_kv_cache_depth = self.core_model.camera_head.trunk_depth + self.camera_head_iterations = 4 + else: + self.camera_head_kv_cache_depth = 0 + self.camera_head_iterations = 0 + + self.use_rel_pose_head = ( + hasattr(self.core_model, "rel_pose_head") + and self.core_model.rel_pose_head is not None + ) + if self.use_rel_pose_head: + self.rel_pose_head_trunk_depth = self.core_model.rel_pose_head.trunk_depth + self.rel_pose_head_iterations = 4 + + self.clear() + + def _clear_predictions(self): + self.sequence_predictions = {} + self.scalar_predictions = {} + + def _update_predictions(self, predictions): + sequence_keys = [ + "pose_enc", + "rel_pose_enc", + "world_points", + "world_points_conf", + "depth", + "depth_conf", + ] + scalar_keys = ["predicted_scale_factor", "global_scale"] + + for k in sequence_keys: + if k in predictions: + self.sequence_predictions.setdefault(k, []).append( + predictions[k].detach().cpu() + ) + + for k in scalar_keys: + if k in predictions: + value = predictions[k] + self.scalar_predictions[k] = ( + value.detach().cpu() if isinstance(value, torch.Tensor) else value + ) + + def _clear_cache(self): + self.aggregator_kv_cache_list = [ + [None, None] for _ in range(self.aggregator_kv_cache_depth) + ] + if self.use_camera_head: + self.camera_head_kv_cache_list = [ + [[None, None] for _ in range(self.camera_head_kv_cache_depth)] + for _ in range(self.camera_head_iterations) + ] + else: + self.camera_head_kv_cache_list = None + if self.use_rel_pose_head: + self.rel_pose_kv_cache_list = [ + [[None, None] for _ in range(self.rel_pose_head_trunk_depth)] + for _ in range(self.rel_pose_head_iterations) + ] + else: + self.rel_pose_kv_cache_list = None + + def _update_cache( + self, aggregator_kv_cache_list, camera_head_kv_cache_list, frame_hw + ): + if self.mode == "causal": + self.aggregator_kv_cache_list = aggregator_kv_cache_list + if self.use_camera_head: + self.camera_head_kv_cache_list = camera_head_kv_cache_list + return + + if self.mode == "window": + h, w = frame_hw + P = ( + h + * w + // self.core_model.aggregator.patch_size + // self.core_model.aggregator.patch_size + + self.core_model.aggregator.patch_start_idx + ) + + for k in range(2): + for i in range(self.aggregator_kv_cache_depth): + cache_size = aggregator_kv_cache_list[i][k].size(2) + if self.keep_first_frame_anchor: + if cache_size <= P: + self.aggregator_kv_cache_list[i][ + k + ] = aggregator_kv_cache_list[i][k].contiguous() + elif cache_size <= self.window_size * P: + self.aggregator_kv_cache_list[i][ + k + ] = aggregator_kv_cache_list[i][k].contiguous() + else: + anchor = aggregator_kv_cache_list[i][k][:, :, :P] + recent_start = cache_size - (self.window_size - 1) * P + recent = aggregator_kv_cache_list[i][k][:, :, recent_start:] + self.aggregator_kv_cache_list[i][k] = torch.cat( + [anchor, recent], dim=2 + ).contiguous() + else: + start_idx = max(0, cache_size - self.window_size * P) + self.aggregator_kv_cache_list[i][k] = aggregator_kv_cache_list[ + i + ][k][:, :, start_idx:].contiguous() + + if camera_head_kv_cache_list is not None: + for k in range(2): + for i in range(self.camera_head_iterations): + for j in range(self.camera_head_kv_cache_depth): + cache_size = camera_head_kv_cache_list[i][j][k].size(2) + if self.keep_first_frame_anchor: + if cache_size <= 1: + self.camera_head_kv_cache_list[i][j][ + k + ] = camera_head_kv_cache_list[i][j][k].contiguous() + elif cache_size <= self.window_size: + self.camera_head_kv_cache_list[i][j][ + k + ] = camera_head_kv_cache_list[i][j][k].contiguous() + else: + anchor = camera_head_kv_cache_list[i][j][k][ + :, :, :1 + ] + recent_start = cache_size - (self.window_size - 1) + recent = camera_head_kv_cache_list[i][j][k][ + :, :, recent_start: + ] + self.camera_head_kv_cache_list[i][j][k] = torch.cat( + [anchor, recent], dim=2 + ).contiguous() + else: + start_idx = max(0, cache_size - self.window_size) + self.camera_head_kv_cache_list[i][j][ + k + ] = camera_head_kv_cache_list[i][j][k][ + :, :, start_idx: + ].contiguous() + return + + raise ValueError(f"Unsupported attention mode: {self.mode}") + + def _get_cache(self): + return self.aggregator_kv_cache_list, self.camera_head_kv_cache_list + + def get_all_predictions(self): + predictions = {} + for key, chunks in self.sequence_predictions.items(): + if not chunks: + continue + predictions[key] = ( + chunks[0] if len(chunks) == 1 else torch.cat(chunks, dim=1) + ) + predictions.update(self.scalar_predictions) + return predictions + + def get_last_prediction(self): + last_predictions = {} + keys_to_extract = [ + "pose_enc", + "rel_pose_enc", + "world_points", + "world_points_conf", + "depth", + "depth_conf", + "predicted_scale_factor", + ] + for k in keys_to_extract: + if k in self.sequence_predictions and self.sequence_predictions[k]: + last_predictions[k] = self.sequence_predictions[k][-1][:, -1:] + elif k in self.scalar_predictions: + last_predictions[k] = self.scalar_predictions[k] + return last_predictions + + def clear(self): + self._clear_predictions() + self._clear_cache() + if self.use_rel_pose_head: + if hasattr(self.core_model.rel_pose_head, "_keyframe_tokens_cache"): + self.core_model.rel_pose_head._keyframe_tokens_cache = {} + if hasattr(self.core_model.rel_pose_head, "_current_frame_id"): + self.core_model.rel_pose_head._current_frame_id = 0 + if hasattr(self.core_model.rel_pose_head, "_frame_info"): + self.core_model.rel_pose_head._frame_info = [] + + def clear_cache_only(self): + self._clear_cache() + if self.use_rel_pose_head: + if hasattr(self.core_model.rel_pose_head, "_keyframe_tokens_cache"): + self.core_model.rel_pose_head._keyframe_tokens_cache = {} + if hasattr(self.core_model.rel_pose_head, "_current_frame_id"): + self.core_model.rel_pose_head._current_frame_id = 0 + if hasattr(self.core_model.rel_pose_head, "_frame_info"): + self.core_model.rel_pose_head._frame_info = [] + + def forward_stream( + self, images, is_keyframe=None, keyframe_indices=None, record: bool = True + ): + aggregator_kv_cache_list, camera_head_kv_cache_list = self._get_cache() + + rel_pose_inputs = None + if ( + self.use_rel_pose_head + and is_keyframe is not None + and keyframe_indices is not None + ): + rel_pose_inputs = { + "is_keyframe": is_keyframe, + "keyframe_indices": keyframe_indices, + "kv_cache_list": self.rel_pose_kv_cache_list, + } + + outputs = self.model( + images=images, + mode=self.mode, + aggregator_kv_cache_list=aggregator_kv_cache_list, + camera_head_kv_cache_list=camera_head_kv_cache_list, + rel_pose_inputs=rel_pose_inputs, + is_keyframe=is_keyframe, + ) + + if record: + self._update_predictions(outputs) + + camera_head_kv_cache_list = outputs.get("camera_head_kv_cache_list", None) + depth_hw = ( + outputs["depth"].shape[2:4] if "depth" in outputs else images.shape[-2:] + ) + self._update_cache( + outputs["aggregator_kv_cache_list"], camera_head_kv_cache_list, depth_hw + ) + + if self.use_rel_pose_head and "rel_pose_kv_cache_list" in outputs: + rel_pose_kv_cache = outputs["rel_pose_kv_cache_list"] + if self.mode == "causal": + self.rel_pose_kv_cache_list = rel_pose_kv_cache + elif self.mode == "window": + for k in range(2): + for i in range(self.rel_pose_head_iterations): + for j in range(self.rel_pose_head_trunk_depth): + if rel_pose_kv_cache[i][j][k] is None: + continue + cache_len = rel_pose_kv_cache[i][j][k].size(2) + if self.keep_first_frame_anchor: + if cache_len <= 1: + self.rel_pose_kv_cache_list[i][j][ + k + ] = rel_pose_kv_cache[i][j][k].contiguous() + elif cache_len <= self.window_size: + self.rel_pose_kv_cache_list[i][j][ + k + ] = rel_pose_kv_cache[i][j][k].contiguous() + else: + anchor = rel_pose_kv_cache[i][j][k][:, :, :1] + recent_start = cache_len - (self.window_size - 1) + recent = rel_pose_kv_cache[i][j][k][ + :, :, recent_start: + ] + self.rel_pose_kv_cache_list[i][j][k] = torch.cat( + [anchor, recent], dim=2 + ).contiguous() + else: + start_idx = max(0, cache_len - self.window_size) + self.rel_pose_kv_cache_list[i][j][ + k + ] = rel_pose_kv_cache[i][j][k][ + :, :, start_idx: + ].contiguous() + + return outputs diff --git a/longstream/utils/__init__.py b/longstream/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/longstream/utils/camera.py b/longstream/utils/camera.py new file mode 100644 index 0000000000000000000000000000000000000000..825278374ce3d440c138148263004e6194a97c23 --- /dev/null +++ b/longstream/utils/camera.py @@ -0,0 +1,50 @@ +import torch +from longstream.utils.vendor.models.components.utils.rotation import ( + quat_to_mat, + mat_to_quat, +) + + +def compose_abs_from_rel( + rel_pose_enc: torch.Tensor, keyframe_indices: torch.Tensor +) -> torch.Tensor: + squeeze_batch = False + if rel_pose_enc.ndim == 2: + rel_pose_enc = rel_pose_enc.unsqueeze(0) + squeeze_batch = True + if keyframe_indices.ndim == 1: + keyframe_indices = keyframe_indices.unsqueeze(0) + if rel_pose_enc.ndim != 3 or keyframe_indices.ndim != 2: + raise ValueError( + f"Expected rel_pose_enc [B,S,D] or [S,D] and keyframe_indices [B,S] or [S], " + f"got {tuple(rel_pose_enc.shape)} and {tuple(keyframe_indices.shape)}" + ) + + B, S, _ = rel_pose_enc.shape + device = rel_pose_enc.device + dtype = rel_pose_enc.dtype + + rel_t = rel_pose_enc[..., :3] + rel_q = rel_pose_enc[..., 3:7] + rel_f = rel_pose_enc[..., 7:9] + rel_R = quat_to_mat(rel_q.reshape(-1, 4)).reshape(B, S, 3, 3) + + abs_R = torch.zeros(B, S, 3, 3, device=device, dtype=dtype) + abs_t = torch.zeros(B, S, 3, device=device, dtype=dtype) + abs_f = torch.zeros(B, S, 2, device=device, dtype=dtype) + + for b in range(B): + abs_R[b, 0] = rel_R[b, 0] + abs_t[b, 0] = rel_t[b, 0] + abs_f[b, 0] = rel_f[b, 0] + for s in range(1, S): + ref_idx = int(keyframe_indices[b, s].item()) + abs_R[b, s] = rel_R[b, s] @ abs_R[b, ref_idx] + abs_t[b, s] = rel_t[b, s] + rel_R[b, s] @ abs_t[b, ref_idx] + abs_f[b, s] = rel_f[b, s] + + abs_q = mat_to_quat(abs_R.reshape(-1, 3, 3)).reshape(B, S, 4) + abs_pose_enc = torch.cat([abs_t, abs_q, abs_f], dim=-1) + if squeeze_batch: + return abs_pose_enc[0] + return abs_pose_enc diff --git a/longstream/utils/depth.py b/longstream/utils/depth.py new file mode 100644 index 0000000000000000000000000000000000000000..4740a92f53e5a85dda8bff904e2f5af1eb299b51 --- /dev/null +++ b/longstream/utils/depth.py @@ -0,0 +1,36 @@ +import numpy as np +import torch +import matplotlib.cm as cm + + +def colorize_depth(depth: torch.Tensor, cmap: str = "plasma") -> np.ndarray: + if torch.is_tensor(depth): + depth_np = depth.detach().cpu().numpy() + else: + depth_np = depth + d_min = np.nanmin(depth_np) + d_max = np.nanmax(depth_np) + if d_max - d_min < 1e-6: + d_max = d_min + 1e-6 + norm = (depth_np - d_min) / (d_max - d_min) + norm = np.clip(norm, 0.0, 1.0) + mapper = cm.get_cmap(cmap) + colored = mapper(norm)[..., :3] + return (colored * 255.0).astype(np.uint8) + + +def unproject_depth_to_points(depth: torch.Tensor, intri: torch.Tensor) -> torch.Tensor: + B, H, W = depth.shape + fx = intri[:, 0, 0].view(B, 1, 1) + fy = intri[:, 1, 1].view(B, 1, 1) + cx = intri[:, 0, 2].view(B, 1, 1) + cy = intri[:, 1, 2].view(B, 1, 1) + + ys = torch.arange(H, device=depth.device).view(1, H, 1).float() + xs = torch.arange(W, device=depth.device).view(1, 1, W).float() + + x = (xs - cx) * depth / fx + y = (ys - cy) * depth / fy + z = depth + pts = torch.stack([x, y, z], dim=-1) + return pts diff --git a/longstream/utils/hub.py b/longstream/utils/hub.py new file mode 100644 index 0000000000000000000000000000000000000000..45036071c453350cf69641b1c4cc957855448b67 --- /dev/null +++ b/longstream/utils/hub.py @@ -0,0 +1,42 @@ +import os +from dataclasses import dataclass +from typing import Optional + + +@dataclass +class HFSpec: + repo_id: str + filename: str + revision: Optional[str] = None + local_dir: str = "checkpoints" + + +def _is_nonempty_str(x) -> bool: + return isinstance(x, str) and len(x) > 0 + + +def resolve_checkpoint_path( + checkpoint: Optional[str], hf: Optional[dict] +) -> Optional[str]: + if _is_nonempty_str(checkpoint): + return checkpoint + if not isinstance(hf, dict): + return None + + repo_id = hf.get("repo_id") + filename = hf.get("filename") + revision = hf.get("revision", None) + local_dir = hf.get("local_dir", "checkpoints") + + if not _is_nonempty_str(repo_id) or not _is_nonempty_str(filename): + return None + + try: + from huggingface_hub import hf_hub_download + except Exception as e: + raise RuntimeError("huggingface_hub is required for auto-download") from e + + os.makedirs(local_dir, exist_ok=True) + return hf_hub_download( + repo_id=repo_id, filename=filename, revision=revision, local_dir=local_dir + ) diff --git a/longstream/utils/sky_mask.py b/longstream/utils/sky_mask.py new file mode 100644 index 0000000000000000000000000000000000000000..6b5294866694f2769e398259b8a1bbf74481bbf2 --- /dev/null +++ b/longstream/utils/sky_mask.py @@ -0,0 +1,100 @@ +import os +import copy +import cv2 +import numpy as np +import shutil +import urllib.request + +try: + import onnxruntime +except Exception: + onnxruntime = None + +SKYSEG_URL = "https://huggingface.co/JianyuanWang/skyseg/resolve/main/skyseg.onnx" +SKYSEG_THRESHOLD = 0.5 + + +def run_skyseg(session, input_size, image): + temp_image = copy.deepcopy(image) + resize_image = cv2.resize(temp_image, dsize=(input_size[0], input_size[1])) + x = cv2.cvtColor(resize_image, cv2.COLOR_BGR2RGB) + x = np.array(x, dtype=np.float32) + mean = [0.485, 0.456, 0.406] + std = [0.229, 0.224, 0.225] + x = (x / 255 - mean) / std + x = x.transpose(2, 0, 1) + x = x.reshape(-1, 3, input_size[0], input_size[1]).astype("float32") + input_name = session.get_inputs()[0].name + result_map = session.run(None, {input_name: x})[0] + return result_map[0, 0] + + +def _normalize_skyseg_output(result_map): + result_map = np.asarray(result_map, dtype=np.float32) + if result_map.size == 0: + return result_map + finite = np.isfinite(result_map) + if not np.any(finite): + return np.zeros_like(result_map, dtype=np.float32) + result_map = np.nan_to_num(result_map, nan=0.0, posinf=1.0, neginf=0.0) + max_value = float(result_map.max()) + min_value = float(result_map.min()) + if min_value >= 0.0 and max_value > 1.5: + result_map = result_map / 255.0 + return np.clip(result_map, 0.0, 1.0) + + +def sky_mask_filename(image_path): + parent = os.path.basename(os.path.dirname(image_path)) + name = os.path.basename(image_path) + if parent: + return f"{parent}__{name}" + return name + + +def segment_sky(image_path, session, mask_filename=None): + image = cv2.imread(image_path) + if image is None: + return None + result_map = run_skyseg(session, [320, 320], image) + result_map_original = cv2.resize(result_map, (image.shape[1], image.shape[0])) + result_map_original = _normalize_skyseg_output(result_map_original) + output_mask = np.zeros(result_map_original.shape, dtype=np.uint8) + output_mask[result_map_original < SKYSEG_THRESHOLD] = 255 + if mask_filename is not None: + os.makedirs(os.path.dirname(mask_filename), exist_ok=True) + cv2.imwrite(mask_filename, output_mask) + return output_mask + + +def compute_sky_mask(image_paths, model_path: str, target_dir: str = None): + if onnxruntime is None: + return None + if not os.path.exists(model_path): + os.makedirs(os.path.dirname(os.path.abspath(model_path)), exist_ok=True) + try: + print(f"[longstream] downloading skyseg.onnx to {model_path}", flush=True) + with urllib.request.urlopen(SKYSEG_URL) as src, open( + model_path, "wb" + ) as dst: + shutil.copyfileobj(src, dst) + except Exception as exc: + print(f"[longstream] failed to download skyseg.onnx: {exc}", flush=True) + return None + if not os.path.exists(model_path): + return None + session = onnxruntime.InferenceSession(model_path) + masks = [] + for image_path in image_paths: + mask_filepath = None + if target_dir is not None: + name = sky_mask_filename(image_path) + mask_filepath = os.path.join(target_dir, name) + if os.path.exists(mask_filepath): + sky_mask = cv2.imread(mask_filepath, cv2.IMREAD_GRAYSCALE) + else: + sky_mask = segment_sky(image_path, session, mask_filepath) + else: + sky_mask = segment_sky(image_path, session, None) + masks.append(sky_mask) + return masks diff --git a/longstream/utils/vendor/__init__.py b/longstream/utils/vendor/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..139597f9cb07c5d48bed18984ec4747f4b4f3438 --- /dev/null +++ b/longstream/utils/vendor/__init__.py @@ -0,0 +1,2 @@ + + diff --git a/longstream/utils/vendor/croco/LICENSE b/longstream/utils/vendor/croco/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..c9342d78f441dccebe0fe4461ad6a791196ef484 --- /dev/null +++ b/longstream/utils/vendor/croco/LICENSE @@ -0,0 +1,52 @@ +CroCo, Copyright (c) 2022-present Naver Corporation, is licensed under the Creative Commons Attribution-NonCommercial-ShareAlike 4.0 license. + +A summary of the CC BY-NC-SA 4.0 license is located here: + https://creativecommons.org/licenses/by-nc-sa/4.0/ + +The CC BY-NC-SA 4.0 license is located here: + https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode + + +SEE NOTICE BELOW WITH RESPECT TO THE FILE: models/pos_embed.py, models/blocks.py + +*************************** + +NOTICE WITH RESPECT TO THE FILE: models/pos_embed.py + +This software is being redistributed in a modifiled form. The original form is available here: + +https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py + +This software in this file incorporates parts of the following software available here: + +Transformer: https://github.com/tensorflow/models/blob/master/official/legacy/transformer/model_utils.py +available under the following license: https://github.com/tensorflow/models/blob/master/LICENSE + +MoCo v3: https://github.com/facebookresearch/moco-v3 +available under the following license: https://github.com/facebookresearch/moco-v3/blob/main/LICENSE + +DeiT: https://github.com/facebookresearch/deit +available under the following license: https://github.com/facebookresearch/deit/blob/main/LICENSE + + +ORIGINAL COPYRIGHT NOTICE AND PERMISSION NOTICE AVAILABLE HERE IS REPRODUCE BELOW: + +https://github.com/facebookresearch/mae/blob/main/LICENSE + +Attribution-NonCommercial 4.0 International + +*************************** + +NOTICE WITH RESPECT TO THE FILE: models/blocks.py + +This software is being redistributed in a modifiled form. The original form is available here: + +https://github.com/rwightman/pytorch-image-models + +ORIGINAL COPYRIGHT NOTICE AND PERMISSION NOTICE AVAILABLE HERE IS REPRODUCE BELOW: + +https://github.com/rwightman/pytorch-image-models/blob/master/LICENSE + +Apache License +Version 2.0, January 2004 +http://www.apache.org/licenses/ diff --git a/longstream/utils/vendor/croco/NOTICE b/longstream/utils/vendor/croco/NOTICE new file mode 100644 index 0000000000000000000000000000000000000000..2a44a6a89e9ace025db2b713442d91e47bfc4656 --- /dev/null +++ b/longstream/utils/vendor/croco/NOTICE @@ -0,0 +1,21 @@ +CroCo +Copyright 2022-present NAVER Corp. + +This project contains subcomponents with separate copyright notices and license terms. +Your use of the source code for these subcomponents is subject to the terms and conditions of the following licenses. + +==== + +facebookresearch/mae +https://github.com/facebookresearch/mae + +Attribution-NonCommercial 4.0 International + +==== + +rwightman/pytorch-image-models +https://github.com/rwightman/pytorch-image-models + +Apache License +Version 2.0, January 2004 +http://www.apache.org/licenses/ diff --git a/longstream/utils/vendor/croco/README.MD b/longstream/utils/vendor/croco/README.MD new file mode 100644 index 0000000000000000000000000000000000000000..ecc8f8263b52a0ec3b826de0418e946a9783ed36 --- /dev/null +++ b/longstream/utils/vendor/croco/README.MD @@ -0,0 +1,124 @@ +# CroCo + CroCo v2 / CroCo-Stereo / CroCo-Flow + +[[`CroCo arXiv`](https://arxiv.org/abs/2210.10716)] [[`CroCo v2 arXiv`](https://arxiv.org/abs/2211.10408)] [[`project page and demo`](https://croco.europe.naverlabs.com/)] + +This repository contains the code for our CroCo model presented in our NeurIPS'22 paper [CroCo: Self-Supervised Pre-training for 3D Vision Tasks by Cross-View Completion](https://openreview.net/pdf?id=wZEfHUM5ri) and its follow-up extension published at ICCV'23 [Improved Cross-view Completion Pre-training for Stereo Matching and Optical Flow](https://openaccess.thecvf.com/content/ICCV2023/html/Weinzaepfel_CroCo_v2_Improved_Cross-view_Completion_Pre-training_for_Stereo_Matching_and_ICCV_2023_paper.html), refered to as CroCo v2: + +![image](assets/arch.jpg) + +```bibtex +@inproceedings{croco, + title={{CroCo: Self-Supervised Pre-training for 3D Vision Tasks by Cross-View Completion}}, + author={{Weinzaepfel, Philippe and Leroy, Vincent and Lucas, Thomas and Br\'egier, Romain and Cabon, Yohann and Arora, Vaibhav and Antsfeld, Leonid and Chidlovskii, Boris and Csurka, Gabriela and Revaud J\'er\^ome}}, + booktitle={{NeurIPS}}, + year={2022} +} + +@inproceedings{croco_v2, + title={{CroCo v2: Improved Cross-view Completion Pre-training for Stereo Matching and Optical Flow}}, + author={Weinzaepfel, Philippe and Lucas, Thomas and Leroy, Vincent and Cabon, Yohann and Arora, Vaibhav and Br{\'e}gier, Romain and Csurka, Gabriela and Antsfeld, Leonid and Chidlovskii, Boris and Revaud, J{\'e}r{\^o}me}, + booktitle={ICCV}, + year={2023} +} +``` + +## License + +The code is distributed under the CC BY-NC-SA 4.0 License. See [LICENSE](LICENSE) for more information. +Some components are based on code from [MAE](https://github.com/facebookresearch/mae) released under the CC BY-NC-SA 4.0 License and [timm](https://github.com/rwightman/pytorch-image-models) released under the Apache 2.0 License. +Some components for stereo matching and optical flow are based on code from [unimatch](https://github.com/autonomousvision/unimatch) released under the MIT license. + +## Preparation + +1. Install dependencies on a machine with a NVidia GPU using e.g. conda. Note that `habitat-sim` is required only for the interactive demo and the synthetic pre-training data generation. If you don't plan to use it, you can ignore the line installing it and use a more recent python version. + +```bash +conda create -n croco python=3.7 cmake=3.14.0 +conda activate croco +conda install habitat-sim headless -c conda-forge -c aihabitat +conda install pytorch torchvision -c pytorch +conda install notebook ipykernel matplotlib +conda install ipywidgets widgetsnbextension +conda install scikit-learn tqdm quaternion opencv # only for pretraining / habitat data generation + +``` + +2. Compile cuda kernels for RoPE + +CroCo v2 relies on RoPE positional embeddings for which you need to compile some cuda kernels. +```bash +cd models/curope/ +python setup.py build_ext --inplace +cd ../../ +``` + +This can be a bit long as we compile for all cuda architectures, feel free to update L9 of `models/curope/setup.py` to compile for specific architectures only. +You might also need to set the environment `CUDA_HOME` in case you use a custom cuda installation. + +In case you cannot provide, we also provide a slow pytorch version, which will be automatically loaded. + +3. Download pre-trained model + +We provide several pre-trained models: + +| modelname | pre-training data | pos. embed. | Encoder | Decoder | +|------------------------------------------------------------------------------------------------------------------------------------|-------------------|-------------|---------|---------| +| [`CroCo.pth`](https://download.europe.naverlabs.com/ComputerVision/CroCo/CroCo.pth) | Habitat | cosine | ViT-B | Small | +| [`CroCo_V2_ViTBase_SmallDecoder.pth`](https://download.europe.naverlabs.com/ComputerVision/CroCo/CroCo_V2_ViTBase_SmallDecoder.pth) | Habitat + real | RoPE | ViT-B | Small | +| [`CroCo_V2_ViTBase_BaseDecoder.pth`](https://download.europe.naverlabs.com/ComputerVision/CroCo/CroCo_V2_ViTBase_BaseDecoder.pth) | Habitat + real | RoPE | ViT-B | Base | +| [`CroCo_V2_ViTLarge_BaseDecoder.pth`](https://download.europe.naverlabs.com/ComputerVision/CroCo/CroCo_V2_ViTLarge_BaseDecoder.pth) | Habitat + real | RoPE | ViT-L | Base | + +To download a specific model, i.e., the first one (`CroCo.pth`) +```bash +mkdir -p pretrained_models/ +wget https://download.europe.naverlabs.com/ComputerVision/CroCo/CroCo.pth -P pretrained_models/ +``` + +## Reconstruction example + +Simply run after downloading the `CroCo_V2_ViTLarge_BaseDecoder` pretrained model (or update the corresponding line in `demo.py`) +```bash +python demo.py +``` + +## Interactive demonstration of cross-view completion reconstruction on the Habitat simulator + +First download the test scene from Habitat: +```bash +python -m habitat_sim.utils.datasets_download --uids habitat_test_scenes --data-path habitat-sim-data/ +``` + +Then, run the Notebook demo `interactive_demo.ipynb`. + +In this demo, you should be able to sample a random reference viewpoint from an [Habitat](https://github.com/facebookresearch/habitat-sim) test scene. Use the sliders to change viewpoint and select a masked target view to reconstruct using CroCo. +![croco_interactive_demo](https://user-images.githubusercontent.com/1822210/200516576-7937bc6a-55f8-49ed-8618-3ddf89433ea4.jpg) + +## Pre-training + +### CroCo + +To pre-train CroCo, please first generate the pre-training data from the Habitat simulator, following the instructions in [datasets/habitat_sim/README.MD](datasets/habitat_sim/README.MD) and then run the following command: +``` +torchrun --nproc_per_node=4 pretrain.py --output_dir ./output/pretraining/ +``` + +Our CroCo pre-training was launched on a single server with 4 GPUs. +It should take around 10 days with A100 or 15 days with V100 to do the 400 pre-training epochs, but decent performances are obtained earlier in training. +Note that, while the code contains the same scaling rule of the learning rate as MAE when changing the effective batch size, we did not experimented if it is valid in our case. +The first run can take a few minutes to start, to parse all available pre-training pairs. + +### CroCo v2 + +For CroCo v2 pre-training, in addition to the generation of the pre-training data from the Habitat simulator above, please pre-extract the crops from the real datasets following the instructions in [datasets/crops/README.MD](datasets/crops/README.MD). +Then, run the following command for the largest model (ViT-L encoder, Base decoder): +``` +torchrun --nproc_per_node=8 pretrain.py --model "CroCoNet(enc_embed_dim=1024, enc_depth=24, enc_num_heads=16, dec_embed_dim=768, dec_num_heads=12, dec_depth=12, pos_embed='RoPE100')" --dataset "habitat_release+ARKitScenes+MegaDepth+3DStreetView+IndoorVL" --warmup_epochs 12 --max_epoch 125 --epochs 250 --amp 0 --keep_freq 5 --output_dir ./output/pretraining_crocov2/ +``` + +Our CroCo v2 pre-training was launched on a single server with 8 GPUs for the largest model, and on a single server with 4 GPUs for the smaller ones, keeping a batch size of 64 per gpu in all cases. +The largest model should take around 12 days on A100. +Note that, while the code contains the same scaling rule of the learning rate as MAE when changing the effective batch size, we did not experimented if it is valid in our case. + +## Stereo matching and Optical flow downstream tasks + +For CroCo-Stereo and CroCo-Flow, please refer to [stereoflow/README.MD](stereoflow/README.MD). diff --git a/longstream/utils/vendor/croco/assets/arch.jpg b/longstream/utils/vendor/croco/assets/arch.jpg new file mode 100644 index 0000000000000000000000000000000000000000..3f5b032729ddc58c06d890a0ebda1749276070c4 Binary files /dev/null and b/longstream/utils/vendor/croco/assets/arch.jpg differ diff --git a/longstream/utils/vendor/croco/croco-stereo-flow-demo.ipynb b/longstream/utils/vendor/croco/croco-stereo-flow-demo.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..f6ef17447beedb4a858e96da5b480b454e69c1bd --- /dev/null +++ b/longstream/utils/vendor/croco/croco-stereo-flow-demo.ipynb @@ -0,0 +1,182 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "9bca0f41", + "metadata": {}, + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "80653ef7", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "markdown", + "id": "4f033862", + "metadata": {}, + "source": [ + "First download the model(s) of your choice by running\n", + "```\n", + "bash stereoflow/download_model.sh crocostereo.pth\n", + "bash stereoflow/download_model.sh crocoflow.pth\n", + "```" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1fb2e392", + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "use_gpu = torch.cuda.is_available() and torch.cuda.device_count()>0\n", + "device = torch.device('cuda:0' if use_gpu else 'cpu')\n", + "import matplotlib.pylab as plt" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e0e25d77", + "metadata": {}, + "outputs": [], + "source": [ + "from stereoflow.test import _load_model_and_criterion\n", + "from stereoflow.engine import tiled_pred\n", + "from stereoflow.datasets_stereo import img_to_tensor, vis_disparity\n", + "from stereoflow.datasets_flow import flowToColor\n", + "tile_overlap=0.7 # recommended value, higher value can be slightly better but slower" + ] + }, + { + "cell_type": "markdown", + "id": "86a921f5", + "metadata": {}, + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "64e483cb", + "metadata": {}, + "outputs": [], + "source": [ + "image1 = np.asarray(Image.open(''))\n", + "image2 = np.asarray(Image.open(''))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f0d04303", + "metadata": {}, + "outputs": [], + "source": [ + "model, _, cropsize, with_conf, task, tile_conf_mode = _load_model_and_criterion('stereoflow_models/crocostereo.pth', None, device)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "47dc14b5", + "metadata": {}, + "outputs": [], + "source": [ + "im1 = img_to_tensor(image1).to(device).unsqueeze(0)\n", + "im2 = img_to_tensor(image2).to(device).unsqueeze(0)\n", + "with torch.inference_mode():\n", + " pred, _, _ = tiled_pred(model, None, im1, im2, None, conf_mode=tile_conf_mode, overlap=tile_overlap, crop=cropsize, with_conf=with_conf, return_time=False)\n", + "pred = pred.squeeze(0).squeeze(0).cpu().numpy()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "583b9f16", + "metadata": {}, + "outputs": [], + "source": [ + "plt.imshow(vis_disparity(pred))\n", + "plt.axis('off')" + ] + }, + { + "cell_type": "markdown", + "id": "d2df5d70", + "metadata": {}, + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9ee257a7", + "metadata": {}, + "outputs": [], + "source": [ + "image1 = np.asarray(Image.open(''))\n", + "image2 = np.asarray(Image.open(''))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d5edccf0", + "metadata": {}, + "outputs": [], + "source": [ + "model, _, cropsize, with_conf, task, tile_conf_mode = _load_model_and_criterion('stereoflow_models/crocoflow.pth', None, device)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b19692c3", + "metadata": {}, + "outputs": [], + "source": [ + "im1 = img_to_tensor(image1).to(device).unsqueeze(0)\n", + "im2 = img_to_tensor(image2).to(device).unsqueeze(0)\n", + "with torch.inference_mode():\n", + " pred, _, _ = tiled_pred(model, None, im1, im2, None, conf_mode=tile_conf_mode, overlap=tile_overlap, crop=cropsize, with_conf=with_conf, return_time=False)\n", + "pred = pred.squeeze(0).permute(1,2,0).cpu().numpy()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "26f79db3", + "metadata": {}, + "outputs": [], + "source": [ + "plt.imshow(flowToColor(pred))\n", + "plt.axis('off')" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.7" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/longstream/utils/vendor/croco/datasets/__init__.py b/longstream/utils/vendor/croco/datasets/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..139597f9cb07c5d48bed18984ec4747f4b4f3438 --- /dev/null +++ b/longstream/utils/vendor/croco/datasets/__init__.py @@ -0,0 +1,2 @@ + + diff --git a/longstream/utils/vendor/croco/datasets/crops/README.MD b/longstream/utils/vendor/croco/datasets/crops/README.MD new file mode 100644 index 0000000000000000000000000000000000000000..71b97a61084536305bfca2cebabed89a16340e0a --- /dev/null +++ b/longstream/utils/vendor/croco/datasets/crops/README.MD @@ -0,0 +1,104 @@ +## Generation of crops from the real datasets + +The instructions below allow to generate the crops used for pre-training CroCo v2 from the following real-world datasets: ARKitScenes, MegaDepth, 3DStreetView and IndoorVL. + +### Download the metadata of the crops to generate + +First, download the metadata and put them in `./data/`: +``` +mkdir -p data +cd data/ +wget https://download.europe.naverlabs.com/ComputerVision/CroCo/data/crop_metadata.zip +unzip crop_metadata.zip +rm crop_metadata.zip +cd .. +``` + +### Prepare the original datasets + +Second, download the original datasets in `./data/original_datasets/`. +``` +mkdir -p data/original_datasets +``` + +##### ARKitScenes + +Download the `raw` dataset from https://github.com/apple/ARKitScenes/blob/main/DATA.md and put it in `./data/original_datasets/ARKitScenes/`. +The resulting file structure should be like: +``` +./data/original_datasets/ARKitScenes/ +└───Training + └───40753679 + │ │ ultrawide + │ │ ... + └───40753686 + │ + ... +``` + +##### MegaDepth + +Download `MegaDepth v1 Dataset` from https://www.cs.cornell.edu/projects/megadepth/ and put it in `./data/original_datasets/MegaDepth/`. +The resulting file structure should be like: + +``` +./data/original_datasets/MegaDepth/ +└───0000 +│ └───images +│ │ │ 1000557903_87fa96b8a4_o.jpg +│ │ └ ... +│ └─── ... +└───0001 +│ │ +│ └ ... +└─── ... +``` + +##### 3DStreetView + +Download `3D_Street_View` dataset from https://github.com/amir32002/3D_Street_View and put it in `./data/original_datasets/3DStreetView/`. +The resulting file structure should be like: + +``` +./data/original_datasets/3DStreetView/ +└───dataset_aligned +│ └───0002 +│ │ │ 0000002_0000001_0000002_0000001.jpg +│ │ └ ... +│ └─── ... +└───dataset_unaligned +│ └───0003 +│ │ │ 0000003_0000001_0000002_0000001.jpg +│ │ └ ... +│ └─── ... +``` + +##### IndoorVL + +Download the `IndoorVL` datasets using [Kapture](https://github.com/naver/kapture). + +``` +pip install kapture +mkdir -p ./data/original_datasets/IndoorVL +cd ./data/original_datasets/IndoorVL +kapture_download_dataset.py update +kapture_download_dataset.py install "HyundaiDepartmentStore_*" +kapture_download_dataset.py install "GangnamStation_*" +cd - +``` + +### Extract the crops + +Now, extract the crops for each of the dataset: +``` +for dataset in ARKitScenes MegaDepth 3DStreetView IndoorVL; +do + python3 datasets/crops/extract_crops_from_images.py --crops ./data/crop_metadata/${dataset}/crops_release.txt --root-dir ./data/original_datasets/${dataset}/ --output-dir ./data/${dataset}_crops/ --imsize 256 --nthread 8 --max-subdir-levels 5 --ideal-number-pairs-in-dir 500; +done +``` + +##### Note for IndoorVL + +Due to some legal issues, we can only release 144,228 pairs out of the 1,593,689 pairs used in the paper. +To account for it in terms of number of pre-training iterations, the pre-training command in this repository uses 125 training epochs including 12 warm-up epochs and learning rate cosine schedule of 250, instead of 100, 10 and 200 respectively. +The impact on the performance is negligible. diff --git a/longstream/utils/vendor/croco/datasets/crops/extract_crops_from_images.py b/longstream/utils/vendor/croco/datasets/crops/extract_crops_from_images.py new file mode 100644 index 0000000000000000000000000000000000000000..1c689bd218803fe840d4b36da3861e44011772db --- /dev/null +++ b/longstream/utils/vendor/croco/datasets/crops/extract_crops_from_images.py @@ -0,0 +1,175 @@ +import argparse +import functools +import math +import os +from multiprocessing import Pool + +from PIL import Image +from tqdm import tqdm + + +def arg_parser(): + parser = argparse.ArgumentParser( + "Generate cropped image pairs from image crop list" + ) + + parser.add_argument("--crops", type=str, required=True, help="crop file") + parser.add_argument("--root-dir", type=str, required=True, help="root directory") + parser.add_argument( + "--output-dir", type=str, required=True, help="output directory" + ) + parser.add_argument("--imsize", type=int, default=256, help="size of the crops") + parser.add_argument( + "--nthread", type=int, required=True, help="number of simultaneous threads" + ) + parser.add_argument( + "--max-subdir-levels", + type=int, + default=5, + help="maximum number of subdirectories", + ) + parser.add_argument( + "--ideal-number-pairs-in-dir", + type=int, + default=500, + help="number of pairs stored in a dir", + ) + return parser + + +def main(args): + listing_path = os.path.join(args.output_dir, "listing.txt") + + print(f"Loading list of crops ... ({args.nthread} threads)") + crops, num_crops_to_generate = load_crop_file(args.crops) + + print(f"Preparing jobs ({len(crops)} candidate image pairs)...") + num_levels = min( + math.ceil(math.log(num_crops_to_generate, args.ideal_number_pairs_in_dir)), + args.max_subdir_levels, + ) + num_pairs_in_dir = math.ceil(num_crops_to_generate ** (1 / num_levels)) + + jobs = prepare_jobs(crops, num_levels, num_pairs_in_dir) + del crops + + os.makedirs(args.output_dir, exist_ok=True) + mmap = Pool(args.nthread).imap_unordered if args.nthread > 1 else map + call = functools.partial(save_image_crops, args) + + print(f"Generating cropped images to {args.output_dir} ...") + with open(listing_path, "w") as listing: + listing.write("# pair_path\n") + for results in tqdm(mmap(call, jobs), total=len(jobs)): + for path in results: + listing.write(f"{path}\n") + print("Finished writing listing to", listing_path) + + +def load_crop_file(path): + data = open(path).read().splitlines() + pairs = [] + num_crops_to_generate = 0 + for line in tqdm(data): + if line.startswith("#"): + continue + line = line.split(", ") + if len(line) < 8: + img1, img2, rotation = line + pairs.append((img1, img2, int(rotation), [])) + else: + l1, r1, t1, b1, l2, r2, t2, b2 = map(int, line) + rect1, rect2 = (l1, t1, r1, b1), (l2, t2, r2, b2) + pairs[-1][-1].append((rect1, rect2)) + num_crops_to_generate += 1 + return pairs, num_crops_to_generate + + +def prepare_jobs(pairs, num_levels, num_pairs_in_dir): + jobs = [] + powers = [num_pairs_in_dir ** level for level in reversed(range(num_levels))] + + def get_path(idx): + idx_array = [] + d = idx + for level in range(num_levels - 1): + idx_array.append(idx // powers[level]) + idx = idx % powers[level] + idx_array.append(d) + return "/".join(map(lambda x: hex(x)[2:], idx_array)) + + idx = 0 + for pair_data in tqdm(pairs): + img1, img2, rotation, crops = pair_data + if -60 <= rotation and rotation <= 60: + rotation = 0 + paths = [get_path(idx + k) for k in range(len(crops))] + idx += len(crops) + jobs.append(((img1, img2), rotation, crops, paths)) + return jobs + + +def load_image(path): + try: + return Image.open(path).convert("RGB") + except Exception as e: + print("skipping", path, e) + raise OSError() + + +def save_image_crops(args, data): + + img_pair, rot, crops, paths = data + try: + img1, img2 = [ + load_image(os.path.join(args.root_dir, impath)) for impath in img_pair + ] + except OSError as e: + return [] + + def area(sz): + return sz[0] * sz[1] + + tgt_size = (args.imsize, args.imsize) + + def prepare_crop(img, rect, rot=0): + + img = img.crop(rect) + + interp = ( + Image.Resampling.LANCZOS + if area(img.size) > 4 * area(tgt_size) + else Image.Resampling.BICUBIC + ) + img = img.resize(tgt_size, resample=interp) + + rot90 = (round(rot / 90) % 4) * 90 + if rot90 == 90: + img = img.transpose(Image.Transpose.ROTATE_90) + elif rot90 == 180: + img = img.transpose(Image.Transpose.ROTATE_180) + elif rot90 == 270: + img = img.transpose(Image.Transpose.ROTATE_270) + return img + + results = [] + for (rect1, rect2), path in zip(crops, paths): + crop1 = prepare_crop(img1, rect1) + crop2 = prepare_crop(img2, rect2, rot) + + fullpath1 = os.path.join(args.output_dir, path + "_1.jpg") + fullpath2 = os.path.join(args.output_dir, path + "_2.jpg") + os.makedirs(os.path.dirname(fullpath1), exist_ok=True) + + assert not os.path.isfile(fullpath1), fullpath1 + assert not os.path.isfile(fullpath2), fullpath2 + crop1.save(fullpath1) + crop2.save(fullpath2) + results.append(path) + + return results + + +if __name__ == "__main__": + args = arg_parser().parse_args() + main(args) diff --git a/longstream/utils/vendor/croco/datasets/habitat_sim/README.MD b/longstream/utils/vendor/croco/datasets/habitat_sim/README.MD new file mode 100644 index 0000000000000000000000000000000000000000..a505781ff9eb91bce7f1d189e848f8ba1c560940 --- /dev/null +++ b/longstream/utils/vendor/croco/datasets/habitat_sim/README.MD @@ -0,0 +1,76 @@ +## Generation of synthetic image pairs using Habitat-Sim + +These instructions allow to generate pre-training pairs from the Habitat simulator. +As we did not save metadata of the pairs used in the original paper, they are not strictly the same, but these data use the same setting and are equivalent. + +### Download Habitat-Sim scenes +Download Habitat-Sim scenes: +- Download links can be found here: https://github.com/facebookresearch/habitat-sim/blob/main/DATASETS.md +- We used scenes from the HM3D, habitat-test-scenes, Replica, ReplicaCad and ScanNet datasets. +- Please put the scenes under `./data/habitat-sim-data/scene_datasets/` following the structure below, or update manually paths in `paths.py`. +``` +./data/ +└──habitat-sim-data/ + └──scene_datasets/ + ├──hm3d/ + ├──gibson/ + ├──habitat-test-scenes/ + ├──replica_cad_baked_lighting/ + ├──replica_cad/ + ├──ReplicaDataset/ + └──scannet/ +``` + +### Image pairs generation +We provide metadata to generate reproducible images pairs for pretraining and validation. +Experiments described in the paper used similar data, but whose generation was not reproducible at the time. + +Specifications: +- 256x256 resolution images, with 60 degrees field of view . +- Up to 1000 image pairs per scene. +- Number of scenes considered/number of images pairs per dataset: + - Scannet: 1097 scenes / 985 209 pairs + - HM3D: + - hm3d/train: 800 / 800k pairs + - hm3d/val: 100 scenes / 100k pairs + - hm3d/minival: 10 scenes / 10k pairs + - habitat-test-scenes: 3 scenes / 3k pairs + - replica_cad_baked_lighting: 13 scenes / 13k pairs + +- Scenes from hm3d/val and hm3d/minival pairs were not used for the pre-training but kept for validation purposes. + +Download metadata and extract it: +```bash +mkdir -p data/habitat_release_metadata/ +cd data/habitat_release_metadata/ +wget https://download.europe.naverlabs.com/ComputerVision/CroCo/data/habitat_release_metadata/multiview_habitat_metadata.tar.gz +tar -xvf multiview_habitat_metadata.tar.gz +cd ../.. +# Location of the metadata +METADATA_DIR="./data/habitat_release_metadata/multiview_habitat_metadata" +``` + +Generate image pairs from metadata: +- The following command will print a list of commandlines to generate image pairs for each scene: +```bash +# Target output directory +PAIRS_DATASET_DIR="./data/habitat_release/" +python datasets/habitat_sim/generate_from_metadata_files.py --input_dir=$METADATA_DIR --output_dir=$PAIRS_DATASET_DIR +``` +- One can launch multiple of such commands in parallel e.g. using GNU Parallel: +```bash +python datasets/habitat_sim/generate_from_metadata_files.py --input_dir=$METADATA_DIR --output_dir=$PAIRS_DATASET_DIR | parallel -j 16 +``` + +## Metadata generation + +Image pairs were randomly sampled using the following commands, whose outputs contain randomness and are thus not exactly reproducible: +```bash +# Print commandlines to generate image pairs from the different scenes available. +PAIRS_DATASET_DIR=MY_CUSTOM_PATH +python datasets/habitat_sim/generate_multiview_images.py --list_commands --output_dir=$PAIRS_DATASET_DIR + +# Once a dataset is generated, pack metadata files for reproducibility. +METADATA_DIR=MY_CUSTON_PATH +python datasets/habitat_sim/pack_metadata_files.py $PAIRS_DATASET_DIR $METADATA_DIR +``` diff --git a/longstream/utils/vendor/croco/datasets/habitat_sim/__init__.py b/longstream/utils/vendor/croco/datasets/habitat_sim/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..139597f9cb07c5d48bed18984ec4747f4b4f3438 --- /dev/null +++ b/longstream/utils/vendor/croco/datasets/habitat_sim/__init__.py @@ -0,0 +1,2 @@ + + diff --git a/longstream/utils/vendor/croco/datasets/habitat_sim/generate_from_metadata.py b/longstream/utils/vendor/croco/datasets/habitat_sim/generate_from_metadata.py new file mode 100644 index 0000000000000000000000000000000000000000..2bd33000d933e3b44f422835db7046f8ad8bfa71 --- /dev/null +++ b/longstream/utils/vendor/croco/datasets/habitat_sim/generate_from_metadata.py @@ -0,0 +1,121 @@ +""" +Script to generate image pairs for a given scene reproducing poses provided in a metadata file. +""" +import argparse +import json +import os + +import cv2 +import PIL.Image +import quaternion +from datasets.habitat_sim.multiview_habitat_sim_generator import ( + MultiviewHabitatSimGenerator, +) +from datasets.habitat_sim.paths import SCENES_DATASET +from tqdm import tqdm + + +def generate_multiview_images_from_metadata( + metadata_filename, + output_dir, + overload_params=dict(), + scene_datasets_paths=None, + exist_ok=False, +): + """ + Generate images from a metadata file for reproducibility purposes. + """ + + if scene_datasets_paths is not None: + scene_datasets_paths = dict( + sorted(scene_datasets_paths.items(), key=lambda x: len(x[0]), reverse=True) + ) + + with open(metadata_filename, "r") as f: + input_metadata = json.load(f) + metadata = dict() + for key, value in input_metadata.items(): + + if key in ("scene_dataset_config_file", "scene", "navmesh") and value != "": + if scene_datasets_paths is not None: + for dataset_label, dataset_path in scene_datasets_paths.items(): + if value.startswith(dataset_label): + value = os.path.normpath( + os.path.join( + dataset_path, os.path.relpath(value, dataset_label) + ) + ) + break + metadata[key] = value + + for key, value in overload_params.items(): + metadata[key] = value + + generation_entries = dict( + [ + (key, value) + for key, value in metadata.items() + if not (key in ("multiviews", "output_dir", "generate_depth")) + ] + ) + generate_depth = metadata["generate_depth"] + + os.makedirs(output_dir, exist_ok=exist_ok) + + generator = MultiviewHabitatSimGenerator(**generation_entries) + + for idx_label, data in tqdm(metadata["multiviews"].items()): + positions = data["positions"] + orientations = data["orientations"] + n = len(positions) + for oidx in range(n): + observation = generator.render_viewpoint( + positions[oidx], quaternion.from_float_array(orientations[oidx]) + ) + observation_label = f"{oidx + 1}" + + img = PIL.Image.fromarray(observation["color"][:, :, :3]) + filename = os.path.join(output_dir, f"{idx_label}_{observation_label}.jpeg") + img.save(filename) + if generate_depth: + + filename = os.path.join( + output_dir, f"{idx_label}_{observation_label}_depth.exr" + ) + cv2.imwrite( + filename, + observation["depth"], + [cv2.IMWRITE_EXR_TYPE, cv2.IMWRITE_EXR_TYPE_HALF], + ) + + camera_params = dict( + [ + (key, observation[key].tolist()) + for key in ("camera_intrinsics", "R_cam2world", "t_cam2world") + ] + ) + filename = os.path.join( + output_dir, f"{idx_label}_{observation_label}_camera_params.json" + ) + with open(filename, "w") as f: + json.dump(camera_params, f) + + with open(os.path.join(output_dir, "metadata.json"), "w") as f: + json.dump(metadata, f) + + generator.close() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--metadata_filename", required=True) + parser.add_argument("--output_dir", required=True) + args = parser.parse_args() + + generate_multiview_images_from_metadata( + metadata_filename=args.metadata_filename, + output_dir=args.output_dir, + scene_datasets_paths=SCENES_DATASET, + overload_params=dict(), + exist_ok=True, + ) diff --git a/longstream/utils/vendor/croco/datasets/habitat_sim/generate_from_metadata_files.py b/longstream/utils/vendor/croco/datasets/habitat_sim/generate_from_metadata_files.py new file mode 100644 index 0000000000000000000000000000000000000000..a8ba13b6c62a25f803bc8d79891e4c727881a587 --- /dev/null +++ b/longstream/utils/vendor/croco/datasets/habitat_sim/generate_from_metadata_files.py @@ -0,0 +1,34 @@ +""" +Script generating commandlines to generate image pairs from metadata files. +""" +import argparse +import glob +import os + +from tqdm import tqdm + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--input_dir", required=True) + parser.add_argument("--output_dir", required=True) + parser.add_argument( + "--prefix", + default="", + help="Commanline prefix, useful e.g. to setup environment.", + ) + args = parser.parse_args() + + input_metadata_filenames = glob.iglob( + f"{args.input_dir}/**/metadata.json", recursive=True + ) + + for metadata_filename in tqdm(input_metadata_filenames): + output_dir = os.path.join( + args.output_dir, + os.path.relpath(os.path.dirname(metadata_filename), args.input_dir), + ) + + if os.path.exists(os.path.join(output_dir, "metadata.json")): + continue + commandline = f"{args.prefix}python datasets/habitat_sim/generate_from_metadata.py --metadata_filename={metadata_filename} --output_dir={output_dir}" + print(commandline) diff --git a/longstream/utils/vendor/croco/datasets/habitat_sim/generate_multiview_images.py b/longstream/utils/vendor/croco/datasets/habitat_sim/generate_multiview_images.py new file mode 100644 index 0000000000000000000000000000000000000000..02b78050f76e3bb0a6498c2cbac343656e8e4e05 --- /dev/null +++ b/longstream/utils/vendor/croco/datasets/habitat_sim/generate_multiview_images.py @@ -0,0 +1,224 @@ +import argparse +import json +import os + +import cv2 +import numpy as np +import PIL.Image +import quaternion +from datasets.habitat_sim.multiview_habitat_sim_generator import ( + MultiviewHabitatSimGenerator, + NoNaviguableSpaceError, +) +from datasets.habitat_sim.paths import list_scenes_available +from tqdm import tqdm + + +def generate_multiview_images_for_scene( + scene_dataset_config_file, + scene, + navmesh, + output_dir, + views_count, + size, + exist_ok=False, + generate_depth=False, + **kwargs, +): + """ + Generate tuples of overlapping views for a given scene. + generate_depth: generate depth images and camera parameters. + """ + if os.path.exists(output_dir) and not exist_ok: + print(f"Scene {scene}: data already generated. Ignoring generation.") + return + try: + print(f"Scene {scene}: {size} multiview acquisitions to generate...") + os.makedirs(output_dir, exist_ok=exist_ok) + + metadata_filename = os.path.join(output_dir, "metadata.json") + + metadata_template = dict( + scene_dataset_config_file=scene_dataset_config_file, + scene=scene, + navmesh=navmesh, + views_count=views_count, + size=size, + generate_depth=generate_depth, + **kwargs, + ) + metadata_template["multiviews"] = dict() + + if os.path.exists(metadata_filename): + print("Metadata file already exists:", metadata_filename) + print("Loading already generated metadata file...") + with open(metadata_filename, "r") as f: + metadata = json.load(f) + + for key in metadata_template.keys(): + if key != "multiviews": + assert ( + metadata_template[key] == metadata[key] + ), f"existing file is inconsistent with the input parameters:\nKey: {key}\nmetadata: {metadata[key]}\ntemplate: {metadata_template[key]}." + else: + print("No temporary file found. Starting generation from scratch...") + metadata = metadata_template + + starting_id = len(metadata["multiviews"]) + print(f"Starting generation from index {starting_id}/{size}...") + if starting_id >= size: + print("Generation already done.") + return + + generator = MultiviewHabitatSimGenerator( + scene_dataset_config_file=scene_dataset_config_file, + scene=scene, + navmesh=navmesh, + views_count=views_count, + size=size, + **kwargs, + ) + + for idx in tqdm(range(starting_id, size)): + + try: + data = generator[idx] + observations = data["observations"] + positions = data["positions"] + orientations = data["orientations"] + + idx_label = f"{idx:08}" + for oidx, observation in enumerate(observations): + observation_label = f"{oidx + 1}" + + img = PIL.Image.fromarray(observation["color"][:, :, :3]) + filename = os.path.join( + output_dir, f"{idx_label}_{observation_label}.jpeg" + ) + img.save(filename) + if generate_depth: + + filename = os.path.join( + output_dir, f"{idx_label}_{observation_label}_depth.exr" + ) + cv2.imwrite( + filename, + observation["depth"], + [cv2.IMWRITE_EXR_TYPE, cv2.IMWRITE_EXR_TYPE_HALF], + ) + + camera_params = dict( + [ + (key, observation[key].tolist()) + for key in ( + "camera_intrinsics", + "R_cam2world", + "t_cam2world", + ) + ] + ) + filename = os.path.join( + output_dir, + f"{idx_label}_{observation_label}_camera_params.json", + ) + with open(filename, "w") as f: + json.dump(camera_params, f) + metadata["multiviews"][idx_label] = { + "positions": positions.tolist(), + "orientations": orientations.tolist(), + "covisibility_ratios": data["covisibility_ratios"].tolist(), + "valid_fractions": data["valid_fractions"].tolist(), + "pairwise_visibility_ratios": data[ + "pairwise_visibility_ratios" + ].tolist(), + } + except RecursionError: + print( + "Recursion error: unable to sample observations for this scene. We will stop there." + ) + break + + if idx % 10 == 0: + with open(metadata_filename, "w") as f: + json.dump(metadata, f) + + with open(metadata_filename, "w") as f: + json.dump(metadata, f) + + generator.close() + except NoNaviguableSpaceError: + pass + + +def create_commandline(scene_data, generate_depth, exist_ok=False): + """ + Create a commandline string to generate a scene. + """ + + def my_formatting(val): + if val is None or val == "": + return '""' + else: + return val + + commandline = f"""python {__file__} --scene {my_formatting(scene_data.scene)} + --scene_dataset_config_file {my_formatting(scene_data.scene_dataset_config_file)} + --navmesh {my_formatting(scene_data.navmesh)} + --output_dir {my_formatting(scene_data.output_dir)} + --generate_depth {int(generate_depth)} + --exist_ok {int(exist_ok)} + """ + commandline = " ".join(commandline.split()) + return commandline + + +if __name__ == "__main__": + os.umask(2) + + parser = argparse.ArgumentParser( + description="""Example of use -- listing commands to generate data for scenes available: + > python datasets/habitat_sim/generate_multiview_habitat_images.py --list_commands + """ + ) + + parser.add_argument("--output_dir", type=str, required=True) + parser.add_argument( + "--list_commands", action="store_true", help="list commandlines to run if true" + ) + parser.add_argument("--scene", type=str, default="") + parser.add_argument("--scene_dataset_config_file", type=str, default="") + parser.add_argument("--navmesh", type=str, default="") + + parser.add_argument("--generate_depth", type=int, default=1) + parser.add_argument("--exist_ok", type=int, default=0) + + kwargs = dict(resolution=(256, 256), hfov=60, views_count=2, size=1000) + + args = parser.parse_args() + generate_depth = bool(args.generate_depth) + exist_ok = bool(args.exist_ok) + + if args.list_commands: + + scenes_data = list_scenes_available(base_output_dir=args.output_dir) + + for scene_data in scenes_data: + print( + create_commandline( + scene_data, generate_depth=generate_depth, exist_ok=exist_ok + ) + ) + else: + if args.scene == "" or args.output_dir == "": + print("Missing scene or output dir argument!") + print(parser.format_help()) + else: + generate_multiview_images_for_scene( + scene=args.scene, + scene_dataset_config_file=args.scene_dataset_config_file, + navmesh=args.navmesh, + output_dir=args.output_dir, + exist_ok=exist_ok, + generate_depth=generate_depth, + **kwargs, + ) diff --git a/longstream/utils/vendor/croco/datasets/habitat_sim/multiview_habitat_sim_generator.py b/longstream/utils/vendor/croco/datasets/habitat_sim/multiview_habitat_sim_generator.py new file mode 100644 index 0000000000000000000000000000000000000000..dc27fdfa36b912a6323ea5b1afe59ce66b061644 --- /dev/null +++ b/longstream/utils/vendor/croco/datasets/habitat_sim/multiview_habitat_sim_generator.py @@ -0,0 +1,479 @@ +import cv2 +import habitat_sim +import numpy as np +import quaternion +from sklearn.neighbors import NearestNeighbors + +R_OPENCV2HABITAT = np.stack( + (habitat_sim.geo.RIGHT, -habitat_sim.geo.UP, habitat_sim.geo.FRONT), axis=0 +) +R_HABITAT2OPENCV = R_OPENCV2HABITAT.T +DEG2RAD = np.pi / 180 + + +def compute_camera_intrinsics(height, width, hfov): + f = width / 2 / np.tan(hfov / 2 * np.pi / 180) + cu, cv = width / 2, height / 2 + return f, cu, cv + + +def compute_camera_pose_opencv_convention(camera_position, camera_orientation): + R_cam2world = quaternion.as_rotation_matrix(camera_orientation) @ R_OPENCV2HABITAT + t_cam2world = np.asarray(camera_position) + return R_cam2world, t_cam2world + + +def compute_pointmap(depthmap, hfov): + """Compute a HxWx3 pointmap in camera frame from a HxW depth map.""" + height, width = depthmap.shape + f, cu, cv = compute_camera_intrinsics(height, width, hfov) + + z_cam = depthmap + u, v = np.meshgrid(range(width), range(height)) + x_cam = (u - cu) / f * z_cam + y_cam = (v - cv) / f * z_cam + X_cam = np.stack((x_cam, y_cam, z_cam), axis=-1) + return X_cam + + +def compute_pointcloud(depthmap, hfov, camera_position, camera_rotation): + """Return a 3D point cloud corresponding to valid pixels of the depth map""" + R_cam2world, t_cam2world = compute_camera_pose_opencv_convention( + camera_position, camera_rotation + ) + + X_cam = compute_pointmap(depthmap=depthmap, hfov=hfov) + valid_mask = X_cam[:, :, 2] != 0.0 + + X_cam = X_cam.reshape(-1, 3)[valid_mask.flatten()] + X_world = X_cam @ R_cam2world.T + t_cam2world.reshape(1, 3) + return X_world + + +def compute_pointcloud_overlaps_scikit( + pointcloud1, pointcloud2, distance_threshold, compute_symmetric=False +): + """ + Compute 'overlapping' metrics based on a distance threshold between two point clouds. + """ + nbrs = NearestNeighbors(n_neighbors=1, algorithm="kd_tree").fit(pointcloud2) + distances, indices = nbrs.kneighbors(pointcloud1) + intersection1 = np.count_nonzero(distances.flatten() < distance_threshold) + + data = {"intersection1": intersection1, "size1": len(pointcloud1)} + if compute_symmetric: + nbrs = NearestNeighbors(n_neighbors=1, algorithm="kd_tree").fit(pointcloud1) + distances, indices = nbrs.kneighbors(pointcloud2) + intersection2 = np.count_nonzero(distances.flatten() < distance_threshold) + data["intersection2"] = intersection2 + data["size2"] = len(pointcloud2) + + return data + + +def _append_camera_parameters(observation, hfov, camera_location, camera_rotation): + """ + Add camera parameters to the observation dictionnary produced by Habitat-Sim + In-place modifications. + """ + R_cam2world, t_cam2world = compute_camera_pose_opencv_convention( + camera_location, camera_rotation + ) + height, width = observation["depth"].shape + f, cu, cv = compute_camera_intrinsics(height, width, hfov) + K = np.asarray([[f, 0, cu], [0, f, cv], [0, 0, 1.0]]) + observation["camera_intrinsics"] = K + observation["t_cam2world"] = t_cam2world + observation["R_cam2world"] = R_cam2world + + +def look_at(eye, center, up, return_cam2world=True): + """ + Return camera pose looking at a given center point. + Analogous of gluLookAt function, using OpenCV camera convention. + """ + z = center - eye + z /= np.linalg.norm(z, axis=-1, keepdims=True) + y = -up + y = y - np.sum(y * z, axis=-1, keepdims=True) * z + y /= np.linalg.norm(y, axis=-1, keepdims=True) + x = np.cross(y, z, axis=-1) + + if return_cam2world: + R = np.stack((x, y, z), axis=-1) + t = eye + else: + + R = np.stack((x, y, z), axis=-2) + t = -np.einsum("...ij, ...j", R, eye) + return R, t + + +def look_at_for_habitat(eye, center, up, return_cam2world=True): + R, t = look_at(eye, center, up) + orientation = quaternion.from_rotation_matrix(R @ R_OPENCV2HABITAT.T) + return orientation, t + + +def generate_orientation_noise(pan_range, tilt_range, roll_range): + return ( + quaternion.from_rotation_vector( + np.random.uniform(*pan_range) * DEG2RAD * habitat_sim.geo.UP + ) + * quaternion.from_rotation_vector( + np.random.uniform(*tilt_range) * DEG2RAD * habitat_sim.geo.RIGHT + ) + * quaternion.from_rotation_vector( + np.random.uniform(*roll_range) * DEG2RAD * habitat_sim.geo.FRONT + ) + ) + + +class NoNaviguableSpaceError(RuntimeError): + def __init__(self, *args): + super().__init__(*args) + + +class MultiviewHabitatSimGenerator: + def __init__( + self, + scene, + navmesh, + scene_dataset_config_file, + resolution=(240, 320), + views_count=2, + hfov=60, + gpu_id=0, + size=10000, + minimum_covisibility=0.5, + transform=None, + ): + self.scene = scene + self.navmesh = navmesh + self.scene_dataset_config_file = scene_dataset_config_file + self.resolution = resolution + self.views_count = views_count + assert self.views_count >= 1 + self.hfov = hfov + self.gpu_id = gpu_id + self.size = size + self.transform = transform + + self.pan_range = (-3, 3) + self.tilt_range = (-10, 10) + self.roll_range = (-5, 5) + + self.height_range = (1.2, 1.8) + + self.random_steps_count = 5 + self.random_step_variance = 2.0 + + self.minimum_valid_fraction = 0.7 + + self.distance_threshold = 0.05 + + self.minimum_covisibility = minimum_covisibility + + self.max_attempts_count = 100 + + self.seed = None + self._lazy_initialization() + + def _lazy_initialization(self): + + if self.seed == None: + + np.random.seed() + self.seed = np.random.randint(2 ** 32 - 1) + sim_cfg = habitat_sim.SimulatorConfiguration() + sim_cfg.scene_id = self.scene + if ( + self.scene_dataset_config_file is not None + and self.scene_dataset_config_file != "" + ): + sim_cfg.scene_dataset_config_file = self.scene_dataset_config_file + sim_cfg.random_seed = self.seed + sim_cfg.load_semantic_mesh = False + sim_cfg.gpu_device_id = self.gpu_id + + depth_sensor_spec = habitat_sim.CameraSensorSpec() + depth_sensor_spec.uuid = "depth" + depth_sensor_spec.sensor_type = habitat_sim.SensorType.DEPTH + depth_sensor_spec.resolution = self.resolution + depth_sensor_spec.hfov = self.hfov + depth_sensor_spec.position = [0.0, 0.0, 0] + depth_sensor_spec.orientation + + rgb_sensor_spec = habitat_sim.CameraSensorSpec() + rgb_sensor_spec.uuid = "color" + rgb_sensor_spec.sensor_type = habitat_sim.SensorType.COLOR + rgb_sensor_spec.resolution = self.resolution + rgb_sensor_spec.hfov = self.hfov + rgb_sensor_spec.position = [0.0, 0.0, 0] + agent_cfg = habitat_sim.agent.AgentConfiguration( + sensor_specifications=[rgb_sensor_spec, depth_sensor_spec] + ) + + cfg = habitat_sim.Configuration(sim_cfg, [agent_cfg]) + self.sim = habitat_sim.Simulator(cfg) + if self.navmesh is not None and self.navmesh != "": + + self.sim.pathfinder.load_nav_mesh(self.navmesh) + + if not self.sim.pathfinder.is_loaded: + + navmesh_settings = habitat_sim.NavMeshSettings() + navmesh_settings.set_defaults() + self.sim.recompute_navmesh(self.sim.pathfinder, navmesh_settings, True) + + if not self.sim.pathfinder.is_loaded: + raise NoNaviguableSpaceError( + f"No naviguable location (scene: {self.scene} -- navmesh: {self.navmesh})" + ) + + self.agent = self.sim.initialize_agent(agent_id=0) + + def close(self): + self.sim.close() + + def __del__(self): + self.sim.close() + + def __len__(self): + return self.size + + def sample_random_viewpoint(self): + """Sample a random viewpoint using the navmesh""" + nav_point = self.sim.pathfinder.get_random_navigable_point() + + viewpoint_height = np.random.uniform(*self.height_range) + viewpoint_position = nav_point + viewpoint_height * habitat_sim.geo.UP + viewpoint_orientation = quaternion.from_rotation_vector( + np.random.uniform(0, 2 * np.pi) * habitat_sim.geo.UP + ) * generate_orientation_noise(self.pan_range, self.tilt_range, self.roll_range) + return viewpoint_position, viewpoint_orientation, nav_point + + def sample_other_random_viewpoint(self, observed_point, nav_point): + """Sample a random viewpoint close to an existing one, using the navmesh and a reference observed point.""" + other_nav_point = nav_point + + walk_directions = self.random_step_variance * np.asarray([1, 0, 1]) + for i in range(self.random_steps_count): + temp = self.sim.pathfinder.snap_point( + other_nav_point + walk_directions * np.random.normal(size=3) + ) + + if not np.isnan(temp[0]): + other_nav_point = temp + + other_viewpoint_height = np.random.uniform(*self.height_range) + other_viewpoint_position = ( + other_nav_point + other_viewpoint_height * habitat_sim.geo.UP + ) + + rotation, position = look_at_for_habitat( + eye=other_viewpoint_position, + center=observed_point, + up=habitat_sim.geo.UP, + return_cam2world=True, + ) + rotation = rotation * generate_orientation_noise( + self.pan_range, self.tilt_range, self.roll_range + ) + return position, rotation, other_nav_point + + def is_other_pointcloud_overlapping(self, ref_pointcloud, other_pointcloud): + """Check if a viewpoint is valid and overlaps significantly with a reference one.""" + + pixels_count = self.resolution[0] * self.resolution[1] + valid_fraction = len(other_pointcloud) / pixels_count + assert valid_fraction <= 1.0 and valid_fraction >= 0.0 + overlap = compute_pointcloud_overlaps_scikit( + ref_pointcloud, + other_pointcloud, + self.distance_threshold, + compute_symmetric=True, + ) + covisibility = min( + overlap["intersection1"] / pixels_count, + overlap["intersection2"] / pixels_count, + ) + is_valid = (valid_fraction >= self.minimum_valid_fraction) and ( + covisibility >= self.minimum_covisibility + ) + return is_valid, valid_fraction, covisibility + + def is_other_viewpoint_overlapping( + self, ref_pointcloud, observation, position, rotation + ): + """Check if a viewpoint is valid and overlaps significantly with a reference one.""" + + other_pointcloud = compute_pointcloud( + observation["depth"], self.hfov, position, rotation + ) + return self.is_other_pointcloud_overlapping(ref_pointcloud, other_pointcloud) + + def render_viewpoint(self, viewpoint_position, viewpoint_orientation): + agent_state = habitat_sim.AgentState() + agent_state.position = viewpoint_position + agent_state.rotation = viewpoint_orientation + self.agent.set_state(agent_state) + viewpoint_observations = self.sim.get_sensor_observations(agent_ids=0) + _append_camera_parameters( + viewpoint_observations, self.hfov, viewpoint_position, viewpoint_orientation + ) + return viewpoint_observations + + def __getitem__(self, useless_idx): + ref_position, ref_orientation, nav_point = self.sample_random_viewpoint() + ref_observations = self.render_viewpoint(ref_position, ref_orientation) + + ref_pointcloud = compute_pointcloud( + depthmap=ref_observations["depth"], + hfov=self.hfov, + camera_position=ref_position, + camera_rotation=ref_orientation, + ) + + pixels_count = self.resolution[0] * self.resolution[1] + ref_valid_fraction = len(ref_pointcloud) / pixels_count + assert ref_valid_fraction <= 1.0 and ref_valid_fraction >= 0.0 + if ref_valid_fraction < self.minimum_valid_fraction: + + return self[0] + + observed_point = np.mean(ref_pointcloud, axis=0) + + viewpoints_observations = [ref_observations] + viewpoints_covisibility = [ref_valid_fraction] + viewpoints_positions = [ref_position] + viewpoints_orientations = [quaternion.as_float_array(ref_orientation)] + viewpoints_clouds = [ref_pointcloud] + viewpoints_valid_fractions = [ref_valid_fraction] + + for _ in range(self.views_count - 1): + + successful_sampling = False + for sampling_attempt in range(self.max_attempts_count): + position, rotation, _ = self.sample_other_random_viewpoint( + observed_point, nav_point + ) + + other_viewpoint_observations = self.render_viewpoint(position, rotation) + other_pointcloud = compute_pointcloud( + other_viewpoint_observations["depth"], self.hfov, position, rotation + ) + + ( + is_valid, + valid_fraction, + covisibility, + ) = self.is_other_pointcloud_overlapping( + ref_pointcloud, other_pointcloud + ) + if is_valid: + successful_sampling = True + break + if not successful_sampling: + print("WARNING: Maximum number of attempts reached.") + + return self[0] + viewpoints_observations.append(other_viewpoint_observations) + viewpoints_covisibility.append(covisibility) + viewpoints_positions.append(position) + viewpoints_orientations.append(quaternion.as_float_array(rotation)) + viewpoints_clouds.append(other_pointcloud) + viewpoints_valid_fractions.append(valid_fraction) + + pairwise_visibility_ratios = np.ones( + (len(viewpoints_observations), len(viewpoints_observations)) + ) + for i in range(len(viewpoints_observations)): + pairwise_visibility_ratios[i, i] = viewpoints_valid_fractions[i] + for j in range(i + 1, len(viewpoints_observations)): + overlap = compute_pointcloud_overlaps_scikit( + viewpoints_clouds[i], + viewpoints_clouds[j], + self.distance_threshold, + compute_symmetric=True, + ) + pairwise_visibility_ratios[i, j] = ( + overlap["intersection1"] / pixels_count + ) + pairwise_visibility_ratios[j, i] = ( + overlap["intersection2"] / pixels_count + ) + + data = { + "observations": viewpoints_observations, + "positions": np.asarray(viewpoints_positions), + "orientations": np.asarray(viewpoints_orientations), + "covisibility_ratios": np.asarray(viewpoints_covisibility), + "valid_fractions": np.asarray(viewpoints_valid_fractions, dtype=float), + "pairwise_visibility_ratios": np.asarray( + pairwise_visibility_ratios, dtype=float + ), + } + + if self.transform is not None: + data = self.transform(data) + return data + + def generate_random_spiral_trajectory( + self, + images_count=100, + max_radius=0.5, + half_turns=5, + use_constant_orientation=False, + ): + """ + Return a list of images corresponding to a spiral trajectory from a random starting point. + Useful to generate nice visualisations. + Use an even number of half turns to get a nice "C1-continuous" loop effect + """ + ref_position, ref_orientation, navpoint = self.sample_random_viewpoint() + ref_observations = self.render_viewpoint(ref_position, ref_orientation) + ref_pointcloud = compute_pointcloud( + depthmap=ref_observations["depth"], + hfov=self.hfov, + camera_position=ref_position, + camera_rotation=ref_orientation, + ) + pixels_count = self.resolution[0] * self.resolution[1] + if len(ref_pointcloud) / pixels_count < self.minimum_valid_fraction: + + return self.generate_random_spiral_trajectory( + images_count, max_radius, half_turns, use_constant_orientation + ) + + observed_point = np.mean(ref_pointcloud, axis=0) + ref_R, ref_t = compute_camera_pose_opencv_convention( + ref_position, ref_orientation + ) + + images = [] + is_valid = [] + + for i, alpha in enumerate(np.linspace(0, 1, images_count)): + r = max_radius * np.abs(np.sin(alpha * np.pi)) + theta = alpha * half_turns * np.pi + x = r * np.cos(theta) + y = r * np.sin(theta) + z = 0.0 + position = ( + ref_position + (ref_R @ np.asarray([x, y, z]).reshape(3, 1)).flatten() + ) + if use_constant_orientation: + orientation = ref_orientation + else: + + orientation, position = look_at_for_habitat( + eye=position, center=observed_point, up=habitat_sim.geo.UP + ) + observations = self.render_viewpoint(position, orientation) + images.append(observations["color"][..., :3]) + _is_valid, valid_fraction, iou = self.is_other_viewpoint_overlapping( + ref_pointcloud, observations, position, orientation + ) + is_valid.append(_is_valid) + return images, np.all(is_valid) diff --git a/longstream/utils/vendor/croco/datasets/habitat_sim/pack_metadata_files.py b/longstream/utils/vendor/croco/datasets/habitat_sim/pack_metadata_files.py new file mode 100644 index 0000000000000000000000000000000000000000..38bb11ae6f1fbb368fff3967a490a7fc70ecf4e4 --- /dev/null +++ b/longstream/utils/vendor/croco/datasets/habitat_sim/pack_metadata_files.py @@ -0,0 +1,74 @@ +""" +Utility script to pack metadata files of the dataset in order to be able to re-generate it elsewhere. +""" +import argparse +import collections +import glob +import json +import os + +from datasets.habitat_sim.paths import SCENES_DATASET +from tqdm import tqdm + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("input_dir") + parser.add_argument("output_dir") + args = parser.parse_args() + + input_dirname = args.input_dir + output_dirname = args.output_dir + + input_metadata_filenames = glob.iglob( + f"{input_dirname}/**/metadata.json", recursive=True + ) + + images_count = collections.defaultdict(lambda: 0) + + os.makedirs(output_dirname) + for input_filename in tqdm(input_metadata_filenames): + + with open(input_filename, "r") as f: + original_metadata = json.load(f) + if ( + "multiviews" not in original_metadata + or len(original_metadata["multiviews"]) == 0 + ): + print("No views in", input_filename) + continue + + relpath = os.path.relpath(input_filename, input_dirname) + print(relpath) + + scenes_dataset_paths = dict( + sorted(SCENES_DATASET.items(), key=lambda x: len(x[1]), reverse=True) + ) + metadata = dict() + for key, value in original_metadata.items(): + if key in ("scene_dataset_config_file", "scene", "navmesh") and value != "": + known_path = False + for dataset, dataset_path in scenes_dataset_paths.items(): + if value.startswith(dataset_path): + value = os.path.join( + dataset, os.path.relpath(value, dataset_path) + ) + known_path = True + break + if not known_path: + raise KeyError("Unknown path:" + value) + metadata[key] = value + + scene_split = metadata["scene"].split("/") + upper_level = ( + "/".join(scene_split[:2]) if scene_split[0] == "hm3d" else scene_split[0] + ) + images_count[upper_level] += len(metadata["multiviews"]) + + output_filename = os.path.join(output_dirname, relpath) + os.makedirs(os.path.dirname(output_filename), exist_ok=True) + with open(output_filename, "w") as f: + json.dump(metadata, f) + + print("Images count:") + for upper_level, count in images_count.items(): + print(f"- {upper_level}: {count}") diff --git a/longstream/utils/vendor/croco/datasets/habitat_sim/paths.py b/longstream/utils/vendor/croco/datasets/habitat_sim/paths.py new file mode 100644 index 0000000000000000000000000000000000000000..be919fd58a8f1ac2d4a0e3dfe5695f103d209d14 --- /dev/null +++ b/longstream/utils/vendor/croco/datasets/habitat_sim/paths.py @@ -0,0 +1,166 @@ +""" +Paths to Habitat-Sim scenes +""" + +import collections +import os + +from tqdm import tqdm + +SCENES_DATASET = { + "hm3d": "./data/habitat-sim-data/scene_datasets/hm3d/", + "gibson": "./data/habitat-sim-data/scene_datasets/gibson/", + "habitat-test-scenes": "./data/habitat-sim/scene_datasets/habitat-test-scenes/", + "replica_cad_baked_lighting": "./data/habitat-sim/scene_datasets/replica_cad_baked_lighting/", + "replica_cad": "./data/habitat-sim/scene_datasets/replica_cad/", + "replica": "./data/habitat-sim/scene_datasets/ReplicaDataset/", + "scannet": "./data/habitat-sim/scene_datasets/scannet/", +} + +SceneData = collections.namedtuple( + "SceneData", ["scene_dataset_config_file", "scene", "navmesh", "output_dir"] +) + + +def list_replicacad_scenes(base_output_dir, base_path=SCENES_DATASET["replica_cad"]): + scene_dataset_config_file = os.path.join( + base_path, "replicaCAD.scene_dataset_config.json" + ) + scenes = [f"apt_{i}" for i in range(6)] + ["empty_stage"] + navmeshes = [f"navmeshes/apt_{i}_static_furniture.navmesh" for i in range(6)] + [ + "empty_stage.navmesh" + ] + scenes_data = [] + for idx in range(len(scenes)): + output_dir = os.path.join(base_output_dir, "ReplicaCAD", scenes[idx]) + + data = SceneData( + scene_dataset_config_file=scene_dataset_config_file, + scene=scenes[idx] + ".scene_instance.json", + navmesh=os.path.join(base_path, navmeshes[idx]), + output_dir=output_dir, + ) + scenes_data.append(data) + return scenes_data + + +def list_replica_cad_baked_lighting_scenes( + base_output_dir, base_path=SCENES_DATASET["replica_cad_baked_lighting"] +): + scene_dataset_config_file = os.path.join( + base_path, "replicaCAD_baked.scene_dataset_config.json" + ) + scenes = sum( + [[f"Baked_sc{i}_staging_{j:02}" for i in range(5)] for j in range(21)], [] + ) + navmeshes = "" + scenes_data = [] + for idx in range(len(scenes)): + output_dir = os.path.join( + base_output_dir, "replica_cad_baked_lighting", scenes[idx] + ) + data = SceneData( + scene_dataset_config_file=scene_dataset_config_file, + scene=scenes[idx], + navmesh="", + output_dir=output_dir, + ) + scenes_data.append(data) + return scenes_data + + +def list_replica_scenes(base_output_dir, base_path): + scenes_data = [] + for scene_id in os.listdir(base_path): + scene = os.path.join(base_path, scene_id, "mesh.ply") + navmesh = os.path.join( + base_path, scene_id, "habitat/mesh_preseg_semantic.navmesh" + ) + scene_dataset_config_file = "" + output_dir = os.path.join(base_output_dir, scene_id) + + data = SceneData( + scene_dataset_config_file=scene_dataset_config_file, + scene=scene, + navmesh=navmesh, + output_dir=output_dir, + ) + scenes_data.append(data) + return scenes_data + + +def list_scenes(base_output_dir, base_path): + """ + Generic method iterating through a base_path folder to find scenes. + """ + scenes_data = [] + for root, dirs, files in os.walk(base_path, followlinks=True): + folder_scenes_data = [] + for file in files: + name, ext = os.path.splitext(file) + if ext == ".glb": + scene = os.path.join(root, name + ".glb") + navmesh = os.path.join(root, name + ".navmesh") + if not os.path.exists(navmesh): + navmesh = "" + relpath = os.path.relpath(root, base_path) + output_dir = os.path.abspath( + os.path.join(base_output_dir, relpath, name) + ) + data = SceneData( + scene_dataset_config_file="", + scene=scene, + navmesh=navmesh, + output_dir=output_dir, + ) + folder_scenes_data.append(data) + + basis_scenes = [ + data.scene[: -len(".basis.glb")] + for data in folder_scenes_data + if data.scene.endswith(".basis.glb") + ] + if len(basis_scenes) != 0: + folder_scenes_data = [ + data + for data in folder_scenes_data + if not (data.scene[: -len(".glb")] in basis_scenes) + ] + + scenes_data.extend(folder_scenes_data) + return scenes_data + + +def list_scenes_available(base_output_dir, scenes_dataset_paths=SCENES_DATASET): + scenes_data = [] + + for split in ("minival", "train", "val", "examples"): + scenes_data += list_scenes( + base_output_dir=os.path.join(base_output_dir, f"hm3d/{split}/"), + base_path=f"{scenes_dataset_paths['hm3d']}/{split}", + ) + + scenes_data += list_scenes( + base_output_dir=os.path.join(base_output_dir, "gibson"), + base_path=scenes_dataset_paths["gibson"], + ) + + scenes_data += list_scenes( + base_output_dir=os.path.join(base_output_dir, "habitat-test-scenes"), + base_path=scenes_dataset_paths["habitat-test-scenes"], + ) + + scenes_data += list_replica_cad_baked_lighting_scenes( + base_output_dir=base_output_dir + ) + + scenes_data += list_scenes( + base_output_dir=os.path.join(base_output_dir, "scannet"), + base_path=scenes_dataset_paths["scannet"], + ) + + list_replica_scenes( + base_output_dir=os.path.join(base_output_dir, "replica"), + base_path=scenes_dataset_paths["replica"], + ) + return scenes_data diff --git a/longstream/utils/vendor/croco/datasets/pairs_dataset.py b/longstream/utils/vendor/croco/datasets/pairs_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..086500a2d21b2058793c36bc29e640e63680de67 --- /dev/null +++ b/longstream/utils/vendor/croco/datasets/pairs_dataset.py @@ -0,0 +1,158 @@ +import os + +from datasets.transforms import get_pair_transforms +from PIL import Image +from torch.utils.data import Dataset + + +def load_image(impath): + return Image.open(impath) + + +def load_pairs_from_cache_file(fname, root=""): + assert os.path.isfile( + fname + ), "cannot parse pairs from {:s}, file does not exist".format(fname) + with open(fname, "r") as fid: + lines = fid.read().strip().splitlines() + pairs = [ + (os.path.join(root, l.split()[0]), os.path.join(root, l.split()[1])) + for l in lines + ] + return pairs + + +def load_pairs_from_list_file(fname, root=""): + assert os.path.isfile( + fname + ), "cannot parse pairs from {:s}, file does not exist".format(fname) + with open(fname, "r") as fid: + lines = fid.read().strip().splitlines() + pairs = [ + (os.path.join(root, l + "_1.jpg"), os.path.join(root, l + "_2.jpg")) + for l in lines + if not l.startswith("#") + ] + return pairs + + +def write_cache_file(fname, pairs, root=""): + if len(root) > 0: + if not root.endswith("/"): + root += "/" + assert os.path.isdir(root) + s = "" + for im1, im2 in pairs: + if len(root) > 0: + assert im1.startswith(root), im1 + assert im2.startswith(root), im2 + s += "{:s} {:s}\n".format(im1[len(root) :], im2[len(root) :]) + with open(fname, "w") as fid: + fid.write(s[:-1]) + + +def parse_and_cache_all_pairs(dname, data_dir="./data/"): + if dname == "habitat_release": + dirname = os.path.join(data_dir, "habitat_release") + assert os.path.isdir(dirname), ( + "cannot find folder for habitat_release pairs: " + dirname + ) + cache_file = os.path.join(dirname, "pairs.txt") + assert not os.path.isfile(cache_file), ( + "cache file already exists: " + cache_file + ) + + print("Parsing pairs for dataset: " + dname) + pairs = [] + for root, dirs, files in os.walk(dirname): + if "val" in root: + continue + dirs.sort() + pairs += [ + ( + os.path.join(root, f), + os.path.join(root, f[: -len("_1.jpeg")] + "_2.jpeg"), + ) + for f in sorted(files) + if f.endswith("_1.jpeg") + ] + print("Found {:,} pairs".format(len(pairs))) + print("Writing cache to: " + cache_file) + write_cache_file(cache_file, pairs, root=dirname) + + else: + raise NotImplementedError("Unknown dataset: " + dname) + + +def dnames_to_image_pairs(dnames, data_dir="./data/"): + """ + dnames: list of datasets with image pairs, separated by + + """ + all_pairs = [] + for dname in dnames.split("+"): + if dname == "habitat_release": + dirname = os.path.join(data_dir, "habitat_release") + assert os.path.isdir(dirname), ( + "cannot find folder for habitat_release pairs: " + dirname + ) + cache_file = os.path.join(dirname, "pairs.txt") + assert os.path.isfile(cache_file), ( + "cannot find cache file for habitat_release pairs, please first create the cache file, see instructions. " + + cache_file + ) + pairs = load_pairs_from_cache_file(cache_file, root=dirname) + elif dname in ["ARKitScenes", "MegaDepth", "3DStreetView", "IndoorVL"]: + dirname = os.path.join(data_dir, dname + "_crops") + assert os.path.isdir( + dirname + ), "cannot find folder for {:s} pairs: {:s}".format(dname, dirname) + list_file = os.path.join(dirname, "listing.txt") + assert os.path.isfile( + list_file + ), "cannot find list file for {:s} pairs, see instructions. {:s}".format( + dname, list_file + ) + pairs = load_pairs_from_list_file(list_file, root=dirname) + print(" {:s}: {:,} pairs".format(dname, len(pairs))) + all_pairs += pairs + if "+" in dnames: + print(" Total: {:,} pairs".format(len(all_pairs))) + return all_pairs + + +class PairsDataset(Dataset): + def __init__( + self, dnames, trfs="", totensor=True, normalize=True, data_dir="./data/" + ): + super().__init__() + self.image_pairs = dnames_to_image_pairs(dnames, data_dir=data_dir) + self.transforms = get_pair_transforms( + transform_str=trfs, totensor=totensor, normalize=normalize + ) + + def __len__(self): + return len(self.image_pairs) + + def __getitem__(self, index): + im1path, im2path = self.image_pairs[index] + im1 = load_image(im1path) + im2 = load_image(im2path) + if self.transforms is not None: + im1, im2 = self.transforms(im1, im2) + return im1, im2 + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser( + prog="Computing and caching list of pairs for a given dataset" + ) + parser.add_argument( + "--data_dir", default="./data/", type=str, help="path where data are stored" + ) + parser.add_argument( + "--dataset", default="habitat_release", type=str, help="name of the dataset" + ) + args = parser.parse_args() + parse_and_cache_all_pairs(dname=args.dataset, data_dir=args.data_dir) diff --git a/longstream/utils/vendor/croco/datasets/transforms.py b/longstream/utils/vendor/croco/datasets/transforms.py new file mode 100644 index 0000000000000000000000000000000000000000..6f6491435e87a66fe52b19f1f524ee9402b7bdcd --- /dev/null +++ b/longstream/utils/vendor/croco/datasets/transforms.py @@ -0,0 +1,130 @@ +import torch +import torchvision.transforms +import torchvision.transforms.functional as F + + +class ComposePair(torchvision.transforms.Compose): + def __call__(self, img1, img2): + for t in self.transforms: + img1, img2 = t(img1, img2) + return img1, img2 + + +class NormalizeBoth(torchvision.transforms.Normalize): + def forward(self, img1, img2): + img1 = super().forward(img1) + img2 = super().forward(img2) + return img1, img2 + + +class ToTensorBoth(torchvision.transforms.ToTensor): + def __call__(self, img1, img2): + img1 = super().__call__(img1) + img2 = super().__call__(img2) + return img1, img2 + + +class RandomCropPair(torchvision.transforms.RandomCrop): + def forward(self, img1, img2): + img1 = super().forward(img1) + img2 = super().forward(img2) + return img1, img2 + + +class ColorJitterPair(torchvision.transforms.ColorJitter): + def __init__(self, assymetric_prob, **kwargs): + super().__init__(**kwargs) + self.assymetric_prob = assymetric_prob + + def jitter_one( + self, + img, + fn_idx, + brightness_factor, + contrast_factor, + saturation_factor, + hue_factor, + ): + for fn_id in fn_idx: + if fn_id == 0 and brightness_factor is not None: + img = F.adjust_brightness(img, brightness_factor) + elif fn_id == 1 and contrast_factor is not None: + img = F.adjust_contrast(img, contrast_factor) + elif fn_id == 2 and saturation_factor is not None: + img = F.adjust_saturation(img, saturation_factor) + elif fn_id == 3 and hue_factor is not None: + img = F.adjust_hue(img, hue_factor) + return img + + def forward(self, img1, img2): + ( + fn_idx, + brightness_factor, + contrast_factor, + saturation_factor, + hue_factor, + ) = self.get_params(self.brightness, self.contrast, self.saturation, self.hue) + img1 = self.jitter_one( + img1, + fn_idx, + brightness_factor, + contrast_factor, + saturation_factor, + hue_factor, + ) + if torch.rand(1) < self.assymetric_prob: + ( + fn_idx, + brightness_factor, + contrast_factor, + saturation_factor, + hue_factor, + ) = self.get_params( + self.brightness, self.contrast, self.saturation, self.hue + ) + img2 = self.jitter_one( + img2, + fn_idx, + brightness_factor, + contrast_factor, + saturation_factor, + hue_factor, + ) + return img1, img2 + + +def get_pair_transforms(transform_str, totensor=True, normalize=True): + + trfs = [] + for s in transform_str.split("+"): + if s.startswith("crop"): + size = int(s[len("crop") :]) + trfs.append(RandomCropPair(size)) + elif s == "acolor": + trfs.append( + ColorJitterPair( + assymetric_prob=1.0, + brightness=(0.6, 1.4), + contrast=(0.6, 1.4), + saturation=(0.6, 1.4), + hue=0.0, + ) + ) + elif s == "": + pass + else: + raise NotImplementedError("Unknown augmentation: " + s) + + if totensor: + trfs.append(ToTensorBoth()) + if normalize: + trfs.append( + NormalizeBoth(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + ) + + if len(trfs) == 0: + return None + elif len(trfs) == 1: + return trfs + else: + return ComposePair(trfs) diff --git a/longstream/utils/vendor/croco/interactive_demo.ipynb b/longstream/utils/vendor/croco/interactive_demo.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..da6a61a77cfdf8ce77999b5f80449857a2f4a706 --- /dev/null +++ b/longstream/utils/vendor/croco/interactive_demo.ipynb @@ -0,0 +1,248 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "import numpy as np\n", + "from models.croco import CroCoNet\n", + "from ipywidgets import interact, interactive, fixed, interact_manual\n", + "import ipywidgets as widgets\n", + "import matplotlib.pyplot as plt\n", + "import quaternion\n", + "import models.masking" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "ckpt = torch.load('pretrained_models/CroCo_V2_ViTLarge_BaseDecoder.pth', 'cpu')\n", + "model = CroCoNet( **ckpt.get('croco_kwargs',{}))\n", + "msg = model.load_state_dict(ckpt['model'], strict=True)\n", + "use_gpu = torch.cuda.is_available() and torch.cuda.device_count()>0\n", + "device = torch.device('cuda:0' if use_gpu else 'cpu')\n", + "model = model.eval()\n", + "model = model.to(device=device)\n", + "print(msg)\n", + "\n", + "def process_images(ref_image, target_image, masking_ratio, reconstruct_unmasked_patches=False):\n", + " \"\"\"\n", + " Perform Cross-View completion using two input images, specified using Numpy arrays.\n", + " \"\"\"\n", + " model.mask_generator = models.masking.RandomMask(model.patch_embed.num_patches, masking_ratio)\n", + "\n", + " imagenet_mean = torch.as_tensor([0.485, 0.456, 0.406]).reshape(1,3,1,1).to(device)\n", + " imagenet_std = torch.as_tensor([0.229, 0.224, 0.225]).reshape(1,3,1,1).to(device)\n", + "\n", + " normalize_input_colors = True\n", + " is_output_normalized = True\n", + " with torch.no_grad():\n", + " target_image = (torch.as_tensor(target_image, dtype=torch.float, device=device).permute(2,0,1) / 255)[None]\n", + " ref_image = (torch.as_tensor(ref_image, dtype=torch.float, device=device).permute(2,0,1) / 255)[None]\n", + "\n", + " if normalize_input_colors:\n", + " ref_image = (ref_image - imagenet_mean) / imagenet_std\n", + " target_image = (target_image - imagenet_mean) / imagenet_std\n", + "\n", + " out, mask, _ = model(target_image, ref_image)\n", + " if not is_output_normalized:\n", + " predicted_image = model.unpatchify(out)\n", + " else:\n", + " patchified = model.patchify(target_image)\n", + " mean = patchified.mean(dim=-1, keepdim=True)\n", + " var = patchified.var(dim=-1, keepdim=True)\n", + " pred_renorm = out * (var + 1.e-6)**.5 + mean\n", + " predicted_image = model.unpatchify(pred_renorm)\n", + "\n", + " image_masks = model.unpatchify(model.patchify(torch.ones_like(ref_image)) * mask[:,:,None])\n", + " masked_target_image = (1 - image_masks) * target_image\n", + " \n", + " if not reconstruct_unmasked_patches:\n", + " predicted_image = predicted_image * image_masks + masked_target_image\n", + "\n", + " if normalize_input_colors:\n", + " predicted_image = predicted_image * imagenet_std + imagenet_mean\n", + " masked_target_image = masked_target_image * imagenet_std + imagenet_mean\n", + " \n", + " masked_target_image = np.asarray(torch.clamp(masked_target_image.squeeze(0).permute(1,2,0) * 255, 0, 255).cpu().numpy(), dtype=np.uint8)\n", + " predicted_image = np.asarray(torch.clamp(predicted_image.squeeze(0).permute(1,2,0) * 255, 0, 255).cpu().numpy(), dtype=np.uint8)\n", + " return masked_target_image, predicted_image" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "os.environ[\"MAGNUM_LOG\"]=\"quiet\"\n", + "os.environ[\"HABITAT_SIM_LOG\"]=\"quiet\"\n", + "import habitat_sim\n", + "\n", + "scene = \"habitat-sim-data/scene_datasets/habitat-test-scenes/skokloster-castle.glb\"\n", + "navmesh = \"habitat-sim-data/scene_datasets/habitat-test-scenes/skokloster-castle.navmesh\"\n", + "\n", + "sim_cfg = habitat_sim.SimulatorConfiguration()\n", + "if use_gpu: sim_cfg.gpu_device_id = 0\n", + "sim_cfg.scene_id = scene\n", + "sim_cfg.load_semantic_mesh = False\n", + "rgb_sensor_spec = habitat_sim.CameraSensorSpec()\n", + "rgb_sensor_spec.uuid = \"color\"\n", + "rgb_sensor_spec.sensor_type = habitat_sim.SensorType.COLOR\n", + "rgb_sensor_spec.resolution = (224,224)\n", + "rgb_sensor_spec.hfov = 56.56\n", + "rgb_sensor_spec.position = [0.0, 0.0, 0.0]\n", + "rgb_sensor_spec.orientation = [0, 0, 0]\n", + "agent_cfg = habitat_sim.agent.AgentConfiguration(sensor_specifications=[rgb_sensor_spec])\n", + "\n", + "\n", + "cfg = habitat_sim.Configuration(sim_cfg, [agent_cfg])\n", + "sim = habitat_sim.Simulator(cfg)\n", + "if navmesh is not None:\n", + " sim.pathfinder.load_nav_mesh(navmesh)\n", + "agent = sim.initialize_agent(agent_id=0)\n", + "\n", + "def sample_random_viewpoint():\n", + " \"\"\" Sample a random viewpoint using the navmesh \"\"\"\n", + " nav_point = sim.pathfinder.get_random_navigable_point()\n", + " viewpoint_height = np.random.uniform(1.0, 1.6)\n", + " viewpoint_position = nav_point + viewpoint_height * habitat_sim.geo.UP\n", + " viewpoint_orientation = quaternion.from_rotation_vector(np.random.uniform(-np.pi, np.pi) * habitat_sim.geo.UP)\n", + " return viewpoint_position, viewpoint_orientation\n", + "\n", + "def render_viewpoint(position, orientation):\n", + " agent_state = habitat_sim.AgentState()\n", + " agent_state.position = position\n", + " agent_state.rotation = orientation\n", + " agent.set_state(agent_state)\n", + " viewpoint_observations = sim.get_sensor_observations(agent_ids=0)\n", + " image = viewpoint_observations['color'][:,:,:3]\n", + " image = np.asarray(np.clip(1.5 * np.asarray(image, dtype=float), 0, 255), dtype=np.uint8)\n", + " return image" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "ref_position, ref_orientation = sample_random_viewpoint()\n", + "ref_image = render_viewpoint(ref_position, ref_orientation)\n", + "plt.clf()\n", + "fig, axes = plt.subplots(1,1, squeeze=False, num=1)\n", + "axes[0,0].imshow(ref_image)\n", + "for ax in axes.flatten():\n", + " ax.set_xticks([])\n", + " ax.set_yticks([])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "reconstruct_unmasked_patches = False\n", + "\n", + "def show_demo(masking_ratio, x, y, z, panorama, elevation):\n", + " R = quaternion.as_rotation_matrix(ref_orientation)\n", + " target_position = ref_position + x * R[:,0] + y * R[:,1] + z * R[:,2]\n", + " target_orientation = (ref_orientation\n", + " * quaternion.from_rotation_vector(-elevation * np.pi/180 * habitat_sim.geo.LEFT) \n", + " * quaternion.from_rotation_vector(-panorama * np.pi/180 * habitat_sim.geo.UP))\n", + " \n", + " ref_image = render_viewpoint(ref_position, ref_orientation)\n", + " target_image = render_viewpoint(target_position, target_orientation)\n", + "\n", + " masked_target_image, predicted_image = process_images(ref_image, target_image, masking_ratio, reconstruct_unmasked_patches)\n", + "\n", + " fig, axes = plt.subplots(1,4, squeeze=True, dpi=300)\n", + " axes[0].imshow(ref_image)\n", + " axes[0].set_xlabel(\"Reference\")\n", + " axes[1].imshow(masked_target_image)\n", + " axes[1].set_xlabel(\"Masked target\")\n", + " axes[2].imshow(predicted_image)\n", + " axes[2].set_xlabel(\"Reconstruction\") \n", + " axes[3].imshow(target_image)\n", + " axes[3].set_xlabel(\"Target\")\n", + " for ax in axes.flatten():\n", + " ax.set_xticks([])\n", + " ax.set_yticks([])\n", + "\n", + "interact(show_demo,\n", + " masking_ratio=widgets.FloatSlider(description='masking', value=0.9, min=0.0, max=1.0),\n", + " x=widgets.FloatSlider(value=0.0, min=-0.5, max=0.5, step=0.05),\n", + " y=widgets.FloatSlider(value=0.0, min=-0.5, max=0.5, step=0.05),\n", + " z=widgets.FloatSlider(value=0.0, min=-0.5, max=0.5, step=0.05),\n", + " panorama=widgets.FloatSlider(value=0.0, min=-20, max=20, step=0.5),\n", + " elevation=widgets.FloatSlider(value=0.0, min=-20, max=20, step=0.5));" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.13" + }, + "vscode": { + "interpreter": { + "hash": "f9237820cd248d7e07cb4fb9f0e4508a85d642f19d831560c0a4b61f3e907e67" + } + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/longstream/utils/vendor/croco/models/blocks.py b/longstream/utils/vendor/croco/models/blocks.py new file mode 100644 index 0000000000000000000000000000000000000000..d030730b1bd2dd7d297a3c80aaeb0620c4620974 --- /dev/null +++ b/longstream/utils/vendor/croco/models/blocks.py @@ -0,0 +1,515 @@ +import collections.abc +from itertools import repeat +import math + +import torch +import torch.nn as nn +from torch.nn.functional import scaled_dot_product_attention +from torch.nn.attention import SDPBackend + + +def _ntuple(n): + def parse(x): + if isinstance(x, collections.abc.Iterable) and not isinstance(x, str): + return x + return tuple(repeat(x, n)) + + return parse + + +to_2tuple = _ntuple(2) + + +def drop_path( + x, drop_prob: float = 0.0, training: bool = False, scale_by_keep: bool = True +): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" + if drop_prob == 0.0 or not training: + return x + keep_prob = 1 - drop_prob + shape = (x.shape[0],) + (1,) * (x.ndim - 1) + random_tensor = x.new_empty(shape).bernoulli_(keep_prob) + if keep_prob > 0.0 and scale_by_keep: + random_tensor.div_(keep_prob) + return x * random_tensor + + +class DropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" + + def __init__(self, drop_prob: float = 0.0, scale_by_keep: bool = True): + super(DropPath, self).__init__() + self.drop_prob = drop_prob + self.scale_by_keep = scale_by_keep + + def forward(self, x): + return drop_path(x, self.drop_prob, self.training, self.scale_by_keep) + + def extra_repr(self): + return f"drop_prob={round(self.drop_prob,3):0.3f}" + + +class Mlp(nn.Module): + """MLP as used in Vision Transformer, MLP-Mixer and related networks""" + + def __init__( + self, + in_features, + hidden_features=None, + out_features=None, + act_layer=nn.GELU, + bias=True, + drop=0.0, + ): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + bias = to_2tuple(bias) + drop_probs = to_2tuple(drop) + + self.fc1 = nn.Linear(in_features, hidden_features, bias=bias[0]) + self.act = act_layer() + self.drop1 = nn.Dropout(drop_probs[0]) + self.fc2 = nn.Linear(hidden_features, out_features, bias=bias[1]) + self.drop2 = nn.Dropout(drop_probs[1]) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop1(x) + x = self.fc2(x) + x = self.drop2(x) + return x + + +class Attention(nn.Module): + def __init__( + self, + dim, + rope=None, + num_heads=8, + qkv_bias=False, + attn_drop=0.0, + proj_drop=0.0, + attn_mask=None, + is_causal=False, + attn_implementation="pytorch_naive", + attn_bias_for_inference_enabled=False, + ): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim ** -0.5 + + self.attn_bias_for_inference_enabled = attn_bias_for_inference_enabled + gamma = 1.0 + train_seqlen = 20 + inference_seqlen = 137 + self.attn_bias_scale = ( + head_dim ** -0.5 + * (gamma * math.log(inference_seqlen) / math.log(train_seqlen)) ** 0.5 + ) + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.dropout_p = attn_drop + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + self.rope = rope + self.attn_mask = attn_mask + self.is_causal = is_causal + self.attn_implementation = attn_implementation + + def forward(self, x, xpos): + B, N, C = x.shape + + qkv = ( + self.qkv(x) + .reshape(B, N, 3, self.num_heads, C // self.num_heads) + .transpose(1, 3) + ) + q, k, v = [qkv[:, :, i] for i in range(3)] + + if self.rope is not None: + with torch.autocast( + device_type=next(self.parameters()).device.type, dtype=torch.float32 + ): + q = self.rope(q, xpos) if xpos is not None else q + k = self.rope(k, xpos) if xpos is not None else k + + if not self.training and self.attn_bias_for_inference_enabled: + scale = self.attn_bias_scale + else: + scale = self.scale + + if self.attn_implementation == "pytorch_naive": + assert ( + self.attn_mask is None + ), "attn_mask not supported for pytorch_naive implementation of scaled dot product attention" + assert ( + self.is_causal is False + ), "is_causal not supported for pytorch_naive implementation of scaled dot product attention" + dtype = k.dtype + with torch.autocast("cuda", dtype=torch.bfloat16): + x = (q @ k.transpose(-2, -1)) * scale + x = x.softmax(dim=-1) + x = self.attn_drop(x) + if dtype == torch.float32: + x = x.to(torch.float32) + x = (x @ v).transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + elif self.attn_implementation == "flash_attention": + with torch.nn.attention.sdpa_kernel(SDPBackend.FLASH_ATTENTION): + dtype = k.dtype + with torch.autocast("cuda", dtype=torch.bfloat16): + x = scaled_dot_product_attention( + q, + k, + v, + attn_mask=self.attn_mask, + dropout_p=self.dropout_p, + is_causal=self.is_causal, + scale=scale, + ) + if dtype == torch.float32: + x = x.to(torch.float32) + x = x.transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + elif self.attn_implementation == "pytorch_auto": + with torch.nn.attention.sdpa_kernel( + [ + SDPBackend.EFFICIENT_ATTENTION, + ] + ): + dtype = k.dtype + with torch.autocast("cuda", dtype=torch.bfloat16): + x = scaled_dot_product_attention( + q, + k, + v, + attn_mask=self.attn_mask, + dropout_p=self.dropout_p, + is_causal=self.is_causal, + scale=scale, + ) + if dtype == torch.float32: + x = x.to(torch.float32) + x = x.transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + else: + raise ValueError(f"Unknown attn_implementation: {self.attn_implementation}") + + return x + + +class Block(nn.Module): + def __init__( + self, + dim, + num_heads, + mlp_ratio=4.0, + qkv_bias=False, + drop=0.0, + attn_drop=0.0, + drop_path=0.0, + act_layer=nn.GELU, + norm_layer=nn.LayerNorm, + rope=None, + attn_implementation="pytorch_naive", + attn_bias_for_inference_enabled=False, + ): + super().__init__() + self.norm1 = norm_layer(dim) + self.attn = Attention( + dim, + rope=rope, + num_heads=num_heads, + qkv_bias=qkv_bias, + attn_drop=attn_drop, + proj_drop=drop, + attn_implementation=attn_implementation, + attn_bias_for_inference_enabled=attn_bias_for_inference_enabled, + ) + + self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp( + in_features=dim, + hidden_features=mlp_hidden_dim, + act_layer=act_layer, + drop=drop, + ) + + def forward(self, x, xpos): + x = x + self.drop_path(self.attn(self.norm1(x), xpos)) + x = x + self.drop_path(self.mlp(self.norm2(x))) + return x + + +class CrossAttention(nn.Module): + def __init__( + self, + dim, + rope=None, + num_heads=8, + qkv_bias=False, + attn_drop=0.0, + proj_drop=0.0, + attn_mask=None, + is_causal=False, + attn_implementation="pytorch_naive", + ): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim ** -0.5 + + self.projq = nn.Linear(dim, dim, bias=qkv_bias) + self.projk = nn.Linear(dim, dim, bias=qkv_bias) + self.projv = nn.Linear(dim, dim, bias=qkv_bias) + self.dropout_p = attn_drop + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + self.rope = rope + + self.attn_mask = attn_mask + self.is_causal = is_causal + self.attn_implementation = attn_implementation + + def forward(self, query, key, value, qpos, kpos): + B, Nq, C = query.shape + Nk = key.shape[1] + Nv = value.shape[1] + + q = ( + self.projq(query) + .reshape(B, Nq, self.num_heads, C // self.num_heads) + .permute(0, 2, 1, 3) + ) + k = ( + self.projk(key) + .reshape(B, Nk, self.num_heads, C // self.num_heads) + .permute(0, 2, 1, 3) + ) + v = ( + self.projv(value) + .reshape(B, Nv, self.num_heads, C // self.num_heads) + .permute(0, 2, 1, 3) + ) + + if self.rope is not None: + with torch.autocast( + device_type=next(self.parameters()).device.type, dtype=torch.float32 + ): + q = self.rope(q, qpos) if qpos is not None else q + k = self.rope(k, kpos) if kpos is not None else k + + if self.attn_implementation == "pytorch_naive": + assert ( + self.attn_mask is None + ), "attn_mask not supported for pytorch_naive implementation of scaled dot product attention" + assert ( + self.is_causal is False + ), "is_causal not supported for pytorch_naive implementation of scaled dot product attention" + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, Nq, C) + x = self.proj(x) + x = self.proj_drop(x) + elif self.attn_implementation == "flash_attention": + with torch.nn.attention.sdpa_kernel(SDPBackend.FLASH_ATTENTION): + + q = q.to(torch.bfloat16) + k = k.to(torch.bfloat16) + v = v.to(torch.bfloat16) + x = scaled_dot_product_attention( + q, + k, + v, + attn_mask=self.attn_mask, + dropout_p=self.dropout_p, + is_causal=self.is_causal, + scale=self.scale, + ) + + x = x.to(torch.float32) + x = x.transpose(1, 2).reshape(B, Nq, C) + x = self.proj(x) + x = self.proj_drop(x) + else: + raise ValueError(f"Unknown attn_implementation: {self.attn_implementation}") + + return x + + +class TrackingAttentionBlock(nn.Module): + def __init__( + self, + dim, + num_heads, + mlp_ratio=4.0, + qkv_bias=False, + drop=0.0, + attn_drop=0.0, + drop_path=0.0, + act_layer=nn.GELU, + norm_layer=nn.LayerNorm, + norm_mem=True, + rope=None, + attn_implementation="pytorch_naive", + ): + super().__init__() + self.cross_attn = CrossAttention( + dim, + rope=rope, + num_heads=num_heads, + qkv_bias=qkv_bias, + attn_drop=attn_drop, + proj_drop=drop, + attn_implementation=attn_implementation, + ) + self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + self.norm2 = norm_layer(dim) + self.norm3 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp( + in_features=dim, + hidden_features=mlp_hidden_dim, + act_layer=act_layer, + drop=drop, + ) + self.norm_y = norm_layer(dim) if norm_mem else nn.Identity() + + def forward(self, x, y, xpos, ypos): + y_ = self.norm_y(y) + x = x + self.drop_path(self.cross_attn(self.norm2(x), y_, y_, xpos, ypos)) + x = x + self.drop_path(self.mlp(self.norm3(x))) + return x + + +class DecoderBlock(nn.Module): + def __init__( + self, + dim, + num_heads, + mlp_ratio=4.0, + qkv_bias=False, + drop=0.0, + attn_drop=0.0, + drop_path=0.0, + act_layer=nn.GELU, + norm_layer=nn.LayerNorm, + norm_mem=True, + rope=None, + attn_implementation="pytorch_naive", + ): + super().__init__() + self.norm1 = norm_layer(dim) + self.attn = Attention( + dim, + rope=rope, + num_heads=num_heads, + qkv_bias=qkv_bias, + attn_drop=attn_drop, + proj_drop=drop, + attn_implementation=attn_implementation, + ) + self.cross_attn = CrossAttention( + dim, + rope=rope, + num_heads=num_heads, + qkv_bias=qkv_bias, + attn_drop=attn_drop, + proj_drop=drop, + attn_implementation=attn_implementation, + ) + self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + self.norm2 = norm_layer(dim) + self.norm3 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp( + in_features=dim, + hidden_features=mlp_hidden_dim, + act_layer=act_layer, + drop=drop, + ) + self.norm_y = norm_layer(dim) if norm_mem else nn.Identity() + + def forward(self, x, y, xpos, ypos): + x = x + self.drop_path(self.attn(self.norm1(x), xpos)) + y_ = self.norm_y(y) + x = x + self.drop_path(self.cross_attn(self.norm2(x), y_, y_, xpos, ypos)) + x = x + self.drop_path(self.mlp(self.norm3(x))) + return x, y + + +class PositionGetter(object): + """return positions of patches""" + + def __init__(self): + self.cache_positions = {} + + def __call__(self, b, h, w, device): + if not (h, w) in self.cache_positions: + x = torch.arange(w, device=device) + y = torch.arange(h, device=device) + self.cache_positions[h, w] = torch.cartesian_prod(y, x) + pos = self.cache_positions[h, w].view(1, h * w, 2).expand(b, -1, 2).clone() + return pos + + +class PatchEmbed(nn.Module): + """just adding _init_weights + position getter compared to timm.models.layers.patch_embed.PatchEmbed""" + + def __init__( + self, + img_size=224, + patch_size=16, + in_chans=3, + embed_dim=768, + norm_layer=None, + flatten=True, + ): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + self.img_size = img_size + self.patch_size = patch_size + self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1]) + self.num_patches = self.grid_size[0] * self.grid_size[1] + self.flatten = flatten + + self.proj = nn.Conv2d( + in_chans, embed_dim, kernel_size=patch_size, stride=patch_size + ) + self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() + + self.position_getter = PositionGetter() + + def forward(self, x): + B, C, H, W = x.shape + torch._assert( + H == self.img_size[0], + f"Input image height ({H}) doesn't match model ({self.img_size[0]}).", + ) + torch._assert( + W == self.img_size[1], + f"Input image width ({W}) doesn't match model ({self.img_size[1]}).", + ) + x = self.proj(x) + pos = self.position_getter(B, x.size(2), x.size(3), x.device) + if self.flatten: + x = x.flatten(2).transpose(1, 2).contiguous() + x = self.norm(x) + return x, pos + + def _init_weights(self): + w = self.proj.weight.data + torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1])) diff --git a/longstream/utils/vendor/croco/models/criterion.py b/longstream/utils/vendor/croco/models/criterion.py new file mode 100644 index 0000000000000000000000000000000000000000..6f75d94a6770aa5ba1a9cfe63c95b528135c65fe --- /dev/null +++ b/longstream/utils/vendor/croco/models/criterion.py @@ -0,0 +1,26 @@ +import torch + + +class MaskedMSE(torch.nn.Module): + def __init__(self, norm_pix_loss=False, masked=True): + """ + norm_pix_loss: normalize each patch by their pixel mean and variance + masked: compute loss over the masked patches only + """ + super().__init__() + self.norm_pix_loss = norm_pix_loss + self.masked = masked + + def forward(self, pred, mask, target): + if self.norm_pix_loss: + mean = target.mean(dim=-1, keepdim=True) + var = target.var(dim=-1, keepdim=True) + target = (target - mean) / (var + 1.0e-6) ** 0.5 + + loss = (pred - target) ** 2 + loss = loss.mean(dim=-1) + if self.masked: + loss = (loss * mask).sum() / mask.sum() + else: + loss = loss.mean() + return loss diff --git a/longstream/utils/vendor/croco/models/croco.py b/longstream/utils/vendor/croco/models/croco.py new file mode 100644 index 0000000000000000000000000000000000000000..64f47bf5b93a815e4b1b80a429624b24c643498a --- /dev/null +++ b/longstream/utils/vendor/croco/models/croco.py @@ -0,0 +1,288 @@ +import torch +import torch.nn as nn + +torch.backends.cuda.matmul.allow_tf32 = True +from functools import partial + +from longstream.utils.vendor.croco.models.blocks import Block, DecoderBlock, PatchEmbed +from longstream.utils.vendor.croco.models.masking import RandomMask +from longstream.utils.vendor.croco.models.pos_embed import ( + RoPE2D, + get_2d_sincos_pos_embed, +) + + +class CroCoNet(nn.Module): + def __init__( + self, + img_size=224, + patch_size=16, + mask_ratio=0.9, + enc_embed_dim=768, + enc_depth=12, + enc_num_heads=12, + dec_embed_dim=512, + dec_depth=8, + dec_num_heads=16, + mlp_ratio=4, + norm_layer=partial(nn.LayerNorm, eps=1e-6), + norm_im2_in_dec=True, + pos_embed="cosine", + attn_implementation="pytorch_naive", + ): + super(CroCoNet, self).__init__() + + self._set_patch_embed(img_size, patch_size, enc_embed_dim) + + self._set_mask_generator(self.patch_embed.num_patches, mask_ratio) + + self.pos_embed = pos_embed + if pos_embed == "cosine": + + enc_pos_embed = get_2d_sincos_pos_embed( + enc_embed_dim, int(self.patch_embed.num_patches ** 0.5), n_cls_token=0 + ) + self.register_buffer( + "enc_pos_embed", torch.from_numpy(enc_pos_embed).float() + ) + + dec_pos_embed = get_2d_sincos_pos_embed( + dec_embed_dim, int(self.patch_embed.num_patches ** 0.5), n_cls_token=0 + ) + self.register_buffer( + "dec_pos_embed", torch.from_numpy(dec_pos_embed).float() + ) + + self.rope = None + elif pos_embed.startswith("RoPE"): + self.enc_pos_embed = None + self.dec_pos_embed = None + if RoPE2D is None: + raise ImportError( + "Cannot find cuRoPE2D, please install it following the README instructions" + ) + freq = float(pos_embed[len("RoPE") :]) + self.rope = RoPE2D(freq=freq) + else: + raise NotImplementedError("Unknown pos_embed " + pos_embed) + + self.attn_implementation = attn_implementation + + self.enc_depth = enc_depth + self.enc_embed_dim = enc_embed_dim + self.enc_blocks = nn.ModuleList( + [ + Block( + enc_embed_dim, + enc_num_heads, + mlp_ratio, + qkv_bias=True, + norm_layer=norm_layer, + rope=self.rope, + attn_implementation=attn_implementation, + ) + for i in range(enc_depth) + ] + ) + self.enc_norm = norm_layer(enc_embed_dim) + + self._set_mask_token(dec_embed_dim) + + self._set_decoder( + enc_embed_dim, + dec_embed_dim, + dec_num_heads, + dec_depth, + mlp_ratio, + norm_layer, + norm_im2_in_dec, + ) + + self._set_prediction_head(dec_embed_dim, patch_size) + + self.initialize_weights() + + def _set_patch_embed(self, img_size=224, patch_size=16, enc_embed_dim=768): + self.patch_embed = PatchEmbed(img_size, patch_size, 3, enc_embed_dim) + + def _set_mask_generator(self, num_patches, mask_ratio): + self.mask_generator = RandomMask(num_patches, mask_ratio) + + def _set_mask_token(self, dec_embed_dim): + self.mask_token = nn.Parameter(torch.zeros(1, 1, dec_embed_dim)) + + def _set_decoder( + self, + enc_embed_dim, + dec_embed_dim, + dec_num_heads, + dec_depth, + mlp_ratio, + norm_layer, + norm_im2_in_dec, + ): + self.dec_depth = dec_depth + self.dec_embed_dim = dec_embed_dim + + self.decoder_embed = nn.Linear(enc_embed_dim, dec_embed_dim, bias=True) + + self.dec_blocks = nn.ModuleList( + [ + DecoderBlock( + dec_embed_dim, + dec_num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=True, + norm_layer=norm_layer, + norm_mem=norm_im2_in_dec, + rope=self.rope, + attn_implementation=self.attn_implementation, + ) + for i in range(dec_depth) + ] + ) + + self.dec_norm = norm_layer(dec_embed_dim) + + def _set_prediction_head(self, dec_embed_dim, patch_size): + self.prediction_head = nn.Linear(dec_embed_dim, patch_size ** 2 * 3, bias=True) + + def initialize_weights(self): + + self.patch_embed._init_weights() + + if self.mask_token is not None: + torch.nn.init.normal_(self.mask_token, std=0.02) + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + + torch.nn.init.xavier_uniform_(m.weight) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + def _encode_image(self, image, do_mask=False, return_all_blocks=False): + """ + image has B x 3 x img_size x img_size + do_mask: whether to perform masking or not + return_all_blocks: if True, return the features at the end of every block + instead of just the features from the last block (eg for some prediction heads) + """ + + x, pos = self.patch_embed(image) + + if self.enc_pos_embed is not None: + x = x + self.enc_pos_embed[None, ...] + + B, N, C = x.size() + if do_mask: + masks = self.mask_generator(x) + x = x[~masks].view(B, -1, C) + posvis = pos[~masks].view(B, -1, 2) + else: + B, N, C = x.size() + masks = torch.zeros((B, N), dtype=bool) + posvis = pos + + if return_all_blocks: + out = [] + for blk in self.enc_blocks: + x = blk(x, posvis) + out.append(x) + out[-1] = self.enc_norm(out[-1]) + return out, pos, masks + else: + for blk in self.enc_blocks: + x = blk(x, posvis) + x = self.enc_norm(x) + return x, pos, masks + + def _decoder(self, feat1, pos1, masks1, feat2, pos2, return_all_blocks=False): + """ + return_all_blocks: if True, return the features at the end of every block + instead of just the features from the last block (eg for some prediction heads) + + masks1 can be None => assume image1 fully visible + """ + + visf1 = self.decoder_embed(feat1) + f2 = self.decoder_embed(feat2) + + B, Nenc, C = visf1.size() + if masks1 is None: + f1_ = visf1 + else: + Ntotal = masks1.size(1) + f1_ = self.mask_token.repeat(B, Ntotal, 1).to(dtype=visf1.dtype) + f1_[~masks1] = visf1.view(B * Nenc, C) + + if self.dec_pos_embed is not None: + f1_ = f1_ + self.dec_pos_embed + f2 = f2 + self.dec_pos_embed + + out = f1_ + out2 = f2 + if return_all_blocks: + _out, out = out, [] + for blk in self.dec_blocks: + _out, out2 = blk(_out, out2, pos1, pos2) + out.append(_out) + out[-1] = self.dec_norm(out[-1]) + else: + for blk in self.dec_blocks: + out, out2 = blk(out, out2, pos1, pos2) + out = self.dec_norm(out) + return out + + def patchify(self, imgs): + """ + imgs: (B, 3, H, W) + x: (B, L, patch_size**2 *3) + """ + p = self.patch_embed.patch_size[0] + assert imgs.shape[2] == imgs.shape[3] and imgs.shape[2] % p == 0 + + h = w = imgs.shape[2] // p + x = imgs.reshape(shape=(imgs.shape[0], 3, h, p, w, p)) + x = torch.einsum("nchpwq->nhwpqc", x) + x = x.reshape(shape=(imgs.shape[0], h * w, p ** 2 * 3)) + + return x + + def unpatchify(self, x, channels=3): + """ + x: (N, L, patch_size**2 *channels) + imgs: (N, 3, H, W) + """ + patch_size = self.patch_embed.patch_size[0] + h = w = int(x.shape[1] ** 0.5) + assert h * w == x.shape[1] + x = x.reshape(shape=(x.shape[0], h, w, patch_size, patch_size, channels)) + x = torch.einsum("nhwpqc->nchpwq", x) + imgs = x.reshape(shape=(x.shape[0], channels, h * patch_size, h * patch_size)) + return imgs + + def forward(self, img1, img2): + """ + img1: tensor of size B x 3 x img_size x img_size + img2: tensor of size B x 3 x img_size x img_size + + out will be B x N x (3*patch_size*patch_size) + masks are also returned as B x N just in case + """ + + feat1, pos1, mask1 = self._encode_image(img1, do_mask=True) + + feat2, pos2, _ = self._encode_image(img2, do_mask=False) + + decfeat = self._decoder(feat1, pos1, mask1, feat2, pos2) + + out = self.prediction_head(decfeat) + + target = self.patchify(img1) + return out, mask1, target diff --git a/longstream/utils/vendor/croco/models/croco_downstream.py b/longstream/utils/vendor/croco/models/croco_downstream.py new file mode 100644 index 0000000000000000000000000000000000000000..761f8c7f036ca165f5122a440e87148e01bda66f --- /dev/null +++ b/longstream/utils/vendor/croco/models/croco_downstream.py @@ -0,0 +1,123 @@ +import torch + +from .croco import CroCoNet + + +def croco_args_from_ckpt(ckpt): + if "croco_kwargs" in ckpt: + return ckpt["croco_kwargs"] + elif "args" in ckpt and hasattr(ckpt["args"], "model"): + s = ckpt["args"].model + assert s.startswith("CroCoNet(") + return eval("dict" + s[len("CroCoNet") :]) + else: + return dict() + + +class CroCoDownstreamMonocularEncoder(CroCoNet): + def __init__(self, head, **kwargs): + """Build network for monocular downstream task, only using the encoder. + It takes an extra argument head, that is called with the features + and a dictionary img_info containing 'width' and 'height' keys + The head is setup with the croconet arguments in this init function + NOTE: It works by *calling super().__init__() but with redefined setters + + """ + super(CroCoDownstreamMonocularEncoder, self).__init__(**kwargs) + head.setup(self) + self.head = head + + def _set_mask_generator(self, *args, **kwargs): + """No mask generator""" + return + + def _set_mask_token(self, *args, **kwargs): + """No mask token""" + self.mask_token = None + return + + def _set_decoder(self, *args, **kwargs): + """No decoder""" + return + + def _set_prediction_head(self, *args, **kwargs): + """No 'prediction head' for downstream tasks.""" + return + + def forward(self, img): + """ + img if of size batch_size x 3 x h x w + """ + B, C, H, W = img.size() + img_info = {"height": H, "width": W} + need_all_layers = ( + hasattr(self.head, "return_all_blocks") and self.head.return_all_blocks + ) + out, _, _ = self._encode_image( + img, do_mask=False, return_all_blocks=need_all_layers + ) + return self.head(out, img_info) + + +class CroCoDownstreamBinocular(CroCoNet): + def __init__(self, head, **kwargs): + """Build network for binocular downstream task + It takes an extra argument head, that is called with the features + and a dictionary img_info containing 'width' and 'height' keys + The head is setup with the croconet arguments in this init function + """ + super(CroCoDownstreamBinocular, self).__init__(**kwargs) + head.setup(self) + self.head = head + + def _set_mask_generator(self, *args, **kwargs): + """No mask generator""" + return + + def _set_mask_token(self, *args, **kwargs): + """No mask token""" + self.mask_token = None + return + + def _set_prediction_head(self, *args, **kwargs): + """No prediction head for downstream tasks, define your own head""" + return + + def encode_image_pairs(self, img1, img2, return_all_blocks=False): + """run encoder for a pair of images + it is actually ~5% faster to concatenate the images along the batch dimension + than to encode them separately + """ + + out, pos, _ = self._encode_image( + torch.cat((img1, img2), dim=0), + do_mask=False, + return_all_blocks=return_all_blocks, + ) + if return_all_blocks: + out, out2 = list(map(list, zip(*[o.chunk(2, dim=0) for o in out]))) + out2 = out2[-1] + else: + out, out2 = out.chunk(2, dim=0) + pos, pos2 = pos.chunk(2, dim=0) + return out, out2, pos, pos2 + + def forward(self, img1, img2): + B, C, H, W = img1.size() + img_info = {"height": H, "width": W} + return_all_blocks = ( + hasattr(self.head, "return_all_blocks") and self.head.return_all_blocks + ) + out, out2, pos, pos2 = self.encode_image_pairs( + img1, img2, return_all_blocks=return_all_blocks + ) + if return_all_blocks: + decout = self._decoder( + out[-1], pos, None, out2, pos2, return_all_blocks=return_all_blocks + ) + decout = out + decout + else: + decout = self._decoder( + out, pos, None, out2, pos2, return_all_blocks=return_all_blocks + ) + return self.head(decout, img_info) diff --git a/longstream/utils/vendor/croco/models/curope/__init__.py b/longstream/utils/vendor/croco/models/curope/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5bdfc72933cc1c3ec5ce285f12882d1934333ecf --- /dev/null +++ b/longstream/utils/vendor/croco/models/curope/__init__.py @@ -0,0 +1 @@ +from .curope2d import cuRoPE2D diff --git a/longstream/utils/vendor/croco/models/curope/curope.cpp b/longstream/utils/vendor/croco/models/curope/curope.cpp new file mode 100644 index 0000000000000000000000000000000000000000..cc4ec136b477c37f3b142bd305863f4a748c4900 --- /dev/null +++ b/longstream/utils/vendor/croco/models/curope/curope.cpp @@ -0,0 +1,60 @@ + +#include + +void rope_2d_cuda( torch::Tensor tokens, const torch::Tensor pos, const float base, const float fwd ); + +void rope_2d_cpu( torch::Tensor tokens, const torch::Tensor positions, const float base, const float fwd ) +{ + const int B = tokens.size(0); + const int N = tokens.size(1); + const int H = tokens.size(2); + const int D = tokens.size(3) / 4; + + auto tok = tokens.accessor(); + auto pos = positions.accessor(); + + for (int b = 0; b < B; b++) { + for (int x = 0; x < 2; x++) { // y and then x (2d) + for (int n = 0; n < N; n++) { + + const int p = pos[b][n][x]; + + for (int h = 0; h < H; h++) { + for (int d = 0; d < D; d++) { + float u = tok[b][n][h][d+0+x*2*D]; + float v = tok[b][n][h][d+D+x*2*D]; + + const float inv_freq = fwd * p / powf(base, d/float(D)); + float c = cosf(inv_freq); + float s = sinf(inv_freq); + + tok[b][n][h][d+0+x*2*D] = u*c - v*s; + tok[b][n][h][d+D+x*2*D] = v*c + u*s; + } + } + } + } + } +} + +void rope_2d( torch::Tensor tokens, // B,N,H,D + const torch::Tensor positions, // B,N,2 + const float base, + const float fwd ) +{ + TORCH_CHECK(tokens.dim() == 4, "tokens must have 4 dimensions"); + TORCH_CHECK(positions.dim() == 3, "positions must have 3 dimensions"); + TORCH_CHECK(tokens.size(0) == positions.size(0), "batch size differs between tokens & positions"); + TORCH_CHECK(tokens.size(1) == positions.size(1), "seq_length differs between tokens & positions"); + TORCH_CHECK(positions.size(2) == 2, "positions.shape[2] must be equal to 2"); + TORCH_CHECK(tokens.is_cuda() == positions.is_cuda(), "tokens and positions are not on the same device" ); + + if (tokens.is_cuda()) + rope_2d_cuda( tokens, positions, base, fwd ); + else + rope_2d_cpu( tokens, positions, base, fwd ); +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("rope_2d", &rope_2d, "RoPE 2d forward/backward"); +} diff --git a/longstream/utils/vendor/croco/models/curope/curope2d.py b/longstream/utils/vendor/croco/models/curope/curope2d.py new file mode 100644 index 0000000000000000000000000000000000000000..d6cfca9ee39cd392e7ec9ac5c452fdb0b63a6d03 --- /dev/null +++ b/longstream/utils/vendor/croco/models/curope/curope2d.py @@ -0,0 +1,38 @@ +import torch + +try: + import curope as _kernels +except ModuleNotFoundError: + from . import curope as _kernels + + +class cuRoPE2D_func(torch.autograd.Function): + @staticmethod + @torch.amp.custom_fwd(cast_inputs=torch.float32, device_type="cuda") + def forward(ctx, tokens, positions, base, F0=1): + ctx.save_for_backward(positions) + ctx.saved_base = base + ctx.saved_F0 = F0 + + _kernels.rope_2d(tokens, positions, base, F0) + ctx.mark_dirty(tokens) + return tokens + + @staticmethod + @torch.amp.custom_bwd(device_type="cuda") + def backward(ctx, grad_res): + positions, base, F0 = ctx.saved_tensors[0], ctx.saved_base, ctx.saved_F0 + _kernels.rope_2d(grad_res, positions, base, -F0) + ctx.mark_dirty(grad_res) + return grad_res, None, None, None + + +class cuRoPE2D(torch.nn.Module): + def __init__(self, freq=100.0, F0=1.0): + super().__init__() + self.base = freq + self.F0 = F0 + + def forward(self, tokens, positions): + cuRoPE2D_func.apply(tokens.transpose(1, 2), positions, self.base, self.F0) + return tokens diff --git a/longstream/utils/vendor/croco/models/curope/kernels.cu b/longstream/utils/vendor/croco/models/curope/kernels.cu new file mode 100644 index 0000000000000000000000000000000000000000..a9cad40fa88e390cd65e938ea4042809be476b59 --- /dev/null +++ b/longstream/utils/vendor/croco/models/curope/kernels.cu @@ -0,0 +1,80 @@ + +#include +#include +#include +#include + +#define CHECK_CUDA(tensor) {\ + TORCH_CHECK((tensor).is_cuda(), #tensor " is not in cuda memory"); \ + TORCH_CHECK((tensor).is_contiguous(), #tensor " is not contiguous"); } +void CHECK_KERNEL() {auto error = cudaGetLastError(); TORCH_CHECK( error == cudaSuccess, cudaGetErrorString(error));} + + +template < typename scalar_t > +__global__ void rope_2d_cuda_kernel( + torch::PackedTensorAccessor32 tokens, + const int64_t* __restrict__ pos, + const float base, + const float fwd ) +{ + const int N = tokens.size(1); + const int H = tokens.size(2); + const int D = tokens.size(3); + + extern __shared__ float shared[]; + float* shared_inv_freq = shared + D; + + const int b = blockIdx.x / N; + const int n = blockIdx.x % N; + + const int Q = D / 4; + + if (threadIdx.x < Q) + shared_inv_freq[threadIdx.x] = fwd / powf(base, threadIdx.x/float(Q)); + __syncthreads(); + + const int X = threadIdx.x < D/2 ? 0 : 1; + const int m = (X*D/2) + (threadIdx.x % Q); // index of u_Y or u_X + + const float freq = pos[blockIdx.x*2+X] * shared_inv_freq[threadIdx.x % Q]; + const float cos = cosf(freq); + const float sin = sinf(freq); + + for (int h = 0; h < H; h++) + { + shared[threadIdx.x] = tokens[b][n][h][threadIdx.x]; + __syncthreads(); + + const float u = shared[m]; + const float v = shared[m+Q]; + + if ((threadIdx.x % (D/2)) < Q) + tokens[b][n][h][threadIdx.x] = u*cos - v*sin; + else + tokens[b][n][h][threadIdx.x] = v*cos + u*sin; + } +} + +void rope_2d_cuda( torch::Tensor tokens, const torch::Tensor pos, const float base, const float fwd ) +{ + const int B = tokens.size(0); // batch size + const int N = tokens.size(1); // sequence length + const int H = tokens.size(2); // number of heads + const int D = tokens.size(3); // dimension per head + + TORCH_CHECK(tokens.stride(3) == 1 && tokens.stride(2) == D, "tokens are not contiguous"); + TORCH_CHECK(pos.is_contiguous(), "positions are not contiguous"); + TORCH_CHECK(pos.size(0) == B && pos.size(1) == N && pos.size(2) == 2, "bad pos.shape"); + TORCH_CHECK(D % 4 == 0, "token dim must be multiple of 4"); + + const int THREADS_PER_BLOCK = D; + const int N_BLOCKS = B * N; // each block takes care of H*D values + const int SHARED_MEM = sizeof(float) * (D + D/4); + + AT_DISPATCH_FLOATING_TYPES_AND_HALF(tokens.type(), "rope_2d_cuda", ([&] { + rope_2d_cuda_kernel <<>> ( + tokens.packed_accessor32(), + pos.data_ptr(), + base, fwd); //, N, H, D ); + })); +} diff --git a/longstream/utils/vendor/croco/models/curope/setup.py b/longstream/utils/vendor/croco/models/curope/setup.py new file mode 100644 index 0000000000000000000000000000000000000000..fbe9203fe15c17d2eed8df02ffd868d722b12f5d --- /dev/null +++ b/longstream/utils/vendor/croco/models/curope/setup.py @@ -0,0 +1,23 @@ +from setuptools import setup +from torch import cuda +from torch.utils.cpp_extension import BuildExtension, CUDAExtension + +all_cuda_archs = cuda.get_gencode_flags().replace("compute=", "arch=").split() + +setup( + name="curope", + ext_modules=[ + CUDAExtension( + name="curope", + sources=[ + "curope.cpp", + "kernels.cu", + ], + extra_compile_args=dict( + nvcc=["-O3", "--ptxas-options=-v", "--use_fast_math"] + all_cuda_archs, + cxx=["-O3"], + ), + ) + ], + cmdclass={"build_ext": BuildExtension}, +) diff --git a/longstream/utils/vendor/croco/models/dpt_block.py b/longstream/utils/vendor/croco/models/dpt_block.py new file mode 100644 index 0000000000000000000000000000000000000000..cff4e02bab1631d06150236ab165f0c26ebe289c --- /dev/null +++ b/longstream/utils/vendor/croco/models/dpt_block.py @@ -0,0 +1,506 @@ +from typing import Iterable, List, Optional, Tuple, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange, repeat + + +def pair(t): + return t if isinstance(t, tuple) else (t, t) + + +def make_scratch(in_shape, out_shape, groups=1, expand=False): + scratch = nn.Module() + + out_shape1 = out_shape + out_shape2 = out_shape + out_shape3 = out_shape + out_shape4 = out_shape + if expand == True: + out_shape1 = out_shape + out_shape2 = out_shape * 2 + out_shape3 = out_shape * 4 + out_shape4 = out_shape * 8 + + scratch.layer1_rn = nn.Conv2d( + in_shape[0], + out_shape1, + kernel_size=3, + stride=1, + padding=1, + bias=False, + groups=groups, + ) + scratch.layer2_rn = nn.Conv2d( + in_shape[1], + out_shape2, + kernel_size=3, + stride=1, + padding=1, + bias=False, + groups=groups, + ) + scratch.layer3_rn = nn.Conv2d( + in_shape[2], + out_shape3, + kernel_size=3, + stride=1, + padding=1, + bias=False, + groups=groups, + ) + scratch.layer4_rn = nn.Conv2d( + in_shape[3], + out_shape4, + kernel_size=3, + stride=1, + padding=1, + bias=False, + groups=groups, + ) + + scratch.layer_rn = nn.ModuleList( + [ + scratch.layer1_rn, + scratch.layer2_rn, + scratch.layer3_rn, + scratch.layer4_rn, + ] + ) + + return scratch + + +class ResidualConvUnit_custom(nn.Module): + """Residual convolution module.""" + + def __init__(self, features, activation, bn): + """Init. + Args: + features (int): number of features + """ + super().__init__() + + self.bn = bn + + self.groups = 1 + + self.conv1 = nn.Conv2d( + features, + features, + kernel_size=3, + stride=1, + padding=1, + bias=not self.bn, + groups=self.groups, + ) + + self.conv2 = nn.Conv2d( + features, + features, + kernel_size=3, + stride=1, + padding=1, + bias=not self.bn, + groups=self.groups, + ) + + if self.bn == True: + self.bn1 = nn.BatchNorm2d(features) + self.bn2 = nn.BatchNorm2d(features) + + self.activation = activation + + self.skip_add = nn.quantized.FloatFunctional() + + def forward(self, x): + """Forward pass. + Args: + x (tensor): input + Returns: + tensor: output + """ + + out = self.activation(x) + out = self.conv1(out) + if self.bn == True: + out = self.bn1(out) + + out = self.activation(out) + out = self.conv2(out) + if self.bn == True: + out = self.bn2(out) + + if self.groups > 1: + out = self.conv_merge(out) + + return self.skip_add.add(out, x) + + +class FeatureFusionBlock_custom(nn.Module): + """Feature fusion block.""" + + def __init__( + self, + features, + activation, + deconv=False, + bn=False, + expand=False, + align_corners=True, + width_ratio=1, + ): + """Init. + Args: + features (int): number of features + """ + super(FeatureFusionBlock_custom, self).__init__() + self.width_ratio = width_ratio + + self.deconv = deconv + self.align_corners = align_corners + + self.groups = 1 + + self.expand = expand + out_features = features + if self.expand == True: + out_features = features // 2 + + self.out_conv = nn.Conv2d( + features, + out_features, + kernel_size=1, + stride=1, + padding=0, + bias=True, + groups=1, + ) + + self.resConfUnit1 = ResidualConvUnit_custom(features, activation, bn) + self.resConfUnit2 = ResidualConvUnit_custom(features, activation, bn) + + self.skip_add = nn.quantized.FloatFunctional() + + def forward(self, *xs, max_chunk_size=100): + """Forward pass. + Returns: + tensor: output + """ + output = xs[0] + + if len(xs) == 2: + res = self.resConfUnit1(xs[1]) + if self.width_ratio != 1: + res = F.interpolate( + res, size=(output.shape[2], output.shape[3]), mode="bilinear" + ) + + output = self.skip_add.add(output, res) + + output = self.resConfUnit2(output) + + if self.width_ratio != 1: + + if (output.shape[3] / output.shape[2]) < (2 / 3) * self.width_ratio: + shape = 3 * output.shape[3] + else: + shape = int(self.width_ratio * 2 * output.shape[2]) + output = F.interpolate( + output, size=(2 * output.shape[2], shape), mode="bilinear" + ) + else: + + chunks = torch.split(output, max_chunk_size, dim=0) + outputs = [] + + for chunk in chunks: + out_chunk = nn.functional.interpolate( + chunk, + scale_factor=2, + mode="bilinear", + align_corners=self.align_corners, + ) + outputs.append(out_chunk) + + output = torch.cat(outputs, dim=0) + + output = self.out_conv(output) + return output + + +def make_fusion_block(features, use_bn, width_ratio=1): + return FeatureFusionBlock_custom( + features, + nn.ReLU(False), + deconv=False, + bn=use_bn, + expand=False, + align_corners=True, + width_ratio=width_ratio, + ) + + +class Interpolate(nn.Module): + """Interpolation module.""" + + def __init__(self, scale_factor, mode, align_corners=False): + """Init. + Args: + scale_factor (float): scaling + mode (str): interpolation mode + """ + super(Interpolate, self).__init__() + + self.interp = nn.functional.interpolate + self.scale_factor = scale_factor + self.mode = mode + self.align_corners = align_corners + + def forward(self, x): + """Forward pass. + Args: + x (tensor): input + Returns: + tensor: interpolated data + """ + + x = self.interp( + x, + scale_factor=self.scale_factor, + mode=self.mode, + align_corners=self.align_corners, + ) + + return x + + +class DPTOutputAdapter(nn.Module): + """DPT output adapter. + + :param num_cahnnels: Number of output channels + :param stride_level: tride level compared to the full-sized image. + E.g. 4 for 1/4th the size of the image. + :param patch_size_full: Int or tuple of the patch size over the full image size. + Patch size for smaller inputs will be computed accordingly. + :param hooks: Index of intermediate layers + :param layer_dims: Dimension of intermediate layers + :param feature_dim: Feature dimension + :param last_dim: out_channels/in_channels for the last two Conv2d when head_type == regression + :param use_bn: If set to True, activates batch norm + :param dim_tokens_enc: Dimension of tokens coming from encoder + """ + + def __init__( + self, + num_channels: int = 1, + stride_level: int = 1, + patch_size: Union[int, Tuple[int, int]] = 16, + main_tasks: Iterable[str] = ("rgb",), + hooks: List[int] = [2, 5, 8, 11], + layer_dims: List[int] = [96, 192, 384, 768], + feature_dim: int = 256, + last_dim: int = 32, + use_bn: bool = False, + dim_tokens_enc: Optional[int] = None, + head_type: str = "regression", + output_width_ratio=1, + **kwargs + ): + super().__init__() + self.num_channels = num_channels + self.stride_level = stride_level + self.patch_size = pair(patch_size) + self.main_tasks = main_tasks + self.hooks = hooks + self.layer_dims = layer_dims + self.feature_dim = feature_dim + self.dim_tokens_enc = ( + dim_tokens_enc * len(self.main_tasks) + if dim_tokens_enc is not None + else None + ) + self.head_type = head_type + + self.P_H = max(1, self.patch_size[0] // stride_level) + self.P_W = max(1, self.patch_size[1] // stride_level) + + self.scratch = make_scratch(layer_dims, feature_dim, groups=1, expand=False) + + self.scratch.refinenet1 = make_fusion_block( + feature_dim, use_bn, output_width_ratio + ) + self.scratch.refinenet2 = make_fusion_block( + feature_dim, use_bn, output_width_ratio + ) + self.scratch.refinenet3 = make_fusion_block( + feature_dim, use_bn, output_width_ratio + ) + self.scratch.refinenet4 = make_fusion_block( + feature_dim, use_bn, output_width_ratio + ) + + if self.head_type == "regression": + + self.head = nn.Sequential( + nn.Conv2d( + feature_dim, feature_dim // 2, kernel_size=3, stride=1, padding=1 + ), + Interpolate( + scale_factor=self.patch_size[0] / 8, + mode="bilinear", + align_corners=True, + ), + nn.Conv2d( + feature_dim // 2, last_dim, kernel_size=3, stride=1, padding=1 + ), + nn.ReLU(True), + nn.Conv2d( + last_dim, self.num_channels, kernel_size=1, stride=1, padding=0 + ), + ) + elif self.head_type == "semseg": + + self.head = nn.Sequential( + nn.Conv2d( + feature_dim, feature_dim, kernel_size=3, padding=1, bias=False + ), + nn.BatchNorm2d(feature_dim) if use_bn else nn.Identity(), + nn.ReLU(True), + nn.Dropout(0.1, False), + nn.Conv2d(feature_dim, self.num_channels, kernel_size=1), + Interpolate(scale_factor=2, mode="bilinear", align_corners=True), + ) + else: + raise ValueError('DPT head_type must be "regression" or "semseg".') + + if self.dim_tokens_enc is not None: + self.init(dim_tokens_enc=dim_tokens_enc) + + def init(self, dim_tokens_enc=768): + """ + Initialize parts of decoder that are dependent on dimension of encoder tokens. + Should be called when setting up MultiMAE. + + :param dim_tokens_enc: Dimension of tokens coming from encoder + """ + + if isinstance(dim_tokens_enc, int): + dim_tokens_enc = 4 * [dim_tokens_enc] + + self.dim_tokens_enc = [dt * len(self.main_tasks) for dt in dim_tokens_enc] + + self.act_1_postprocess = nn.Sequential( + nn.Conv2d( + in_channels=self.dim_tokens_enc[0], + out_channels=self.layer_dims[0], + kernel_size=1, + stride=1, + padding=0, + ), + nn.ConvTranspose2d( + in_channels=self.layer_dims[0], + out_channels=self.layer_dims[0], + kernel_size=4, + stride=4, + padding=0, + bias=True, + dilation=1, + groups=1, + ), + ) + + self.act_2_postprocess = nn.Sequential( + nn.Conv2d( + in_channels=self.dim_tokens_enc[1], + out_channels=self.layer_dims[1], + kernel_size=1, + stride=1, + padding=0, + ), + nn.ConvTranspose2d( + in_channels=self.layer_dims[1], + out_channels=self.layer_dims[1], + kernel_size=2, + stride=2, + padding=0, + bias=True, + dilation=1, + groups=1, + ), + ) + + self.act_3_postprocess = nn.Sequential( + nn.Conv2d( + in_channels=self.dim_tokens_enc[2], + out_channels=self.layer_dims[2], + kernel_size=1, + stride=1, + padding=0, + ) + ) + + self.act_4_postprocess = nn.Sequential( + nn.Conv2d( + in_channels=self.dim_tokens_enc[3], + out_channels=self.layer_dims[3], + kernel_size=1, + stride=1, + padding=0, + ), + nn.Conv2d( + in_channels=self.layer_dims[3], + out_channels=self.layer_dims[3], + kernel_size=3, + stride=2, + padding=1, + ), + ) + + self.act_postprocess = nn.ModuleList( + [ + self.act_1_postprocess, + self.act_2_postprocess, + self.act_3_postprocess, + self.act_4_postprocess, + ] + ) + + def adapt_tokens(self, encoder_tokens): + + x = [] + x.append(encoder_tokens[:, :]) + x = torch.cat(x, dim=-1) + return x + + def forward(self, encoder_tokens: List[torch.Tensor], image_size): + + assert ( + self.dim_tokens_enc is not None + ), "Need to call init(dim_tokens_enc) function first" + H, W = image_size + + N_H = H // (self.stride_level * self.P_H) + N_W = W // (self.stride_level * self.P_W) + + layers = [encoder_tokens[hook] for hook in self.hooks] + + layers = [self.adapt_tokens(l) for l in layers] + + layers = [ + rearrange(l, "b (nh nw) c -> b c nh nw", nh=N_H, nw=N_W) for l in layers + ] + + layers = [self.act_postprocess[idx](l) for idx, l in enumerate(layers)] + + layers = [self.scratch.layer_rn[idx](l) for idx, l in enumerate(layers)] + + path_4 = self.scratch.refinenet4(layers[3]) + path_3 = self.scratch.refinenet3(path_4, layers[2]) + path_2 = self.scratch.refinenet2(path_3, layers[1]) + path_1 = self.scratch.refinenet1(path_2, layers[0]) + + out = self.head(path_1) + + return out diff --git a/longstream/utils/vendor/croco/models/head_downstream.py b/longstream/utils/vendor/croco/models/head_downstream.py new file mode 100644 index 0000000000000000000000000000000000000000..78715c7c71cecf4dbb70d176f422b94519b02460 --- /dev/null +++ b/longstream/utils/vendor/croco/models/head_downstream.py @@ -0,0 +1,75 @@ +""" +A head is a module where the __init__ defines only the head hyperparameters. +A method setup(croconet) takes a CroCoNet and set all layers according to the head and croconet attributes. +The forward takes the features as well as a dictionary img_info containing the keys 'width' and 'height' +""" + +import torch +import torch.nn as nn + +from .dpt_block import DPTOutputAdapter + + +class PixelwiseTaskWithDPT(nn.Module): + """DPT module for CroCo. + by default, hooks_idx will be equal to: + * for encoder-only: 4 equally spread layers + * for encoder+decoder: last encoder + 3 equally spread layers of the decoder + """ + + def __init__( + self, + *, + hooks_idx=None, + layer_dims=[96, 192, 384, 768], + output_width_ratio=1, + num_channels=1, + postprocess=None, + **kwargs, + ): + super(PixelwiseTaskWithDPT, self).__init__() + self.return_all_blocks = True + self.postprocess = postprocess + self.output_width_ratio = output_width_ratio + self.num_channels = num_channels + self.hooks_idx = hooks_idx + self.layer_dims = layer_dims + + def setup(self, croconet): + dpt_args = { + "output_width_ratio": self.output_width_ratio, + "num_channels": self.num_channels, + } + if self.hooks_idx is None: + if hasattr(croconet, "dec_blocks"): + step = {8: 3, 12: 4, 24: 8}[croconet.dec_depth] + hooks_idx = [ + croconet.dec_depth + croconet.enc_depth - 1 - i * step + for i in range(3, -1, -1) + ] + else: + step = croconet.enc_depth // 4 + hooks_idx = [ + croconet.enc_depth - 1 - i * step for i in range(3, -1, -1) + ] + self.hooks_idx = hooks_idx + print( + f" PixelwiseTaskWithDPT: automatically setting hook_idxs={self.hooks_idx}" + ) + dpt_args["hooks"] = self.hooks_idx + dpt_args["layer_dims"] = self.layer_dims + self.dpt = DPTOutputAdapter(**dpt_args) + dim_tokens = [ + croconet.enc_embed_dim + if hook < croconet.enc_depth + else croconet.dec_embed_dim + for hook in self.hooks_idx + ] + dpt_init_args = {"dim_tokens_enc": dim_tokens} + self.dpt.init(**dpt_init_args) + + def forward(self, x, img_info): + out = self.dpt(x, image_size=(img_info["height"], img_info["width"])) + if self.postprocess: + out = self.postprocess(out) + return out diff --git a/longstream/utils/vendor/croco/models/masking.py b/longstream/utils/vendor/croco/models/masking.py new file mode 100644 index 0000000000000000000000000000000000000000..bbc73e9b6798a3fb1d10967b6939b9775e016b7d --- /dev/null +++ b/longstream/utils/vendor/croco/models/masking.py @@ -0,0 +1,18 @@ +import torch +import torch.nn as nn + + +class RandomMask(nn.Module): + """ + random masking + """ + + def __init__(self, num_patches, mask_ratio): + super().__init__() + self.num_patches = num_patches + self.num_mask = int(mask_ratio * self.num_patches) + + def __call__(self, x): + noise = torch.rand(x.size(0), self.num_patches, device=x.device) + argsort = torch.argsort(noise, dim=1) + return argsort < self.num_mask diff --git a/longstream/utils/vendor/croco/models/perceiver_block.py b/longstream/utils/vendor/croco/models/perceiver_block.py new file mode 100644 index 0000000000000000000000000000000000000000..b4280bf57ce4ca3dd50c5db7ba826de4476ae5bc --- /dev/null +++ b/longstream/utils/vendor/croco/models/perceiver_block.py @@ -0,0 +1,129 @@ +import numpy as np +import torch +import torch.nn as nn + +from blocks import CrossAttention, Block +from pos_embed import get_1d_sincos_pos_embed_from_grid + + +class PerceiverCompressor(nn.Module): + def __init__( + self, + token_dim, + latent_dim, + num_latents, + num_cross_layers, + num_latent_transformer_layers, + num_heads=8, + dropout=0.0, + norm_layer=nn.LayerNorm, + ): + super(PerceiverCompressor, self).__init__() + self.token_dim = token_dim + self.latent_dim = latent_dim + self.num_latents = num_latents + self.num_heads = num_heads + self.num_cross_layers = num_cross_layers + self.num_latent_transformer_layers = num_latent_transformer_layers + + self.latents = nn.Parameter(torch.randn(num_latents, latent_dim)) + + self.cross_attention_layers = nn.ModuleList() + for _ in range(num_cross_layers): + cross_attn_layer = nn.ModuleDict( + { + "cross_attn": CrossAttention( + dim=latent_dim, + num_heads=num_heads, + qkv_bias=True, + attn_drop=dropout, + proj_drop=dropout, + ), + "latent_transformer": nn.ModuleList( + [ + Block( + dim=latent_dim, + num_heads=num_heads, + mlp_ratio=4.0, + qkv_bias=True, + drop=dropout, + attn_drop=dropout, + norm_layer=norm_layer, + ) + for _ in range(num_latent_transformer_layers) + ] + ), + "norm1": norm_layer(latent_dim), + "norm2": norm_layer(latent_dim), + "norm_x": norm_layer(latent_dim), + } + ) + self.cross_attention_layers.append(cross_attn_layer) + + def forward(self, x, pos, image_ids): + """ + Args: + x (torch.Tensor): Input tensor of shape [B, P, C] where + B - batch size + P - total number of patches from all images + C - dimension of each visual token + pos (torch.Tensor): Positional tensor of shape [B, P, 2] indicating positions + image_ids (torch.Tensor): Tensor of shape [B, P] specifying which image each patch belongs to + Returns: + torch.Tensor: Compressed latent representation of shape [B, L, D] where + L - number of latents + D - dimension of each latent representation + """ + B, P, C = x.shape + + latents = self.latents.unsqueeze(0).expand(B, -1, -1) + + num_images = (torch.max(image_ids) + 1).cpu().item() + image_pos_emb = ( + torch.from_numpy( + get_1d_sincos_pos_embed_from_grid(self.token_dim, np.arange(num_images)) + ) + .float() + .to(x.device) + ) + + image_pos = image_pos_emb[image_ids] + x += image_pos + + for layer in self.cross_attention_layers: + + latents = layer["cross_attn"]( + query=layer["norm1"](latents), + key=layer["norm_x"](x), + value=layer["norm_x"](x), + qpos=None, + kpos=pos, + ) + + for latent_transformer_layer in layer["latent_transformer"]: + latents = latent_transformer_layer(x=layer["norm2"](latents), xpos=None) + + return latents + + +B, P, C = 2, 100 * 256, 768 +L, D = 1000, 768 +num_cross_layers = 4 +num_latent_transformer_layers = 2 +num_heads = 8 +dropout = 0.1 + +compressor = PerceiverCompressor( + token_dim=C, + latent_dim=D, + num_latents=L, + num_cross_layers=num_cross_layers, + num_latent_transformer_layers=num_latent_transformer_layers, + num_heads=num_heads, + dropout=dropout, +).cuda() +input_tensor = torch.randn(B, P, C).cuda() +pos_tensor = torch.randn(B, P, 2).cuda() +image_ids = torch.tensor([[i] * 256 for i in range(100)] * B).cuda().reshape(B, -1) +output_tensor = compressor(input_tensor, pos_tensor, image_ids) +print(output_tensor.shape) diff --git a/longstream/utils/vendor/croco/models/pos_embed.py b/longstream/utils/vendor/croco/models/pos_embed.py new file mode 100644 index 0000000000000000000000000000000000000000..5eba9bbd5046f68c39332c72360b182d8f9a23a2 --- /dev/null +++ b/longstream/utils/vendor/croco/models/pos_embed.py @@ -0,0 +1,150 @@ +import numpy as np +import torch + + +def get_2d_sincos_pos_embed(embed_dim, grid_size, n_cls_token=0): + """ + grid_size: int of the grid height and width + return: + pos_embed: [grid_size*grid_size, embed_dim] or [n_cls_token+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) + """ + grid_h = np.arange(grid_size, dtype=np.float32) + grid_w = np.arange(grid_size, dtype=np.float32) + grid = np.meshgrid(grid_w, grid_h) + grid = np.stack(grid, axis=0) + + grid = grid.reshape([2, 1, grid_size, grid_size]) + pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) + if n_cls_token > 0: + pos_embed = np.concatenate( + [np.zeros([n_cls_token, embed_dim]), pos_embed], axis=0 + ) + return pos_embed + + +def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): + assert embed_dim % 2 == 0 + + emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) + emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) + + emb = np.concatenate([emb_h, emb_w], axis=1) + return emb + + +def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): + """ + embed_dim: output dimension for each position + pos: a list of positions to be encoded: size (M,) + out: (M, D) + """ + assert embed_dim % 2 == 0 + omega = np.arange(embed_dim // 2, dtype=float) + omega /= embed_dim / 2.0 + omega = 1.0 / 10000 ** omega + + pos = pos.reshape(-1) + out = np.einsum("m,d->md", pos, omega) + + emb_sin = np.sin(out) + emb_cos = np.cos(out) + + emb = np.concatenate([emb_sin, emb_cos], axis=1) + return emb + + +def interpolate_pos_embed(model, checkpoint_model): + if "pos_embed" in checkpoint_model: + pos_embed_checkpoint = checkpoint_model["pos_embed"] + embedding_size = pos_embed_checkpoint.shape[-1] + num_patches = model.patch_embed.num_patches + num_extra_tokens = model.pos_embed.shape[-2] - num_patches + + orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5) + + new_size = int(num_patches ** 0.5) + + if orig_size != new_size: + print( + "Position interpolate from %dx%d to %dx%d" + % (orig_size, orig_size, new_size, new_size) + ) + extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] + + pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] + pos_tokens = pos_tokens.reshape( + -1, orig_size, orig_size, embedding_size + ).permute(0, 3, 1, 2) + pos_tokens = torch.nn.functional.interpolate( + pos_tokens, + size=(new_size, new_size), + mode="bicubic", + align_corners=False, + ) + pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2) + new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) + checkpoint_model["pos_embed"] = new_pos_embed + + +try: + from longstream.utils.vendor.croco.models.curope import cuRoPE2D + + RoPE2D = cuRoPE2D +except ImportError: + print( + "Warning, cannot find cuda-compiled version of RoPE2D, using a slow pytorch version instead" + ) + + class RoPE2D(torch.nn.Module): + def __init__(self, freq=100.0, F0=1.0): + super().__init__() + self.base = freq + self.F0 = F0 + self.cache = {} + + def get_cos_sin(self, D, seq_len, device, dtype): + if (D, seq_len, device, dtype) not in self.cache: + inv_freq = 1.0 / ( + self.base ** (torch.arange(0, D, 2).float().to(device) / D) + ) + t = torch.arange(seq_len, device=device, dtype=inv_freq.dtype) + freqs = torch.einsum("i,j->ij", t, inv_freq).to(dtype) + freqs = torch.cat((freqs, freqs), dim=-1) + cos = freqs.cos() + sin = freqs.sin() + self.cache[D, seq_len, device, dtype] = (cos, sin) + return self.cache[D, seq_len, device, dtype] + + @staticmethod + def rotate_half(x): + x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + def apply_rope1d(self, tokens, pos1d, cos, sin): + assert pos1d.ndim == 2 + cos = torch.nn.functional.embedding(pos1d, cos)[:, None, :, :] + sin = torch.nn.functional.embedding(pos1d, sin)[:, None, :, :] + return (tokens * cos) + (self.rotate_half(tokens) * sin) + + def forward(self, tokens, positions): + """ + input: + * tokens: batch_size x nheads x ntokens x dim + * positions: batch_size x ntokens x 2 (y and x position of each token) + output: + * tokens after appplying RoPE2D (batch_size x nheads x ntokens x dim) + """ + assert ( + tokens.size(3) % 2 == 0 + ), "number of dimensions should be a multiple of two" + D = tokens.size(3) // 2 + assert positions.ndim == 3 and positions.shape[-1] == 2 + cos, sin = self.get_cos_sin( + D, int(positions.max()) + 1, tokens.device, tokens.dtype + ) + + y, x = tokens.chunk(2, dim=-1) + y = self.apply_rope1d(y, positions[:, :, 0], cos, sin) + x = self.apply_rope1d(x, positions[:, :, 1], cos, sin) + tokens = torch.cat((y, x), dim=-1) + return tokens diff --git a/longstream/utils/vendor/croco/utils/misc.py b/longstream/utils/vendor/croco/utils/misc.py new file mode 100644 index 0000000000000000000000000000000000000000..cc82352e4b6e2e56405dcb9efe752a636f5f3376 --- /dev/null +++ b/longstream/utils/vendor/croco/utils/misc.py @@ -0,0 +1,506 @@ +import builtins +import datetime +import json +import math +import os +import time +from collections import defaultdict, deque +from pathlib import Path + +import numpy as np +import torch +import torch.distributed as dist +from torch import inf + + +class SmoothedValue(object): + """Track a series of values and provide access to smoothed values over a + window or the global series average. + """ + + def __init__(self, window_size=20, fmt=None): + if fmt is None: + fmt = "{median:.4f} ({global_avg:.4f})" + self.deque = deque(maxlen=window_size) + self.total = 0.0 + self.count = 0 + self.fmt = fmt + + def update(self, value, n=1): + self.deque.append(value) + self.count += n + self.total += value * n + + def synchronize_between_processes(self): + """ + Warning: does not synchronize the deque! + """ + if not is_dist_avail_and_initialized(): + return + t = torch.tensor([self.count, self.total], dtype=torch.float64, device="cuda") + dist.barrier() + dist.all_reduce(t) + t = t.tolist() + self.count = int(t[0]) + self.total = t[1] + + @property + def median(self): + d = torch.tensor(list(self.deque)) + return d.median().item() + + @property + def avg(self): + d = torch.tensor(list(self.deque), dtype=torch.float32) + return d.mean().item() + + @property + def global_avg(self): + return self.total / self.count + + @property + def max(self): + return max(self.deque) + + @property + def value(self): + return self.deque[-1] + + def __str__(self): + return self.fmt.format( + median=self.median, + avg=self.avg, + global_avg=self.global_avg, + max=self.max, + value=self.value, + ) + + +class MetricLogger(object): + def __init__(self, delimiter="\t"): + self.meters = defaultdict(SmoothedValue) + self.delimiter = delimiter + + def update(self, **kwargs): + for k, v in kwargs.items(): + if v is None: + continue + if isinstance(v, torch.Tensor): + v = v.item() + assert isinstance(v, (float, int)) + self.meters[k].update(v) + + def __getattr__(self, attr): + if attr in self.meters: + return self.meters[attr] + if attr in self.__dict__: + return self.__dict__[attr] + raise AttributeError( + "'{}' object has no attribute '{}'".format(type(self).__name__, attr) + ) + + def __str__(self): + loss_str = [] + for name, meter in self.meters.items(): + loss_str.append("{}: {}".format(name, str(meter))) + return self.delimiter.join(loss_str) + + def synchronize_between_processes(self): + for meter in self.meters.values(): + meter.synchronize_between_processes() + + def add_meter(self, name, meter): + self.meters[name] = meter + + def log_every(self, iterable, print_freq, header=None, max_iter=None): + i = 0 + if not header: + header = "" + start_time = time.time() + end = time.time() + iter_time = SmoothedValue(fmt="{avg:.4f}") + data_time = SmoothedValue(fmt="{avg:.4f}") + len_iterable = min(len(iterable), max_iter) if max_iter else len(iterable) + space_fmt = ":" + str(len(str(len_iterable))) + "d" + log_msg = [ + header, + "[{0" + space_fmt + "}/{1}]", + "eta: {eta}", + "{meters}", + "time: {time}", + "data: {data}", + ] + if torch.cuda.is_available(): + log_msg.append("max mem: {memory:.0f}") + log_msg = self.delimiter.join(log_msg) + MB = 1024.0 * 1024.0 + for it, obj in enumerate(iterable): + data_time.update(time.time() - end) + yield obj + iter_time.update(time.time() - end) + if i % print_freq == 0 or i == len_iterable - 1: + eta_seconds = iter_time.global_avg * (len_iterable - i) + eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) + if torch.cuda.is_available(): + print( + log_msg.format( + i, + len_iterable, + eta=eta_string, + meters=str(self), + time=str(iter_time), + data=str(data_time), + memory=torch.cuda.max_memory_allocated() / MB, + ) + ) + else: + print( + log_msg.format( + i, + len_iterable, + eta=eta_string, + meters=str(self), + time=str(iter_time), + data=str(data_time), + ) + ) + i += 1 + end = time.time() + if max_iter and it >= max_iter: + break + total_time = time.time() - start_time + total_time_str = str(datetime.timedelta(seconds=int(total_time))) + print( + "{} Total time: {} ({:.4f} s / it)".format( + header, total_time_str, total_time / len_iterable + ) + ) + + +def setup_for_distributed(is_master): + """ + This function disables printing when not in master process + """ + builtin_print = builtins.print + + def print(*args, **kwargs): + force = kwargs.pop("force", False) + force = force or (get_world_size() > 8) + if is_master or force: + now = datetime.datetime.now().time() + builtin_print("[{}] ".format(now), end="") + builtin_print(*args, **kwargs) + + builtins.print = print + + +def is_dist_avail_and_initialized(): + if not dist.is_available(): + return False + if not dist.is_initialized(): + return False + return True + + +def get_world_size(): + if not is_dist_avail_and_initialized(): + return 1 + return dist.get_world_size() + + +def get_rank(): + if not is_dist_avail_and_initialized(): + return 0 + return dist.get_rank() + + +def is_main_process(): + return get_rank() == 0 + + +def save_on_master(*args, **kwargs): + if is_main_process(): + torch.save(*args, **kwargs) + + +def init_distributed_mode(args): + nodist = args.nodist if hasattr(args, "nodist") else False + if "RANK" in os.environ and "WORLD_SIZE" in os.environ and not nodist: + args.rank = int(os.environ["RANK"]) + args.world_size = int(os.environ["WORLD_SIZE"]) + args.gpu = int(os.environ["LOCAL_RANK"]) + else: + print("Not using distributed mode") + setup_for_distributed(is_master=True) + args.distributed = False + return + + args.distributed = True + + torch.cuda.set_device(args.gpu) + args.dist_backend = "nccl" + print( + "| distributed init (rank {}): {}, gpu {}".format( + args.rank, args.dist_url, args.gpu + ), + flush=True, + ) + torch.distributed.init_process_group( + backend=args.dist_backend, + init_method=args.dist_url, + world_size=args.world_size, + rank=args.rank, + ) + torch.distributed.barrier() + setup_for_distributed(args.rank == 0) + + +class NativeScalerWithGradNormCount: + state_dict_key = "amp_scaler" + + def __init__(self, enabled=True): + self._scaler = torch.cuda.amp.GradScaler(enabled=enabled) + + def __call__( + self, + loss, + optimizer, + clip_grad=None, + parameters=None, + create_graph=False, + update_grad=True, + ): + self._scaler.scale(loss).backward(create_graph=create_graph) + if update_grad: + if clip_grad is not None: + assert parameters is not None + self._scaler.unscale_(optimizer) + norm = torch.nn.utils.clip_grad_norm_(parameters, clip_grad) + else: + self._scaler.unscale_(optimizer) + + norm = None + self._scaler.step(optimizer) + self._scaler.update() + else: + norm = None + return norm + + def state_dict(self): + return self._scaler.state_dict() + + def load_state_dict(self, state_dict): + self._scaler.load_state_dict(state_dict) + + +def get_grad_norm_(parameters, norm_type: float = 2.0) -> torch.Tensor: + if isinstance(parameters, torch.Tensor): + parameters = [parameters] + parameters = [p for p in parameters if p.grad is not None] + norm_type = float(norm_type) + if len(parameters) == 0: + return torch.tensor(0.0) + device = parameters[0].grad.device + if norm_type == inf: + total_norm = max(p.grad.detach().abs().max().to(device) for p in parameters) + else: + total_norm = torch.norm( + torch.stack( + [torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters] + ), + norm_type, + ) + return total_norm + + +def save_model( + args, epoch, model_without_ddp, optimizer, loss_scaler, fname=None, best_so_far=None +): + output_dir = Path(args.output_dir) + if fname is None: + fname = str(epoch) + checkpoint_path = output_dir / ("checkpoint-%s.pth" % fname) + to_save = { + "model": model_without_ddp.state_dict(), + "optimizer": optimizer.state_dict(), + "scaler": loss_scaler.state_dict(), + "args": args, + "epoch": epoch, + } + if best_so_far is not None: + to_save["best_so_far"] = best_so_far + print(f">> Saving model to {checkpoint_path} ...") + save_on_master(to_save, checkpoint_path) + + +def load_model(args, model_without_ddp, optimizer, loss_scaler): + args.start_epoch = 0 + best_so_far = None + if args.resume is not None: + if args.resume.startswith("https"): + checkpoint = torch.hub.load_state_dict_from_url( + args.resume, map_location="cpu", check_hash=True + ) + else: + checkpoint = torch.load(args.resume, map_location="cpu") + print("Resume checkpoint %s" % args.resume) + model_without_ddp.load_state_dict(checkpoint["model"], strict=False) + args.start_epoch = checkpoint["epoch"] + 1 + optimizer.load_state_dict(checkpoint["optimizer"]) + if "scaler" in checkpoint: + loss_scaler.load_state_dict(checkpoint["scaler"]) + if "best_so_far" in checkpoint: + best_so_far = checkpoint["best_so_far"] + print(" & best_so_far={:g}".format(best_so_far)) + else: + print("") + print("With optim & sched! start_epoch={:d}".format(args.start_epoch), end="") + return best_so_far + + +def all_reduce_mean(x): + world_size = get_world_size() + if world_size > 1: + x_reduce = torch.tensor(x).cuda() + dist.all_reduce(x_reduce) + x_reduce /= world_size + return x_reduce.item() + else: + return x + + +def _replace(text, src, tgt, rm=""): + """Advanced string replacement. + Given a text: + - replace all elements in src by the corresponding element in tgt + - remove all elements in rm + """ + if len(tgt) == 1: + tgt = tgt * len(src) + assert len(src) == len(tgt), f"'{src}' and '{tgt}' should have the same len" + for s, t in zip(src, tgt): + text = text.replace(s, t) + for c in rm: + text = text.replace(c, "") + return text + + +def filename(obj): + """transform a python obj or cmd into a proper filename. + - \1 gets replaced by slash '/' + - \2 gets replaced by comma ',' + """ + if not isinstance(obj, str): + obj = repr(obj) + obj = str(obj).replace("()", "") + obj = _replace(obj, "_,(*/\1\2", "-__x%/,", rm=" )'\"") + assert all(len(s) < 256 for s in obj.split(os.sep)), ( + "filename too long (>256 characters):\n" + obj + ) + return obj + + +def _get_num_layer_for_vit(var_name, enc_depth, dec_depth): + if var_name in ("cls_token", "mask_token", "pos_embed", "global_tokens"): + return 0 + elif var_name.startswith("patch_embed"): + return 0 + elif var_name.startswith("enc_blocks"): + layer_id = int(var_name.split(".")[1]) + return layer_id + 1 + elif var_name.startswith("decoder_embed") or var_name.startswith("enc_norm"): + return enc_depth + elif var_name.startswith("dec_blocks"): + layer_id = int(var_name.split(".")[1]) + return enc_depth + layer_id + 1 + elif var_name.startswith("dec_norm"): + return enc_depth + dec_depth + elif any(var_name.startswith(k) for k in ["head", "prediction_head"]): + return enc_depth + dec_depth + 1 + else: + raise NotImplementedError(var_name) + + +def get_parameter_groups( + model, weight_decay, layer_decay=1.0, skip_list=(), no_lr_scale_list=[] +): + parameter_group_names = {} + parameter_group_vars = {} + enc_depth, dec_depth = None, None + + assert layer_decay == 1.0 or 0.0 < layer_decay < 1.0 + if layer_decay < 1.0: + enc_depth = model.enc_depth + dec_depth = model.dec_depth if hasattr(model, "dec_blocks") else 0 + num_layers = enc_depth + dec_depth + layer_decay_values = list( + layer_decay ** (num_layers + 1 - i) for i in range(num_layers + 2) + ) + + for name, param in model.named_parameters(): + if not param.requires_grad: + continue + + if len(param.shape) == 1 or name.endswith(".bias") or name in skip_list: + group_name = "no_decay" + this_weight_decay = 0.0 + else: + group_name = "decay" + this_weight_decay = weight_decay + + if layer_decay < 1.0: + skip_scale = False + layer_id = _get_num_layer_for_vit(name, enc_depth, dec_depth) + group_name = "layer_%d_%s" % (layer_id, group_name) + if name in no_lr_scale_list: + skip_scale = True + group_name = f"{group_name}_no_lr_scale" + else: + layer_id = 0 + skip_scale = True + + if group_name not in parameter_group_names: + if not skip_scale: + scale = layer_decay_values[layer_id] + else: + scale = 1.0 + + parameter_group_names[group_name] = { + "weight_decay": this_weight_decay, + "params": [], + "lr_scale": scale, + } + parameter_group_vars[group_name] = { + "weight_decay": this_weight_decay, + "params": [], + "lr_scale": scale, + } + + parameter_group_vars[group_name]["params"].append(param) + parameter_group_names[group_name]["params"].append(name) + print("Param groups = %s" % json.dumps(parameter_group_names, indent=2)) + return list(parameter_group_vars.values()) + + +def adjust_learning_rate(optimizer, epoch, args): + """Decay the learning rate with half-cycle cosine after warmup""" + + if epoch < args.warmup_epochs: + lr = args.lr * epoch / args.warmup_epochs + else: + lr = args.min_lr + (args.lr - args.min_lr) * 0.5 * ( + 1.0 + + math.cos( + math.pi + * (epoch - args.warmup_epochs) + / (args.epochs - args.warmup_epochs) + ) + ) + + for param_group in optimizer.param_groups: + if "lr_scale" in param_group: + param_group["lr"] = lr * param_group["lr_scale"] + else: + param_group["lr"] = lr + + return lr diff --git a/longstream/utils/vendor/dust3r/__init__.py b/longstream/utils/vendor/dust3r/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..139597f9cb07c5d48bed18984ec4747f4b4f3438 --- /dev/null +++ b/longstream/utils/vendor/dust3r/__init__.py @@ -0,0 +1,2 @@ + + diff --git a/longstream/utils/vendor/dust3r/heads/__init__.py b/longstream/utils/vendor/dust3r/heads/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..4a3cb0150a84edd819c5aa744755ef074685d1b0 --- /dev/null +++ b/longstream/utils/vendor/dust3r/heads/__init__.py @@ -0,0 +1,12 @@ +from .dpt_head import create_dpt_head +from .linear_head import LinearPts3d + + +def head_factory(head_type, output_mode, net, has_conf=False): + """ " build a prediction head for the decoder""" + if head_type == "linear" and output_mode == "pts3d": + return LinearPts3d(net, has_conf) + elif head_type == "dpt" and output_mode == "pts3d": + return create_dpt_head(net, has_conf=has_conf) + else: + raise NotImplementedError(f"unexpected {head_type=} and {output_mode=}") diff --git a/longstream/utils/vendor/dust3r/heads/dpt_head.py b/longstream/utils/vendor/dust3r/heads/dpt_head.py new file mode 100644 index 0000000000000000000000000000000000000000..4e4682cd1590e3d2ebfb2f1ad615d469c6d0f9e3 --- /dev/null +++ b/longstream/utils/vendor/dust3r/heads/dpt_head.py @@ -0,0 +1,224 @@ +from typing import List + +import torch +import torch.nn as nn +from einops import rearrange +import torch.nn.functional as F + +from longstream.utils.vendor.croco.models.dpt_block import DPTOutputAdapter +from longstream.utils.vendor.croco.models.blocks import Mlp + +import longstream.utils.vendor.dust3r.utils.path_to_croco +from longstream.utils.vendor.dust3r.heads.postprocess import postprocess + + +class DPTOutputAdapter_fix(DPTOutputAdapter): + """ + Adapt croco's DPTOutputAdapter implementation for dust3r: + remove duplicated weigths, and fix forward for dust3r + """ + + def init(self, dim_tokens_enc=768): + super().init(dim_tokens_enc) + + del self.act_1_postprocess + del self.act_2_postprocess + del self.act_3_postprocess + del self.act_4_postprocess + + def forward(self, encoder_tokens: List[torch.Tensor], image_size=None): + assert ( + self.dim_tokens_enc is not None + ), "Need to call init(dim_tokens_enc) function first" + + image_size = self.image_size if image_size is None else image_size + H, W = image_size + + N_H = H // (self.stride_level * self.P_H) + N_W = W // (self.stride_level * self.P_W) + + layers = [encoder_tokens[hook] for hook in self.hooks] + + layers = [self.adapt_tokens(l) for l in layers] + + layers = [ + rearrange(l, "b (nh nw) c -> b c nh nw", nh=N_H, nw=N_W) for l in layers + ] + + layers = [self.act_postprocess[idx](l) for idx, l in enumerate(layers)] + + layers = [self.scratch.layer_rn[idx](l) for idx, l in enumerate(layers)] + + path_4 = self.scratch.refinenet4(layers[3])[ + :, :, : layers[2].shape[2], : layers[2].shape[3] + ] + path_3 = self.scratch.refinenet3(path_4, layers[2]) + path_2 = self.scratch.refinenet2(path_3, layers[1]) + path_1 = self.scratch.refinenet1(path_2, layers[0]) + + if self.training: + max_chunk_size = 1 + else: + max_chunk_size = 50 + chunks = torch.split(path_1, max_chunk_size, dim=0) + outputs = [] + + for chunk in chunks: + out_chunk = self.head(chunk) + outputs.append(out_chunk) + + out = torch.cat(outputs, dim=0) + return out + + +class PixelwiseTaskWithDPT(nn.Module): + """DPT module for dust3r, can return 3D points + confidence for all pixels""" + + def __init__( + self, + *, + n_cls_token=0, + hooks_idx=None, + dim_tokens=None, + output_width_ratio=1, + num_channels=1, + postprocess=None, + depth_mode=None, + conf_mode=None, + vis_mode=None, + **kwargs + ): + super(PixelwiseTaskWithDPT, self).__init__() + self.return_all_layers = True + self.postprocess = postprocess + self.depth_mode = depth_mode + self.conf_mode = conf_mode + self.vis_mode = vis_mode + + assert n_cls_token == 0, "Not implemented" + dpt_args = dict( + output_width_ratio=output_width_ratio, num_channels=num_channels, **kwargs + ) + if hooks_idx is not None: + dpt_args.update(hooks=hooks_idx) + self.dpt = DPTOutputAdapter_fix(**dpt_args) + dpt_init_args = {} if dim_tokens is None else {"dim_tokens_enc": dim_tokens} + self.dpt.init(**dpt_init_args) + + def forward(self, x, img_info): + out = self.dpt(x, image_size=(img_info[0], img_info[1])) + if self.postprocess: + out = self.postprocess(out, self.depth_mode, self.conf_mode, self.vis_mode) + return out + + +class Cat_MLP_LocalFeatures_DPT_Pts3d(PixelwiseTaskWithDPT): + """Mixture between MLP and DPT head that outputs 3d points and local features (with MLP). + The input for both heads is a concatenation of Encoder and Decoder outputs + """ + + def __init__( + self, + desc_head_args, + local_feat_dim=16, + hidden_dim_factor=4.0, + hooks_idx=None, + dim_tokens=None, + num_channels=1, + postprocess=None, + feature_dim=256, + last_dim=32, + depth_mode=None, + conf_mode=None, + head_type="regression", + **kwargs + ): + super().__init__( + num_channels=num_channels, + feature_dim=feature_dim, + last_dim=last_dim, + hooks_idx=hooks_idx, + dim_tokens=dim_tokens, + depth_mode=depth_mode, + postprocess=postprocess, + conf_mode=conf_mode, + head_type=head_type, + ) + self.local_feat_dim = local_feat_dim + + patch_size = desc_head_args["patch_size"] + if isinstance(patch_size, tuple): + assert ( + len(patch_size) == 2 + and isinstance(patch_size[0], int) + and isinstance(patch_size[1], int) + ), "What is your patchsize format? Expected a single int or a tuple of two ints." + assert ( + patch_size[0] == patch_size[1] + ), "Error, non square patches not managed" + patch_size = patch_size[0] + self.patch_size = patch_size + + self.desc_mode = desc_head_args["desc_mode"] + self.two_confs = desc_head_args["two_confs"] + self.desc_conf_mode = desc_head_args["desc_conf_mode"] + idim = desc_head_args["enc_embed_dim"] + desc_head_args["dec_embed_dim"] + + self.features_head = Mlp( + in_features=idim, + hidden_features=int(hidden_dim_factor * idim), + out_features=(self.local_feat_dim + self.two_confs) * self.patch_size ** 2, + ) + + def forward(self, decout, img_shape): + + pts3d = self.dpt(decout, image_size=(img_shape[0], img_shape[1])) + + enc_output, dec_output = decout[0], decout[-1] + cat_output = torch.cat([enc_output, dec_output], dim=-1) + H, W = img_shape + B, S, D = cat_output.shape + + local_features = self.features_head(cat_output) + local_features = local_features.transpose(-1, -2).view( + B, -1, H // self.patch_size, W // self.patch_size + ) + local_features = F.pixel_shuffle(local_features, self.patch_size) + + out = torch.cat([pts3d, local_features], dim=1) + if self.postprocess: + out = self.postprocess( + out, + depth_mode=self.depth_mode, + conf_mode=self.conf_mode, + desc_dim=self.local_feat_dim, + desc_mode=self.desc_mode, + two_confs=self.two_confs, + desc_conf_mode=self.desc_conf_mode, + ) + + return out + + +def create_dpt_head(net, has_conf=False): + """ + return PixelwiseTaskWithDPT for given net params + """ + assert net.dec_depth > 9 + l2 = net.dec_depth + feature_dim = 256 + last_dim = feature_dim // 2 + out_nchan = 3 + ed = net.enc_embed_dim + dd = net.dec_embed_dim + return PixelwiseTaskWithDPT( + num_channels=out_nchan + has_conf, + feature_dim=feature_dim, + last_dim=last_dim, + hooks_idx=[0, l2 * 2 // 4, l2 * 3 // 4, l2], + dim_tokens=[ed, dd, dd, dd], + postprocess=postprocess, + depth_mode=net.depth_mode, + conf_mode=net.conf_mode, + head_type="regression", + ) diff --git a/longstream/utils/vendor/dust3r/heads/linear_head.py b/longstream/utils/vendor/dust3r/heads/linear_head.py new file mode 100644 index 0000000000000000000000000000000000000000..e5ea9dfcc5a54c22bc2019bcdcf1ff9118414b87 --- /dev/null +++ b/longstream/utils/vendor/dust3r/heads/linear_head.py @@ -0,0 +1,36 @@ +import torch.nn as nn +import torch.nn.functional as F + +from longstream.utils.vendor.dust3r.heads.postprocess import postprocess + + +class LinearPts3d(nn.Module): + """ + Linear head for dust3r + Each token outputs: - 16x16 3D points (+ confidence) + """ + + def __init__(self, net, has_conf=False): + super().__init__() + self.patch_size = net.patch_embed.patch_size[0] + self.depth_mode = net.depth_mode + self.conf_mode = net.conf_mode + self.has_conf = has_conf + + self.proj = nn.Linear(net.dec_embed_dim, (3 + has_conf) * self.patch_size ** 2) + + def setup(self, croconet): + pass + + def forward(self, decout, img_shape): + H, W = img_shape + tokens = decout[-1] + B, S, D = tokens.shape + + feat = self.proj(tokens) + feat = feat.transpose(-1, -2).view( + B, -1, H // self.patch_size, W // self.patch_size + ) + feat = F.pixel_shuffle(feat, self.patch_size) + + return postprocess(feat, self.depth_mode, self.conf_mode) diff --git a/longstream/utils/vendor/dust3r/heads/postprocess.py b/longstream/utils/vendor/dust3r/heads/postprocess.py new file mode 100644 index 0000000000000000000000000000000000000000..e998f218bcc15eaf84227cb4d44527be33dd7794 --- /dev/null +++ b/longstream/utils/vendor/dust3r/heads/postprocess.py @@ -0,0 +1,92 @@ +import torch + + +def reg_desc(desc, mode): + if "norm" in mode: + desc = desc / desc.norm(dim=-1, keepdim=True) + else: + raise ValueError(f"Unknown desc mode {mode}") + return desc + + +def postprocess_with_feature( + out, + depth_mode, + conf_mode, + desc_dim=None, + desc_mode="norm", + two_confs=False, + desc_conf_mode=None, +): + if desc_conf_mode is None: + desc_conf_mode = conf_mode + fmap = out.permute(0, 2, 3, 1) + res = dict(pts3d=reg_dense_depth(fmap[..., 0:3], mode=depth_mode)) + if conf_mode is not None: + res["conf"] = reg_dense_conf(fmap[..., 3], mode=conf_mode) + if desc_dim is not None: + start = 3 + int(conf_mode is not None) + res["desc"] = reg_desc(fmap[..., start : start + desc_dim], mode=desc_mode) + if two_confs: + res["desc_conf"] = reg_dense_conf( + fmap[..., start + desc_dim], mode=desc_conf_mode + ) + else: + res["desc_conf"] = res["conf"].clone() + return res + + +def postprocess(out, depth_mode, conf_mode, vis_mode): + """ + extract 3D points/confidence from prediction head output + """ + fmap = out.permute(0, 2, 3, 1) + res = dict(pts3d=reg_dense_depth(fmap[:, :, :, 0:3], mode=depth_mode)) + + if conf_mode is not None: + res["conf"] = reg_dense_conf(fmap[:, :, :, 3], mode=conf_mode) + + if vis_mode is not None: + res["vis"] = reg_dense_conf(fmap[:, :, :, 3:4], mode=vis_mode) + + return res + + +def reg_dense_depth(xyz, mode): + """ + extract 3D points from prediction head output + """ + mode, vmin, vmax = mode + + no_bounds = (vmin == -float("inf")) and (vmax == float("inf")) + assert no_bounds + + if mode == "linear": + if no_bounds: + return xyz + return xyz.clip(min=vmin, max=vmax) + + d = xyz.norm(dim=-1, keepdim=True) + xyz = xyz / d.clip(min=1e-8) + + if mode == "square": + return xyz * d.square() + + if mode == "exp": + return xyz * torch.expm1(d) + + raise ValueError(f"bad {mode=}") + + +def reg_dense_conf(x, mode): + """ + extract confidence from prediction head output + """ + mode, vmin, vmax = mode + if mode == "exp": + return vmin + x.exp().clip(max=vmax - vmin) + if mode == "sigmoid": + return (vmax - vmin) * torch.sigmoid(x) + vmin + if mode == "none": + return x + raise ValueError(f"bad {mode=}") diff --git a/longstream/utils/vendor/dust3r/model.py b/longstream/utils/vendor/dust3r/model.py new file mode 100644 index 0000000000000000000000000000000000000000..2f71033d3269096d6b3e065e278fcf9514a8f75b --- /dev/null +++ b/longstream/utils/vendor/dust3r/model.py @@ -0,0 +1,604 @@ +import os +from copy import deepcopy + +import huggingface_hub +import torch +import torch.distributed +import torch.nn as nn +import numpy as np +from longstream.utils.vendor.croco.models.croco import CroCoNet +from longstream.utils.vendor.croco.models.blocks import Block +from longstream.utils.vendor.croco.models.pos_embed import ( + get_1d_sincos_pos_embed_from_grid, +) +from packaging import version + +from longstream.utils.vendor.dust3r.patch_embed import get_patch_embed + +from .heads import head_factory +from .utils.misc import ( + fill_default_args, + freeze_all_params, + interleave, + is_symmetrized, + transpose_to_landscape, +) +import torch.autograd.profiler as profiler + +inf = float("inf") + +hf_version_number = huggingface_hub.__version__ +assert version.parse(hf_version_number) >= version.parse( + "0.22.0" +), "Outdated huggingface_hub version, please reinstall requirements.txt" + + +def load_model(model_path, device, verbose=True): + if verbose: + print("... loading model from", model_path) + ckpt = torch.load(model_path, map_location="cpu") + args = ckpt["args"].model.replace("ManyAR_PatchEmbed", "PatchEmbedDust3R") + if "landscape_only" not in args: + args = args[:-1] + ", landscape_only=False)" + else: + args = args.replace(" ", "").replace( + "landscape_only=True", "landscape_only=False" + ) + assert "landscape_only=False" in args + if verbose: + print(f"instantiating : {args}") + net = eval(args) + s = net.load_state_dict(ckpt["model"], strict=False) + if verbose: + print(s) + return net.to(device) + + +class AsymmetricCroCo3DStereo( + CroCoNet, + huggingface_hub.PyTorchModelHubMixin, + library_name="dust3r", + repo_url="https://github.com/naver/dust3r", + tags=["image-to-3d"], +): + """Two siamese encoders, followed by two decoders. + The goal is to output 3d points directly, both images in view1's frame + (hence the asymmetry). + """ + + def __init__( + self, + output_mode="pts3d", + head_type="linear", + depth_mode=("exp", -inf, inf), + conf_mode=("exp", 1, inf), + freeze="none", + landscape_only=True, + patch_embed_cls="PatchEmbedDust3R", + **croco_kwargs, + ): + self.patch_embed_cls = patch_embed_cls + self.croco_args = fill_default_args(croco_kwargs, super().__init__) + super().__init__(**croco_kwargs) + + self.dec_blocks2 = deepcopy(self.dec_blocks) + self.set_downstream_head( + output_mode, + head_type, + landscape_only, + depth_mode, + conf_mode, + **croco_kwargs, + ) + self.set_freeze(freeze) + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path, **kw): + if os.path.isfile(pretrained_model_name_or_path): + return load_model(pretrained_model_name_or_path, device="cpu") + else: + return super(AsymmetricCroCo3DStereo, cls).from_pretrained( + pretrained_model_name_or_path, **kw + ) + + def _set_patch_embed(self, img_size=224, patch_size=16, enc_embed_dim=768): + self.patch_embed = get_patch_embed( + self.patch_embed_cls, img_size, patch_size, enc_embed_dim + ) + + def load_state_dict(self, ckpt, **kw): + + new_ckpt = dict(ckpt) + if not any(k.startswith("dec_blocks2") for k in ckpt): + for key, value in ckpt.items(): + if key.startswith("dec_blocks"): + new_ckpt[key.replace("dec_blocks", "dec_blocks2")] = value + return super().load_state_dict(new_ckpt, **kw) + + def set_freeze(self, freeze): + self.freeze = freeze + to_be_frozen = { + "none": [], + "mask": [self.mask_token], + "encoder": [self.mask_token, self.patch_embed, self.enc_blocks], + } + freeze_all_params(to_be_frozen[freeze]) + + def _set_prediction_head(self, *args, **kwargs): + """No prediction head""" + return + + def set_downstream_head( + self, + output_mode, + head_type, + landscape_only, + depth_mode, + conf_mode, + patch_size, + img_size, + **kw, + ): + assert ( + img_size[0] % patch_size == 0 and img_size[1] % patch_size == 0 + ), f"{img_size=} must be multiple of {patch_size=}" + self.output_mode = output_mode + self.head_type = head_type + self.depth_mode = depth_mode + self.conf_mode = conf_mode + + self.downstream_head1 = head_factory( + head_type, output_mode, self, has_conf=bool(conf_mode) + ) + self.downstream_head2 = head_factory( + head_type, output_mode, self, has_conf=bool(conf_mode) + ) + + self.head1 = transpose_to_landscape( + self.downstream_head1, activate=landscape_only + ) + self.head2 = transpose_to_landscape( + self.downstream_head2, activate=landscape_only + ) + + def _encode_image(self, image, true_shape): + + x, pos = self.patch_embed(image, true_shape=true_shape) + + assert self.enc_pos_embed is None + + for blk in self.enc_blocks: + x = blk(x, pos) + + x = self.enc_norm(x) + return x, pos, None + + def _encode_image_pairs(self, img1, img2, true_shape1, true_shape2): + if img1.shape[-2:] == img2.shape[-2:]: + out, pos, _ = self._encode_image( + torch.cat((img1, img2), dim=0), + torch.cat((true_shape1, true_shape2), dim=0), + ) + out, out2 = out.chunk(2, dim=0) + pos, pos2 = pos.chunk(2, dim=0) + else: + out, pos, _ = self._encode_image(img1, true_shape1) + out2, pos2, _ = self._encode_image(img2, true_shape2) + return out, out2, pos, pos2 + + def _encode_symmetrized(self, view1, view2): + img1 = view1["img"] + img2 = view2["img"] + B = img1.shape[0] + + shape1 = view1.get( + "true_shape", torch.tensor(img1.shape[-2:])[None].repeat(B, 1) + ) + shape2 = view2.get( + "true_shape", torch.tensor(img2.shape[-2:])[None].repeat(B, 1) + ) + + if is_symmetrized(view1, view2): + + feat1, feat2, pos1, pos2 = self._encode_image_pairs( + img1[::2], img2[::2], shape1[::2], shape2[::2] + ) + feat1, feat2 = interleave(feat1, feat2) + pos1, pos2 = interleave(pos1, pos2) + else: + feat1, feat2, pos1, pos2 = self._encode_image_pairs( + img1, img2, shape1, shape2 + ) + + return (shape1, shape2), (feat1, feat2), (pos1, pos2) + + def _decoder(self, f1, pos1, f2, pos2): + final_output = [(f1, f2)] + + f1 = self.decoder_embed(f1) + f2 = self.decoder_embed(f2) + + final_output.append((f1, f2)) + for blk1, blk2 in zip(self.dec_blocks, self.dec_blocks2): + + f1, _ = blk1(*final_output[-1][::+1], pos1, pos2) + + f2, _ = blk2(*final_output[-1][::-1], pos2, pos1) + + final_output.append((f1, f2)) + + del final_output[1] + final_output[-1] = tuple(map(self.dec_norm, final_output[-1])) + return zip(*final_output) + + def _downstream_head(self, head_num, decout, img_shape): + B, S, D = decout[-1].shape + + head = getattr(self, f"head{head_num}") + return head(decout, img_shape) + + def forward(self, view1, view2): + + (shape1, shape2), (feat1, feat2), (pos1, pos2) = self._encode_symmetrized( + view1, view2 + ) + + dec1, dec2 = self._decoder(feat1, pos1, feat2, pos2) + + with torch.cuda.amp.autocast(enabled=False): + res1 = self._downstream_head(1, [tok.float() for tok in dec1], shape1) + res2 = self._downstream_head(2, [tok.float() for tok in dec2], shape2) + + res2["pts3d_in_other_view"] = res2.pop("pts3d") + return res1, res2 + + +class FlashDUSt3R( + CroCoNet, + huggingface_hub.PyTorchModelHubMixin, + library_name="dust3r", + repo_url="https://github.com/naver/dust3r", + tags=["image-to-3d"], +): + """Two siamese encoders, followed by a single large decoder. + The goal is to output 3d points directly, processing multiple views. + """ + + def __init__( + self, + output_mode="pts3d", + head_type="linear", + depth_mode=("exp", -inf, inf), + conf_mode=("exp", 1, inf), + freeze="none", + landscape_only=True, + patch_embed_cls="PatchEmbedDust3R", + decoder_pos_embed_type="sinusoidal", + attn_implementation="pytorch_naive", + random_image_idx_embedding=False, + **croco_kwargs, + ): + self.patch_embed_cls = patch_embed_cls + self.random_image_idx_embedding = random_image_idx_embedding + self.croco_args = fill_default_args(croco_kwargs, super().__init__) + croco_kwargs["attn_implementation"] = attn_implementation + super().__init__(**croco_kwargs) + + self.register_buffer( + "image_idx_emb", + torch.from_numpy( + get_1d_sincos_pos_embed_from_grid(self.dec_embed_dim, np.arange(1000)) + ).float(), + persistent=False, + ) + + del self.dec_blocks + torch.cuda.empty_cache() + + self.decoder_pos_embed_type = decoder_pos_embed_type + self.multiview_dec_blocks = nn.ModuleList( + [ + Block( + dim=self.dec_embed_dim, + num_heads=8, + mlp_ratio=4.0, + qkv_bias=True, + drop=0.0, + attn_drop=0.0, + norm_layer=nn.LayerNorm, + attn_implementation=attn_implementation, + ) + for _ in range(12) + ] + ) + self.set_downstream_head( + output_mode, + head_type, + landscape_only, + depth_mode, + conf_mode, + **croco_kwargs, + ) + self.set_freeze(freeze) + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path, **kw): + if os.path.isfile(pretrained_model_name_or_path): + return load_model(pretrained_model_name_or_path, device="cpu") + else: + return super(FlashDUSt3R, cls).from_pretrained( + pretrained_model_name_or_path, **kw + ) + + def _set_patch_embed(self, img_size=224, patch_size=16, enc_embed_dim=768): + self.patch_embed = get_patch_embed( + self.patch_embed_cls, img_size, patch_size, enc_embed_dim + ) + + def load_state_dict(self, ckpt, **kw): + return super().load_state_dict(ckpt, **kw) + + def set_freeze(self, freeze): + self.freeze = freeze + to_be_frozen = { + "none": [], + "mask": [self.mask_token], + "encoder": [self.mask_token, self.patch_embed, self.enc_blocks], + "sandwich": [ + self.mask_token, + self.patch_embed, + self.enc_blocks, + self.downstream_head, + ], + } + freeze_all_params(to_be_frozen[freeze]) + + def _set_prediction_head(self, *args, **kwargs): + """No prediction head""" + return + + def set_downstream_head( + self, + output_mode, + head_type, + landscape_only, + depth_mode, + conf_mode, + patch_size, + img_size, + **kw, + ): + assert ( + img_size[0] % patch_size == 0 and img_size[1] % patch_size == 0 + ), f"{img_size=} must be multiple of {patch_size=}" + self.output_mode = output_mode + self.head_type = head_type + self.depth_mode = depth_mode + self.conf_mode = conf_mode + + self.downstream_head = head_factory( + head_type, output_mode, self, has_conf=bool(conf_mode) + ) + + self.head = transpose_to_landscape( + self.downstream_head, activate=landscape_only + ) + + def _encode_image(self, image, true_shape): + + x, pos = self.patch_embed(image, true_shape=true_shape) + + assert self.enc_pos_embed is None + + for blk in self.enc_blocks: + x = blk(x, pos) + + x = self.enc_norm(x) + return x, pos + + def _encode_images(self, views): + B = views[0]["img"].shape[0] + encoded_feats, positions, shapes = [], [], [] + + for view in views: + img = view["img"] + true_shape = view.get( + "true_shape", torch.tensor(img.shape[-2:])[None].repeat(B, 1) + ) + feat, pos = self._encode_image(img, true_shape) + encoded_feats.append(feat) + positions.append(pos) + shapes.append(true_shape) + + return encoded_feats, positions, shapes + + def _generate_per_rank_generator(self): + + per_forward_pass_seed = torch.randint(0, 2 ** 32, (1,)).item() + world_rank = ( + torch.distributed.get_rank() if torch.distributed.is_initialized() else 0 + ) + per_rank_seed = per_forward_pass_seed + world_rank + + per_rank_generator = torch.Generator() + per_rank_generator.manual_seed(per_rank_seed) + return per_rank_generator + + def _get_random_image_pos( + self, encoded_feats, batch_size, num_views, max_image_idx, device + ): + """ + Generates non-repeating random image indices for each sample, retrieves corresponding + positional embeddings for each view, and concatenates them. + + Args: + encoded_feats (list of tensors): Encoded features for each view. + batch_size (int): Number of samples in the batch. + num_views (int): Number of views per sample. + max_image_idx (int): Maximum image index for embedding. + device (torch.device): Device to move data to. + + Returns: + Tensor: Concatenated positional embeddings for the entire batch. + """ + + image_ids = torch.zeros(batch_size, num_views, dtype=torch.long) + + image_ids[:, 0] = 0 + + per_rank_generator = self._generate_per_rank_generator() + + for b in range(batch_size): + + random_ids = ( + torch.randperm(max_image_idx, generator=per_rank_generator)[ + : num_views - 1 + ] + + 1 + ) + image_ids[b, 1:] = random_ids + + image_ids = image_ids.to(device) + + image_pos_list = [] + + for i in range(num_views): + + num_patches = encoded_feats[i].shape[1] + + image_pos_for_view = self.image_idx_emb[image_ids[:, i]] + + image_pos_for_view = image_pos_for_view.unsqueeze(1).repeat( + 1, num_patches, 1 + ) + + image_pos_list.append(image_pos_for_view) + + image_pos = torch.cat(image_pos_list, dim=1) + + return image_pos + + def _decoder(self, encoded_feats, positions, image_ids): + x = torch.cat(encoded_feats, dim=1) + pos = torch.cat(positions, dim=1) + + final_output = [x] + + x = self.decoder_embed(x) + + if self.random_image_idx_embedding: + + image_pos = self._get_random_image_pos( + encoded_feats=encoded_feats, + batch_size=encoded_feats[0].shape[0], + num_views=len(encoded_feats), + max_image_idx=self.image_idx_emb.shape[0] - 1, + device=x.device, + ) + else: + + num_images = (torch.max(image_ids) + 1).cpu().item() + image_idx_emb = self.image_idx_emb[:num_images] + image_pos = image_idx_emb[image_ids] + + x += image_pos + + for blk in self.multiview_dec_blocks: + x = blk(x, pos) + final_output.append(x) + + x = self.dec_norm(x) + final_output[-1] = x + return final_output + + def forward(self, views): + """ + Args: + views (list[dict]): a list of views, each view is a dict of tensors, the tensors are batched + + Returns: + list[dict]: a list of results for each view + """ + + encoded_feats, positions, shapes = self._encode_images(views) + + num_images = len(views) + B, _, _ = encoded_feats[0].shape + + different_resolution_across_views = not all( + encoded_feats[0].shape[1] == encoded_feat.shape[1] + for encoded_feat in encoded_feats + ) + + image_ids = [] + + for i, encoded_feat in enumerate(encoded_feats): + num_patches = encoded_feat.shape[1] + + image_ids.extend([i] * num_patches) + + image_ids = ( + torch.tensor(image_ids * B).reshape(B, -1).to(encoded_feats[0].device) + ) + + dec_output = self._decoder(encoded_feats, positions, image_ids) + + final_results = [{} for _ in range(num_images)] + + with profiler.record_function("head: gathered outputs"): + + gathered_outputs_list = [] + if different_resolution_across_views: + for img_id in range(num_images): + gathered_outputs_per_view = [] + for layer_output in dec_output: + B, P, D = layer_output.shape + mask = image_ids == img_id + gathered_output = layer_output[mask].view(B, -1, D) + gathered_outputs_per_view.append(gathered_output) + gathered_outputs_list.append(gathered_outputs_per_view) + else: + for layer_output in dec_output: + B, P, D = layer_output.shape + gathered_outputs_per_view = [] + for img_id in range(num_images): + mask = image_ids == img_id + gathered_output = layer_output[mask].view(B, -1, D) + gathered_outputs_per_view.append(gathered_output) + gathered_outputs_list.append( + torch.cat(gathered_outputs_per_view, dim=0) + ) + + with profiler.record_function("head: forward pass"): + if different_resolution_across_views: + + final_results = [{} for _ in range(num_images)] + for img_id in range(num_images): + img_result = self.head( + gathered_outputs_list[img_id], shapes[img_id] + ) + + for key in img_result.keys(): + if key == "pts3d": + final_results[img_id]["pts3d_in_other_view"] = img_result[ + key + ] + else: + final_results[img_id][key] = img_result[key] + else: + + concatenated_shapes = torch.cat(shapes, dim=0) + + result = self.head(gathered_outputs_list, concatenated_shapes) + + final_results = [{} for _ in range(num_images)] + + for key in result.keys(): + for img_id in range(num_images): + img_result = result[key][img_id * B : (img_id + 1) * B] + if key == "pts3d": + final_results[img_id]["pts3d_in_other_view"] = img_result + else: + final_results[img_id][key] = img_result + + return final_results diff --git a/longstream/utils/vendor/dust3r/optim_factory.py b/longstream/utils/vendor/dust3r/optim_factory.py new file mode 100644 index 0000000000000000000000000000000000000000..ad6ebe532867be5f7d5c7c7b67075df0120989a2 --- /dev/null +++ b/longstream/utils/vendor/dust3r/optim_factory.py @@ -0,0 +1,6 @@ +def adjust_learning_rate_by_lr(optimizer, lr): + for param_group in optimizer.param_groups: + if "lr_scale" in param_group: + param_group["lr"] = lr * param_group["lr_scale"] + else: + param_group["lr"] = lr diff --git a/longstream/utils/vendor/dust3r/patch_embed.py b/longstream/utils/vendor/dust3r/patch_embed.py new file mode 100644 index 0000000000000000000000000000000000000000..913ffb9f14f80129a71900180b4cfa3d4e7b57b2 --- /dev/null +++ b/longstream/utils/vendor/dust3r/patch_embed.py @@ -0,0 +1,117 @@ +import torch +from longstream.utils.vendor.croco.models.blocks import PatchEmbed + + +def get_patch_embed(patch_embed_cls, img_size, patch_size, enc_embed_dim): + assert patch_embed_cls in ["PatchEmbedDust3R", "ManyAR_PatchEmbed"] + patch_embed = eval(patch_embed_cls)(img_size, patch_size, 3, enc_embed_dim) + return patch_embed + + +class PatchEmbedDust3R(PatchEmbed): + def forward(self, x, **kw): + B, C, H, W = x.shape + assert ( + H % self.patch_size[0] == 0 + ), f"Input image height ({H}) is not a multiple of patch size ({self.patch_size[0]})." + assert ( + W % self.patch_size[1] == 0 + ), f"Input image width ({W}) is not a multiple of patch size ({self.patch_size[1]})." + x = self.proj(x) + pos = self.position_getter(B, x.size(2), x.size(3), x.device) + if self.flatten: + x = x.flatten(2).transpose(1, 2) + x = self.norm(x) + return x, pos + + +class ManyAR_PatchEmbed(PatchEmbed): + """Handle images with non-square aspect ratio. + All images in the same batch have the same aspect ratio. + true_shape = [(height, width) ...] indicates the actual shape of each image. + """ + + def __init__( + self, + img_size=224, + patch_size=16, + in_chans=3, + embed_dim=768, + norm_layer=None, + flatten=True, + ): + self.embed_dim = embed_dim + super().__init__(img_size, patch_size, in_chans, embed_dim, norm_layer, flatten) + + def forward(self, img, true_shape): + + if not self.training: + x = img + B, C, H, W = x.shape + assert ( + H % self.patch_size[0] == 0 + ), f"Input image height ({H}) is not a multiple of patch size ({self.patch_size[0]})." + assert ( + W % self.patch_size[1] == 0 + ), f"Input image width ({W}) is not a multiple of patch size ({self.patch_size[1]})." + x = self.proj(x) + pos = self.position_getter(B, x.size(2), x.size(3), x.device) + if self.flatten: + x = x.flatten(2).transpose(1, 2) + x = self.norm(x) + return x, pos + + B, C, H, W = img.shape + + assert W >= H, f"img should be in landscape mode, but got {W=} {H=}" + assert ( + H % self.patch_size[0] == 0 + ), f"Input image height ({H}) is not a multiple of patch size ({self.patch_size[0]})." + assert ( + W % self.patch_size[1] == 0 + ), f"Input image width ({W}) is not a multiple of patch size ({self.patch_size[1]})." + assert true_shape.shape == ( + B, + 2, + ), f"true_shape has the wrong shape={true_shape.shape}" + + W //= self.patch_size[0] + H //= self.patch_size[1] + n_tokens = H * W + + height, width = true_shape.T + is_landscape = width >= height + is_portrait = ~is_landscape + + if is_landscape.any(): + new_landscape_content = self.proj(img[is_landscape]) + new_landscape_content = new_landscape_content.permute(0, 2, 3, 1).flatten( + 1, 2 + ) + if is_portrait.any(): + new_protrait_content = self.proj(img[is_portrait].swapaxes(-1, -2)) + new_protrait_content = new_protrait_content.permute(0, 2, 3, 1).flatten( + 1, 2 + ) + + x = img.new_empty( + (B, n_tokens, self.embed_dim), dtype=next(self.named_parameters())[1].dtype + ) + + if is_landscape.any(): + x[is_landscape] = new_landscape_content.to(x.dtype) + if is_portrait.any(): + x[is_portrait] = new_protrait_content.to(x.dtype) + + pos = img.new_empty((B, n_tokens, 2), dtype=torch.int64) + if is_landscape.any(): + pos[is_landscape] = self.position_getter(1, H, W, pos.device).expand( + is_landscape.sum(), -1, -1 + ) + if is_portrait.any(): + pos[is_portrait] = self.position_getter(1, W, H, pos.device).expand( + is_portrait.sum(), -1, -1 + ) + + x = self.norm(x) + return x, pos diff --git a/longstream/utils/vendor/dust3r/post_process.py b/longstream/utils/vendor/dust3r/post_process.py new file mode 100644 index 0000000000000000000000000000000000000000..cb6f37cd57b88a78a34e31d76dce04b503250c0a --- /dev/null +++ b/longstream/utils/vendor/dust3r/post_process.py @@ -0,0 +1,114 @@ +import numpy as np +import torch + +from longstream.utils.vendor.dust3r.utils.geometry import xy_grid + + +def estimate_focal_knowing_depth( + pts3d, pp, focal_mode="median", min_focal=0.0, max_focal=np.inf +): + """Reprojection method, for when the absolute depth is known: + 1) estimate the camera focal using a robust estimator + 2) reproject points onto true rays, minimizing a certain error + """ + B, H, W, THREE = pts3d.shape + assert THREE == 3 + + pixels = xy_grid(W, H, device=pts3d.device).view(1, -1, 2) - pp.view(-1, 1, 2) + pts3d = pts3d.flatten(1, 2) + + if focal_mode == "median": + with torch.no_grad(): + + u, v = pixels.unbind(dim=-1) + x, y, z = pts3d.unbind(dim=-1) + fx_votes = (u * z) / x + fy_votes = (v * z) / y + + f_votes = torch.cat((fx_votes.view(B, -1), fy_votes.view(B, -1)), dim=-1) + focal = torch.nanmedian(f_votes, dim=-1).values + + elif focal_mode == "weiszfeld": + + xy_over_z = (pts3d[..., :2] / pts3d[..., 2:3]).nan_to_num(posinf=0, neginf=0) + + dot_xy_px = (xy_over_z * pixels).sum(dim=-1) + dot_xy_xy = xy_over_z.square().sum(dim=-1) + + focal = dot_xy_px.mean(dim=1) / dot_xy_xy.mean(dim=1) + + for iter in range(10): + + dis = (pixels - focal.view(-1, 1, 1) * xy_over_z).norm(dim=-1) + + w = dis.clip(min=1e-8).reciprocal() + + focal = (w * dot_xy_px).mean(dim=1) / (w * dot_xy_xy).mean(dim=1) + else: + raise ValueError(f"bad {focal_mode=}") + + focal_base = max(H, W) / (2 * np.tan(np.deg2rad(60) / 2)) + focal = focal.clip(min=min_focal * focal_base, max=max_focal * focal_base) + + return focal + + +def estimate_focal_knowing_depth_and_confidence_mask( + pts3d, pp, conf_mask, focal_mode="median", min_focal=0.0, max_focal=np.inf +): + """Reprojection method for when the absolute depth is known: + 1) estimate the camera focal using a robust estimator + 2) reproject points onto true rays, minimizing a certain error + This function considers only points where conf_mask is True. + """ + B, H, W, THREE = pts3d.shape + assert THREE == 3 + + pixels = xy_grid(W, H, device=pts3d.device).view(1, H, W, 2) - pp.view(-1, 1, 1, 2) + + conf_mask = conf_mask.view(B, H, W) + valid_indices = conf_mask + + pts3d_valid = pts3d[valid_indices] + pixels_valid = pixels[valid_indices] + + if pts3d_valid.numel() == 0: + + focal_base = max(H, W) / (2 * np.tan(np.deg2rad(60) / 2)) + return torch.tensor([focal_base]) + + if focal_mode == "median": + with torch.no_grad(): + + u, v = pixels_valid.unbind(dim=-1) + x, y, z = pts3d_valid.unbind(dim=-1) + fx_votes = (u * z) / x + fy_votes = (v * z) / y + + f_votes = torch.cat((fx_votes.view(-1), fy_votes.view(-1)), dim=-1) + focal = torch.nanmedian(f_votes).unsqueeze(0) + + elif focal_mode == "weiszfeld": + + xy_over_z = (pts3d_valid[..., :2] / pts3d_valid[..., 2:3]).nan_to_num( + posinf=0, neginf=0 + ) + + dot_xy_px = (xy_over_z * pixels_valid).sum(dim=-1) + dot_xy_xy = xy_over_z.square().sum(dim=-1) + + focal = dot_xy_px.mean() / dot_xy_xy.mean() + + for _ in range(100): + + dis = (pixels_valid - focal * xy_over_z).norm(dim=-1) + w = dis.clip(min=1e-8).reciprocal() + + focal = (w * dot_xy_px).sum() / (w * dot_xy_xy).sum() + focal = focal.unsqueeze(0) + else: + raise ValueError(f"bad focal_mode={focal_mode}") + + focal_base = max(H, W) / (2 * np.tan(np.deg2rad(60) / 2)) + focal = focal.clip(min=min_focal * focal_base, max=max_focal * focal_base) + return focal diff --git a/longstream/utils/vendor/dust3r/utils/__init__.py b/longstream/utils/vendor/dust3r/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..139597f9cb07c5d48bed18984ec4747f4b4f3438 --- /dev/null +++ b/longstream/utils/vendor/dust3r/utils/__init__.py @@ -0,0 +1,2 @@ + + diff --git a/longstream/utils/vendor/dust3r/utils/camera.py b/longstream/utils/vendor/dust3r/utils/camera.py new file mode 100644 index 0000000000000000000000000000000000000000..30bc6c926e51dcec5e22ef36258b6aa18e1df4ed --- /dev/null +++ b/longstream/utils/vendor/dust3r/utils/camera.py @@ -0,0 +1,203 @@ +from typing import Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F + +inf = float("inf") + + +def _sqrt_positive_part(x: torch.Tensor) -> torch.Tensor: + """ + Returns torch.sqrt(torch.max(0, x)) + but with a zero subgradient where x is 0. + """ + ret = torch.zeros_like(x) + positive_mask = x > 0 + ret[positive_mask] = torch.sqrt(x[positive_mask]) + return ret + + +def matrix_to_quaternion(matrix: torch.Tensor) -> torch.Tensor: + """ + Convert rotations given as rotation matrices to quaternions. + + Args: + matrix: Rotation matrices as tensor of shape (..., 3, 3). + + Returns: + quaternions with real part first, as tensor of shape (..., 4). + """ + if matrix.size(-1) != 3 or matrix.size(-2) != 3: + raise ValueError(f"Invalid rotation matrix shape {matrix.shape}.") + + batch_dim = matrix.shape[:-2] + m00, m01, m02, m10, m11, m12, m20, m21, m22 = torch.unbind( + matrix.reshape(batch_dim + (9,)), dim=-1 + ) + + q_abs = _sqrt_positive_part( + torch.stack( + [ + 1.0 + m00 + m11 + m22, + 1.0 + m00 - m11 - m22, + 1.0 - m00 + m11 - m22, + 1.0 - m00 - m11 + m22, + ], + dim=-1, + ) + ) + + quat_by_rijk = torch.stack( + [ + torch.stack([q_abs[..., 0] ** 2, m21 - m12, m02 - m20, m10 - m01], dim=-1), + torch.stack([m21 - m12, q_abs[..., 1] ** 2, m10 + m01, m02 + m20], dim=-1), + torch.stack([m02 - m20, m10 + m01, q_abs[..., 2] ** 2, m12 + m21], dim=-1), + torch.stack([m10 - m01, m20 + m02, m21 + m12, q_abs[..., 3] ** 2], dim=-1), + ], + dim=-2, + ) + + flr = torch.tensor(0.1).to(dtype=q_abs.dtype, device=q_abs.device) + quat_candidates = quat_by_rijk / (2.0 * q_abs[..., None].max(flr)) + + out = quat_candidates[ + F.one_hot(q_abs.argmax(dim=-1), num_classes=4) > 0.5, : + ].reshape(batch_dim + (4,)) + return standardize_quaternion(out) + + +def standardize_quaternion(quaternions: torch.Tensor) -> torch.Tensor: + """ + Convert a unit quaternion to a standard form: one in which the real + part is non negative. + + Args: + quaternions: Quaternions with real part first, + as tensor of shape (..., 4). + + Returns: + Standardized quaternions as tensor of shape (..., 4). + """ + quaternions = F.normalize(quaternions, p=2, dim=-1) + return torch.where(quaternions[..., 0:1] < 0, -quaternions, quaternions) + + +def camera_to_pose_encoding( + camera, + pose_encoding_type="absT_quaR", +): + """ + Inverse to pose_encoding_to_camera + camera: opencv, cam2world + """ + if pose_encoding_type == "absT_quaR": + + quaternion_R = matrix_to_quaternion(camera[:, :3, :3]) + + pose_encoding = torch.cat([camera[:, :3, 3], quaternion_R], dim=-1) + else: + raise ValueError(f"Unknown pose encoding {pose_encoding_type}") + + return pose_encoding + + +def quaternion_to_matrix(quaternions: torch.Tensor) -> torch.Tensor: + """ + Convert rotations given as quaternions to rotation matrices. + + Args: + quaternions: quaternions with real part first, + as tensor of shape (..., 4). + + Returns: + Rotation matrices as tensor of shape (..., 3, 3). + """ + r, i, j, k = torch.unbind(quaternions, -1) + + two_s = 2.0 / (quaternions * quaternions).sum(-1) + + o = torch.stack( + ( + 1 - two_s * (j * j + k * k), + two_s * (i * j - k * r), + two_s * (i * k + j * r), + two_s * (i * j + k * r), + 1 - two_s * (i * i + k * k), + two_s * (j * k - i * r), + two_s * (i * k - j * r), + two_s * (j * k + i * r), + 1 - two_s * (i * i + j * j), + ), + -1, + ) + return o.reshape(quaternions.shape[:-1] + (3, 3)) + + +def pose_encoding_to_camera( + pose_encoding, + pose_encoding_type="absT_quaR", +): + """ + Args: + pose_encoding: A tensor of shape `BxC`, containing a batch of + `B` `C`-dimensional pose encodings. + pose_encoding_type: The type of pose encoding, + """ + + if pose_encoding_type == "absT_quaR": + + abs_T = pose_encoding[:, :3] + quaternion_R = pose_encoding[:, 3:7] + R = quaternion_to_matrix(quaternion_R) + else: + raise ValueError(f"Unknown pose encoding {pose_encoding_type}") + + c2w_mats = torch.eye(4, 4).to(R.dtype).to(R.device) + c2w_mats = c2w_mats[None].repeat(len(R), 1, 1) + c2w_mats[:, :3, :3] = R + c2w_mats[:, :3, 3] = abs_T + + return c2w_mats + + +def quaternion_conjugate(q): + """Compute the conjugate of quaternion q (w, x, y, z).""" + + q_conj = torch.cat([q[..., :1], -q[..., 1:]], dim=-1) + return q_conj + + +def quaternion_multiply(q1, q2): + """Multiply two quaternions q1 and q2.""" + w1, x1, y1, z1 = q1.unbind(dim=-1) + w2, x2, y2, z2 = q2.unbind(dim=-1) + + w = w1 * w2 - x1 * x2 - y1 * y2 - z1 * z2 + x = w1 * x2 + x1 * w2 + y1 * z2 - z1 * y2 + y = w1 * y2 - x1 * z2 + y1 * w2 + z1 * x2 + z = w1 * z2 + x1 * y2 - y1 * x2 + z1 * w2 + + return torch.stack((w, x, y, z), dim=-1) + + +def rotate_vector(q, v): + """Rotate vector v by quaternion q.""" + q_vec = q[..., 1:] + q_w = q[..., :1] + + t = 2.0 * torch.cross(q_vec, v, dim=-1) + v_rot = v + q_w * t + torch.cross(q_vec, t, dim=-1) + return v_rot + + +def relative_pose_absT_quatR(t1, q1, t2, q2): + """Compute the relative translation and quaternion between two poses.""" + + q1_inv = quaternion_conjugate(q1) + + q_rel = quaternion_multiply(q1_inv, q2) + + delta_t = t2 - t1 + t_rel = rotate_vector(q1_inv, delta_t) + return t_rel, q_rel diff --git a/longstream/utils/vendor/dust3r/utils/device.py b/longstream/utils/vendor/dust3r/utils/device.py new file mode 100644 index 0000000000000000000000000000000000000000..e1e15e0d3b6afa26523fc8f17ab5f0bbcc222292 --- /dev/null +++ b/longstream/utils/vendor/dust3r/utils/device.py @@ -0,0 +1,82 @@ +import numpy as np +import torch + + +def todevice(batch, device, callback=None, non_blocking=False): + """Transfer some variables to another device (i.e. GPU, CPU:torch, CPU:numpy). + + batch: list, tuple, dict of tensors or other things + device: pytorch device or 'numpy' + callback: function that would be called on every sub-elements. + """ + if callback: + batch = callback(batch) + + if isinstance(batch, dict): + return {k: todevice(v, device) for k, v in batch.items()} + + if isinstance(batch, (tuple, list)): + return type(batch)(todevice(x, device) for x in batch) + + x = batch + if device == "numpy": + if isinstance(x, torch.Tensor): + x = x.detach().cpu().numpy() + elif x is not None: + if isinstance(x, np.ndarray): + x = torch.from_numpy(x) + if torch.is_tensor(x): + x = x.to(device, non_blocking=non_blocking) + return x + + +to_device = todevice + + +def to_numpy(x): + return todevice(x, "numpy") + + +def to_cpu(x): + return todevice(x, "cpu") + + +def to_cuda(x): + return todevice(x, "cuda") + + +def collate_with_cat(whatever, lists=False): + if isinstance(whatever, dict): + return {k: collate_with_cat(vals, lists=lists) for k, vals in whatever.items()} + + elif isinstance(whatever, (tuple, list)): + if len(whatever) == 0: + return whatever + elem = whatever[0] + T = type(whatever) + + if elem is None: + return None + if isinstance(elem, (bool, float, int, str)): + return whatever + if isinstance(elem, tuple): + return T(collate_with_cat(x, lists=lists) for x in zip(*whatever)) + if isinstance(elem, dict): + return { + k: collate_with_cat([e[k] for e in whatever], lists=lists) for k in elem + } + + if isinstance(elem, torch.Tensor): + return listify(whatever) if lists else torch.cat(whatever) + if isinstance(elem, np.ndarray): + return ( + listify(whatever) + if lists + else torch.cat([torch.from_numpy(x) for x in whatever]) + ) + + return sum(whatever, T()) + + +def listify(elems): + return [x for e in elems for x in e] diff --git a/longstream/utils/vendor/dust3r/utils/geometry.py b/longstream/utils/vendor/dust3r/utils/geometry.py new file mode 100644 index 0000000000000000000000000000000000000000..ca2ef83fa59020905d61ee83de7bcfb58c0de64b --- /dev/null +++ b/longstream/utils/vendor/dust3r/utils/geometry.py @@ -0,0 +1,384 @@ +import numpy as np +import torch +from scipy.spatial import cKDTree as KDTree + +from longstream.utils.vendor.dust3r.utils.device import to_numpy +from longstream.utils.vendor.dust3r.utils.misc import invalid_to_nans, invalid_to_zeros + + +def xy_grid( + W, + H, + device=None, + origin=(0, 0), + unsqueeze=None, + cat_dim=-1, + homogeneous=False, + **arange_kw, +): + """Output a (H,W,2) array of int32 + with output[j,i,0] = i + origin[0] + output[j,i,1] = j + origin[1] + """ + if device is None: + + arange, meshgrid, stack, ones = np.arange, np.meshgrid, np.stack, np.ones + else: + + arange = lambda *a, **kw: torch.arange(*a, device=device, **kw) + meshgrid, stack = torch.meshgrid, torch.stack + ones = lambda *a: torch.ones(*a, device=device) + + tw, th = [arange(o, o + s, **arange_kw) for s, o in zip((W, H), origin)] + grid = meshgrid(tw, th, indexing="xy") + if homogeneous: + grid = grid + (ones((H, W)),) + if unsqueeze is not None: + grid = (grid[0].unsqueeze(unsqueeze), grid[1].unsqueeze(unsqueeze)) + if cat_dim is not None: + grid = stack(grid, cat_dim) + return grid + + +def geotrf(Trf, pts, ncol=None, norm=False): + """Apply a geometric transformation to a list of 3-D points. + + H: 3x3 or 4x4 projection matrix (typically a Homography) + p: numpy/torch/tuple of coordinates. Shape must be (...,2) or (...,3) + + ncol: int. number of columns of the result (2 or 3) + norm: float. if != 0, the resut is projected on the z=norm plane. + + Returns an array of projected 2d points. + """ + assert Trf.ndim >= 2 + if isinstance(Trf, np.ndarray): + pts = np.asarray(pts) + elif isinstance(Trf, torch.Tensor): + pts = torch.as_tensor(pts, dtype=Trf.dtype) + + output_reshape = pts.shape[:-1] + ncol = ncol or pts.shape[-1] + + if ( + isinstance(Trf, torch.Tensor) + and isinstance(pts, torch.Tensor) + and Trf.ndim == 3 + and pts.ndim == 4 + ): + d = pts.shape[3] + if Trf.shape[-1] == d: + pts = torch.einsum("bij, bhwj -> bhwi", Trf, pts) + elif Trf.shape[-1] == d + 1: + pts = ( + torch.einsum("bij, bhwj -> bhwi", Trf[:, :d, :d], pts) + + Trf[:, None, None, :d, d] + ) + else: + raise ValueError(f"bad shape, not ending with 3 or 4, for {pts.shape=}") + else: + if Trf.ndim >= 3: + n = Trf.ndim - 2 + assert Trf.shape[:n] == pts.shape[:n], "batch size does not match" + Trf = Trf.reshape(-1, Trf.shape[-2], Trf.shape[-1]) + + if pts.ndim > Trf.ndim: + + pts = pts.reshape(Trf.shape[0], -1, pts.shape[-1]) + elif pts.ndim == 2: + + pts = pts[:, None, :] + + if pts.shape[-1] + 1 == Trf.shape[-1]: + Trf = Trf.swapaxes(-1, -2) + pts = pts @ Trf[..., :-1, :] + Trf[..., -1:, :] + elif pts.shape[-1] == Trf.shape[-1]: + Trf = Trf.swapaxes(-1, -2) + pts = pts @ Trf + else: + pts = Trf @ pts.T + if pts.ndim >= 2: + pts = pts.swapaxes(-1, -2) + + if norm: + pts = pts / pts[..., -1:] + if norm != 1: + pts *= norm + + res = pts[..., :ncol].reshape(*output_reshape, ncol) + return res + + +def inv(mat): + """Invert a torch or numpy matrix""" + if isinstance(mat, torch.Tensor): + + if mat.dtype == torch.bfloat16: + mat = mat.to(torch.float32) + mat = torch.linalg.inv(mat) + return mat + else: + mat = torch.linalg.inv(mat) + return mat + if isinstance(mat, np.ndarray): + return np.linalg.inv(mat) + raise ValueError(f"bad matrix type = {type(mat)}") + + +def depthmap_to_pts3d(depth, pseudo_focal, pp=None, **_): + """ + Args: + - depthmap (BxHxW array): + - pseudo_focal: [B,H,W] ; [B,2,H,W] or [B,1,H,W] + Returns: + pointmap of absolute coordinates (BxHxWx3 array) + """ + + if len(depth.shape) == 4: + B, H, W, n = depth.shape + else: + B, H, W = depth.shape + n = None + + if len(pseudo_focal.shape) == 3: + pseudo_focalx = pseudo_focaly = pseudo_focal + elif len(pseudo_focal.shape) == 4: + pseudo_focalx = pseudo_focal[:, 0] + if pseudo_focal.shape[1] == 2: + pseudo_focaly = pseudo_focal[:, 1] + else: + pseudo_focaly = pseudo_focalx + else: + raise NotImplementedError("Error, unknown input focal shape format.") + + assert pseudo_focalx.shape == depth.shape[:3] + assert pseudo_focaly.shape == depth.shape[:3] + grid_x, grid_y = xy_grid(W, H, cat_dim=0, device=depth.device)[:, None] + + if pp is None: + grid_x = grid_x - (W - 1) / 2 + grid_y = grid_y - (H - 1) / 2 + else: + grid_x = grid_x.expand(B, -1, -1) - pp[:, 0, None, None] + grid_y = grid_y.expand(B, -1, -1) - pp[:, 1, None, None] + + if n is None: + pts3d = torch.empty((B, H, W, 3), device=depth.device) + pts3d[..., 0] = depth * grid_x / pseudo_focalx + pts3d[..., 1] = depth * grid_y / pseudo_focaly + pts3d[..., 2] = depth + else: + pts3d = torch.empty((B, H, W, 3, n), device=depth.device) + pts3d[..., 0, :] = depth * (grid_x / pseudo_focalx)[..., None] + pts3d[..., 1, :] = depth * (grid_y / pseudo_focaly)[..., None] + pts3d[..., 2, :] = depth + return pts3d + + +def depthmap_to_camera_coordinates(depthmap, camera_intrinsics, pseudo_focal=None): + """ + Args: + - depthmap (HxW array): + - camera_intrinsics: a 3x3 matrix + Returns: + pointmap of absolute coordinates (HxWx3 array), and a mask specifying valid pixels. + """ + camera_intrinsics = np.float32(camera_intrinsics) + H, W = depthmap.shape + + assert camera_intrinsics[0, 1] == 0.0 + assert camera_intrinsics[1, 0] == 0.0 + if pseudo_focal is None: + fu = camera_intrinsics[0, 0] + fv = camera_intrinsics[1, 1] + else: + assert pseudo_focal.shape == (H, W) + fu = fv = pseudo_focal + cu = camera_intrinsics[0, 2] + cv = camera_intrinsics[1, 2] + + u, v = np.meshgrid(np.arange(W), np.arange(H)) + z_cam = depthmap + x_cam = (u - cu) * z_cam / fu + y_cam = (v - cv) * z_cam / fv + X_cam = np.stack((x_cam, y_cam, z_cam), axis=-1).astype(np.float32) + + valid_mask = depthmap > 0.0 + return X_cam, valid_mask + + +def depthmap_to_absolute_camera_coordinates( + depthmap, camera_intrinsics, camera_pose, **kw +): + """ + Args: + - depthmap (HxW array): + - camera_intrinsics: a 3x3 matrix + - camera_pose: a 4x3 or 4x4 cam2world matrix + Returns: + pointmap of absolute coordinates (HxWx3 array), and a mask specifying valid pixels. + """ + X_cam, valid_mask = depthmap_to_camera_coordinates(depthmap, camera_intrinsics) + + R_cam2world = camera_pose[:3, :3] + t_cam2world = camera_pose[:3, 3] + + X_world = ( + np.einsum("ik, vuk -> vui", R_cam2world, X_cam) + t_cam2world[None, None, :] + ) + return X_world, valid_mask + + +def colmap_to_opencv_intrinsics(K): + """ + Modify camera intrinsics to follow a different convention. + Coordinates of the center of the top-left pixels are by default: + - (0.5, 0.5) in Colmap + - (0,0) in OpenCV + """ + K = K.copy() + K[0, 2] -= 0.5 + K[1, 2] -= 0.5 + return K + + +def opencv_to_colmap_intrinsics(K): + """ + Modify camera intrinsics to follow a different convention. + Coordinates of the center of the top-left pixels are by default: + - (0.5, 0.5) in Colmap + - (0,0) in OpenCV + """ + K = K.copy() + K[0, 2] += 0.5 + K[1, 2] += 0.5 + return K + + +def normalize_pointcloud(pts1, pts2, norm_mode="avg_dis", valid1=None, valid2=None): + """renorm pointmaps pts1, pts2 with norm_mode""" + assert pts1.ndim >= 3 and pts1.shape[-1] == 3 + assert pts2 is None or (pts2.ndim >= 3 and pts2.shape[-1] == 3) + norm_mode, dis_mode = norm_mode.split("_") + + if norm_mode == "avg": + + nan_pts1, nnz1 = invalid_to_zeros(pts1, valid1, ndim=3) + nan_pts2, nnz2 = ( + invalid_to_zeros(pts2, valid2, ndim=3) if pts2 is not None else (None, 0) + ) + all_pts = ( + torch.cat((nan_pts1, nan_pts2), dim=1) if pts2 is not None else nan_pts1 + ) + + all_dis = all_pts.norm(dim=-1) + if dis_mode == "dis": + pass + elif dis_mode == "log1p": + all_dis = torch.log1p(all_dis) + elif dis_mode == "warp-log1p": + + log_dis = torch.log1p(all_dis) + warp_factor = log_dis / all_dis.clip(min=1e-8) + H1, W1 = pts1.shape[1:-1] + pts1 = pts1 * warp_factor[:, : W1 * H1].view(-1, H1, W1, 1) + if pts2 is not None: + H2, W2 = pts2.shape[1:-1] + pts2 = pts2 * warp_factor[:, W1 * H1 :].view(-1, H2, W2, 1) + all_dis = log_dis + else: + raise ValueError(f"bad {dis_mode=}") + + norm_factor = all_dis.sum(dim=1) / (nnz1 + nnz2 + 1e-8) + else: + + nan_pts1 = invalid_to_nans(pts1, valid1, ndim=3) + nan_pts2 = invalid_to_nans(pts2, valid2, ndim=3) if pts2 is not None else None + all_pts = ( + torch.cat((nan_pts1, nan_pts2), dim=1) if pts2 is not None else nan_pts1 + ) + + all_dis = all_pts.norm(dim=-1) + + if norm_mode == "avg": + norm_factor = all_dis.nanmean(dim=1) + elif norm_mode == "median": + norm_factor = all_dis.nanmedian(dim=1).values.detach() + elif norm_mode == "sqrt": + norm_factor = all_dis.sqrt().nanmean(dim=1) ** 2 + else: + raise ValueError(f"bad {norm_mode=}") + + norm_factor = norm_factor.clip(min=1e-8) + while norm_factor.ndim < pts1.ndim: + norm_factor.unsqueeze_(-1) + + res = pts1 / norm_factor + if pts2 is not None: + res = (res, pts2 / norm_factor) + return res + + +@torch.no_grad() +def get_joint_pointcloud_depth(z1, z2, valid_mask1, valid_mask2=None, quantile=0.5): + + _z1 = invalid_to_nans(z1, valid_mask1).reshape(len(z1), -1) + _z2 = ( + invalid_to_nans(z2, valid_mask2).reshape(len(z2), -1) + if z2 is not None + else None + ) + _z = torch.cat((_z1, _z2), dim=-1) if z2 is not None else _z1 + + if quantile == 0.5: + shift_z = torch.nanmedian(_z, dim=-1).values + else: + shift_z = torch.nanquantile(_z, quantile, dim=-1) + return shift_z + + +@torch.no_grad() +def get_joint_pointcloud_center_scale( + pts1, pts2, valid_mask1=None, valid_mask2=None, z_only=False, center=True +): + + _pts1 = invalid_to_nans(pts1, valid_mask1).reshape(len(pts1), -1, 3) + _pts2 = ( + invalid_to_nans(pts2, valid_mask2).reshape(len(pts2), -1, 3) + if pts2 is not None + else None + ) + _pts = torch.cat((_pts1, _pts2), dim=1) if pts2 is not None else _pts1 + + _center = torch.nanmedian(_pts, dim=1, keepdim=True).values + if z_only: + _center[..., :2] = 0 + + _norm = ((_pts - _center) if center else _pts).norm(dim=-1) + scale = torch.nanmedian(_norm, dim=1).values + return _center[:, None, :, :], scale[:, None, None, None] + + +def find_reciprocal_matches(P1, P2): + """ + returns 3 values: + 1 - reciprocal_in_P2: a boolean array of size P2.shape[0], a "True" value indicates a match + 2 - nn2_in_P1: a int array of size P2.shape[0], it contains the indexes of the closest points in P1 + 3 - reciprocal_in_P2.sum(): the number of matches + """ + tree1 = KDTree(P1) + tree2 = KDTree(P2) + + _, nn1_in_P2 = tree2.query(P1, workers=8) + _, nn2_in_P1 = tree1.query(P2, workers=8) + + reciprocal_in_P1 = nn2_in_P1[nn1_in_P2] == np.arange(len(nn1_in_P2)) + reciprocal_in_P2 = nn1_in_P2[nn2_in_P1] == np.arange(len(nn2_in_P1)) + assert reciprocal_in_P1.sum() == reciprocal_in_P2.sum() + return reciprocal_in_P2, nn2_in_P1, reciprocal_in_P2.sum() + + +def get_med_dist_between_poses(poses): + from scipy.spatial.distance import pdist + + return np.median(pdist([to_numpy(p[:3, 3]) for p in poses])) diff --git a/longstream/utils/vendor/dust3r/utils/image.py b/longstream/utils/vendor/dust3r/utils/image.py new file mode 100644 index 0000000000000000000000000000000000000000..0e0becd51b881b36df47151ebef6525c8231059a --- /dev/null +++ b/longstream/utils/vendor/dust3r/utils/image.py @@ -0,0 +1,233 @@ +import os + +import numpy as np +import PIL.Image +import torch +import torchvision.transforms as tvf +from PIL.ImageOps import exif_transpose + +os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1" +import cv2 + +try: + from pillow_heif import register_heif_opener + + register_heif_opener() + heif_support_enabled = True +except ImportError: + heif_support_enabled = False + +ImgNorm = tvf.Compose([tvf.ToTensor(), tvf.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) + + +def imread_cv2(path, options=cv2.IMREAD_COLOR): + """Open an image or a depthmap with opencv-python.""" + if path.endswith((".exr", "EXR")): + options = cv2.IMREAD_ANYDEPTH + img = cv2.imread(path, options) + if img is None: + raise IOError(f"Could not load image={path} with {options=}") + if img.ndim == 3: + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + return img + + +def rgb(ftensor, true_shape=None): + if isinstance(ftensor, list): + return [rgb(x, true_shape=true_shape) for x in ftensor] + if isinstance(ftensor, torch.Tensor): + ftensor = ftensor.detach().cpu().numpy() + if ftensor.ndim == 3 and ftensor.shape[0] == 3: + ftensor = ftensor.transpose(1, 2, 0) + elif ftensor.ndim == 4 and ftensor.shape[1] == 3: + ftensor = ftensor.transpose(0, 2, 3, 1) + if true_shape is not None: + H, W = true_shape + ftensor = ftensor[:H, :W] + if ftensor.dtype == np.uint8: + img = np.float32(ftensor) / 255 + else: + img = (ftensor * 0.5) + 0.5 + return img.clip(min=0, max=1) + + +def _resize_pil_image(img, long_edge_size): + S = max(img.size) + if S > long_edge_size: + interp = PIL.Image.LANCZOS + elif S <= long_edge_size: + interp = PIL.Image.BICUBIC + new_size = tuple(int(round(x * long_edge_size / S)) for x in img.size) + return img.resize(new_size, interp) + + +def load_images( + folder_or_list, + size, + square_ok=False, + verbose=True, + rotate_clockwise_90=False, + crop_to_landscape=False, + patch_size=16, +): + """open and convert all images in a list or folder to proper input format for DUSt3R""" + if isinstance(folder_or_list, str): + if verbose: + print(f">> Loading images from {folder_or_list}") + root, folder_content = folder_or_list, sorted(os.listdir(folder_or_list)) + + elif isinstance(folder_or_list, list): + if verbose: + print(f">> Loading a list of {len(folder_or_list)} images") + root, folder_content = "", folder_or_list + + else: + raise ValueError(f"bad {folder_or_list=} ({type(folder_or_list)})") + + supported_images_extensions = [".jpg", ".jpeg", ".png"] + if heif_support_enabled: + supported_images_extensions += [".heic", ".heif"] + supported_images_extensions = tuple(supported_images_extensions) + + imgs = [] + for path in folder_content: + if not path.lower().endswith(supported_images_extensions): + continue + img = exif_transpose(PIL.Image.open(os.path.join(root, path))).convert("RGB") + + if rotate_clockwise_90: + img = img.rotate(-90, expand=True) + + if crop_to_landscape: + + desired_aspect_ratio = 4 / 3 + width, height = img.size + current_aspect_ratio = width / height + + if current_aspect_ratio > desired_aspect_ratio: + + new_width = int(height * desired_aspect_ratio) + left = (width - new_width) // 2 + right = left + new_width + top = 0 + bottom = height + else: + + new_height = int(width / desired_aspect_ratio) + top = (height - new_height) // 2 + bottom = top + new_height + left = 0 + right = width + + img = img.crop((left, top, right, bottom)) + + W1, H1 = img.size + if size == 224: + + img = _resize_pil_image(img, round(size * max(W1 / H1, H1 / W1))) + else: + + img = _resize_pil_image(img, size) + W, H = img.size + cx, cy = W // 2, H // 2 + if size == 224: + half = min(cx, cy) + img = img.crop((cx - half, cy - half, cx + half, cy + half)) + else: + + halfw, halfh = ((2 * cx) // patch_size) * patch_size // 2, ( + (2 * cy) // patch_size + ) * patch_size // 2 + if not (square_ok) and W == H: + halfh = 3 * halfw / 4 + img = img.crop((cx - halfw, cy - halfh, cx + halfw, cy + halfh)) + + W2, H2 = img.size + if verbose: + print(f" - adding {path} with resolution {W1}x{H1} --> {W2}x{H2}") + + true_shape = [img.size[::-1]] + imgs.append( + dict( + img=ImgNorm(img)[None], + true_shape=np.int32(true_shape), + idx=len(imgs), + instance=str(len(imgs)), + ) + ) + + assert imgs, "no images foud at " + root + if verbose: + print(f" (Found {len(imgs)} images)") + return imgs + + +def load_images_for_eval( + folder_or_list, size, square_ok=False, verbose=True, crop=True, patch_size=16 +): + """open and convert all images in a list or folder to proper input format for DUSt3R""" + if isinstance(folder_or_list, str): + if verbose: + print(f">> Loading images from {folder_or_list}") + root, folder_content = folder_or_list, sorted(os.listdir(folder_or_list)) + + elif isinstance(folder_or_list, list): + if verbose: + print(f">> Loading a list of {len(folder_or_list)} images") + root, folder_content = "", folder_or_list + + else: + raise ValueError(f"bad {folder_or_list=} ({type(folder_or_list)})") + + supported_images_extensions = [".jpg", ".jpeg", ".png"] + if heif_support_enabled: + supported_images_extensions += [".heic", ".heif"] + supported_images_extensions = tuple(supported_images_extensions) + + imgs = [] + for i, path in enumerate(folder_content): + if not path.lower().endswith(supported_images_extensions): + continue + img = exif_transpose(PIL.Image.open(os.path.join(root, path))).convert("RGB") + W1, H1 = img.size + if size == 224: + + img = _resize_pil_image(img, round(size * max(W1 / H1, H1 / W1))) + else: + + img = _resize_pil_image(img, size) + W, H = img.size + cx, cy = W // 2, H // 2 + if size == 224: + half = min(cx, cy) + if crop: + img = img.crop((cx - half, cy - half, cx + half, cy + half)) + else: + img = img.resize((2 * half, 2 * half), PIL.Image.LANCZOS) + else: + halfw, halfh = ((2 * cx) // patch_size) * (patch_size // 2), ( + (2 * cy) // patch_size + ) * (patch_size // 2) + if not (square_ok) and W == H: + halfh = 3 * halfw / 4 + if crop: + img = img.crop((cx - halfw, cy - halfh, cx + halfw, cy + halfh)) + else: + img = img.resize((2 * halfw, 2 * halfh), PIL.Image.LANCZOS) + W2, H2 = img.size + if verbose: + print(f" - adding {path} with resolution {W1}x{H1} --> {W2}x{H2}") + + imgs.append( + dict( + img=ImgNorm(img)[None], + true_shape=np.int32([img.size[::-1]]), + idx=len(imgs), + instance=str(len(imgs)), + ) + ) + + assert imgs, "no images foud at " + root + if verbose: + print(f" (Found {len(imgs)} images)") + return imgs diff --git a/longstream/utils/vendor/dust3r/utils/misc.py b/longstream/utils/vendor/dust3r/utils/misc.py new file mode 100644 index 0000000000000000000000000000000000000000..1cc665a8b773997352ba1002fdf0b34b8fd82d97 --- /dev/null +++ b/longstream/utils/vendor/dust3r/utils/misc.py @@ -0,0 +1,120 @@ +import torch + + +def fill_default_args(kwargs, func): + import inspect + + signature = inspect.signature(func) + + for k, v in signature.parameters.items(): + if v.default is inspect.Parameter.empty: + continue + kwargs.setdefault(k, v.default) + + return kwargs + + +def freeze_all_params(modules): + for module in modules: + try: + for n, param in module.named_parameters(): + param.requires_grad = False + except AttributeError: + + module.requires_grad = False + + +def is_symmetrized(gt1, gt2): + x = gt1["instance"] + y = gt2["instance"] + if len(x) == len(y) and len(x) == 1: + return False + ok = True + for i in range(0, len(x), 2): + ok = ok and (x[i] == y[i + 1]) and (x[i + 1] == y[i]) + return ok + + +def flip(tensor): + """flip so that tensor[0::2] <=> tensor[1::2]""" + return torch.stack((tensor[1::2], tensor[0::2]), dim=1).flatten(0, 1) + + +def interleave(tensor1, tensor2): + res1 = torch.stack((tensor1, tensor2), dim=1).flatten(0, 1) + res2 = torch.stack((tensor2, tensor1), dim=1).flatten(0, 1) + return res1, res2 + + +def transpose_to_landscape(head, activate=True): + """Predict in the correct aspect-ratio, + then transpose the result in landscape + and stack everything back together. + """ + + def wrapper_no(decout, true_shape): + B = len(true_shape) + assert true_shape[0:1].allclose(true_shape), "true_shape must be all identical" + H, W = true_shape[0].cpu().tolist() + res = head(decout, (H, W)) + return res + + def wrapper_yes(decout, true_shape): + + if not head.training: + return wrapper_no(decout, true_shape) + + B = len(true_shape) + + H, W = int(true_shape.min()), int(true_shape.max()) + + height, width = true_shape.T + is_landscape = width >= height + is_portrait = ~is_landscape + + if is_landscape.all(): + return head(decout, (H, W)) + if is_portrait.all(): + return transposed(head(decout, (W, H))) + + def selout(ar): + return [d[ar] for d in decout] + + l_result = head(selout(is_landscape), (H, W)) + p_result = transposed(head(selout(is_portrait), (W, H))) + + result = {} + for k in l_result | p_result: + x = l_result[k].new(B, *l_result[k].shape[1:]) + x[is_landscape] = l_result[k] + x[is_portrait] = p_result[k] + result[k] = x + + return result + + return wrapper_yes if activate else wrapper_no + + +def transposed(dic): + return {k: v.swapaxes(1, 2) for k, v in dic.items()} + + +def invalid_to_nans(arr, valid_mask, ndim=999): + if valid_mask is not None: + arr = arr.clone() + arr[~valid_mask] = float("nan") + if arr.ndim > ndim: + arr = arr.flatten(-2 - (arr.ndim - ndim), -2) + return arr + + +def invalid_to_zeros(arr, valid_mask, ndim=999): + if valid_mask is not None: + arr = arr.clone() + arr[~valid_mask] = 0 + nnz = valid_mask.view(len(valid_mask), -1).sum(1) + else: + nnz = arr.numel() // len(arr) if len(arr) else 0 + if arr.ndim > ndim: + arr = arr.flatten(-2 - (arr.ndim - ndim), -2) + return arr, nnz diff --git a/longstream/utils/vendor/dust3r/utils/parallel.py b/longstream/utils/vendor/dust3r/utils/parallel.py new file mode 100644 index 0000000000000000000000000000000000000000..38e39dbb967d892768d72c1a107d2f9843bd71af --- /dev/null +++ b/longstream/utils/vendor/dust3r/utils/parallel.py @@ -0,0 +1,81 @@ +from tqdm import tqdm +from multiprocessing.dummy import Pool as ThreadPool +from multiprocessing import cpu_count + + +def parallel_threads( + function, + args, + workers=0, + star_args=False, + kw_args=False, + front_num=1, + Pool=ThreadPool, + **tqdm_kw +): + """tqdm but with parallel execution. + + Will essentially return + res = [ function(arg) # default + function(*arg) # if star_args is True + function(**arg) # if kw_args is True + for arg in args] + + Note: + the first elements of args will not be parallelized. + This can be useful for debugging. + """ + while workers <= 0: + workers += cpu_count() + if workers == 1: + front_num = float("inf") + + try: + n_args_parallel = len(args) - front_num + except TypeError: + n_args_parallel = None + args = iter(args) + + front = [] + while len(front) < front_num: + try: + a = next(args) + except StopIteration: + return front + front.append( + function(*a) if star_args else function(**a) if kw_args else function(a) + ) + + out = [] + with Pool(workers) as pool: + + if star_args: + futures = pool.imap(starcall, [(function, a) for a in args]) + elif kw_args: + futures = pool.imap(starstarcall, [(function, a) for a in args]) + else: + futures = pool.imap(function, args) + + for f in tqdm(futures, total=n_args_parallel, **tqdm_kw): + out.append(f) + return front + out + + +def parallel_processes(*args, **kwargs): + """Same as parallel_threads, with processes""" + import multiprocessing as mp + + kwargs["Pool"] = mp.Pool + return parallel_threads(*args, **kwargs) + + +def starcall(args): + """convenient wrapper for Process.Pool""" + function, args = args + return function(*args) + + +def starstarcall(args): + """convenient wrapper for Process.Pool""" + function, args = args + return function(**args) diff --git a/longstream/utils/vendor/dust3r/utils/path_to_croco.py b/longstream/utils/vendor/dust3r/utils/path_to_croco.py new file mode 100644 index 0000000000000000000000000000000000000000..c657efb1c2d915e9d3b45b9ecccc4fcfbc90ddae --- /dev/null +++ b/longstream/utils/vendor/dust3r/utils/path_to_croco.py @@ -0,0 +1,15 @@ +import os.path as path +import sys + +HERE_PATH = path.normpath(path.dirname(__file__)) +CROCO_REPO_PATH = path.normpath(path.join(HERE_PATH, "../../croco")) +CROCO_MODELS_PATH = path.join(CROCO_REPO_PATH, "models") + +if path.isdir(CROCO_MODELS_PATH): + + sys.path.insert(0, CROCO_REPO_PATH) +else: + raise ImportError( + f"croco is not initialized, could not find: {CROCO_MODELS_PATH}.\n " + "Did you forget to run 'git submodule update --init --recursive' ?" + ) diff --git a/longstream/utils/vendor/losses/__init__.py b/longstream/utils/vendor/losses/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9c710569ac0e32343fa2e78c0d9d93104fa17fdf --- /dev/null +++ b/longstream/utils/vendor/losses/__init__.py @@ -0,0 +1,6 @@ +__all__ = [ + "feature_reprojection_energy", + "energy_drop_loss", +] + +from .reprojection import feature_reprojection_energy, energy_drop_loss diff --git a/longstream/utils/vendor/losses/reprojection.py b/longstream/utils/vendor/losses/reprojection.py new file mode 100644 index 0000000000000000000000000000000000000000..7d43dcac4adccd2e11fe9f12b0882d63638dcdb9 --- /dev/null +++ b/longstream/utils/vendor/losses/reprojection.py @@ -0,0 +1,101 @@ +from typing import Optional, Tuple + +import torch +import torch.nn.functional as F + + +def project_points( + points: torch.Tensor, K: torch.Tensor +) -> Tuple[torch.Tensor, torch.Tensor]: + z = points[..., 2].clamp(min=1e-6) + x = points[..., 0] / z + y = points[..., 1] / z + + fx = K[:, 0, 0].unsqueeze(-1) + fy = K[:, 1, 1].unsqueeze(-1) + cx = K[:, 0, 2].unsqueeze(-1) + cy = K[:, 1, 2].unsqueeze(-1) + + u = fx * x + cx + v = fy * y + cy + return torch.stack((u, v), dim=-1), z + + +def sample_features(features: torch.Tensor, uv: torch.Tensor) -> torch.Tensor: + """Bilinear sample feature maps at pixel-space coordinates.""" + B, C, H, W = features.shape + u, v = uv[..., 0], uv[..., 1] + grid_u = 2.0 * (u / (W - 1)) - 1.0 + grid_v = 2.0 * (v / (H - 1)) - 1.0 + grid = torch.stack((grid_u, grid_v), dim=-1).view(B, -1, 1, 2) + sampled = F.grid_sample( + features, + grid, + mode="bilinear", + padding_mode="zeros", + align_corners=True, + ) + return sampled.squeeze(-1).permute(0, 2, 1) + + +def feature_reprojection_energy( + feat1: torch.Tensor, + feat2: torch.Tensor, + depth1: torch.Tensor, + T1_to_2: torch.Tensor, + K1: torch.Tensor, + K2: torch.Tensor, + n_samples: int = 1024, + mask: Optional[torch.Tensor] = None, +) -> torch.Tensor: + """ + Compute a photometric/feature reprojection energy using randomly sampled pixels. + """ + B, C, H, W = feat1.shape + device = feat1.device + + ys = torch.randint(0, H, (B, n_samples), device=device) + xs = torch.randint(0, W, (B, n_samples), device=device) + flat_idx = ys * W + xs + + feat1_flat = feat1.view(B, C, -1).permute(0, 2, 1) + f1 = torch.gather(feat1_flat, 1, flat_idx.unsqueeze(-1).expand(-1, -1, C)) + + depth_flat = depth1.view(B, -1) + d1 = torch.gather(depth_flat, 1, flat_idx) + + u = xs.float() + v = ys.float() + fx = K1[:, 0, 0].unsqueeze(-1) + fy = K1[:, 1, 1].unsqueeze(-1) + cx = K1[:, 0, 2].unsqueeze(-1) + cy = K1[:, 1, 2].unsqueeze(-1) + + X = torch.stack(((u - cx) / fx * d1, (v - cy) / fy * d1, d1), dim=-1) + + R = T1_to_2[:, :3, :3] + t = T1_to_2[:, :3, 3] + X_transformed = (R @ X.transpose(1, 2)).transpose(1, 2) + t.unsqueeze(1) + + uv2, z2 = project_points(X_transformed, K2) + f2 = sample_features(feat2, uv2) + + diff = (f1 - f2).pow(2).sum(dim=-1) + + valid = z2 > 1e-6 + valid = valid & (uv2[..., 0] >= 0.0) & (uv2[..., 0] <= W - 1) + valid = valid & (uv2[..., 1] >= 0.0) & (uv2[..., 1] <= H - 1) + if mask is not None: + valid = valid & mask.bool() + + valid_f = valid.float() + energy = torch.sqrt(diff + 1e-4) * valid_f + denom = valid_f.sum(dim=-1).clamp(min=1.0) + energy = energy.sum(dim=-1) / denom + return energy.mean() + + +def energy_drop_loss( + energy_before: torch.Tensor, energy_after: torch.Tensor, margin: float = 0.0 +) -> torch.Tensor: + return torch.clamp(energy_after - energy_before + margin, min=0.0) diff --git a/longstream/utils/vendor/models/__init__.py b/longstream/utils/vendor/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..139597f9cb07c5d48bed18984ec4747f4b4f3438 --- /dev/null +++ b/longstream/utils/vendor/models/__init__.py @@ -0,0 +1,2 @@ + + diff --git a/longstream/utils/vendor/models/components/aggregator/streamaggregator.py b/longstream/utils/vendor/models/components/aggregator/streamaggregator.py new file mode 100644 index 0000000000000000000000000000000000000000..e7d2fb6f8775e0c4016be9c520f0e175b12b520c --- /dev/null +++ b/longstream/utils/vendor/models/components/aggregator/streamaggregator.py @@ -0,0 +1,857 @@ +import logging +import torch +import torch.nn as nn +from typing import Tuple, List, Optional, Union +from torch.utils.checkpoint import checkpoint + +from ..layers import PatchEmbed +from ..layers.block import Block +from ..layers.rope import ( + RotaryPositionEmbedding2D, + RotaryPositionEmbedding3D, + PositionGetter, +) +from ..layers.vision_transformer import vit_small, vit_base, vit_large, vit_giant2 + +logger = logging.getLogger(__name__) + +_RESNET_MEAN = [0.485, 0.456, 0.406] +_RESNET_STD = [0.229, 0.224, 0.225] + + +class STreamAggregator(nn.Module): + def __init__( + self, + img_size=518, + patch_size=14, + embed_dim=1024, + depth=24, + num_heads=16, + mlp_ratio=4.0, + num_register_tokens=4, + block_fn=Block, + qkv_bias=True, + proj_bias=True, + ffn_bias=True, + patch_embed="dinov2_vitl14_reg", + aa_order=["frame", "global"], + aa_block_size=1, + qk_norm=True, + rope_freq=100, + init_values=0.01, + use_role_embedding=True, + disable_keyframe_distinction=False, + keyframe_stride=8, + use_segment_mask=False, + use_3d_rope=False, + window_size=5000, + ): + super().__init__() + + self.__build_patch_embed__( + patch_embed, img_size, patch_size, num_register_tokens, embed_dim=embed_dim + ) + + if rope_freq > 0: + if use_3d_rope: + self.rope = RotaryPositionEmbedding3D(frequency=rope_freq) + else: + self.rope = RotaryPositionEmbedding2D(frequency=rope_freq) + else: + self.rope = None + + self.position_getter = PositionGetter() if self.rope is not None else None + self.use_3d_rope = use_3d_rope + + self.frame_blocks = nn.ModuleList( + [ + block_fn( + dim=embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + proj_bias=proj_bias, + ffn_bias=ffn_bias, + init_values=init_values, + qk_norm=qk_norm, + rope=self.rope, + ) + for _ in range(depth) + ] + ) + + self.global_blocks = nn.ModuleList( + [ + block_fn( + dim=embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + proj_bias=proj_bias, + ffn_bias=ffn_bias, + init_values=init_values, + qk_norm=qk_norm, + rope=self.rope, + ) + for _ in range(depth) + ] + ) + + self.depth = depth + self.aa_order = aa_order + self.patch_size = patch_size + self.aa_block_size = aa_block_size + + if self.depth % self.aa_block_size != 0: + raise ValueError( + f"depth ({depth}) must be divisible by aa_block_size ({aa_block_size})" + ) + + self.aa_block_num = self.depth // self.aa_block_size + self.use_role_embedding = use_role_embedding + self.disable_keyframe_distinction = disable_keyframe_distinction + self.num_register_tokens = num_register_tokens + self.use_segment_mask = use_segment_mask + self.window_size = 50000 + + self.camera_token_norm = nn.Parameter(torch.randn(1, 1, embed_dim)) + self.register_token_norm = nn.Parameter( + torch.randn(1, num_register_tokens, embed_dim) + ) + nn.init.normal_(self.camera_token_norm, std=1e-6) + nn.init.normal_(self.register_token_norm, std=1e-6) + + if not disable_keyframe_distinction: + self.camera_token_key = nn.Parameter(torch.randn(1, 1, embed_dim)) + self.register_token_key = nn.Parameter( + torch.randn(1, num_register_tokens, embed_dim) + ) + nn.init.normal_(self.camera_token_key, std=1e-6) + nn.init.normal_(self.register_token_key, std=1e-6) + else: + + self.camera_token_key = None + self.register_token_key = None + + self.camera_token = nn.Parameter(torch.randn(1, 2, 1, embed_dim)) + self.register_token = nn.Parameter( + torch.randn(1, 2, num_register_tokens, embed_dim) + ) + nn.init.normal_(self.camera_token, std=1e-6) + nn.init.normal_(self.register_token, std=1e-6) + + if self.use_role_embedding: + self.role_embed_key = nn.Parameter(torch.randn(1, 1, embed_dim)) + self.role_embed_norm = nn.Parameter(torch.randn(1, 1, embed_dim)) + nn.init.normal_(self.role_embed_key, std=0.02) + nn.init.normal_(self.role_embed_norm, std=0.02) + + self.patch_start_idx = 1 + num_register_tokens + + for name, value in ( + ("_resnet_mean", _RESNET_MEAN), + ("_resnet_std", _RESNET_STD), + ): + self.register_buffer( + name, + torch.FloatTensor(value).view(1, 1, 3, 1, 1), + persistent=False, + ) + + def __build_patch_embed__( + self, + patch_embed, + img_size, + patch_size, + num_register_tokens, + interpolate_antialias=True, + interpolate_offset=0.0, + block_chunks=0, + init_values=1.0, + embed_dim=1024, + ): + """ + Build the patch embed layer. If 'conv', we use a + simple PatchEmbed conv layer. Otherwise, we use a vision transformer. + """ + + if "conv" in patch_embed: + self.patch_embed = PatchEmbed( + img_size=img_size, + patch_size=patch_size, + in_chans=3, + embed_dim=embed_dim, + ) + else: + vit_models = { + "dinov2_vitl14_reg": vit_large, + "dinov2_vitb14_reg": vit_base, + "dinov2_vits14_reg": vit_small, + "dinov2_vitg2_reg": vit_giant2, + } + + self.patch_embed = vit_models[patch_embed]( + img_size=img_size, + patch_size=patch_size, + num_register_tokens=num_register_tokens, + interpolate_antialias=interpolate_antialias, + interpolate_offset=interpolate_offset, + block_chunks=block_chunks, + init_values=init_values, + ) + + if hasattr(self.patch_embed, "mask_token"): + self.patch_embed.mask_token.requires_grad_(False) + + def _create_attn_mask( + self, + S: int, + P: int, + mode: str, + dtype: torch.dtype, + device: torch.device, + reorder_indices: Optional[torch.Tensor] = None, + is_keyframe: Optional[torch.Tensor] = None, + keyframe_indices: Optional[torch.Tensor] = None, + ): + """ + Create attention mask based on mode and optionally adjust for reordering. + + Args: + S: Sequence length + P: Tokens per frame + mode: "causal", "window", or "full" + dtype: Data type + device: Device + reorder_indices: Optional reordering indices [B*S] for keyframe-first ordering + is_keyframe: Optional keyframe mask [B, S] + keyframe_indices: Optional reference keyframe indices [B, S] + + Returns: + Attention mask [N, N] where N = S * P (or [B, 1, N, N] if segment-aware) + """ + N = S * P + + if mode == "full": + return None + + if ( + self.use_segment_mask + and is_keyframe is not None + and keyframe_indices is not None + ): + B = is_keyframe.shape[0] + + should_print = False + if should_print: + print(f"\n[Aggregator Segment Mask DEBUG]") + print(f" use_segment_mask={self.use_segment_mask}") + print(f" Mode: {mode}") + print(f" Sequence length: {S}, Tokens per frame: {P}") + print(f" is_keyframe[0]: {is_keyframe[0].tolist()}") + print(f" keyframe_indices[0]: {keyframe_indices[0].tolist()}") + + ref = keyframe_indices + + idx = torch.arange(S, device=device) + j_indices = torch.arange(S, device=device).view(1, 1, S).expand(B, S, -1) + + is_ref_frame = j_indices == ref[:, :, None] + same_ref = ref[:, :, None] == ref[:, None, :] + can_attend_nonkf = is_ref_frame | same_ref + + if mode == "causal": + causal_mask = idx[None, :, None] >= idx[None, None, :] + + prev_kf_mask = torch.zeros(B, S, S, dtype=torch.bool, device=device) + for b in range(B): + kf_positions = [i for i in range(S) if is_keyframe[b, i]] + for kf_idx, kf_pos in enumerate(kf_positions): + if kf_idx == 0: + prev_kf_mask[b, kf_pos, : kf_pos + 1] = True + else: + prev_kf_pos = kf_positions[kf_idx - 1] + prev_kf_mask[b, kf_pos, prev_kf_pos : kf_pos + 1] = True + + can_attend_kf = causal_mask.expand(B, -1, -1) & prev_kf_mask + mode_constraint = causal_mask + elif mode == "window": + causal_mask = idx[None, :, None] >= idx[None, None, :] + window_mask = ( + idx[None, :, None] - idx[None, None, :] + ) < self.window_size + window_causal = causal_mask & window_mask + can_attend_kf = window_causal.expand(B, -1, -1) + mode_constraint = window_causal + else: + raise NotImplementedError(f"Unknown mode: {mode}") + + is_kf_expanded = is_keyframe[:, :, None].expand(-1, -1, S) + can_attend = torch.where(is_kf_expanded, can_attend_kf, can_attend_nonkf) + + mask_bool = can_attend & mode_constraint + + mask_bool_expanded = torch.zeros(B, N, N, dtype=torch.bool, device=device) + for i in range(S): + for j in range(S): + mask_bool_expanded[ + :, i * P : (i + 1) * P, j * P : (j + 1) * P + ] = mask_bool[:, i : i + 1, j : j + 1] + + zero = torch.zeros(1, dtype=dtype, device=device) + neg_inf = torch.tensor(float("-inf"), dtype=dtype, device=device) + segment_mask = torch.where(mask_bool_expanded, zero, neg_inf).unsqueeze(1) + + if should_print: + + kf_positions = [i for i in range(S) if is_keyframe[0, i]] + if len(kf_positions) >= 2: + + kf_pos = kf_positions[1] + visible_frames = [] + for j in range(S): + if mask_bool[0, kf_pos, j]: + visible_frames.append(j) + print( + f" Frame {kf_pos} (keyframe) can attend to frames: {visible_frames}" + ) + + if kf_pos + 1 < S: + post_pos = kf_pos + 1 + visible_frames = [] + for j in range(S): + if mask_bool[0, post_pos, j]: + visible_frames.append(j) + print( + f" Frame {post_pos} (post-switch) can attend to frames: {visible_frames}" + ) + print(f" ✅ Segment mask is working correctly!") + + return segment_mask + + if mode == "causal": + mask_original = torch.zeros((N, N), dtype=dtype, device=device) + for i in range(S): + curr_view_start = i * P + curr_view_end = (i + 1) * P + + mask_original[curr_view_start:curr_view_end, curr_view_end:] = float( + "-inf" + ) + elif mode == "window": + mask_original = torch.zeros((N, N), dtype=dtype, device=device) + for i in range(S): + curr_view_start = i * P + curr_view_end = (i + 1) * P + + mask_original[curr_view_start:curr_view_end, P:] = float("-inf") + + start_view = max(1, i - self.window_size + 1) + mask_original[ + curr_view_start:curr_view_end, start_view * P : (i + 1) * P + ] = 0 + else: + raise NotImplementedError(f"Unknown attention mode: {mode}") + + if reorder_indices is not None: + + mask_reordered = torch.zeros_like(mask_original) + + for new_i in range(S): + for new_j in range(S): + orig_i = reorder_indices[new_i].item() + orig_j = reorder_indices[new_j].item() + + mask_reordered[ + new_i * P : (new_i + 1) * P, new_j * P : (new_j + 1) * P + ] = mask_original[ + orig_i * P : (orig_i + 1) * P, orig_j * P : (orig_j + 1) * P + ] + + return mask_reordered + + return mask_original + + def forward( + self, + images: torch.Tensor, + mode: str = "causal", + kv_cache_list: Optional[List[List[torch.Tensor]]] = None, + is_keyframe: Optional[torch.Tensor] = None, + keyframe_indices: Optional[torch.Tensor] = None, + additional_tokens: Optional[torch.Tensor] = None, + reorder_keyframes_first: bool = False, + ) -> Union[ + Tuple[List[torch.Tensor], int], + Tuple[List[torch.Tensor], int, List[List[torch.Tensor]]], + ]: + """ + Args: + images (torch.Tensor): Input images with shape [B, S, 3, H, W], in range [0, 1]. + B: batch size, S: sequence length, 3: RGB channels, H: height, W: width + mode (str): Global attention mode, could be either "causal", "window" or "full" + kv_cache_list (List[List[torch.Tensor]]): List of cached key-value pairs for + each global attention layer of the aggregator + is_keyframe (torch.Tensor): Boolean tensor indicating keyframes [B, S] + keyframe_indices (torch.Tensor): Reference keyframe indices for each frame [B, S] + additional_tokens (torch.Tensor): Additional tokens to insert (e.g., scale token) [B, C, T] + reorder_keyframes_first (bool): If True, reorder tokens so keyframes come first + + Returns: + (list[torch.Tensor], int): + The list of outputs from the attention blocks, + and the patch_start_idx indicating where patch tokens begin. + """ + B, S, C_in, H, W = images.shape + + if C_in != 3: + raise ValueError(f"Expected 3 input channels, got {C_in}") + + images = (images - self._resnet_mean) / self._resnet_std + + images = images.view(B * S, C_in, H, W) + + patch_tokens = self.patch_embed(images) + + if isinstance(patch_tokens, dict): + patch_tokens = patch_tokens["x_norm_patchtokens"] + + _, P, C = patch_tokens.shape + + if is_keyframe is not None: + + camera_token, register_token = self._select_role_tokens(B, S, is_keyframe) + else: + + is_anchor_exist = kv_cache_list is None or kv_cache_list[0][0] is None + camera_token = slice_expand_and_flatten( + self.camera_token, B, S, is_anchor_exist=is_anchor_exist + ) + register_token = slice_expand_and_flatten( + self.register_token, B, S, is_anchor_exist=is_anchor_exist + ) + + if additional_tokens is not None: + + T = additional_tokens.shape[-1] + additional_tokens_expanded = ( + additional_tokens.unsqueeze(1).repeat(1, S, 1, 1).view(B * S, T, C) + ) + + tokens = torch.cat( + [ + camera_token, + register_token, + additional_tokens_expanded, + patch_tokens, + ], + dim=1, + ) + patch_start_idx_with_additional = self.patch_start_idx + T + + else: + tokens = torch.cat([camera_token, register_token, patch_tokens], dim=1) + patch_start_idx_with_additional = self.patch_start_idx + + if ( + self.use_role_embedding + and not self.disable_keyframe_distinction + and is_keyframe is not None + ): + P_patch = patch_tokens.shape[1] + tokens = self._apply_role_embedding(tokens, B, S, P_patch, is_keyframe) + + pos = None + if self.rope is not None and self.position_getter is not None: + if self.use_3d_rope: + + pos = self.position_getter.get_3d_positions( + B, + S, + H // self.patch_size, + W // self.patch_size, + device=images.device, + ) + else: + + pos = self.position_getter( + B * S, + H // self.patch_size, + W // self.patch_size, + device=images.device, + ) + + if patch_start_idx_with_additional > 0 and pos is not None: + + pos = pos + 1 + pos_dim = 3 if self.use_3d_rope else 2 + pos_special = ( + torch.zeros(B * S, patch_start_idx_with_additional, pos_dim) + .to(images.device) + .to(pos.dtype) + ) + + if self.use_3d_rope: + + temporal_indices = torch.arange( + S, device=images.device, dtype=pos.dtype + ) + temporal_indices = temporal_indices.repeat_interleave(B).view( + B * S, 1, 1 + ) + temporal_indices = temporal_indices.expand( + -1, patch_start_idx_with_additional, -1 + ) + pos_special[:, :, 2:3] = temporal_indices + + pos = torch.cat([pos_special, pos], dim=1) + + _, P, C = tokens.shape + + reorder_indices = None + restore_indices = None + frame_reorder_map = None + if is_keyframe is not None and reorder_keyframes_first: + ( + tokens, + pos, + reorder_indices, + restore_indices, + ) = self._reorder_keyframes_first(tokens, pos, B, S, P, is_keyframe) + + if B > 0 and reorder_indices is not None: + frame_reorder_map = torch.zeros( + S, dtype=torch.long, device=tokens.device + ) + for new_s in range(S): + orig_frame_in_batch = reorder_indices[new_s].item() % S + frame_reorder_map[new_s] = orig_frame_in_batch + + attn_mask = None + if kv_cache_list is None: + attn_mask = self._create_attn_mask( + S, + P, + mode, + tokens.dtype, + tokens.device, + reorder_indices=frame_reorder_map, + is_keyframe=is_keyframe, + keyframe_indices=keyframe_indices, + ) + + frame_idx = 0 + global_idx = 0 + output_list = [] + + for block_idx in range(self.aa_block_num): + frame_intermediates = None + global_intermediates = None + + for attn_type in self.aa_order: + if attn_type == "frame": + ( + tokens, + frame_idx, + frame_intermediates, + ) = self._process_frame_attention( + tokens, B, S, P, C, frame_idx, pos=pos + ) + elif attn_type == "global": + if kv_cache_list is not None: + kv_cache = kv_cache_list[global_idx] + ( + tokens, + global_idx, + global_intermediates, + kv_cache, + ) = self._process_global_attention( + tokens, + B, + S, + P, + C, + global_idx, + pos=pos, + attn_mask=attn_mask, + kv_cache=kv_cache, + ) + kv_cache_list[global_idx - 1] = kv_cache + else: + ( + tokens, + global_idx, + global_intermediates, + ) = self._process_global_attention( + tokens, B, S, P, C, global_idx, pos=pos, attn_mask=attn_mask + ) + else: + raise ValueError(f"Unknown attention type: {attn_type}") + + if frame_intermediates is not None and global_intermediates is not None: + for i in range(len(frame_intermediates)): + + concat_inter = torch.cat( + [frame_intermediates[i], global_intermediates[i]], dim=-1 + ) + output_list.append(concat_inter) + + if kv_cache_list is not None: + return ( + output_list, + patch_start_idx_with_additional, + kv_cache_list, + restore_indices, + ) + else: + return output_list, patch_start_idx_with_additional, restore_indices + + def _process_frame_attention( + self, tokens, B, S, P, C, frame_idx, pos: Optional[torch.Tensor] = None + ): + """ + Process frame attention blocks. We keep tokens in shape (B*S, P, C). + """ + + if tokens.shape != (B * S, P, C): + tokens = tokens.view(B, S, P, C).view(B * S, P, C) + + if pos is not None: + + pos_dim = pos.shape[-1] + expected_shape = (B * S, P, pos_dim) + if pos.shape != expected_shape: + pos = pos.view(B, S, P, pos_dim).view(B * S, P, pos_dim) + + intermediates = [] + + for _ in range(self.aa_block_size): + tokens = checkpoint( + self.frame_blocks[frame_idx], tokens, pos, use_reentrant=False + ) + frame_idx += 1 + intermediates.append(tokens.view(B, S, P, C)) + + return tokens, frame_idx, intermediates + + def _process_global_attention( + self, + tokens, + B, + S, + P, + C, + global_idx, + pos: Optional[torch.Tensor] = None, + attn_mask: Optional[torch.Tensor] = None, + kv_cache: Optional[List[torch.Tensor]] = None, + ): + """ + Process global attention blocks. We keep tokens in shape (B, S*P, C). + """ + if tokens.shape != (B, S * P, C): + tokens = tokens.view(B, S, P, C).view(B, S * P, C) + + if pos is not None: + + pos_dim = pos.shape[-1] + expected_shape = (B, S * P, pos_dim) + if pos.shape != expected_shape: + pos = pos.view(B, S, P, pos_dim).view(B, S * P, pos_dim) + + intermediates = [] + + for _ in range(self.aa_block_size): + if kv_cache is not None: + result = checkpoint( + self.global_blocks[global_idx], + tokens, + pos, + attn_mask, + kv_cache, + use_reentrant=False, + ) + tokens, kv_cache = result + else: + tokens = checkpoint( + self.global_blocks[global_idx], + tokens, + pos, + attn_mask, + use_reentrant=False, + ) + global_idx += 1 + intermediates.append(tokens.view(B, S, P, C)) + + if kv_cache is not None: + return tokens, global_idx, intermediates, kv_cache + + return tokens, global_idx, intermediates + + def _select_role_tokens(self, B, S, is_keyframe): + """ + Select camera and register tokens based on keyframe mask. + + When disable_keyframe_distinction=True, all frames use the same tokens (camera_token_norm/register_token_norm). + When disable_keyframe_distinction=False, keyframes use _key tokens and normal frames use _norm tokens. + + Args: + B: Batch size + S: Sequence length + is_keyframe: Boolean tensor [B, S] indicating keyframes + + Returns: + camera_token: Selected camera tokens [B*S, 1, C] + register_token: Selected register tokens [B*S, num_register_tokens, C] + """ + + device = is_keyframe.device + is_keyframe = is_keyframe.bool() + + if self.disable_keyframe_distinction: + + camera_token = self.camera_token_norm.expand(B * S, -1, -1) + register_token = self.register_token_norm.expand(B * S, -1, -1) + return camera_token, register_token + + if self.camera_token_key is None or self.register_token_key is None: + raise RuntimeError( + "camera_token_key and register_token_key are not initialized. " + "This happens when disable_keyframe_distinction=True but is_keyframe distinction is requested. " + "Please set disable_keyframe_distinction=False in the configuration." + ) + + camera_tokens = [] + register_tokens = [] + + for b in range(B): + for s in range(S): + if is_keyframe[b, s]: + + camera_tokens.append(self.camera_token_key) + register_tokens.append(self.register_token_key) + else: + + camera_tokens.append(self.camera_token_norm) + register_tokens.append(self.register_token_norm) + + camera_token = torch.cat(camera_tokens, dim=0) + register_token = torch.cat(register_tokens, dim=0) + + return camera_token, register_token + + def _apply_role_embedding(self, tokens, B, S, P_patch, is_keyframe): + """ + Apply role embeddings to all tokens (including patches) for attention bias. + + 🔥 使用 FP32 进行 role embedding 计算,避免数值不稳定和 NaN + + Args: + tokens: Combined tokens [B*S, total_tokens, C] + B: Batch size + S: Sequence length + P_patch: Number of patch tokens per image + is_keyframe: Boolean tensor [B, S] indicating keyframes + + Returns: + tokens_with_role: Tokens with role embeddings added [B*S, total_tokens, C] + """ + + device = tokens.device + is_keyframe = is_keyframe.bool() + _, total_tokens, C = tokens.shape + + role_embeds = [] + + for b in range(B): + for s in range(S): + if is_keyframe[b, s]: + + role_embed = self.role_embed_key.expand(1, total_tokens, -1) + else: + + role_embed = self.role_embed_norm.expand(1, total_tokens, -1) + role_embeds.append(role_embed) + + role_embedding = torch.cat(role_embeds, dim=0) + + tokens_with_role = tokens + 0.1 * role_embedding + + return tokens_with_role + + def _reorder_keyframes_first(self, tokens, pos, B, S, P, is_keyframe): + """ + Reorder tokens so that keyframe tokens come first, followed by normal frame tokens. + + Args: + tokens: Combined tokens [B*S, P, C] + pos: Position embeddings [B*S, P, 2] or [B*S, P, 3] or None + (2D for spatial-only RoPE, 3D for spatial+temporal RoPE) + B: Batch size + S: Sequence length + P: Number of tokens per frame + is_keyframe: Boolean tensor [B, S] indicating keyframes + + Returns: + reordered_tokens: Tokens with keyframes first [B*S, P, C] + reordered_pos: Position embeddings reordered [B*S, P, 2/3] or None + reorder_indices: Indices used for reordering [B*S] + restore_indices: Indices to restore original order [B*S] + """ + device = tokens.device + is_keyframe = is_keyframe.bool() + + reorder_indices = [] + for b in range(B): + keyframe_indices = [] + normal_indices = [] + for s in range(S): + idx = b * S + s + if is_keyframe[b, s]: + keyframe_indices.append(idx) + else: + normal_indices.append(idx) + + reorder_indices.extend(keyframe_indices + normal_indices) + + reorder_indices = torch.tensor(reorder_indices, device=device, dtype=torch.long) + + restore_indices = torch.zeros_like(reorder_indices) + restore_indices[reorder_indices] = torch.arange( + B * S, device=device, dtype=torch.long + ) + + reordered_tokens = tokens[reorder_indices] + + reordered_pos = None + if pos is not None: + reordered_pos = pos[reorder_indices] + + return reordered_tokens, reordered_pos, reorder_indices, restore_indices + + +def slice_expand_and_flatten(token_tensor, B, S, is_anchor_exist=False): + """ + Processes specialized tokens with shape (1, 2, X, C) for multi-frame processing: + 1) Uses the first position (index=0) for the first frame only + 2) Uses the second position (index=1) for all remaining frames (S-1 frames) + 3) Expands both to match batch size B + 4) Concatenates to form (B, S, X, C) where each sequence has 1 first-position token + followed by (S-1) second-position tokens + 5) Flattens to (B*S, X, C) for processing + + Returns: + torch.Tensor: Processed tokens with shape (B*S, X, C) + """ + + if is_anchor_exist: + query = token_tensor[:, 0:1, ...].expand(B, 1, *token_tensor.shape[2:]) + else: + query = token_tensor[:, 1:, ...].expand(B, 1, *token_tensor.shape[2:]) + + others = token_tensor[:, 1:, ...].expand(B, S - 1, *token_tensor.shape[2:]) + + combined = torch.cat([query, others], dim=1) + + combined = combined.view(B * S, *combined.shape[2:]) + return combined diff --git a/longstream/utils/vendor/models/components/heads/ba_refiner.py b/longstream/utils/vendor/models/components/heads/ba_refiner.py new file mode 100644 index 0000000000000000000000000000000000000000..5d1fd48b667cf2514590c49352d234080500da5b --- /dev/null +++ b/longstream/utils/vendor/models/components/heads/ba_refiner.py @@ -0,0 +1,154 @@ +import torch +import torch.nn as nn +from typing import Dict, Optional, Tuple + + +class FeedForward(nn.Module): + def __init__(self, dim: int, mlp_ratio: float = 4.0, dropout: float = 0.0) -> None: + super().__init__() + hidden = int(dim * mlp_ratio) + self.net = nn.Sequential( + nn.Linear(dim, hidden), + nn.GELU(), + nn.Dropout(dropout), + nn.Linear(hidden, dim), + nn.Dropout(dropout), + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.net(x) + + +class TransformerBlock(nn.Module): + def __init__( + self, dim: int, nhead: int = 8, dropout: float = 0.0, mlp_ratio: float = 4.0 + ) -> None: + super().__init__() + self.norm1 = nn.LayerNorm(dim) + self.attn = nn.MultiheadAttention(dim, nhead, batch_first=True, dropout=dropout) + self.norm2 = nn.LayerNorm(dim) + self.ffn = FeedForward(dim, mlp_ratio=mlp_ratio, dropout=dropout) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + h = self.norm1(x) + y, _ = self.attn(h, h, h, need_weights=False) + x = x + y + x = x + self.ffn(self.norm2(x)) + return x + + +class MeanBARefiner(nn.Module): + """Windowed BA refiner that predicts pose (SE3) and log-depth residuals.""" + + def __init__( + self, + dim_in: Optional[int] = None, + dim_hidden: int = 512, + depth: int = 3, + nhead: int = 8, + depth_mode: str = "grid", + hw: Optional[Tuple[int, int]] = None, + rank: Optional[int] = None, + mlp_ratio: float = 4.0, + dropout: float = 0.0, + ) -> None: + super().__init__() + if depth_mode not in {"grid", "lowrank"}: + raise ValueError(f"Unsupported depth_mode: {depth_mode}") + if depth_mode == "lowrank" and rank is None: + raise ValueError("rank must be provided when depth_mode='lowrank'") + + self.depth_mode = depth_mode + self.default_hw = hw + self.rank = rank + self.dim_hidden = dim_hidden + + if dim_in is None: + self.input_proj = nn.LazyLinear(dim_hidden) + else: + self.input_proj = nn.Linear(dim_in, dim_hidden) + + self.blocks = nn.ModuleList( + [ + TransformerBlock( + dim_hidden, nhead=nhead, dropout=dropout, mlp_ratio=mlp_ratio + ) + for _ in range(depth) + ] + ) + self.output_norm = nn.LayerNorm(dim_hidden) + + self.pose_head = nn.Sequential( + nn.Linear(dim_hidden, dim_hidden), + nn.GELU(), + nn.Linear(dim_hidden, 6), + ) + self.pose_gate = nn.Parameter(torch.zeros(1)) + + self.depth_hidden = nn.Sequential( + nn.Linear(dim_hidden, dim_hidden), + nn.GELU(), + ) + if depth_mode == "lowrank": + self.depth_proj: Optional[nn.Linear] = nn.Linear(dim_hidden, rank) + else: + self.depth_proj = ( + nn.Linear(dim_hidden, hw[0] * hw[1]) if hw is not None else None + ) + self.depth_gate = nn.Parameter(torch.zeros(1)) + + def forward( + self, + cam_tokens: torch.Tensor, + frame_summaries: torch.Tensor, + pose0_rel_6d: torch.Tensor, + depth0_log_low: torch.Tensor, + extras: Optional[Dict[str, torch.Tensor]] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + B, S, _ = cam_tokens.shape + flat_depth = depth0_log_low.reshape(B, S, -1) + + features = [cam_tokens, frame_summaries, pose0_rel_6d, flat_depth] + if extras: + for value in extras.values(): + if value is None: + continue + if value.dim() == 2: + value = value.unsqueeze(1) + if value.shape[0] != B: + if value.shape[0] == 1: + value = value.expand(B, *value.shape[1:]) + else: + raise ValueError("Extras must broadcast along batch dimension") + if value.dim() == 4: + value = value.reshape(B, S, -1) + features.append(value) + + fused = torch.cat(features, dim=-1) + hidden = self.input_proj(fused) + for blk in self.blocks: + hidden = blk(hidden) + hidden = self.output_norm(hidden) + + dpose = self.pose_head(hidden) + dpose = torch.tanh(dpose) * torch.sigmoid(self.pose_gate) + + depth_feat = self.depth_hidden(hidden) + if self.depth_mode == "grid": + depth_dim = flat_depth.shape[-1] + if self.depth_proj is None or self.depth_proj.out_features != depth_dim: + self.depth_proj = nn.Linear(self.dim_hidden, depth_dim) + self.depth_proj = self.depth_proj.to( + device=depth_feat.device, dtype=depth_feat.dtype + ) + ddepth = self.depth_proj(depth_feat) + h, w = depth0_log_low.shape[-2:] + ddepth = ddepth.view(B, S, h, w) + else: + self.depth_proj = self.depth_proj.to( + device=depth_feat.device, dtype=depth_feat.dtype + ) + ddepth = self.depth_proj(depth_feat) + + ddepth = torch.tanh(ddepth) * torch.sigmoid(self.depth_gate) + return dpose, ddepth diff --git a/longstream/utils/vendor/models/components/heads/camera_head.py b/longstream/utils/vendor/models/components/heads/camera_head.py new file mode 100644 index 0000000000000000000000000000000000000000..a1cba0e903870491321f1d7f1147bba1f40d8f02 --- /dev/null +++ b/longstream/utils/vendor/models/components/heads/camera_head.py @@ -0,0 +1,681 @@ +from typing import List, Tuple, Optional, Union + +import torch +import torch.nn as nn + +from ..layers import Mlp +from ..layers.block import Block +from .head_act import activate_pose + + +class CameraHead(nn.Module): + """ + CameraHead predicts camera parameters from token representations using iterative refinement. + + It applies a series of transformer blocks (the "trunk") to dedicated camera tokens. + """ + + def __init__( + self, + dim_in: int = 2048, + trunk_depth: int = 4, + pose_encoding_type: str = "absT_quaR_FoV", + num_heads: int = 16, + mlp_ratio: int = 4, + init_values: float = 0.01, + trans_act: str = "linear", + quat_act: str = "linear", + fl_act: str = "relu", + window_size: int = 5, + ): + super().__init__() + + if pose_encoding_type == "absT_quaR_FoV": + self.target_dim = 9 + else: + raise ValueError(f"Unsupported camera encoding type: {pose_encoding_type}") + + self.trans_act = trans_act + self.quat_act = quat_act + self.fl_act = fl_act + self.trunk_depth = trunk_depth + self.window_size = window_size + + self.trunk = nn.Sequential( + *[ + Block( + dim=dim_in, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + init_values=init_values, + ) + for _ in range(trunk_depth) + ] + ) + + self.token_norm = nn.LayerNorm(dim_in) + self.trunk_norm = nn.LayerNorm(dim_in) + + self.empty_pose_tokens = nn.Parameter(torch.zeros(1, 1, self.target_dim)) + self.embed_pose = nn.Linear(self.target_dim, dim_in) + + self.poseLN_modulation = nn.Sequential( + nn.SiLU(), nn.Linear(dim_in, 3 * dim_in, bias=True) + ) + + self.adaln_norm = nn.LayerNorm(dim_in, elementwise_affine=False, eps=1e-6) + self.pose_branch = Mlp( + in_features=dim_in, + hidden_features=dim_in // 2, + out_features=self.target_dim, + drop=0, + ) + + def _create_attn_mask( + self, S: int, mode: str, dtype: torch.dtype, device: torch.device + ) -> Optional[torch.Tensor]: + N = S + mask = torch.zeros((N, N), dtype=dtype, device=device) + + if mode == "causal": + for i in range(S): + curr_view_start = i + curr_view_end = i + 1 + mask[curr_view_start:curr_view_end, curr_view_end:] = float("-inf") + elif mode == "window": + for i in range(S): + curr_view_start = i + curr_view_end = i + 1 + mask[curr_view_start:curr_view_end, 1:] = float("-inf") + start_view = max(1, i - self.window_size + 1) + mask[curr_view_start:curr_view_end, start_view : (i + 1)] = 0 + elif mode == "full": + mask = None + else: + raise NotImplementedError(f"Unknown attention mode: {mode}") + + return mask + + def forward( + self, + aggregated_tokens_list: list, + num_iterations: int = 4, + mode: str = "causal", + kv_cache_list: Optional[List[List[List[torch.Tensor]]]] = None, + ) -> Union[list, Tuple[list, List[List[List[torch.Tensor]]]]]: + """ + Forward pass to predict camera parameters. + + Args: + aggregated_tokens_list (list): List of token tensors from the network; + the last tensor is used for prediction. + num_iterations (int, optional): Number of iterative refinement steps. Defaults to 4. + mode (str): Global attention mode, could be either "causal", "window" or "full" + kv_cache_list (List[List[List[torch.Tensor]]]): List of cached key-value pairs for + each iterations and each attention layer of the camera head + + Returns: + list: A list of predicted camera encodings (post-activation) from each iteration. + """ + + tokens = aggregated_tokens_list[-1] + + pose_tokens = tokens[:, :, 0] + pose_tokens = self.token_norm(pose_tokens) + + B, S, C = pose_tokens.shape + attn_mask = None + if kv_cache_list is None: + attn_mask = self._create_attn_mask( + S, mode, pose_tokens.dtype, pose_tokens.device + ) + + pred_pose_enc_list = self.trunk_fn( + pose_tokens, num_iterations, attn_mask, kv_cache_list + ) + return pred_pose_enc_list + + def trunk_fn( + self, + pose_tokens: torch.Tensor, + num_iterations: int, + attn_mask: Optional[torch.Tensor], + kv_cache_list: Optional[List[List[List[torch.Tensor]]]] = None, + ) -> Union[list, Tuple[list, List[List[List[torch.Tensor]]]]]: + """ + Iteratively refine camera pose predictions. + + Args: + pose_tokens (torch.Tensor): Normalized camera tokens with shape [B, S, C]. + num_iterations (int): Number of refinement iterations. + + Returns: + list: List of activated camera encodings from each iteration. + """ + B, S, C = pose_tokens.shape + pred_pose_enc = None + pred_pose_enc_list = [] + + for iter in range(num_iterations): + + if pred_pose_enc is None: + module_input = self.embed_pose(self.empty_pose_tokens.expand(B, S, -1)) + else: + + module_input = self.embed_pose(pred_pose_enc) + + shift_msa, scale_msa, gate_msa = self.poseLN_modulation(module_input).chunk( + 3, dim=-1 + ) + + adaln_output = self.adaln_norm(pose_tokens) + modulated_output = modulate(adaln_output, shift_msa, scale_msa) + gated_output = gate_msa * modulated_output + pose_tokens_modulated = gated_output + pose_tokens + + for i in range(self.trunk_depth): + if kv_cache_list is not None: + pose_tokens_modulated, kv_cache_list[iter][i] = self.trunk[i]( + pose_tokens_modulated, + attn_mask=attn_mask, + kv_cache=kv_cache_list[iter][i], + ) + else: + pose_tokens_modulated = self.trunk[i]( + pose_tokens_modulated, attn_mask=attn_mask + ) + + trunk_norm_output = self.trunk_norm(pose_tokens_modulated) + pred_pose_enc_delta = self.pose_branch(trunk_norm_output) + + if pred_pose_enc is None: + pred_pose_enc = pred_pose_enc_delta + else: + pred_pose_enc = pred_pose_enc + pred_pose_enc_delta + + activated_pose = activate_pose( + pred_pose_enc, + trans_act=self.trans_act, + quat_act=self.quat_act, + fl_act=self.fl_act, + ) + pred_pose_enc_list.append(activated_pose) + + if kv_cache_list is not None: + return pred_pose_enc_list, kv_cache_list + else: + return pred_pose_enc_list + + +def modulate(x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor) -> torch.Tensor: + """ + Modulate the input tensor using scaling and shifting parameters. + """ + + return x * (1 + scale) + shift + + +class RelPoseHead(nn.Module): + """ + Enhanced Relative Pose Head for dynamic keyframe-based pose prediction. + + Key features: + 1. True relative pose prediction (not incremental from fixed anchor) + 2. Dynamic keyframe switching support + 3. SE(3) and Sim(3) pose modes + 4. Role-aware processing for keyframes vs non-keyframes + """ + + def __init__( + self, + dim_in: int = 2048, + trunk_depth: int = 4, + pose_mode: str = "SE3", + num_heads: int = 16, + mlp_ratio: int = 4, + init_values: float = 0.01, + trans_act: str = "linear", + quat_act: str = "linear", + fl_act: str = "relu", + use_global_scale: bool = False, + use_pair_cross_attn: bool = False, + detach_reference: bool = False, + xattn_temperature: float = 1.0, + use_precat: bool = False, + use_kf_role_embed: bool = True, + kf_role_embed_init_std: float = 0.02, + window_size: int = 50000, + ): + super().__init__() + + self.pose_mode = pose_mode + self.use_global_scale = use_global_scale and (pose_mode == "Sim3") + self.use_pair_cross_attn = use_pair_cross_attn + self.detach_reference = detach_reference + self.xattn_temperature = xattn_temperature + self.use_precat = use_precat + self.use_kf_role_embed = use_kf_role_embed + self.kf_role_embed_init_std = kf_role_embed_init_std + + self.target_dim = 9 + + self.trans_act = trans_act + self.quat_act = quat_act + self.fl_act = fl_act + self.trunk_depth = trunk_depth + self.window_size = 50000 + + self.trunk = nn.Sequential( + *[ + Block( + dim=dim_in, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + init_values=init_values, + ) + for _ in range(trunk_depth) + ] + ) + + self.token_norm = nn.LayerNorm(dim_in) + self.trunk_norm = nn.LayerNorm(dim_in) + + self.empty_pose_tokens = nn.Parameter(torch.zeros(1, 1, self.target_dim)) + self.embed_pose = nn.Linear(self.target_dim, dim_in) + + self.poseLN_modulation = nn.Sequential( + nn.SiLU(), nn.Linear(dim_in, 3 * dim_in, bias=True) + ) + + self.adaln_norm = nn.LayerNorm(dim_in, elementwise_affine=False, eps=1e-6) + + self.pose_branch = Mlp( + in_features=dim_in, + hidden_features=dim_in // 2, + out_features=self.target_dim, + drop=0, + ) + + if self.use_global_scale: + self.global_scale = nn.Parameter(torch.ones(1)) + + if self.use_pair_cross_attn: + self.xattn_q = nn.Linear(dim_in, dim_in, bias=False) + self.xattn_k = nn.Linear(dim_in, dim_in, bias=False) + self.xattn_v = nn.Linear(dim_in, dim_in, bias=False) + self.xattn_out = nn.Linear(dim_in, dim_in, bias=False) + + if self.use_precat: + self.precat_proj = nn.Linear(dim_in * 2, dim_in, bias=True) + + if self.use_kf_role_embed: + self.kf_role_embed = nn.Parameter(torch.randn(1, 1, dim_in)) + nn.init.normal_(self.kf_role_embed, std=self.kf_role_embed_init_std) + else: + self.kf_role_embed = None + + def _create_attn_mask( + self, S: int, mode: str, dtype: torch.dtype, device: torch.device + ) -> Optional[torch.Tensor]: + """Create attention mask for the given mode.""" + N = S + + if mode == "causal": + mask = torch.zeros((N, N), dtype=dtype, device=device) + for i in range(S): + mask[i, i + 1 :] = float("-inf") + return mask + elif mode == "window": + mask = torch.zeros((N, N), dtype=dtype, device=device) + for i in range(S): + mask[i, :] = float("-inf") + start = max(0, i - self.window_size + 1) + mask[i, start : i + 1] = 0 + return mask + elif mode == "full": + return None + else: + raise NotImplementedError(f"Unknown attention mode: {mode}") + + def forward( + self, + aggregated_tokens_list: list, + keyframe_indices: torch.Tensor, + is_keyframe: torch.Tensor, + num_iterations: int = 4, + mode: str = "causal", + kv_cache_list: Optional[List[List[List[torch.Tensor]]]] = None, + compute_switch_poses: bool = False, + ): + """ + Forward pass for relative pose prediction. + + Args: + aggregated_tokens_list: List of aggregated tokens from the network + keyframe_indices: Indices of reference keyframes for each frame [B, S] + is_keyframe: Boolean mask indicating keyframes [B, S] + num_iterations: Number of iterative refinement steps + mode: Attention mode ("causal", "window", or "full") + kv_cache_list: Optional KV cache for streaming + + Returns: + dict containing: + - pose_enc: Predicted relative poses [B, S, 9] + - is_keyframe: Keyframe mask [B, S] + - keyframe_indices: Reference keyframe indices [B, S] + - global_scale: Global scale for Sim(3) mode (if applicable) + """ + mode = "causal" + + tokens = aggregated_tokens_list[-1] + pose_tokens = tokens[:, :, 0] + pose_tokens = self.token_norm(pose_tokens) + + B, S, C = pose_tokens.shape + + if kv_cache_list is not None and S == 1: + + if not hasattr(self, "_keyframe_tokens_cache"): + + self._keyframe_tokens_cache = {} + self._current_frame_id = 0 + + self._frame_info = [] + + curr_is_kf = is_keyframe[0, 0].item() if is_keyframe is not None else True + curr_ref_idx = ( + keyframe_indices[0, 0].item() + if keyframe_indices is not None + else self._current_frame_id + ) + self._frame_info.append((curr_is_kf, curr_ref_idx)) + + if curr_is_kf: + + self._keyframe_tokens_cache[ + self._current_frame_id + ] = pose_tokens.squeeze(1) + + self._current_frame_id += 1 + + ref_tokens = None + if keyframe_indices is not None: + if kv_cache_list is not None and S == 1: + + ref_frame_id = keyframe_indices[0, 0].item() + + if ref_frame_id in self._keyframe_tokens_cache: + + ref_tokens = self._keyframe_tokens_cache[ref_frame_id].unsqueeze(1) + else: + + ref_tokens = pose_tokens + + if self.detach_reference: + ref_tokens = ref_tokens.detach() + else: + + total_frames = pose_tokens.shape[1] + ref_idx = ( + keyframe_indices.clamp(0, total_frames - 1) + .unsqueeze(-1) + .expand(-1, -1, C) + ) + ref_tokens = torch.gather(pose_tokens, dim=1, index=ref_idx) + if self.detach_reference: + ref_tokens = ref_tokens.detach() + + if ( + self.use_kf_role_embed + and ref_tokens is not None + and self.kf_role_embed is not None + ): + + current_indices = ( + torch.arange(S, device=keyframe_indices.device) + .unsqueeze(0) + .expand(B, -1) + ) + + is_self_ref = current_indices == keyframe_indices + + add_kf_embed_mask = (~is_self_ref).unsqueeze(-1).float() + + ref_tokens = ref_tokens + add_kf_embed_mask * self.kf_role_embed.expand( + B, S, -1 + ) + + if self.use_pair_cross_attn and (ref_tokens is not None): + q = self.xattn_q(pose_tokens) + k = self.xattn_k(ref_tokens) + v = self.xattn_v(ref_tokens) + + scale = (q * k).sum(dim=-1, keepdim=True) / (C ** 0.5) + gate = torch.sigmoid(scale / self.xattn_temperature) + + pair_info = self.xattn_out(gate * v) + pose_tokens = pose_tokens + pair_info + + if self.use_precat and (ref_tokens is not None): + pose_tokens = self.precat_proj(torch.cat([pose_tokens, ref_tokens], dim=-1)) + + attn_mask = None + if keyframe_indices is not None: + + ref = keyframe_indices + B = ref.shape[0] + + j_indices = ( + torch.arange(S, device=pose_tokens.device) + .view(1, 1, S) + .expand(B, S, -1) + ) + is_ref_frame = j_indices == ref[:, :, None] + same_ref = ref[:, :, None] == ref[:, None, :] + can_attend_nonkf = is_ref_frame | same_ref + + idx = torch.arange(S, device=pose_tokens.device) + + if mode == "causal": + + causal_mask = idx[None, :, None] >= idx[None, None, :] + + prev_kf_mask = torch.zeros( + B, S, S, dtype=torch.bool, device=pose_tokens.device + ) + for b in range(B): + kf_positions = [i for i in range(S) if is_keyframe[b, i]] + for kf_idx, kf_pos in enumerate(kf_positions): + if kf_idx == 0: + + prev_kf_mask[b, kf_pos, : kf_pos + 1] = True + else: + + prev_kf_pos = kf_positions[kf_idx - 1] + prev_kf_mask[b, kf_pos, prev_kf_pos : kf_pos + 1] = True + + can_attend_kf = causal_mask.expand(B, -1, -1) & prev_kf_mask + mode_constraint = causal_mask + elif mode == "window": + + causal_mask = idx[None, :, None] >= idx[None, None, :] + window_mask = ( + idx[None, :, None] - idx[None, None, :] + ) < self.window_size + window_causal = causal_mask & window_mask + can_attend_kf = window_causal.expand(B, -1, -1) + mode_constraint = window_causal + elif mode == "full": + + can_attend_kf = torch.ones( + B, S, S, dtype=torch.bool, device=pose_tokens.device + ) + mode_constraint = torch.ones( + 1, S, S, dtype=torch.bool, device=pose_tokens.device + ) + else: + raise NotImplementedError(f"Unknown mode: {mode}") + + is_kf_expanded = is_keyframe[:, :, None].expand(-1, -1, S) + can_attend = torch.where(is_kf_expanded, can_attend_kf, can_attend_nonkf) + + mask_bool = can_attend & mode_constraint + + zero = torch.zeros(1, dtype=pose_tokens.dtype, device=pose_tokens.device) + neg_inf = torch.full( + (1,), float("-inf"), dtype=pose_tokens.dtype, device=pose_tokens.device + ) + attn_mask = torch.where(mask_bool, zero, neg_inf)[:, None, :, :] + + if kv_cache_list is not None and len(kv_cache_list) > 0 and S == 1: + + k_cache = kv_cache_list[0][0][0] + if k_cache is not None: + cache_len = k_cache.shape[2] + + curr_idx = len(self._frame_info) - 1 + curr_is_kf, curr_ref_idx = self._frame_info[curr_idx] + + cache_mask_vals = [] + visible_frames = [] + for i in range(max(0, curr_idx - cache_len), curr_idx): + cache_is_kf, cache_ref_idx = self._frame_info[i] + + if curr_is_kf: + + prev_kf_idx = None + for j in range(curr_idx - 1, -1, -1): + if self._frame_info[j][0]: + prev_kf_idx = j + break + + if prev_kf_idx is not None: + + can_see = i >= prev_kf_idx + else: + + can_see = True + else: + + is_ref = i == curr_ref_idx + same_ref = cache_ref_idx == curr_ref_idx + can_see = is_ref or same_ref + + mask_val = zero if can_see else neg_inf + cache_mask_vals.append(mask_val) + if can_see: + visible_frames.append(i) + + if len(cache_mask_vals) > 0: + cache_mask = torch.stack(cache_mask_vals, dim=0).view( + 1, 1, 1, len(cache_mask_vals) + ) + cache_mask = cache_mask.expand(B, 1, S, len(cache_mask_vals)) + attn_mask = torch.cat([cache_mask, attn_mask], dim=-1) + + else: + + attn_mask = self._create_attn_mask( + S, mode, pose_tokens.dtype, pose_tokens.device + ) + + pred_pose_enc_list = [] + pred_pose_enc = None + + for iter_idx in range(num_iterations): + + if pred_pose_enc is None: + module_input = self.embed_pose(self.empty_pose_tokens.expand(B, S, -1)) + else: + + module_input = self.embed_pose(pred_pose_enc) + + shift_msa, scale_msa, gate_msa = self.poseLN_modulation(module_input).chunk( + 3, dim=-1 + ) + + pose_tokens_modulated = gate_msa * modulate( + self.adaln_norm(pose_tokens), shift_msa, scale_msa + ) + pose_tokens_modulated = pose_tokens_modulated + pose_tokens + + for i in range(self.trunk_depth): + if ( + kv_cache_list is not None + and iter_idx < len(kv_cache_list) + and i < len(kv_cache_list[iter_idx]) + ): + pose_tokens_modulated, kv_cache_list[iter_idx][i] = self.trunk[i]( + pose_tokens_modulated, + attn_mask=attn_mask, + kv_cache=kv_cache_list[iter_idx][i], + ) + else: + pose_tokens_modulated = self.trunk[i]( + pose_tokens_modulated, attn_mask=attn_mask + ) + + trunk_norm_output = self.trunk_norm(pose_tokens_modulated) + pred_pose_enc_delta = self.pose_branch(trunk_norm_output) + + if pred_pose_enc is None: + pred_pose_enc = pred_pose_enc_delta + else: + pred_pose_enc = pred_pose_enc + pred_pose_enc_delta + + activated_pose = activate_pose( + pred_pose_enc, + trans_act=self.trans_act, + quat_act=self.quat_act, + fl_act=self.fl_act, + ) + + pred_pose_enc_list.append(activated_pose) + + final_pose_enc = pred_pose_enc_list[-1] + if final_pose_enc.dtype != torch.float32: + final_pose_enc = final_pose_enc.float() + + result = { + "pose_enc": final_pose_enc, + "is_keyframe": is_keyframe, + "keyframe_indices": keyframe_indices, + } + + if self.training and len(pred_pose_enc_list) > 0: + result["pose_enc_list"] = pred_pose_enc_list + + if compute_switch_poses: + switch_poses = self._compute_switch_poses( + pred_pose_enc_list[-1], keyframe_indices, is_keyframe + ) + result["switch_poses"] = switch_poses + + if self.use_global_scale: + result["global_scale"] = self.global_scale.expand(B, 1) + + if kv_cache_list is not None: + result["kv_cache_list"] = kv_cache_list + + return result + + def _compute_switch_poses(self, poses, keyframe_indices, is_keyframe): + """ + Compute T_{k'←k} for keyframe switches. + + Returns a dictionary mapping (k, k') pairs to the relative transformation. + """ + B, S, _ = poses.shape + switch_poses = {} + + for b in range(B): + prev_kf_idx = None + for s in range(S): + if is_keyframe[b, s]: + if prev_kf_idx is not None: + + key = (b, prev_kf_idx, s) + switch_poses[key] = poses[b, s].clone() + prev_kf_idx = s + + return switch_poses diff --git a/longstream/utils/vendor/models/components/heads/dpt_head.py b/longstream/utils/vendor/models/components/heads/dpt_head.py new file mode 100644 index 0000000000000000000000000000000000000000..1decbb9ac1deb7dc00e732d5769409d990f2f518 --- /dev/null +++ b/longstream/utils/vendor/models/components/heads/dpt_head.py @@ -0,0 +1,517 @@ +import os +from typing import List, Dict, Tuple, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F +from .head_act import activate_head +from .utils import create_uv_grid, position_grid_to_embed + + +class DPTHead(nn.Module): + def __init__( + self, + dim_in: int, + patch_size: int = 14, + output_dim: int = 4, + activation: str = "inv_log", + conf_activation: str = "expp1", + features: int = 256, + out_channels: List[int] = [256, 512, 1024, 1024], + intermediate_layer_idx: List[int] = [4, 11, 17, 23], + pos_embed: bool = True, + feature_only: bool = False, + down_ratio: int = 1, + ) -> None: + super(DPTHead, self).__init__() + self.patch_size = patch_size + self.activation = activation + self.conf_activation = conf_activation + self.pos_embed = pos_embed + self.feature_only = feature_only + self.down_ratio = down_ratio + self.intermediate_layer_idx = intermediate_layer_idx + + self.norm = nn.LayerNorm(dim_in) + + self.projects = nn.ModuleList( + [ + nn.Conv2d( + in_channels=dim_in, + out_channels=oc, + kernel_size=1, + stride=1, + padding=0, + ) + for oc in out_channels + ] + ) + + self.resize_layers = nn.ModuleList( + [ + nn.ConvTranspose2d( + in_channels=out_channels[0], + out_channels=out_channels[0], + kernel_size=4, + stride=4, + padding=0, + ), + nn.ConvTranspose2d( + in_channels=out_channels[1], + out_channels=out_channels[1], + kernel_size=2, + stride=2, + padding=0, + ), + nn.Identity(), + nn.Conv2d( + in_channels=out_channels[3], + out_channels=out_channels[3], + kernel_size=3, + stride=2, + padding=1, + ), + ] + ) + + self.scratch = _make_scratch( + out_channels, + features, + expand=False, + ) + + self.scratch.stem_transpose = None + self.scratch.refinenet1 = _make_fusion_block(features) + self.scratch.refinenet2 = _make_fusion_block(features) + self.scratch.refinenet3 = _make_fusion_block(features) + self.scratch.refinenet4 = _make_fusion_block(features, has_residual=False) + + head_features_1 = features + head_features_2 = 32 + + if feature_only: + self.scratch.output_conv1 = nn.Conv2d( + head_features_1, head_features_1, kernel_size=3, stride=1, padding=1 + ) + else: + self.scratch.output_conv1 = nn.Conv2d( + head_features_1, + head_features_1 // 2, + kernel_size=3, + stride=1, + padding=1, + ) + conv2_in_channels = head_features_1 // 2 + + self.scratch.output_conv2 = nn.Sequential( + nn.Conv2d( + conv2_in_channels, + head_features_2, + kernel_size=3, + stride=1, + padding=1, + ), + nn.ReLU(inplace=True), + nn.Conv2d( + head_features_2, output_dim, kernel_size=1, stride=1, padding=0 + ), + ) + + def forward( + self, + aggregated_tokens_list: List[torch.Tensor], + images: torch.Tensor, + patch_start_idx: int, + frames_chunk_size: int = 8, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + B, S, _, H, W = images.shape + + if frames_chunk_size is None or frames_chunk_size >= S: + return self._forward_impl(aggregated_tokens_list, images, patch_start_idx) + + assert frames_chunk_size > 0 + + all_preds = [] + all_conf = [] + + for frames_start_idx in range(0, S, frames_chunk_size): + frames_end_idx = min(frames_start_idx + frames_chunk_size, S) + + if self.feature_only: + chunk_output = self._forward_impl( + aggregated_tokens_list, + images, + patch_start_idx, + frames_start_idx, + frames_end_idx, + ) + all_preds.append(chunk_output) + else: + chunk_preds, chunk_conf = self._forward_impl( + aggregated_tokens_list, + images, + patch_start_idx, + frames_start_idx, + frames_end_idx, + ) + all_preds.append(chunk_preds) + all_conf.append(chunk_conf) + + if self.feature_only: + return torch.cat(all_preds, dim=1) + else: + return torch.cat(all_preds, dim=1), torch.cat(all_conf, dim=1) + + def _forward_impl( + self, + aggregated_tokens_list: List[torch.Tensor], + images: torch.Tensor, + patch_start_idx: int, + frames_start_idx: int = None, + frames_end_idx: int = None, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + if frames_start_idx is not None and frames_end_idx is not None: + images = images[:, frames_start_idx:frames_end_idx].contiguous() + + B, S, _, H, W = images.shape + + patch_h, patch_w = H // self.patch_size, W // self.patch_size + + out = [] + dpt_idx = 0 + + for layer_idx in self.intermediate_layer_idx: + x = aggregated_tokens_list[layer_idx][:, :, patch_start_idx:] + + if frames_start_idx is not None and frames_end_idx is not None: + x = x[:, frames_start_idx:frames_end_idx] + + x = x.reshape(B * S, -1, x.shape[-1]) + + x = self.norm(x) + + x = x.permute(0, 2, 1).reshape((x.shape[0], x.shape[-1], patch_h, patch_w)) + + x = self.projects[dpt_idx](x) + if self.pos_embed: + x = self._apply_pos_embed(x, W, H) + x = self.resize_layers[dpt_idx](x) + + out.append(x) + dpt_idx += 1 + + out = self.scratch_forward(out) + + out = custom_interpolate( + out, + ( + int(patch_h * self.patch_size / self.down_ratio), + int(patch_w * self.patch_size / self.down_ratio), + ), + mode="bilinear", + align_corners=True, + ) + + if self.pos_embed: + out = self._apply_pos_embed(out, W, H) + + if self.feature_only: + return out.reshape(B, S, *out.shape[1:]) + + out = self.scratch.output_conv2(out) + preds, conf = activate_head( + out, activation=self.activation, conf_activation=self.conf_activation + ) + + preds = preds.reshape(B, S, *preds.shape[1:]) + conf = conf.reshape(B, S, *conf.shape[1:]) + + return preds, conf + + def _apply_pos_embed( + self, x: torch.Tensor, W: int, H: int, ratio: float = 0.1 + ) -> torch.Tensor: + """ + Apply positional embedding to tensor x. + """ + patch_w = x.shape[-1] + patch_h = x.shape[-2] + pos_embed = create_uv_grid( + patch_w, patch_h, aspect_ratio=W / H, dtype=x.dtype, device=x.device + ) + pos_embed = position_grid_to_embed(pos_embed, x.shape[1]) + pos_embed = pos_embed * ratio + pos_embed = pos_embed.permute(2, 0, 1)[None].expand(x.shape[0], -1, -1, -1) + return x + pos_embed + + def scratch_forward(self, features: List[torch.Tensor]) -> torch.Tensor: + layer_1, layer_2, layer_3, layer_4 = features + + layer_1_rn = self.scratch.layer1_rn(layer_1) + layer_2_rn = self.scratch.layer2_rn(layer_2) + layer_3_rn = self.scratch.layer3_rn(layer_3) + layer_4_rn = self.scratch.layer4_rn(layer_4) + + out = self.scratch.refinenet4(layer_4_rn, size=layer_3_rn.shape[2:]) + del layer_4_rn, layer_4 + + out = self.scratch.refinenet3(out, layer_3_rn, size=layer_2_rn.shape[2:]) + del layer_3_rn, layer_3 + + out = self.scratch.refinenet2(out, layer_2_rn, size=layer_1_rn.shape[2:]) + del layer_2_rn, layer_2 + + out = self.scratch.refinenet1(out, layer_1_rn) + del layer_1_rn, layer_1 + + out = self.scratch.output_conv1(out) + return out + + +def _make_fusion_block( + features: int, size: int = None, has_residual: bool = True, groups: int = 1 +) -> nn.Module: + return FeatureFusionBlock( + features, + nn.ReLU(inplace=True), + deconv=False, + bn=False, + expand=False, + align_corners=True, + size=size, + has_residual=has_residual, + groups=groups, + ) + + +def _make_scratch( + in_shape: List[int], out_shape: int, groups: int = 1, expand: bool = False +) -> nn.Module: + scratch = nn.Module() + out_shape1 = out_shape + out_shape2 = out_shape + out_shape3 = out_shape + if len(in_shape) >= 4: + out_shape4 = out_shape + + if expand: + out_shape1 = out_shape + out_shape2 = out_shape * 2 + out_shape3 = out_shape * 4 + if len(in_shape) >= 4: + out_shape4 = out_shape * 8 + + scratch.layer1_rn = nn.Conv2d( + in_shape[0], + out_shape1, + kernel_size=3, + stride=1, + padding=1, + bias=False, + groups=groups, + ) + scratch.layer2_rn = nn.Conv2d( + in_shape[1], + out_shape2, + kernel_size=3, + stride=1, + padding=1, + bias=False, + groups=groups, + ) + scratch.layer3_rn = nn.Conv2d( + in_shape[2], + out_shape3, + kernel_size=3, + stride=1, + padding=1, + bias=False, + groups=groups, + ) + if len(in_shape) >= 4: + scratch.layer4_rn = nn.Conv2d( + in_shape[3], + out_shape4, + kernel_size=3, + stride=1, + padding=1, + bias=False, + groups=groups, + ) + return scratch + + +class ResidualConvUnit(nn.Module): + """Residual convolution module.""" + + def __init__(self, features, activation, bn, groups=1): + """Init. + + Args: + features (int): number of features + """ + super().__init__() + + self.bn = bn + self.groups = groups + self.conv1 = nn.Conv2d( + features, + features, + kernel_size=3, + stride=1, + padding=1, + bias=True, + groups=self.groups, + ) + self.conv2 = nn.Conv2d( + features, + features, + kernel_size=3, + stride=1, + padding=1, + bias=True, + groups=self.groups, + ) + + self.norm1 = None + self.norm2 = None + + self.activation = activation + self.skip_add = nn.quantized.FloatFunctional() + + def forward(self, x): + """Forward pass. + + Args: + x (tensor): input + + Returns: + tensor: output + """ + + out = self.activation(x) + out = self.conv1(out) + if self.norm1 is not None: + out = self.norm1(out) + + out = self.activation(out) + out = self.conv2(out) + if self.norm2 is not None: + out = self.norm2(out) + + return self.skip_add.add(out, x) + + +class FeatureFusionBlock(nn.Module): + """Feature fusion block.""" + + def __init__( + self, + features, + activation, + deconv=False, + bn=False, + expand=False, + align_corners=True, + size=None, + has_residual=True, + groups=1, + ): + """Init. + + Args: + features (int): number of features + """ + super(FeatureFusionBlock, self).__init__() + + self.deconv = deconv + self.align_corners = align_corners + self.groups = groups + self.expand = expand + out_features = features + if self.expand == True: + out_features = features // 2 + + self.out_conv = nn.Conv2d( + features, + out_features, + kernel_size=1, + stride=1, + padding=0, + bias=True, + groups=self.groups, + ) + + if has_residual: + self.resConfUnit1 = ResidualConvUnit( + features, activation, bn, groups=self.groups + ) + + self.has_residual = has_residual + self.resConfUnit2 = ResidualConvUnit( + features, activation, bn, groups=self.groups + ) + + self.skip_add = nn.quantized.FloatFunctional() + self.size = size + + def forward(self, *xs, size=None): + """Forward pass. + + Returns: + tensor: output + """ + output = xs[0] + + if self.has_residual: + res = self.resConfUnit1(xs[1]) + output = self.skip_add.add(output, res) + + output = self.resConfUnit2(output) + + if (size is None) and (self.size is None): + modifier = {"scale_factor": 2} + elif size is None: + modifier = {"size": self.size} + else: + modifier = {"size": size} + + output = custom_interpolate( + output, **modifier, mode="bilinear", align_corners=self.align_corners + ) + output = self.out_conv(output) + + return output + + +def custom_interpolate( + x: torch.Tensor, + size: Tuple[int, int] = None, + scale_factor: float = None, + mode: str = "bilinear", + align_corners: bool = True, +) -> torch.Tensor: + """ + Custom interpolate to avoid INT_MAX issues in nn.functional.interpolate. + """ + if size is None: + size = (int(x.shape[-2] * scale_factor), int(x.shape[-1] * scale_factor)) + + INT_MAX = 1610612736 + + input_elements = size[0] * size[1] * x.shape[0] * x.shape[1] + + if input_elements > INT_MAX: + chunks = torch.chunk(x, chunks=(input_elements // INT_MAX) + 1, dim=0) + interpolated_chunks = [ + nn.functional.interpolate( + chunk, size=size, mode=mode, align_corners=align_corners + ) + for chunk in chunks + ] + x = torch.cat(interpolated_chunks, dim=0) + return x.contiguous() + else: + return nn.functional.interpolate( + x, size=size, mode=mode, align_corners=align_corners + ) diff --git a/longstream/utils/vendor/models/components/heads/head_act.py b/longstream/utils/vendor/models/components/heads/head_act.py new file mode 100644 index 0000000000000000000000000000000000000000..ae75d1d471db1ed532f2aabeadfc2df1c3267ab5 --- /dev/null +++ b/longstream/utils/vendor/models/components/heads/head_act.py @@ -0,0 +1,150 @@ +import torch +import torch.nn.functional as F + + +class SafeCat(torch.autograd.Function): + """A safe version of torch.cat that prevents NaN/Inf gradients from propagating. + + This function acts as a 'fuse' - it performs normal concatenation in forward pass, + but in backward pass it replaces any NaN/Inf gradients with zeros before passing + them back to the input tensors. + """ + + @staticmethod + def forward(ctx, *args): + + *xs, dim = args + ctx.dim = dim + ctx.sizes = [x.size(dim) for x in xs] + return torch.cat(xs, dim=dim) + + @staticmethod + def backward(ctx, g_out): + + dxs = [] + s = 0 + for k in ctx.sizes: + sl = [slice(None)] * g_out.ndim + sl[ctx.dim] = slice(s, s + k) + g = g_out[tuple(sl)] + + g = torch.nan_to_num(g, nan=0.0, posinf=0.0, neginf=0.0) + dxs.append(g) + s += k + + return (*dxs, None) + + +def activate_pose( + pred_pose_enc, trans_act="linear", quat_act="linear", fl_act="linear" +): + """ + Activate pose parameters with specified activation functions. + + Args: + pred_pose_enc: Tensor containing encoded pose parameters [translation, quaternion, focal length] + trans_act: Activation type for translation component + quat_act: Activation type for quaternion component + fl_act: Activation type for focal length component + + Returns: + Activated pose parameters tensor + """ + + T = base_pose_act(pred_pose_enc[..., :3], trans_act) + quat = base_pose_act(pred_pose_enc[..., 3:7], quat_act) + fl = base_pose_act(pred_pose_enc[..., 7:], fl_act) + + quat = torch.nn.functional.normalize(quat, p=2, dim=-1, eps=1e-8) + + pred_pose_enc = SafeCat.apply(T, quat, fl, -1) + + return pred_pose_enc + + +def base_pose_act(pose_enc, act_type="linear"): + """ + Apply basic activation function to pose parameters. + + Args: + pose_enc: Tensor containing encoded pose parameters + act_type: Activation type ("linear", "inv_log", "exp", "relu") + + Returns: + Activated pose parameters + """ + if act_type == "linear": + return pose_enc + elif act_type == "inv_log": + return inverse_log_transform(pose_enc) + elif act_type == "exp": + return torch.exp(pose_enc) + elif act_type == "relu": + return F.relu(pose_enc) + else: + raise ValueError(f"Unknown act_type: {act_type}") + + +def activate_head(out, activation="norm_exp", conf_activation="expp1"): + """ + Process network output to extract 3D points and confidence values. + + Args: + out: Network output tensor (B, C, H, W) + activation: Activation type for 3D points + conf_activation: Activation type for confidence values + + Returns: + Tuple of (3D points tensor, confidence tensor) + """ + + fmap = out.permute(0, 2, 3, 1) + + xyz = fmap[:, :, :, :-1] + conf = fmap[:, :, :, -1] + + if activation == "norm_exp": + d = xyz.norm(dim=-1, keepdim=True).clamp(min=1e-8) + pts3d = (xyz / d) * torch.expm1(d) + elif activation == "norm": + pts3d = xyz / xyz.norm(dim=-1, keepdim=True) + elif activation == "exp": + pts3d = torch.exp(xyz) + elif activation == "relu": + pts3d = F.relu(xyz) + elif activation == "inv_log": + pts3d = inverse_log_transform(xyz) + elif activation == "xy_inv_log": + xy, z = xyz.split([2, 1], dim=-1) + z = inverse_log_transform(z) + pts3d = torch.cat([xy * z, z], dim=-1) + elif activation == "sigmoid": + pts3d = torch.sigmoid(xyz) + elif activation == "linear": + pts3d = xyz + else: + raise ValueError(f"Unknown activation: {activation}") + + if conf_activation == "expp1": + conf_out = 1 + conf.exp() + elif conf_activation == "expp0": + conf_out = conf.exp() + elif conf_activation == "sigmoid": + conf_out = torch.sigmoid(conf) + else: + raise ValueError(f"Unknown conf_activation: {conf_activation}") + + return pts3d, conf_out + + +def inverse_log_transform(y): + """ + Apply inverse log transform: sign(y) * (exp(|y|) - 1) + + Args: + y: Input tensor + + Returns: + Transformed tensor + """ + return torch.sign(y) * (torch.expm1(torch.abs(y))) diff --git a/longstream/utils/vendor/models/components/heads/utils.py b/longstream/utils/vendor/models/components/heads/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..0ed6ca87d2619a5a8b8011e5cb0a7aa83ffc03ae --- /dev/null +++ b/longstream/utils/vendor/models/components/heads/utils.py @@ -0,0 +1,106 @@ +import torch +import torch.nn as nn + + +def position_grid_to_embed( + pos_grid: torch.Tensor, embed_dim: int, omega_0: float = 100 +) -> torch.Tensor: + """ + Convert 2D position grid (HxWx2) to sinusoidal embeddings (HxWxC) + + Args: + pos_grid: Tensor of shape (H, W, 2) containing 2D coordinates + embed_dim: Output channel dimension for embeddings + + Returns: + Tensor of shape (H, W, embed_dim) with positional embeddings + """ + H, W, grid_dim = pos_grid.shape + assert grid_dim == 2 + pos_flat = pos_grid.reshape(-1, grid_dim) + + emb_x = make_sincos_pos_embed(embed_dim // 2, pos_flat[:, 0], omega_0=omega_0) + emb_y = make_sincos_pos_embed(embed_dim // 2, pos_flat[:, 1], omega_0=omega_0) + + emb = torch.cat([emb_x, emb_y], dim=-1) + + return emb.view(H, W, embed_dim) + + +def make_sincos_pos_embed( + embed_dim: int, pos: torch.Tensor, omega_0: float = 100 +) -> torch.Tensor: + """ + This function generates a 1D positional embedding from a given grid using sine and cosine functions. + + Args: + - embed_dim: The embedding dimension. + - pos: The position to generate the embedding from. + + Returns: + - emb: The generated 1D positional embedding. + """ + assert embed_dim % 2 == 0 + device = pos.device + omega = torch.arange( + embed_dim // 2, + dtype=torch.float32 if device.type == "mps" else torch.double, + device=device, + ) + omega /= embed_dim / 2.0 + omega = 1.0 / omega_0 ** omega + + pos = pos.reshape(-1) + out = torch.einsum("m,d->md", pos, omega) + + emb_sin = torch.sin(out) + emb_cos = torch.cos(out) + + emb = torch.cat([emb_sin, emb_cos], dim=1) + return emb.float() + + +def create_uv_grid( + width: int, + height: int, + aspect_ratio: float = None, + dtype: torch.dtype = None, + device: torch.device = None, +) -> torch.Tensor: + """ + Create a normalized UV grid of shape (width, height, 2). + + The grid spans horizontally and vertically according to an aspect ratio, + ensuring the top-left corner is at (-x_span, -y_span) and the bottom-right + corner is at (x_span, y_span), normalized by the diagonal of the plane. + + Args: + width (int): Number of points horizontally. + height (int): Number of points vertically. + aspect_ratio (float, optional): Width-to-height ratio. Defaults to width/height. + dtype (torch.dtype, optional): Data type of the resulting tensor. + device (torch.device, optional): Device on which the tensor is created. + + Returns: + torch.Tensor: A (width, height, 2) tensor of UV coordinates. + """ + + if aspect_ratio is None: + aspect_ratio = float(width) / float(height) + + diag_factor = (aspect_ratio ** 2 + 1.0) ** 0.5 + span_x = aspect_ratio / diag_factor + span_y = 1.0 / diag_factor + + left_x = -span_x * (width - 1) / width + right_x = span_x * (width - 1) / width + top_y = -span_y * (height - 1) / height + bottom_y = span_y * (height - 1) / height + + x_coords = torch.linspace(left_x, right_x, steps=width, dtype=dtype, device=device) + y_coords = torch.linspace(top_y, bottom_y, steps=height, dtype=dtype, device=device) + + uu, vv = torch.meshgrid(x_coords, y_coords, indexing="xy") + uv_grid = torch.stack((uu, vv), dim=-1) + + return uv_grid diff --git a/longstream/utils/vendor/models/components/layers/__init__.py b/longstream/utils/vendor/models/components/layers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..7b3f0048703b51bcb8bf63d327b66bef0d51f8da --- /dev/null +++ b/longstream/utils/vendor/models/components/layers/__init__.py @@ -0,0 +1,5 @@ +from .mlp import Mlp +from .patch_embed import PatchEmbed +from .swiglu_ffn import SwiGLUFFN, SwiGLUFFNFused +from .block import NestedTensorBlock +from .attention import MemEffAttention diff --git a/longstream/utils/vendor/models/components/layers/attention.py b/longstream/utils/vendor/models/components/layers/attention.py new file mode 100644 index 0000000000000000000000000000000000000000..b05c47d6ab86f7f8cf2ec02be417650757a9c54a --- /dev/null +++ b/longstream/utils/vendor/models/components/layers/attention.py @@ -0,0 +1,115 @@ +import torch +from torch import Tensor +from torch import nn +import torch.nn.functional as F + +XFORMERS_AVAILABLE = False + + +class Attention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = True, + proj_bias: bool = True, + attn_drop: float = 0.0, + proj_drop: float = 0.0, + norm_layer: nn.Module = nn.LayerNorm, + qk_norm: bool = False, + fused_attn: bool = True, + rope=None, + ) -> None: + super().__init__() + assert dim % num_heads == 0, "dim should be divisible by num_heads" + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.scale = self.head_dim ** -0.5 + self.fused_attn = fused_attn + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() + self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim, bias=proj_bias) + self.proj_drop = nn.Dropout(proj_drop) + self.rope = rope + + def forward(self, x: Tensor, pos=None, attn_mask=None, kv_cache=None) -> Tensor: + B, N, C = x.shape + + qkv = self.qkv(x) + + qkv = qkv.reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4) + q, k, v = qkv.unbind(0) + + q, k = self.q_norm(q), self.k_norm(k) + + if self.rope is not None: + q = self.rope(q, pos) + k = self.rope(k, pos) + + if kv_cache is not None: + k_cache, v_cache = kv_cache + if k_cache is not None and v_cache is not None: + k = torch.cat([k_cache, k], dim=2) + v = torch.cat([v_cache, v], dim=2) + kv_cache = [k, v] + + if self.fused_attn: + + q = q.contiguous() + k = k.contiguous() + v = v.contiguous() + if attn_mask is not None: + attn_mask = attn_mask.contiguous() + + x = F.scaled_dot_product_attention( + q, + k, + v, + dropout_p=self.attn_drop.p if self.training else 0.0, + attn_mask=attn_mask, + ) + else: + q = q * self.scale + attn = q @ k.transpose(-2, -1) + + if attn_mask is not None: + attn = attn + attn_mask + + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + x = attn @ v + + x = x.transpose(1, 2).reshape(B, N, C) + + x = self.proj(x) + + x = self.proj_drop(x) + + if kv_cache is not None: + return x, kv_cache + + return x + + +class MemEffAttention(Attention): + def forward(self, x: Tensor, attn_bias=None, pos=None) -> Tensor: + assert pos is None + if not XFORMERS_AVAILABLE: + if attn_bias is not None: + raise AssertionError("xFormers is required for using nested tensors") + return super().forward(x) + + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads) + + q, k, v = unbind(qkv, 2) + + x = memory_efficient_attention(q, k, v, attn_bias=attn_bias) + x = x.reshape([B, N, C]) + + x = self.proj(x) + x = self.proj_drop(x) + return x diff --git a/longstream/utils/vendor/models/components/layers/block.py b/longstream/utils/vendor/models/components/layers/block.py new file mode 100644 index 0000000000000000000000000000000000000000..af2c1d49294318a404c550fddc33ada4a17e52fb --- /dev/null +++ b/longstream/utils/vendor/models/components/layers/block.py @@ -0,0 +1,312 @@ +import logging +import os +from typing import Callable, List, Any, Tuple, Dict +import warnings + +import torch +from torch import nn, Tensor + +from .attention import Attention +from .drop_path import DropPath +from .layer_scale import LayerScale +from .mlp import Mlp + +XFORMERS_AVAILABLE = False + + +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + mlp_ratio: float = 4.0, + qkv_bias: bool = True, + proj_bias: bool = True, + ffn_bias: bool = True, + drop: float = 0.0, + attn_drop: float = 0.0, + init_values=None, + drop_path: float = 0.0, + act_layer: Callable[..., nn.Module] = nn.GELU, + norm_layer: Callable[..., nn.Module] = nn.LayerNorm, + attn_class: Callable[..., nn.Module] = Attention, + ffn_layer: Callable[..., nn.Module] = Mlp, + qk_norm: bool = False, + fused_attn: bool = True, + rope=None, + ) -> None: + super().__init__() + + self.norm1 = norm_layer(dim) + + self.attn = attn_class( + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + proj_bias=proj_bias, + attn_drop=attn_drop, + proj_drop=drop, + qk_norm=qk_norm, + fused_attn=fused_attn, + rope=rope, + ) + + self.ls1 = ( + LayerScale(dim, init_values=init_values) if init_values else nn.Identity() + ) + self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = ffn_layer( + in_features=dim, + hidden_features=mlp_hidden_dim, + act_layer=act_layer, + drop=drop, + bias=ffn_bias, + ) + self.ls2 = ( + LayerScale(dim, init_values=init_values) if init_values else nn.Identity() + ) + self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + self.sample_drop_ratio = drop_path + + def forward(self, x: Tensor, pos=None, attn_mask=None, kv_cache=None) -> Tensor: + def attn_residual_func( + x: Tensor, pos=None, attn_mask=None, kv_cache=None + ) -> Tensor: + if kv_cache is not None: + x, kv_cache = self.attn( + self.norm1(x), pos=pos, attn_mask=attn_mask, kv_cache=kv_cache + ) + return self.ls1(x), kv_cache + elif attn_mask is not None: + return self.ls1(self.attn(self.norm1(x), pos=pos, attn_mask=attn_mask)) + else: + return self.ls1(self.attn(self.norm1(x), pos=pos)) + + def ffn_residual_func(x: Tensor) -> Tensor: + return self.ls2(self.mlp(self.norm2(x))) + + if self.training and self.sample_drop_ratio > 0.1: + assert ( + attn_mask is None and kv_cache is None + ), "attn_mask and kv_cache are not supported for drop_add_residual_stochastic_depth yet" + + x = drop_add_residual_stochastic_depth( + x, + pos=pos, + residual_func=attn_residual_func, + sample_drop_ratio=self.sample_drop_ratio, + ) + x = drop_add_residual_stochastic_depth( + x, + residual_func=ffn_residual_func, + sample_drop_ratio=self.sample_drop_ratio, + ) + elif self.training and self.sample_drop_ratio > 0.0: + x_input = x + + if kv_cache is not None: + delta_x, kv_cache = attn_residual_func( + x, pos=pos, attn_mask=attn_mask, kv_cache=kv_cache + ) + else: + delta_x = attn_residual_func(x, pos=pos, attn_mask=attn_mask) + + x_after_attn = x + self.drop_path1(delta_x) + + x = x_after_attn + self.drop_path1(ffn_residual_func(x_after_attn)) + else: + x_input = x + + if kv_cache is not None: + delta_x, kv_cache = attn_residual_func( + x, pos=pos, attn_mask=attn_mask, kv_cache=kv_cache + ) + else: + delta_x = attn_residual_func(x, pos=pos, attn_mask=attn_mask) + + x_after_attn = x + delta_x + + x = x_after_attn + ffn_residual_func(x_after_attn) + + if kv_cache is not None: + return x, kv_cache + else: + return x + + +def drop_add_residual_stochastic_depth( + x: Tensor, + residual_func: Callable[[Tensor], Tensor], + sample_drop_ratio: float = 0.0, + pos=None, +) -> Tensor: + + b, n, d = x.shape + sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1) + brange = (torch.randperm(b, device=x.device))[:sample_subset_size] + x_subset = x[brange] + + if pos is not None: + + pos = pos[brange] + residual = residual_func(x_subset, pos=pos) + else: + residual = residual_func(x_subset) + + x_flat = x.flatten(1) + residual = residual.flatten(1) + + residual_scale_factor = b / sample_subset_size + + x_plus_residual = torch.index_add( + x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor + ) + return x_plus_residual.view_as(x) + + +def get_branges_scales(x, sample_drop_ratio=0.0): + b, n, d = x.shape + sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1) + brange = (torch.randperm(b, device=x.device))[:sample_subset_size] + residual_scale_factor = b / sample_subset_size + return brange, residual_scale_factor + + +def add_residual(x, brange, residual, residual_scale_factor, scaling_vector=None): + if scaling_vector is None: + x_flat = x.flatten(1) + residual = residual.flatten(1) + x_plus_residual = torch.index_add( + x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor + ) + else: + x_plus_residual = scaled_index_add( + x, + brange, + residual.to(dtype=x.dtype), + scaling=scaling_vector, + alpha=residual_scale_factor, + ) + return x_plus_residual + + +attn_bias_cache: Dict[Tuple, Any] = {} + + +def get_attn_bias_and_cat(x_list, branges=None): + """ + this will perform the index select, cat the tensors, and provide the attn_bias from cache + """ + batch_sizes = ( + [b.shape[0] for b in branges] + if branges is not None + else [x.shape[0] for x in x_list] + ) + all_shapes = tuple((b, x.shape[1]) for b, x in zip(batch_sizes, x_list)) + if all_shapes not in attn_bias_cache.keys(): + seqlens = [] + for b, x in zip(batch_sizes, x_list): + for _ in range(b): + seqlens.append(x.shape[1]) + attn_bias = fmha.BlockDiagonalMask.from_seqlens(seqlens) + attn_bias._batch_sizes = batch_sizes + attn_bias_cache[all_shapes] = attn_bias + + if branges is not None: + cat_tensors = index_select_cat([x.flatten(1) for x in x_list], branges).view( + 1, -1, x_list[0].shape[-1] + ) + else: + tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in x_list) + cat_tensors = torch.cat(tensors_bs1, dim=1) + + return attn_bias_cache[all_shapes], cat_tensors + + +def drop_add_residual_stochastic_depth_list( + x_list: List[Tensor], + residual_func: Callable[[Tensor, Any], Tensor], + sample_drop_ratio: float = 0.0, + scaling_vector=None, +) -> Tensor: + + branges_scales = [ + get_branges_scales(x, sample_drop_ratio=sample_drop_ratio) for x in x_list + ] + branges = [s[0] for s in branges_scales] + residual_scale_factors = [s[1] for s in branges_scales] + + attn_bias, x_cat = get_attn_bias_and_cat(x_list, branges) + + residual_list = attn_bias.split(residual_func(x_cat, attn_bias=attn_bias)) + + outputs = [] + for x, brange, residual, residual_scale_factor in zip( + x_list, branges, residual_list, residual_scale_factors + ): + outputs.append( + add_residual( + x, brange, residual, residual_scale_factor, scaling_vector + ).view_as(x) + ) + return outputs + + +class NestedTensorBlock(Block): + def forward_nested(self, x_list: List[Tensor]) -> List[Tensor]: + """ + x_list contains a list of tensors to nest together and run + """ + assert isinstance(self.attn, MemEffAttention) + + if self.training and self.sample_drop_ratio > 0.0: + + def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor: + return self.attn(self.norm1(x), attn_bias=attn_bias) + + def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor: + return self.mlp(self.norm2(x)) + + x_list = drop_add_residual_stochastic_depth_list( + x_list, + residual_func=attn_residual_func, + sample_drop_ratio=self.sample_drop_ratio, + scaling_vector=self.ls1.gamma + if isinstance(self.ls1, LayerScale) + else None, + ) + x_list = drop_add_residual_stochastic_depth_list( + x_list, + residual_func=ffn_residual_func, + sample_drop_ratio=self.sample_drop_ratio, + scaling_vector=self.ls2.gamma + if isinstance(self.ls1, LayerScale) + else None, + ) + return x_list + else: + + def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor: + return self.ls1(self.attn(self.norm1(x), attn_bias=attn_bias)) + + def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor: + return self.ls2(self.mlp(self.norm2(x))) + + attn_bias, x = get_attn_bias_and_cat(x_list) + x = x + attn_residual_func(x, attn_bias=attn_bias) + x = x + ffn_residual_func(x) + return attn_bias.split(x) + + def forward(self, x_or_x_list): + if isinstance(x_or_x_list, Tensor): + return super().forward(x_or_x_list) + elif isinstance(x_or_x_list, list): + if not XFORMERS_AVAILABLE: + raise AssertionError("xFormers is required for using nested tensors") + return self.forward_nested(x_or_x_list) + else: + raise AssertionError diff --git a/longstream/utils/vendor/models/components/layers/drop_path.py b/longstream/utils/vendor/models/components/layers/drop_path.py new file mode 100644 index 0000000000000000000000000000000000000000..310c6a9637d3996d4a6f8517dd15f9a2af6946ca --- /dev/null +++ b/longstream/utils/vendor/models/components/layers/drop_path.py @@ -0,0 +1,24 @@ +from torch import nn + + +def drop_path(x, drop_prob: float = 0.0, training: bool = False): + if drop_prob == 0.0 or not training: + return x + keep_prob = 1 - drop_prob + shape = (x.shape[0],) + (1,) * (x.ndim - 1) + random_tensor = x.new_empty(shape).bernoulli_(keep_prob) + if keep_prob > 0.0: + random_tensor.div_(keep_prob) + output = x * random_tensor + return output + + +class DropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" + + def __init__(self, drop_prob=None): + super(DropPath, self).__init__() + self.drop_prob = drop_prob + + def forward(self, x): + return drop_path(x, self.drop_prob, self.training) diff --git a/longstream/utils/vendor/models/components/layers/layer_scale.py b/longstream/utils/vendor/models/components/layers/layer_scale.py new file mode 100644 index 0000000000000000000000000000000000000000..9e83807bdec80588636761c42e4a245fda6a67ea --- /dev/null +++ b/longstream/utils/vendor/models/components/layers/layer_scale.py @@ -0,0 +1,20 @@ +from typing import Union + +import torch +from torch import Tensor +from torch import nn + + +class LayerScale(nn.Module): + def __init__( + self, + dim: int, + init_values: Union[float, Tensor] = 1e-5, + inplace: bool = False, + ) -> None: + super().__init__() + self.inplace = inplace + self.gamma = nn.Parameter(init_values * torch.ones(dim)) + + def forward(self, x: Tensor) -> Tensor: + return x.mul_(self.gamma) if self.inplace else x * self.gamma diff --git a/longstream/utils/vendor/models/components/layers/mlp.py b/longstream/utils/vendor/models/components/layers/mlp.py new file mode 100644 index 0000000000000000000000000000000000000000..8d12860274f57e1aeb080d615421225474ee9bf7 --- /dev/null +++ b/longstream/utils/vendor/models/components/layers/mlp.py @@ -0,0 +1,30 @@ +from typing import Callable, Optional + +from torch import Tensor, nn + + +class Mlp(nn.Module): + def __init__( + self, + in_features: int, + hidden_features: Optional[int] = None, + out_features: Optional[int] = None, + act_layer: Callable[..., nn.Module] = nn.GELU, + drop: float = 0.0, + bias: bool = True, + ) -> None: + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features, bias=bias) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features, bias=bias) + self.drop = nn.Dropout(drop) + + def forward(self, x: Tensor) -> Tensor: + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x diff --git a/longstream/utils/vendor/models/components/layers/patch_embed.py b/longstream/utils/vendor/models/components/layers/patch_embed.py new file mode 100644 index 0000000000000000000000000000000000000000..dec0da785a791d9e0c56b0185f759268622b16e4 --- /dev/null +++ b/longstream/utils/vendor/models/components/layers/patch_embed.py @@ -0,0 +1,91 @@ +from typing import Callable, Optional, Tuple, Union + +from torch import Tensor +import torch.nn as nn + + +def make_2tuple(x): + if isinstance(x, tuple): + assert len(x) == 2 + return x + + assert isinstance(x, int) + return (x, x) + + +class PatchEmbed(nn.Module): + """ + 2D image to patch embedding: (B,C,H,W) -> (B,N,D) + + Args: + img_size: Image size. + patch_size: Patch token size. + in_chans: Number of input image channels. + embed_dim: Number of linear projection output channels. + norm_layer: Normalization layer. + """ + + def __init__( + self, + img_size: Union[int, Tuple[int, int]] = 224, + patch_size: Union[int, Tuple[int, int]] = 16, + in_chans: int = 3, + embed_dim: int = 768, + norm_layer: Optional[Callable] = None, + flatten_embedding: bool = True, + ) -> None: + super().__init__() + + image_HW = make_2tuple(img_size) + patch_HW = make_2tuple(patch_size) + patch_grid_size = ( + image_HW[0] // patch_HW[0], + image_HW[1] // patch_HW[1], + ) + + self.img_size = image_HW + self.patch_size = patch_HW + self.patches_resolution = patch_grid_size + self.num_patches = patch_grid_size[0] * patch_grid_size[1] + + self.in_chans = in_chans + self.embed_dim = embed_dim + + self.flatten_embedding = flatten_embedding + + self.proj = nn.Conv2d( + in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW + ) + self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() + + def forward(self, x: Tensor) -> Tensor: + _, _, H, W = x.shape + patch_H, patch_W = self.patch_size + + assert ( + H % patch_H == 0 + ), f"Input image height {H} is not a multiple of patch height {patch_H}" + assert ( + W % patch_W == 0 + ), f"Input image width {W} is not a multiple of patch width: {patch_W}" + + x = self.proj(x) + H, W = x.size(2), x.size(3) + x = x.flatten(2).transpose(1, 2) + x = self.norm(x) + if not self.flatten_embedding: + x = x.reshape(-1, H, W, self.embed_dim) + return x + + def flops(self) -> float: + Ho, Wo = self.patches_resolution + flops = ( + Ho + * Wo + * self.embed_dim + * self.in_chans + * (self.patch_size[0] * self.patch_size[1]) + ) + if self.norm is not None: + flops += Ho * Wo * self.embed_dim + return flops diff --git a/longstream/utils/vendor/models/components/layers/rope.py b/longstream/utils/vendor/models/components/layers/rope.py new file mode 100644 index 0000000000000000000000000000000000000000..771dd727385150775f470e82291c9ea6e34933fe --- /dev/null +++ b/longstream/utils/vendor/models/components/layers/rope.py @@ -0,0 +1,381 @@ +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from typing import Dict, Tuple + + +class PositionGetter: + """Generates and caches 2D/3D spatial positions for patches in a grid. + + This class efficiently manages the generation of spatial coordinates for patches + in a 2D grid, caching results to avoid redundant computations. It supports both + 2D (spatial only) and 3D (spatial + temporal) position generation. + + Attributes: + position_cache: Dictionary storing precomputed position tensors for different + grid dimensions. + """ + + def __init__(self): + """Initializes the position generator with an empty cache.""" + self.position_cache: Dict[Tuple[int, int], torch.Tensor] = {} + + def __call__( + self, batch_size: int, height: int, width: int, device: torch.device + ) -> torch.Tensor: + """Generates spatial positions for a batch of patches. + + Args: + batch_size: Number of samples in the batch. + height: Height of the grid in patches. + width: Width of the grid in patches. + device: Target device for the position tensor. + + Returns: + Tensor of shape (batch_size, height*width, 2) containing y,x coordinates + for each position in the grid, repeated for each batch item. + """ + if (height, width) not in self.position_cache: + y_coords = torch.arange(height, device=device) + x_coords = torch.arange(width, device=device) + positions = torch.cartesian_prod(y_coords, x_coords) + self.position_cache[height, width] = positions + + cached_positions = self.position_cache[height, width] + return ( + cached_positions.view(1, height * width, 2) + .expand(batch_size, -1, -1) + .clone() + ) + + def get_3d_positions( + self, + batch_size: int, + seq_len: int, + height: int, + width: int, + device: torch.device, + ) -> torch.Tensor: + """Generates 3D positions (spatial + temporal) for a batch of frame sequences. + + Args: + batch_size: Number of samples in the batch (B). + seq_len: Number of frames in the sequence (S). + height: Height of the grid in patches. + width: Width of the grid in patches. + device: Target device for the position tensor. + + Returns: + Tensor of shape (batch_size*seq_len, height*width, 3) containing y, x, t coordinates. + The temporal coordinate (t) is the frame index, shared by all patches in the same frame. + """ + + spatial_positions = self(1, height, width, device) + + temporal_indices = torch.arange(seq_len, device=device) + + batch_seq_size = batch_size * seq_len + + spatial_expanded = spatial_positions.expand(batch_seq_size, -1, -1) + + num_patches = height * width + temporal_column = temporal_indices.repeat(batch_size) + temporal_column = temporal_column.view(batch_seq_size, 1, 1) + temporal_column = temporal_column.expand(-1, num_patches, -1) + + positions_3d = torch.cat([spatial_expanded, temporal_column], dim=-1) + + return positions_3d + + +class RotaryPositionEmbedding2D(nn.Module): + """2D Rotary Position Embedding implementation. + + This module applies rotary position embeddings to input tokens based on their + 2D spatial positions. It handles the position-dependent rotation of features + separately for vertical and horizontal dimensions. + + Args: + frequency: Base frequency for the position embeddings. Default: 100.0 + scaling_factor: Scaling factor for frequency computation. Default: 1.0 + + Attributes: + base_frequency: Base frequency for computing position embeddings. + scaling_factor: Factor to scale the computed frequencies. + frequency_cache: Cache for storing precomputed frequency components. + """ + + def __init__(self, frequency: float = 100.0, scaling_factor: float = 1.0): + """Initializes the 2D RoPE module.""" + super().__init__() + self.base_frequency = frequency + self.scaling_factor = scaling_factor + self.frequency_cache: Dict[Tuple, Tuple[torch.Tensor, torch.Tensor]] = {} + + def _compute_frequency_components( + self, dim: int, seq_len: int, device: torch.device, dtype: torch.dtype + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Computes frequency components for rotary embeddings. + + Args: + dim: Feature dimension (must be even). + seq_len: Maximum sequence length. + device: Target device for computations. + dtype: Data type for the computed tensors. + + Returns: + Tuple of (cosine, sine) tensors for frequency components. + """ + cache_key = (dim, seq_len, device, dtype) + if cache_key not in self.frequency_cache: + + exponents = torch.arange(0, dim, 2, device=device).float() / dim + inv_freq = 1.0 / (self.base_frequency ** exponents) + + positions = torch.arange(seq_len, device=device, dtype=inv_freq.dtype) + angles = torch.einsum("i,j->ij", positions, inv_freq) + + angles = angles.to(dtype) + angles = torch.cat((angles, angles), dim=-1) + cos_components = angles.cos().to(dtype) + sin_components = angles.sin().to(dtype) + self.frequency_cache[cache_key] = (cos_components, sin_components) + + return self.frequency_cache[cache_key] + + @staticmethod + def _rotate_features(x: torch.Tensor) -> torch.Tensor: + """Performs feature rotation by splitting and recombining feature dimensions. + + Args: + x: Input tensor to rotate. + + Returns: + Rotated feature tensor. + """ + feature_dim = x.shape[-1] + x1, x2 = x[..., : feature_dim // 2], x[..., feature_dim // 2 :] + return torch.cat((-x2, x1), dim=-1) + + def _apply_1d_rope( + self, + tokens: torch.Tensor, + positions: torch.Tensor, + cos_comp: torch.Tensor, + sin_comp: torch.Tensor, + ) -> torch.Tensor: + """Applies 1D rotary position embeddings along one dimension. + + Args: + tokens: Input token features. + positions: Position indices. + cos_comp: Cosine components for rotation. + sin_comp: Sine components for rotation. + + Returns: + Tokens with applied rotary position embeddings. + """ + + cos = F.embedding(positions, cos_comp)[:, None, :, :] + sin = F.embedding(positions, sin_comp)[:, None, :, :] + + return (tokens * cos) + (self._rotate_features(tokens) * sin) + + def forward(self, tokens: torch.Tensor, positions: torch.Tensor) -> torch.Tensor: + """Applies 2D rotary position embeddings to input tokens. + + Args: + tokens: Input tensor of shape (batch_size, n_heads, n_tokens, dim). + The feature dimension (dim) must be divisible by 4. + positions: Position tensor of shape (batch_size, n_tokens, 2) containing + the y and x coordinates for each token. + + Returns: + Tensor of same shape as input with applied 2D rotary position embeddings. + + Raises: + AssertionError: If input dimensions are invalid or positions are malformed. + """ + + assert tokens.size(-1) % 2 == 0, "Feature dimension must be even" + assert ( + positions.ndim == 3 and positions.shape[-1] == 2 + ), "Positions must have shape (batch_size, n_tokens, 2)" + + feature_dim = tokens.size(-1) // 2 + + max_position = int(positions.max()) + 1 + cos_comp, sin_comp = self._compute_frequency_components( + feature_dim, max_position, tokens.device, tokens.dtype + ) + + vertical_features, horizontal_features = tokens.chunk(2, dim=-1) + + vertical_features = self._apply_1d_rope( + vertical_features, positions[..., 0], cos_comp, sin_comp + ) + horizontal_features = self._apply_1d_rope( + horizontal_features, positions[..., 1], cos_comp, sin_comp + ) + + return torch.cat((vertical_features, horizontal_features), dim=-1) + + +class RotaryPositionEmbedding3D(nn.Module): + """3D Rotary Position Embedding implementation. + + This module extends 2D RoPE to handle 3D positions (spatial + temporal). + It applies rotary position embeddings based on y, x, and t coordinates, + splitting the feature dimension into three parts. + + Args: + frequency: Base frequency for the position embeddings. Default: 100.0 + scaling_factor: Scaling factor for frequency computation. Default: 1.0 + + Attributes: + base_frequency: Base frequency for computing position embeddings. + scaling_factor: Factor to scale the computed frequencies. + frequency_cache: Cache for storing precomputed frequency components. + """ + + def __init__(self, frequency: float = 100.0, scaling_factor: float = 1.0): + """Initializes the 3D RoPE module.""" + super().__init__() + self.base_frequency = frequency + self.scaling_factor = scaling_factor + self.frequency_cache: Dict[Tuple, Tuple[torch.Tensor, torch.Tensor]] = {} + + def _compute_frequency_components( + self, dim: int, seq_len: int, device: torch.device, dtype: torch.dtype + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Computes frequency components for rotary embeddings. + + Args: + dim: Feature dimension (must be even). + seq_len: Maximum sequence length. + device: Target device for computations. + dtype: Data type for the computed tensors. + + Returns: + Tuple of (cosine, sine) tensors for frequency components. + """ + cache_key = (dim, seq_len, device, dtype) + if cache_key not in self.frequency_cache: + + exponents = torch.arange(0, dim, 2, device=device).float() / dim + inv_freq = 1.0 / (self.base_frequency ** exponents) + + positions = torch.arange(seq_len, device=device, dtype=inv_freq.dtype) + angles = torch.einsum("i,j->ij", positions, inv_freq) + + angles = angles.to(dtype) + angles = torch.cat((angles, angles), dim=-1) + cos_components = angles.cos().to(dtype) + sin_components = angles.sin().to(dtype) + self.frequency_cache[cache_key] = (cos_components, sin_components) + + return self.frequency_cache[cache_key] + + @staticmethod + def _rotate_features(x: torch.Tensor) -> torch.Tensor: + """Performs feature rotation by splitting and recombining feature dimensions. + + Args: + x: Input tensor to rotate. + + Returns: + Rotated feature tensor. + """ + feature_dim = x.shape[-1] + x1, x2 = x[..., : feature_dim // 2], x[..., feature_dim // 2 :] + return torch.cat((-x2, x1), dim=-1) + + def _apply_1d_rope( + self, + tokens: torch.Tensor, + positions: torch.Tensor, + cos_comp: torch.Tensor, + sin_comp: torch.Tensor, + ) -> torch.Tensor: + """Applies 1D rotary position embeddings along one dimension. + + Args: + tokens: Input token features. + positions: Position indices. + cos_comp: Cosine components for rotation. + sin_comp: Sine components for rotation. + + Returns: + Tokens with applied rotary position embeddings. + """ + + cos = F.embedding(positions, cos_comp)[:, None, :, :] + sin = F.embedding(positions, sin_comp)[:, None, :, :] + + return (tokens * cos) + (self._rotate_features(tokens) * sin) + + def forward(self, tokens: torch.Tensor, positions: torch.Tensor) -> torch.Tensor: + """Applies 3D rotary position embeddings to input tokens. + + Args: + tokens: Input tensor of shape (batch_size, n_heads, n_tokens, dim). + The feature dimension (dim) must be divisible by 6 (for even distribution). + positions: Position tensor of shape (batch_size, n_tokens, 3) containing + the y, x, and t coordinates for each token. + + Returns: + Tensor of same shape as input with applied 3D rotary position embeddings. + + Raises: + AssertionError: If input dimensions are invalid or positions are malformed. + """ + + assert tokens.size(-1) % 2 == 0, "Feature dimension must be even" + assert ( + positions.ndim == 3 and positions.shape[-1] == 3 + ), "Positions must have shape (batch_size, n_tokens, 3)" + + total_dim = tokens.size(-1) + dim_per_axis = total_dim // 3 + + if dim_per_axis % 2 != 0: + dim_per_axis = dim_per_axis - 1 + + y_dim = dim_per_axis + x_dim = dim_per_axis + t_dim = total_dim - y_dim - x_dim + + if t_dim % 2 != 0: + + x_dim = x_dim - 1 + t_dim = total_dim - y_dim - x_dim + + y_features = tokens[..., :y_dim] + x_features = tokens[..., y_dim : y_dim + x_dim] + t_features = tokens[..., y_dim + x_dim :] + + max_position = int(positions.max()) + 1 + + cos_comp_y, sin_comp_y = self._compute_frequency_components( + y_dim, max_position, tokens.device, tokens.dtype + ) + y_features = self._apply_1d_rope( + y_features, positions[..., 0], cos_comp_y, sin_comp_y + ) + + cos_comp_x, sin_comp_x = self._compute_frequency_components( + x_dim, max_position, tokens.device, tokens.dtype + ) + x_features = self._apply_1d_rope( + x_features, positions[..., 1], cos_comp_x, sin_comp_x + ) + + cos_comp_t, sin_comp_t = self._compute_frequency_components( + t_dim, max_position, tokens.device, tokens.dtype + ) + t_features = self._apply_1d_rope( + t_features, positions[..., 2], cos_comp_t, sin_comp_t + ) + + return torch.cat((y_features, x_features, t_features), dim=-1) diff --git a/longstream/utils/vendor/models/components/layers/swiglu_ffn.py b/longstream/utils/vendor/models/components/layers/swiglu_ffn.py new file mode 100644 index 0000000000000000000000000000000000000000..d58ce5a94bdabf9621d6cfd09e77b9151be71650 --- /dev/null +++ b/longstream/utils/vendor/models/components/layers/swiglu_ffn.py @@ -0,0 +1,56 @@ +import os +from typing import Callable, Optional +import warnings + +from torch import Tensor, nn +import torch.nn.functional as F + + +class SwiGLUFFN(nn.Module): + def __init__( + self, + in_features: int, + hidden_features: Optional[int] = None, + out_features: Optional[int] = None, + act_layer: Callable[..., nn.Module] = None, + drop: float = 0.0, + bias: bool = True, + ) -> None: + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias) + self.w3 = nn.Linear(hidden_features, out_features, bias=bias) + + def forward(self, x: Tensor) -> Tensor: + x12 = self.w12(x) + x1, x2 = x12.chunk(2, dim=-1) + hidden = F.silu(x1) * x2 + return self.w3(hidden) + + +XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None + +SwiGLU = SwiGLUFFN +XFORMERS_AVAILABLE = False + + +class SwiGLUFFNFused(SwiGLU): + def __init__( + self, + in_features: int, + hidden_features: Optional[int] = None, + out_features: Optional[int] = None, + act_layer: Callable[..., nn.Module] = None, + drop: float = 0.0, + bias: bool = True, + ) -> None: + out_features = out_features or in_features + hidden_features = hidden_features or in_features + hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8 + super().__init__( + in_features=in_features, + hidden_features=hidden_features, + out_features=out_features, + bias=bias, + ) diff --git a/longstream/utils/vendor/models/components/layers/vision_transformer.py b/longstream/utils/vendor/models/components/layers/vision_transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..8c7ac5e151c5e12671120f4a72d050776b4bb165 --- /dev/null +++ b/longstream/utils/vendor/models/components/layers/vision_transformer.py @@ -0,0 +1,437 @@ +from functools import partial +import math +import logging +from typing import Sequence, Tuple, Union, Callable + +import torch +import torch.nn as nn +from torch.utils.checkpoint import checkpoint +from torch.nn.init import trunc_normal_ +from . import ( + Mlp, + PatchEmbed, + SwiGLUFFNFused, + MemEffAttention, + NestedTensorBlock as Block, +) + +logger = logging.getLogger("dinov2") + + +def named_apply( + fn: Callable, module: nn.Module, name="", depth_first=True, include_root=False +) -> nn.Module: + if not depth_first and include_root: + fn(module=module, name=name) + for child_name, child_module in module.named_children(): + child_name = ".".join((name, child_name)) if name else child_name + named_apply( + fn=fn, + module=child_module, + name=child_name, + depth_first=depth_first, + include_root=True, + ) + if depth_first and include_root: + fn(module=module, name=name) + return module + + +class BlockChunk(nn.ModuleList): + def forward(self, x): + for b in self: + x = b(x) + return x + + +class DinoVisionTransformer(nn.Module): + def __init__( + self, + img_size=224, + patch_size=16, + in_chans=3, + embed_dim=768, + depth=12, + num_heads=12, + mlp_ratio=4.0, + qkv_bias=True, + ffn_bias=True, + proj_bias=True, + drop_path_rate=0.0, + drop_path_uniform=False, + init_values=None, + embed_layer=PatchEmbed, + act_layer=nn.GELU, + block_fn=Block, + ffn_layer="mlp", + block_chunks=1, + num_register_tokens=0, + interpolate_antialias=False, + interpolate_offset=0.1, + qk_norm=False, + ): + """ + Args: + img_size (int, tuple): input image size + patch_size (int, tuple): patch size + in_chans (int): number of input channels + embed_dim (int): embedding dimension + depth (int): depth of transformer + num_heads (int): number of attention heads + mlp_ratio (int): ratio of mlp hidden dim to embedding dim + qkv_bias (bool): enable bias for qkv if True + proj_bias (bool): enable bias for proj in attn if True + ffn_bias (bool): enable bias for ffn if True + drop_path_rate (float): stochastic depth rate + drop_path_uniform (bool): apply uniform drop rate across blocks + weight_init (str): weight init scheme + init_values (float): layer-scale init values + embed_layer (nn.Module): patch embedding layer + act_layer (nn.Module): MLP activation layer + block_fn (nn.Module): transformer block class + ffn_layer (str): "mlp", "swiglu", "swiglufused" or "identity" + block_chunks: (int) split block sequence into block_chunks units for FSDP wrap + num_register_tokens: (int) number of extra cls tokens (so-called "registers") + interpolate_antialias: (str) flag to apply anti-aliasing when interpolating positional embeddings + interpolate_offset: (float) work-around offset to apply when interpolating positional embeddings + """ + super().__init__() + norm_layer = partial(nn.LayerNorm, eps=1e-6) + + self.use_checkpoint = False + + self.num_features = self.embed_dim = embed_dim + self.num_tokens = 1 + self.n_blocks = depth + self.num_heads = num_heads + self.patch_size = patch_size + self.num_register_tokens = num_register_tokens + self.interpolate_antialias = interpolate_antialias + self.interpolate_offset = interpolate_offset + + self.patch_embed = embed_layer( + img_size=img_size, + patch_size=patch_size, + in_chans=in_chans, + embed_dim=embed_dim, + ) + num_patches = self.patch_embed.num_patches + + self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) + self.pos_embed = nn.Parameter( + torch.zeros(1, num_patches + self.num_tokens, embed_dim) + ) + assert num_register_tokens >= 0 + self.register_tokens = ( + nn.Parameter(torch.zeros(1, num_register_tokens, embed_dim)) + if num_register_tokens + else None + ) + + if drop_path_uniform is True: + dpr = [drop_path_rate] * depth + else: + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] + + if ffn_layer == "mlp": + logger.info("using MLP layer as FFN") + ffn_layer = Mlp + elif ffn_layer == "swiglufused" or ffn_layer == "swiglu": + logger.info("using SwiGLU layer as FFN") + ffn_layer = SwiGLUFFNFused + elif ffn_layer == "identity": + logger.info("using Identity layer as FFN") + + def f(*args, **kwargs): + return nn.Identity() + + ffn_layer = f + else: + raise NotImplementedError + + blocks_list = [ + block_fn( + dim=embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + proj_bias=proj_bias, + ffn_bias=ffn_bias, + drop_path=dpr[i], + norm_layer=norm_layer, + act_layer=act_layer, + ffn_layer=ffn_layer, + init_values=init_values, + qk_norm=qk_norm, + ) + for i in range(depth) + ] + if block_chunks > 0: + self.chunked_blocks = True + chunked_blocks = [] + chunksize = depth // block_chunks + for i in range(0, depth, chunksize): + + chunked_blocks.append( + [nn.Identity()] * i + blocks_list[i : i + chunksize] + ) + self.blocks = nn.ModuleList([BlockChunk(p) for p in chunked_blocks]) + else: + self.chunked_blocks = False + self.blocks = nn.ModuleList(blocks_list) + + self.norm = norm_layer(embed_dim) + self.head = nn.Identity() + + self.mask_token = nn.Parameter(torch.zeros(1, embed_dim)) + + self.init_weights() + + def init_weights(self): + trunc_normal_(self.pos_embed, std=0.02) + nn.init.normal_(self.cls_token, std=1e-6) + if self.register_tokens is not None: + nn.init.normal_(self.register_tokens, std=1e-6) + named_apply(init_weights_vit_timm, self) + + def interpolate_pos_encoding(self, x, w, h): + previous_dtype = x.dtype + npatch = x.shape[1] - 1 + N = self.pos_embed.shape[1] - 1 + if npatch == N and w == h: + return self.pos_embed + pos_embed = self.pos_embed.float() + class_pos_embed = pos_embed[:, 0] + patch_pos_embed = pos_embed[:, 1:] + dim = x.shape[-1] + w0 = w // self.patch_size + h0 = h // self.patch_size + M = int(math.sqrt(N)) + assert N == M * M + kwargs = {} + if self.interpolate_offset: + + sx = float(w0 + self.interpolate_offset) / M + sy = float(h0 + self.interpolate_offset) / M + kwargs["scale_factor"] = (sx, sy) + else: + + kwargs["size"] = (w0, h0) + patch_pos_embed = nn.functional.interpolate( + patch_pos_embed.reshape(1, M, M, dim).permute(0, 3, 1, 2), + mode="bicubic", + antialias=self.interpolate_antialias, + **kwargs, + ) + assert (w0, h0) == patch_pos_embed.shape[-2:] + patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) + return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to( + previous_dtype + ) + + def prepare_tokens_with_masks(self, x, masks=None): + B, nc, w, h = x.shape + x = self.patch_embed(x) + if masks is not None: + x = torch.where( + masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x + ) + + x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1) + x = x + self.interpolate_pos_encoding(x, w, h) + + if self.register_tokens is not None: + x = torch.cat( + ( + x[:, :1], + self.register_tokens.expand(x.shape[0], -1, -1), + x[:, 1:], + ), + dim=1, + ) + + return x + + def forward_features_list(self, x_list, masks_list): + x = [ + self.prepare_tokens_with_masks(x, masks) + for x, masks in zip(x_list, masks_list) + ] + + for blk in self.blocks: + if self.use_checkpoint: + x = checkpoint(blk, x, use_reentrant=self.use_reentrant) + else: + x = blk(x) + + all_x = x + output = [] + for x, masks in zip(all_x, masks_list): + x_norm = self.norm(x) + output.append( + { + "x_norm_clstoken": x_norm[:, 0], + "x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1], + "x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :], + "x_prenorm": x, + "masks": masks, + } + ) + return output + + def forward_features(self, x, masks=None): + if isinstance(x, list): + return self.forward_features_list(x, masks) + + x = self.prepare_tokens_with_masks(x, masks) + + for blk in self.blocks: + if self.use_checkpoint: + x = checkpoint(blk, x, use_reentrant=self.use_reentrant) + else: + x = blk(x) + + x_norm = self.norm(x) + return { + "x_norm_clstoken": x_norm[:, 0], + "x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1], + "x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :], + "x_prenorm": x, + "masks": masks, + } + + def _get_intermediate_layers_not_chunked(self, x, n=1): + x = self.prepare_tokens_with_masks(x) + + output, total_block_len = [], len(self.blocks) + blocks_to_take = ( + range(total_block_len - n, total_block_len) if isinstance(n, int) else n + ) + for i, blk in enumerate(self.blocks): + x = blk(x) + if i in blocks_to_take: + output.append(x) + assert len(output) == len( + blocks_to_take + ), f"only {len(output)} / {len(blocks_to_take)} blocks found" + return output + + def _get_intermediate_layers_chunked(self, x, n=1): + x = self.prepare_tokens_with_masks(x) + output, i, total_block_len = [], 0, len(self.blocks[-1]) + + blocks_to_take = ( + range(total_block_len - n, total_block_len) if isinstance(n, int) else n + ) + for block_chunk in self.blocks: + for blk in block_chunk[i:]: + x = blk(x) + if i in blocks_to_take: + output.append(x) + i += 1 + assert len(output) == len( + blocks_to_take + ), f"only {len(output)} / {len(blocks_to_take)} blocks found" + return output + + def get_intermediate_layers( + self, + x: torch.Tensor, + n: Union[int, Sequence] = 1, + reshape: bool = False, + return_class_token: bool = False, + norm=True, + ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]: + if self.chunked_blocks: + outputs = self._get_intermediate_layers_chunked(x, n) + else: + outputs = self._get_intermediate_layers_not_chunked(x, n) + if norm: + outputs = [self.norm(out) for out in outputs] + class_tokens = [out[:, 0] for out in outputs] + outputs = [out[:, 1 + self.num_register_tokens :] for out in outputs] + if reshape: + B, _, w, h = x.shape + outputs = [ + out.reshape(B, w // self.patch_size, h // self.patch_size, -1) + .permute(0, 3, 1, 2) + .contiguous() + for out in outputs + ] + if return_class_token: + return tuple(zip(outputs, class_tokens)) + return tuple(outputs) + + def forward(self, *args, is_training=True, **kwargs): + ret = self.forward_features(*args, **kwargs) + if is_training: + return ret + else: + return self.head(ret["x_norm_clstoken"]) + + +def init_weights_vit_timm(module: nn.Module, name: str = ""): + """ViT weight initialization, original timm impl (for reproducibility)""" + if isinstance(module, nn.Linear): + trunc_normal_(module.weight, std=0.02) + if module.bias is not None: + nn.init.zeros_(module.bias) + + +def vit_small(patch_size=16, num_register_tokens=0, **kwargs): + model = DinoVisionTransformer( + patch_size=patch_size, + embed_dim=384, + depth=12, + num_heads=6, + mlp_ratio=4, + block_fn=partial(Block, attn_class=MemEffAttention), + num_register_tokens=num_register_tokens, + **kwargs, + ) + return model + + +def vit_base(patch_size=16, num_register_tokens=0, **kwargs): + model = DinoVisionTransformer( + patch_size=patch_size, + embed_dim=768, + depth=12, + num_heads=12, + mlp_ratio=4, + block_fn=partial(Block, attn_class=MemEffAttention), + num_register_tokens=num_register_tokens, + **kwargs, + ) + return model + + +def vit_large(patch_size=16, num_register_tokens=0, **kwargs): + model = DinoVisionTransformer( + patch_size=patch_size, + embed_dim=1024, + depth=24, + num_heads=16, + mlp_ratio=4, + block_fn=partial(Block, attn_class=MemEffAttention), + num_register_tokens=num_register_tokens, + **kwargs, + ) + return model + + +def vit_giant2(patch_size=16, num_register_tokens=0, **kwargs): + """ + Close to ViT-giant, with embed-dim 1536 and 24 heads => embed-dim per head 64 + """ + model = DinoVisionTransformer( + patch_size=patch_size, + embed_dim=1536, + depth=40, + num_heads=24, + mlp_ratio=4, + block_fn=partial(Block, attn_class=MemEffAttention), + num_register_tokens=num_register_tokens, + **kwargs, + ) + return model diff --git a/longstream/utils/vendor/models/components/utils/geometry.py b/longstream/utils/vendor/models/components/utils/geometry.py new file mode 100644 index 0000000000000000000000000000000000000000..5eb95f415048bd4ee47b290013ebaa383c874ca2 --- /dev/null +++ b/longstream/utils/vendor/models/components/utils/geometry.py @@ -0,0 +1,246 @@ +import os +import torch +import numpy as np + + +def unproject_depth_map_to_point_map( + depth_map: np.ndarray, extrinsics_cam: np.ndarray, intrinsics_cam: np.ndarray +) -> np.ndarray: + """ + Unproject a batch of depth maps to 3D world coordinates. + + Args: + depth_map (np.ndarray): Batch of depth maps of shape (S, H, W, 1) or (S, H, W) + extrinsics_cam (np.ndarray): Batch of camera extrinsic matrices of shape (S, 3, 4) + intrinsics_cam (np.ndarray): Batch of camera intrinsic matrices of shape (S, 3, 3) + + Returns: + np.ndarray: Batch of 3D world coordinates of shape (S, H, W, 3) + """ + if isinstance(depth_map, torch.Tensor): + depth_map = depth_map.cpu().numpy() + if isinstance(extrinsics_cam, torch.Tensor): + extrinsics_cam = extrinsics_cam.cpu().numpy() + if isinstance(intrinsics_cam, torch.Tensor): + intrinsics_cam = intrinsics_cam.cpu().numpy() + + world_points_list = [] + for frame_idx in range(depth_map.shape[0]): + cur_world_points, _, _ = depth_to_world_coords_points( + depth_map[frame_idx].squeeze(-1), + extrinsics_cam[frame_idx], + intrinsics_cam[frame_idx], + ) + world_points_list.append(cur_world_points) + world_points_array = np.stack(world_points_list, axis=0) + + return world_points_array + + +def depth_to_world_coords_points( + depth_map: np.ndarray, + extrinsic: np.ndarray, + intrinsic: np.ndarray, + eps=1e-8, +) -> tuple[np.ndarray, np.ndarray, np.ndarray]: + """ + Convert a depth map to world coordinates. + + Args: + depth_map (np.ndarray): Depth map of shape (H, W). + intrinsic (np.ndarray): Camera intrinsic matrix of shape (3, 3). + extrinsic (np.ndarray): Camera extrinsic matrix of shape (3, 4). OpenCV camera coordinate convention, cam from world. + + Returns: + tuple[np.ndarray, np.ndarray]: World coordinates (H, W, 3) and valid depth mask (H, W). + """ + if depth_map is None: + return None, None, None + + point_mask = depth_map > eps + + cam_coords_points = depth_to_cam_coords_points(depth_map, intrinsic) + + cam_to_world_extrinsic = closed_form_inverse_se3(extrinsic[None])[0] + + R_cam_to_world = cam_to_world_extrinsic[:3, :3] + t_cam_to_world = cam_to_world_extrinsic[:3, 3] + + world_coords_points = np.dot(cam_coords_points, R_cam_to_world.T) + t_cam_to_world + + return world_coords_points, cam_coords_points, point_mask + + +def depth_to_cam_coords_points( + depth_map: np.ndarray, intrinsic: np.ndarray +) -> tuple[np.ndarray, np.ndarray]: + """ + Convert a depth map to camera coordinates. + + Args: + depth_map (np.ndarray): Depth map of shape (H, W). + intrinsic (np.ndarray): Camera intrinsic matrix of shape (3, 3). + + Returns: + tuple[np.ndarray, np.ndarray]: Camera coordinates (H, W, 3) + """ + H, W = depth_map.shape + assert intrinsic.shape == (3, 3), "Intrinsic matrix must be 3x3" + assert ( + intrinsic[0, 1] == 0 and intrinsic[1, 0] == 0 + ), "Intrinsic matrix must have zero skew" + + fu, fv = intrinsic[0, 0], intrinsic[1, 1] + cu, cv = intrinsic[0, 2], intrinsic[1, 2] + + u, v = np.meshgrid(np.arange(W), np.arange(H)) + + x_cam = (u - cu) * depth_map / fu + y_cam = (v - cv) * depth_map / fv + z_cam = depth_map + + cam_coords = np.stack((x_cam, y_cam, z_cam), axis=-1).astype(np.float32) + + return cam_coords + + +def normalize_depth_using_non_zero_pixels(depth, return_norm_factor=False): + """ + Normalize the depth by the average depth of non-zero depth pixels. + Compatible with MapAnything's implementation. + + Args: + depth (torch.Tensor): Depth tensor of size [B, H, W, 1]. + return_norm_factor (bool): Whether to return the normalization factor. + + Returns: + normalized_depth (torch.Tensor): Normalized depth tensor. + norm_factor (torch.Tensor): Norm factor tensor of size B (if return_norm_factor=True). + """ + assert depth.ndim == 4 and depth.shape[3] == 1 + + valid_depth_mask = depth > 0 + valid_sum = torch.sum(depth * valid_depth_mask, dim=(1, 2, 3)) + valid_count = torch.sum(valid_depth_mask, dim=(1, 2, 3)) + + norm_factor = valid_sum / (valid_count + 1e-8) + while norm_factor.ndim < depth.ndim: + norm_factor.unsqueeze_(-1) + + norm_factor = norm_factor.clip(min=1e-8) + normalized_depth = depth / norm_factor + + output = ( + (normalized_depth, norm_factor.squeeze(-1).squeeze(-1).squeeze(-1)) + if return_norm_factor + else normalized_depth + ) + + return output + + +def normalize_pose_translations(pose_translations, return_norm_factor=False): + """ + Normalize the pose translations by the average norm of the non-zero pose translations. + Compatible with MapAnything's implementation. + + Args: + pose_translations (torch.Tensor): Pose translations tensor of size [B, V, 3]. + B is the batch size, V is the number of views. + return_norm_factor (bool): Whether to return the normalization factor. + + Returns: + normalized_pose_translations (torch.Tensor): Normalized pose translations tensor of size [B, V, 3]. + norm_factor (torch.Tensor): Norm factor tensor of size B (if return_norm_factor=True). + """ + assert pose_translations.ndim == 3 and pose_translations.shape[2] == 3 + + pose_translations_dis = pose_translations.norm(dim=-1) + non_zero_pose_translations_dis = pose_translations_dis > 0 + + sum_of_all_views_pose_translations = pose_translations_dis.sum(dim=1) + count_of_all_views_with_non_zero_pose_translations = ( + non_zero_pose_translations_dis.sum(dim=1) + ) + norm_factor = sum_of_all_views_pose_translations / ( + count_of_all_views_with_non_zero_pose_translations + 1e-8 + ) + + norm_factor = norm_factor.clip(min=1e-8) + normalized_pose_translations = pose_translations / norm_factor.unsqueeze( + -1 + ).unsqueeze(-1) + + output = ( + (normalized_pose_translations, norm_factor) + if return_norm_factor + else normalized_pose_translations + ) + + return output + + +def apply_log_to_norm(input_data): + """ + Normalize the input data and apply a logarithmic transformation based on the normalization factor. + Compatible with MapAnything's implementation. + + Args: + input_data (torch.Tensor): The input tensor to be normalized and transformed. + + Returns: + torch.Tensor: The transformed tensor after normalization and logarithmic scaling. + """ + org_d = input_data.norm(dim=-1, keepdim=True) + input_data = input_data / org_d.clip(min=1e-8) + input_data = input_data * torch.log1p(org_d) + return input_data + + +def closed_form_inverse_se3(se3, R=None, T=None): + """ + Compute the inverse of each 4x4 (or 3x4) SE3 matrix in a batch. + + If `R` and `T` are provided, they must correspond to the rotation and translation + components of `se3`. Otherwise, they will be extracted from `se3`. + + Args: + se3: Nx4x4 or Nx3x4 array or tensor of SE3 matrices. + R (optional): Nx3x3 array or tensor of rotation matrices. + T (optional): Nx3x1 array or tensor of translation vectors. + + Returns: + Inverted SE3 matrices with the same type and device as `se3`. + + Shapes: + se3: (N, 4, 4) + R: (N, 3, 3) + T: (N, 3, 1) + """ + + is_numpy = isinstance(se3, np.ndarray) + + if se3.shape[-2:] != (4, 4) and se3.shape[-2:] != (3, 4): + raise ValueError(f"se3 must be of shape (N,4,4), got {se3.shape}.") + + if R is None: + R = se3[:, :3, :3] + if T is None: + T = se3[:, :3, 3:] + + if is_numpy: + + R_transposed = np.transpose(R, (0, 2, 1)) + + top_right = -np.matmul(R_transposed, T) + inverted_matrix = np.tile(np.eye(4), (len(R), 1, 1)) + else: + R_transposed = R.transpose(1, 2) + top_right = -torch.bmm(R_transposed, T) + inverted_matrix = torch.eye(4, 4)[None].repeat(len(R), 1, 1) + inverted_matrix = inverted_matrix.to(R.dtype).to(R.device) + + inverted_matrix[:, :3, :3] = R_transposed + inverted_matrix[:, :3, 3:] = top_right + + return inverted_matrix diff --git a/longstream/utils/vendor/models/components/utils/load_fn.py b/longstream/utils/vendor/models/components/utils/load_fn.py new file mode 100644 index 0000000000000000000000000000000000000000..6138f3408550592e8dc55c612932991d0afd3d8a --- /dev/null +++ b/longstream/utils/vendor/models/components/utils/load_fn.py @@ -0,0 +1,133 @@ +import torch +from PIL import Image +from torchvision import transforms as TF + + +def load_and_preprocess_images(image_path_list, mode="crop"): + """ + A quick start function to load and preprocess images for model input. + This assumes the images should have the same shape for easier batching, but our model can also work well with different shapes. + + Args: + image_path_list (list): List of paths to image files + mode (str, optional): Preprocessing mode, either "crop" or "pad". + - "crop" (default): Sets width to 518px and center crops height if needed. + - "pad": Preserves all pixels by making the largest dimension 518px + and padding the smaller dimension to reach a square shape. + + Returns: + torch.Tensor: Batched tensor of preprocessed images with shape (N, 3, H, W) + + Raises: + ValueError: If the input list is empty or if mode is invalid + + Notes: + - Images with different dimensions will be padded with white (value=1.0) + - A warning is printed when images have different shapes + - When mode="crop": The function ensures width=518px while maintaining aspect ratio + and height is center-cropped if larger than 518px + - When mode="pad": The function ensures the largest dimension is 518px while maintaining aspect ratio + and the smaller dimension is padded to reach a square shape (518x518) + - Dimensions are adjusted to be divisible by 14 for compatibility with model requirements + """ + + if len(image_path_list) == 0: + raise ValueError("At least 1 image is required") + + if mode not in ["crop", "pad"]: + raise ValueError("Mode must be either 'crop' or 'pad'") + + images = [] + shapes = set() + to_tensor = TF.ToTensor() + target_size = 518 + + for image_path in image_path_list: + + img = Image.open(image_path) + + if img.mode == "RGBA": + + background = Image.new("RGBA", img.size, (255, 255, 255, 255)) + + img = Image.alpha_composite(background, img) + + img = img.convert("RGB") + + width, height = img.size + + if mode == "pad": + + if width >= height: + new_width = target_size + new_height = round(height * (new_width / width) / 14) * 14 + else: + new_height = target_size + new_width = round(width * (new_height / height) / 14) * 14 + else: + + new_width = target_size + + new_height = round(height * (new_width / width) / 14) * 14 + + img = img.resize((new_width, new_height), Image.Resampling.BICUBIC) + img = to_tensor(img) + + if mode == "crop" and new_height > target_size: + start_y = (new_height - target_size) // 2 + img = img[:, start_y : start_y + target_size, :] + + if mode == "pad": + h_padding = target_size - img.shape[1] + w_padding = target_size - img.shape[2] + + if h_padding > 0 or w_padding > 0: + pad_top = h_padding // 2 + pad_bottom = h_padding - pad_top + pad_left = w_padding // 2 + pad_right = w_padding - pad_left + + img = torch.nn.functional.pad( + img, + (pad_left, pad_right, pad_top, pad_bottom), + mode="constant", + value=1.0, + ) + + shapes.add((img.shape[1], img.shape[2])) + images.append(img) + + if len(shapes) > 1: + print(f"Warning: Found images with different shapes: {shapes}") + + max_height = max(shape[0] for shape in shapes) + max_width = max(shape[1] for shape in shapes) + + padded_images = [] + for img in images: + h_padding = max_height - img.shape[1] + w_padding = max_width - img.shape[2] + + if h_padding > 0 or w_padding > 0: + pad_top = h_padding // 2 + pad_bottom = h_padding - pad_top + pad_left = w_padding // 2 + pad_right = w_padding - pad_left + + img = torch.nn.functional.pad( + img, + (pad_left, pad_right, pad_top, pad_bottom), + mode="constant", + value=1.0, + ) + padded_images.append(img) + images = padded_images + + images = torch.stack(images) + + if len(image_path_list) == 1: + + if images.dim() == 3: + images = images.unsqueeze(0) + + return images diff --git a/longstream/utils/vendor/models/components/utils/pose_enc.py b/longstream/utils/vendor/models/components/utils/pose_enc.py new file mode 100644 index 0000000000000000000000000000000000000000..0ba32a329197d6ff140668e284e84839ea80de24 --- /dev/null +++ b/longstream/utils/vendor/models/components/utils/pose_enc.py @@ -0,0 +1,151 @@ +import torch +from .rotation import quat_to_mat, mat_to_quat + + +def extri_intri_to_pose_encoding( + extrinsics, + intrinsics, + image_size_hw=None, + pose_encoding_type="absT_quaR_FoV", + gt_pts3d_scale=None, +): + """Convert camera extrinsics and intrinsics to a compact pose encoding. + + This function transforms camera parameters into a unified pose encoding format, + which can be used for various downstream tasks like pose prediction or representation. + + Args: + extrinsics (torch.Tensor): Camera extrinsic parameters with shape BxSx3x4, + where B is batch size and S is sequence length. + In OpenCV coordinate system (x-right, y-down, z-forward), representing camera from world transformation. + The format is [R|t] where R is a 3x3 rotation matrix and t is a 3x1 translation vector. + intrinsics (torch.Tensor): Camera intrinsic parameters with shape BxSx3x3. + Defined in pixels, with format: + [[fx, 0, cx], + [0, fy, cy], + [0, 0, 1]] + where fx, fy are focal lengths and (cx, cy) is the principal point + image_size_hw (tuple): Tuple of (height, width) of the image in pixels. + Required for computing field of view values. For example: (256, 512). + pose_encoding_type (str): Type of pose encoding to use. Currently only + supports "absT_quaR_FoV" (absolute translation, quaternion rotation, field of view). + + Returns: + torch.Tensor: Encoded camera pose parameters with shape BxSx9. + For "absT_quaR_FoV" type, the 9 dimensions are: + - [:3] = absolute translation vector T (3D) + - [3:7] = rotation as quaternion quat (4D) + - [7:] = field of view (2D) + """ + + if pose_encoding_type == "absT_quaR_FoV": + R = extrinsics[:, :, :3, :3] + T = extrinsics[:, :, :3, 3] + + quat = mat_to_quat(R) + + H, W = image_size_hw + fov_h = 2 * torch.atan((H / 2) / intrinsics[..., 1, 1]) + fov_w = 2 * torch.atan((W / 2) / intrinsics[..., 0, 0]) + pose_encoding = torch.cat( + [T, quat, fov_h[..., None], fov_w[..., None]], dim=-1 + ).float() + elif pose_encoding_type == "relT_quaR_FoV": + R = extrinsics[:, :, :3, :3] + T = extrinsics[:, :, :3, 3] + + T = T / gt_pts3d_scale.view(-1, 1, 1) + + quat = mat_to_quat(R) + + H, W = image_size_hw + fov_h = 2 * torch.atan((H / 2) / intrinsics[..., 1, 1]) + fov_w = 2 * torch.atan((W / 2) / intrinsics[..., 0, 0]) + pose_encoding = torch.cat( + [T, quat, fov_h[..., None], fov_w[..., None]], dim=-1 + ).float() + else: + raise NotImplementedError + + return pose_encoding + + +def pose_encoding_to_extri_intri( + pose_encoding, + image_size_hw=None, + pose_encoding_type="absT_quaR_FoV", + build_intrinsics=True, +): + """Convert a pose encoding back to camera extrinsics and intrinsics. + + This function performs the inverse operation of extri_intri_to_pose_encoding, + reconstructing the full camera parameters from the compact encoding. + + Args: + pose_encoding (torch.Tensor): Encoded camera pose parameters with shape BxSx9, + where B is batch size and S is sequence length. + For "absT_quaR_FoV" type, the 9 dimensions are: + - [:3] = absolute translation vector T (3D) + - [3:7] = rotation as quaternion quat (4D) + - [7:] = field of view (2D) + image_size_hw (tuple): Tuple of (height, width) of the image in pixels. + Required for reconstructing intrinsics from field of view values. + For example: (256, 512). + pose_encoding_type (str): Type of pose encoding used. Currently only + supports "absT_quaR_FoV" (absolute translation, quaternion rotation, field of view). + build_intrinsics (bool): Whether to reconstruct the intrinsics matrix. + If False, only extrinsics are returned and intrinsics will be None. + + Returns: + tuple: (extrinsics, intrinsics) + - extrinsics (torch.Tensor): Camera extrinsic parameters with shape BxSx3x4. + In OpenCV coordinate system (x-right, y-down, z-forward), representing camera from world + transformation. The format is [R|t] where R is a 3x3 rotation matrix and t is + a 3x1 translation vector. + - intrinsics (torch.Tensor or None): Camera intrinsic parameters with shape BxSx3x3, + or None if build_intrinsics is False. Defined in pixels, with format: + [[fx, 0, cx], + [0, fy, cy], + [0, 0, 1]] + where fx, fy are focal lengths and (cx, cy) is the principal point, + assumed to be at the center of the image (W/2, H/2). + """ + + intrinsics = None + + if pose_encoding_type == "absT_quaR_FoV": + T = pose_encoding[..., :3] + quat = pose_encoding[..., 3:7] + fov_h = pose_encoding[..., 7] + fov_w = pose_encoding[..., 8] + + R = quat_to_mat(quat) + extrinsics = torch.cat([R, T[..., None]], dim=-1) + + if build_intrinsics: + H, W = image_size_hw + f_default = max(float(H), float(W)) / 2.0 + default_fov_h = 2 * torch.atan( + torch.tensor(float(H) / (2.0 * f_default), device=pose_encoding.device) + ) + default_fov_w = 2 * torch.atan( + torch.tensor(float(W) / (2.0 * f_default), device=pose_encoding.device) + ) + invalid_h = (~torch.isfinite(fov_h)) | (torch.abs(fov_h) < 1e-6) + invalid_w = (~torch.isfinite(fov_w)) | (torch.abs(fov_w) < 1e-6) + fov_h = torch.where(invalid_h, default_fov_h, fov_h) + fov_w = torch.where(invalid_w, default_fov_w, fov_w) + fy = (H / 2.0) / torch.tan(fov_h / 2.0) + fx = (W / 2.0) / torch.tan(fov_w / 2.0) + intrinsics = torch.zeros( + pose_encoding.shape[:2] + (3, 3), device=pose_encoding.device + ) + intrinsics[..., 0, 0] = fx + intrinsics[..., 1, 1] = fy + intrinsics[..., 0, 2] = W / 2 + intrinsics[..., 1, 2] = H / 2 + intrinsics[..., 2, 2] = 1.0 + else: + raise NotImplementedError + + return extrinsics, intrinsics diff --git a/longstream/utils/vendor/models/components/utils/prope.py b/longstream/utils/vendor/models/components/utils/prope.py new file mode 100644 index 0000000000000000000000000000000000000000..fc4a3e0375b1b264f24842783aa3d5daab10c600 --- /dev/null +++ b/longstream/utils/vendor/models/components/utils/prope.py @@ -0,0 +1,392 @@ +from functools import partial +from typing import Callable, Optional, Tuple, List + +import torch +import torch.nn.functional as F + + +class PropeDotProductAttention(torch.nn.Module): + """PRoPE attention with precomputed RoPE coefficients.""" + + coeffs_x_0: torch.Tensor + coeffs_x_1: torch.Tensor + coeffs_y_0: torch.Tensor + coeffs_y_1: torch.Tensor + + def __init__( + self, + head_dim: int, + patches_x: int, + patches_y: int, + image_width: int, + image_height: int, + freq_base: float = 100.0, + freq_scale: float = 1.0, + ): + super().__init__() + self.head_dim = head_dim + self.patches_x = patches_x + self.patches_y = patches_y + self.image_width = image_width + self.image_height = image_height + + coeffs_x: Tuple[torch.Tensor, torch.Tensor] = _rope_precompute_coeffs( + torch.tile(torch.arange(patches_x), (patches_y,)), + freq_base=freq_base, + freq_scale=freq_scale, + feat_dim=head_dim // 4, + ) + coeffs_y: Tuple[torch.Tensor, torch.Tensor] = _rope_precompute_coeffs( + torch.repeat_interleave(torch.arange(patches_y), patches_x), + freq_base=freq_base, + freq_scale=freq_scale, + feat_dim=head_dim // 4, + ) + + self.register_buffer("coeffs_x_0", coeffs_x[0], persistent=False) + self.register_buffer("coeffs_x_1", coeffs_x[1], persistent=False) + self.register_buffer("coeffs_y_0", coeffs_y[0], persistent=False) + self.register_buffer("coeffs_y_1", coeffs_y[1], persistent=False) + + def load_state_dict(self, state_dict, strict=True): + + state_dict.pop("coeffs_x_0", None) + state_dict.pop("coeffs_x_1", None) + state_dict.pop("coeffs_y_0", None) + state_dict.pop("coeffs_y_1", None) + super().load_state_dict(state_dict, strict) + + def forward( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + viewmats: torch.Tensor, + Ks: Optional[torch.Tensor], + **kwargs, + ) -> torch.Tensor: + return prope_dot_product_attention( + q, + k, + v, + viewmats=viewmats, + Ks=Ks, + patches_x=self.patches_x, + patches_y=self.patches_y, + image_width=self.image_width, + image_height=self.image_height, + coeffs_x=(self.coeffs_x_0, self.coeffs_x_1), + coeffs_y=(self.coeffs_y_0, self.coeffs_y_1), + **kwargs, + ) + + def _precompute_and_cache_apply_fns( + self, viewmats: torch.Tensor, Ks: Optional[torch.Tensor] + ): + (batch, cameras, _, _) = viewmats.shape + assert viewmats.shape == (batch, cameras, 4, 4) + assert Ks is None or Ks.shape == (batch, cameras, 3, 3) + self.cameras = cameras + + self.apply_fn_q, self.apply_fn_kv, self.apply_fn_o = _prepare_apply_fns( + head_dim=self.head_dim, + viewmats=viewmats, + Ks=Ks, + patches_x=self.patches_x, + patches_y=self.patches_y, + image_width=self.image_width, + image_height=self.image_height, + coeffs_x=(self.coeffs_x_0, self.coeffs_x_1), + coeffs_y=(self.coeffs_y_0, self.coeffs_y_1), + ) + + def _apply_to_q(self, q: torch.Tensor) -> torch.Tensor: + (batch, num_heads, seqlen, head_dim) = q.shape + assert seqlen == self.cameras * self.patches_x * self.patches_y + assert head_dim == self.head_dim + assert q.shape == (batch, num_heads, seqlen, head_dim) + assert self.apply_fn_q is not None + return self.apply_fn_q(q) + + def _apply_to_kv(self, kv: torch.Tensor) -> torch.Tensor: + (batch, num_heads, seqlen, head_dim) = kv.shape + assert seqlen == self.cameras * self.patches_x * self.patches_y + assert head_dim == self.head_dim + assert kv.shape == (batch, num_heads, seqlen, head_dim) + assert self.apply_fn_kv is not None + return self.apply_fn_kv(kv) + + def _apply_to_o(self, o: torch.Tensor) -> torch.Tensor: + (batch, num_heads, seqlen, head_dim) = o.shape + assert seqlen == self.cameras * self.patches_x * self.patches_y + assert head_dim == self.head_dim + assert o.shape == (batch, num_heads, seqlen, head_dim) + assert self.apply_fn_o is not None + return self.apply_fn_o(o) + + +def prope_dot_product_attention( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + *, + viewmats: torch.Tensor, + Ks: Optional[torch.Tensor], + patches_x: int, + patches_y: int, + image_width: int, + image_height: int, + coeffs_x: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + coeffs_y: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + **kwargs, +) -> torch.Tensor: + """Similar to torch.nn.functional.scaled_dot_product_attention, but applies PRoPE-style + positional encoding. + + Currently, we assume that the sequence length is equal to: + + cameras * patches_x * patches_y + + And token ordering allows the `(seqlen,)` axis to be reshaped into + `(cameras, patches_x, patches_y)`. + """ + + (batch, num_heads, seqlen, head_dim) = q.shape + cameras = viewmats.shape[1] + assert q.shape == k.shape == v.shape + assert viewmats.shape == (batch, cameras, 4, 4) + assert Ks is None or Ks.shape == (batch, cameras, 3, 3) + assert seqlen == cameras * patches_x * patches_y + + apply_fn_q, apply_fn_kv, apply_fn_o = _prepare_apply_fns( + head_dim=head_dim, + viewmats=viewmats, + Ks=Ks, + patches_x=patches_x, + patches_y=patches_y, + image_width=image_width, + image_height=image_height, + coeffs_x=coeffs_x, + coeffs_y=coeffs_y, + ) + + out = F.scaled_dot_product_attention( + query=apply_fn_q(q), + key=apply_fn_kv(k), + value=apply_fn_kv(v), + **kwargs, + ) + out = apply_fn_o(out) + assert out.shape == (batch, num_heads, seqlen, head_dim) + return out + + +def _prepare_apply_fns( + head_dim: int, + viewmats: torch.Tensor, + Ks: Optional[torch.Tensor], + patches_x: int, + patches_y: int, + image_width: int, + image_height: int, + coeffs_x: Optional[torch.Tensor] = None, + coeffs_y: Optional[torch.Tensor] = None, +) -> Tuple[ + Callable[[torch.Tensor], torch.Tensor], + Callable[[torch.Tensor], torch.Tensor], + Callable[[torch.Tensor], torch.Tensor], +]: + """Prepare transforms for PRoPE-style positional encoding.""" + device = viewmats.device + (batch, cameras, _, _) = viewmats.shape + + if Ks is not None: + Ks_norm = torch.zeros_like(Ks) + Ks_norm[..., 0, 0] = Ks[..., 0, 0] / image_width + Ks_norm[..., 1, 1] = Ks[..., 1, 1] / image_height + Ks_norm[..., 0, 2] = Ks[..., 0, 2] / image_width - 0.5 + Ks_norm[..., 1, 2] = Ks[..., 1, 2] / image_height - 0.5 + Ks_norm[..., 2, 2] = 1.0 + del Ks + + P = torch.einsum("...ij,...jk->...ik", _lift_K(Ks_norm), viewmats) + P_T = P.transpose(-1, -2) + P_inv = torch.einsum( + "...ij,...jk->...ik", + _invert_SE3(viewmats), + _lift_K(_invert_K(Ks_norm)), + ) + + else: + + P = viewmats + P_T = P.transpose(-1, -2) + P_inv = _invert_SE3(viewmats) + + assert P.shape == P_inv.shape == (batch, cameras, 4, 4) + + if coeffs_x is None: + coeffs_x = _rope_precompute_coeffs( + torch.tile(torch.arange(patches_x, device=device), (patches_y * cameras,)), + freq_base=100.0, + freq_scale=1.0, + feat_dim=head_dim // 4, + ) + if coeffs_y is None: + coeffs_y = _rope_precompute_coeffs( + torch.tile( + torch.repeat_interleave( + torch.arange(patches_y, device=device), patches_x + ), + (cameras,), + ), + freq_base=100.0, + freq_scale=1.0, + feat_dim=head_dim // 4, + ) + + assert head_dim % 4 == 0 + transforms_q = [ + (partial(_apply_tiled_projmat, matrix=P_T), head_dim // 2), + (partial(_rope_apply_coeffs, coeffs=coeffs_x), head_dim // 4), + (partial(_rope_apply_coeffs, coeffs=coeffs_y), head_dim // 4), + ] + transforms_kv = [ + (partial(_apply_tiled_projmat, matrix=P_inv), head_dim // 2), + (partial(_rope_apply_coeffs, coeffs=coeffs_x), head_dim // 4), + (partial(_rope_apply_coeffs, coeffs=coeffs_y), head_dim // 4), + ] + transforms_o = [ + (partial(_apply_tiled_projmat, matrix=P), head_dim // 2), + (partial(_rope_apply_coeffs, coeffs=coeffs_x, inverse=True), head_dim // 4), + (partial(_rope_apply_coeffs, coeffs=coeffs_y, inverse=True), head_dim // 4), + ] + + apply_fn_q = partial(_apply_block_diagonal, func_size_pairs=transforms_q) + apply_fn_kv = partial(_apply_block_diagonal, func_size_pairs=transforms_kv) + apply_fn_o = partial(_apply_block_diagonal, func_size_pairs=transforms_o) + return apply_fn_q, apply_fn_kv, apply_fn_o + + +def _apply_tiled_projmat( + feats: torch.Tensor, + matrix: torch.Tensor, +) -> torch.Tensor: + """Apply projection matrix to features.""" + + (batch, num_heads, seqlen, feat_dim) = feats.shape + cameras = matrix.shape[1] + assert seqlen > cameras and seqlen % cameras == 0 + D = matrix.shape[-1] + assert matrix.shape == (batch, cameras, D, D) + assert feat_dim % D == 0 + return torch.einsum( + "bcij,bncpkj->bncpki", + matrix, + feats.reshape((batch, num_heads, cameras, -1, feat_dim // D, D)), + ).reshape(feats.shape) + + +def _rope_precompute_coeffs( + positions: torch.Tensor, + freq_base: float, + freq_scale: float, + feat_dim: int, +) -> Tuple[torch.Tensor, torch.Tensor]: + """Precompute RoPE coefficients.""" + assert len(positions.shape) == 1 + assert feat_dim % 2 == 0 + num_freqs = feat_dim // 2 + freqs = freq_scale * ( + freq_base + ** ( + -torch.arange(num_freqs, device=positions.device)[None, None, None, :] + / num_freqs + ) + ) + angles = positions[None, None, :, None] * freqs + + assert angles.shape == (1, 1, positions.shape[0], num_freqs) + return torch.cos(angles), torch.sin(angles) + + +def _rope_apply_coeffs( + feats: torch.Tensor, + coeffs: Tuple[torch.Tensor, torch.Tensor], + inverse: bool = False, +) -> torch.Tensor: + """Apply RoPE coefficients to features. We adopt a 'split' ordering + convention. (in contrast to 'interleaved')""" + cos, sin = coeffs + + if cos.shape[2] != feats.shape[2]: + n_repeats = feats.shape[2] // cos.shape[2] + cos = cos.repeat(1, 1, n_repeats, 1) + sin = sin.repeat(1, 1, n_repeats, 1) + assert len(feats.shape) == len(cos.shape) == len(sin.shape) == 4 + assert cos.shape[-1] == sin.shape[-1] == feats.shape[-1] // 2 + x_in = feats[..., : feats.shape[-1] // 2] + y_in = feats[..., feats.shape[-1] // 2 :] + return torch.cat( + ( + [cos * x_in + sin * y_in, -sin * x_in + cos * y_in] + if not inverse + else [cos * x_in - sin * y_in, sin * x_in + cos * y_in] + ), + dim=-1, + ) + + +def _apply_block_diagonal( + feats: torch.Tensor, + func_size_pairs: List[Tuple[Callable[[torch.Tensor], torch.Tensor], int]], +) -> torch.Tensor: + """Apply a block-diagonal function to an input array. + + Each function is specified as a tuple with form: + + ((Tensor) -> Tensor, int) + + Where the integer is the size of the input to the function. + """ + funcs, block_sizes = zip(*func_size_pairs) + assert feats.shape[-1] == sum(block_sizes) + x_blocks = torch.split(feats, block_sizes, dim=-1) + out = torch.cat( + [f(x_block) for f, x_block in zip(funcs, x_blocks)], + dim=-1, + ) + assert out.shape == feats.shape, "Input/output shapes should match." + return out + + +def _invert_SE3(transforms: torch.Tensor) -> torch.Tensor: + """Invert a 4x4 SE(3) matrix.""" + assert transforms.shape[-2:] == (4, 4) + Rinv = transforms[..., :3, :3].transpose(-1, -2) + out = torch.zeros_like(transforms) + out[..., :3, :3] = Rinv + out[..., :3, 3] = -torch.einsum("...ij,...j->...i", Rinv, transforms[..., :3, 3]) + out[..., 3, 3] = 1.0 + return out + + +def _lift_K(Ks: torch.Tensor) -> torch.Tensor: + """Lift 3x3 matrices to homogeneous 4x4 matrices.""" + assert Ks.shape[-2:] == (3, 3) + out = torch.zeros(Ks.shape[:-2] + (4, 4), device=Ks.device) + out[..., :3, :3] = Ks + out[..., 3, 3] = 1.0 + return out + + +def _invert_K(Ks: torch.Tensor) -> torch.Tensor: + """Invert 3x3 intrinsics matrices. Assumes no skew.""" + assert Ks.shape[-2:] == (3, 3) + out = torch.zeros_like(Ks) + out[..., 0, 0] = 1.0 / Ks[..., 0, 0] + out[..., 1, 1] = 1.0 / Ks[..., 1, 1] + out[..., 0, 2] = -Ks[..., 0, 2] / Ks[..., 0, 0] + out[..., 1, 2] = -Ks[..., 1, 2] / Ks[..., 1, 1] + out[..., 2, 2] = 1.0 + return out diff --git a/longstream/utils/vendor/models/components/utils/rotation.py b/longstream/utils/vendor/models/components/utils/rotation.py new file mode 100644 index 0000000000000000000000000000000000000000..ac1770430d3e859389c0832021f53eb8292b54a0 --- /dev/null +++ b/longstream/utils/vendor/models/components/utils/rotation.py @@ -0,0 +1,120 @@ +import torch +import numpy as np +import torch.nn.functional as F + + +def quat_to_mat(quaternions: torch.Tensor) -> torch.Tensor: + """ + Quaternion Order: XYZW or say ijkr, scalar-last + + Convert rotations given as quaternions to rotation matrices. + Args: + quaternions: quaternions with real part last, + as tensor of shape (..., 4). + + Returns: + Rotation matrices as tensor of shape (..., 3, 3). + """ + i, j, k, r = torch.unbind(quaternions, -1) + + two_s = 2.0 / (quaternions * quaternions).sum(-1) + + o = torch.stack( + ( + 1 - two_s * (j * j + k * k), + two_s * (i * j - k * r), + two_s * (i * k + j * r), + two_s * (i * j + k * r), + 1 - two_s * (i * i + k * k), + two_s * (j * k - i * r), + two_s * (i * k - j * r), + two_s * (j * k + i * r), + 1 - two_s * (i * i + j * j), + ), + -1, + ) + return o.reshape(quaternions.shape[:-1] + (3, 3)) + + +def mat_to_quat(matrix: torch.Tensor) -> torch.Tensor: + """ + Convert rotations given as rotation matrices to quaternions. + + Args: + matrix: Rotation matrices as tensor of shape (..., 3, 3). + + Returns: + quaternions with real part last, as tensor of shape (..., 4). + Quaternion Order: XYZW or say ijkr, scalar-last + """ + if matrix.size(-1) != 3 or matrix.size(-2) != 3: + raise ValueError(f"Invalid rotation matrix shape {matrix.shape}.") + + batch_dim = matrix.shape[:-2] + m00, m01, m02, m10, m11, m12, m20, m21, m22 = torch.unbind( + matrix.reshape(batch_dim + (9,)), dim=-1 + ) + + q_abs = _sqrt_positive_part( + torch.stack( + [ + 1.0 + m00 + m11 + m22, + 1.0 + m00 - m11 - m22, + 1.0 - m00 + m11 - m22, + 1.0 - m00 - m11 + m22, + ], + dim=-1, + ) + ) + + quat_by_rijk = torch.stack( + [ + torch.stack([q_abs[..., 0] ** 2, m21 - m12, m02 - m20, m10 - m01], dim=-1), + torch.stack([m21 - m12, q_abs[..., 1] ** 2, m10 + m01, m02 + m20], dim=-1), + torch.stack([m02 - m20, m10 + m01, q_abs[..., 2] ** 2, m12 + m21], dim=-1), + torch.stack([m10 - m01, m20 + m02, m21 + m12, q_abs[..., 3] ** 2], dim=-1), + ], + dim=-2, + ) + + flr = torch.tensor(0.1).to(dtype=q_abs.dtype, device=q_abs.device) + quat_candidates = quat_by_rijk / (2.0 * q_abs[..., None].max(flr)) + + out = quat_candidates[ + F.one_hot(q_abs.argmax(dim=-1), num_classes=4) > 0.5, : + ].reshape(batch_dim + (4,)) + + out = out[..., [1, 2, 3, 0]] + + out = standardize_quaternion(out) + + return out + + +def _sqrt_positive_part(x: torch.Tensor) -> torch.Tensor: + """ + Returns torch.sqrt(torch.max(0, x)) + but with a zero subgradient where x is 0. + """ + ret = torch.zeros_like(x) + positive_mask = x > 0 + if torch.is_grad_enabled(): + ret[positive_mask] = torch.sqrt(x[positive_mask]) + else: + ret = torch.where(positive_mask, torch.sqrt(x), ret) + return ret + + +def standardize_quaternion(quaternions: torch.Tensor) -> torch.Tensor: + """ + Convert a unit quaternion to a standard form: one in which the real + part is non negative. + + Args: + quaternions: Quaternions with real part last, + as tensor of shape (..., 4). + + Returns: + Standardized quaternions as tensor of shape (..., 4). + """ + return torch.where(quaternions[..., 3:4] < 0, -quaternions, quaternions) diff --git a/longstream/utils/vendor/models/components/utils/se3.py b/longstream/utils/vendor/models/components/utils/se3.py new file mode 100644 index 0000000000000000000000000000000000000000..4f4f419aad7db5e235e939df36acce0d4d9859b6 --- /dev/null +++ b/longstream/utils/vendor/models/components/utils/se3.py @@ -0,0 +1,138 @@ +from typing import Tuple + +import torch + + +def _batch_eye( + batch_shape: Tuple[int, ...], dtype: torch.dtype, device: torch.device +) -> torch.Tensor: + eye = torch.eye(3, dtype=dtype, device=device) + view_shape = (1,) * len(batch_shape) + (3, 3) + return eye.view(view_shape).expand(batch_shape + (3, 3)).clone() + + +def _skew(w: torch.Tensor) -> torch.Tensor: + wx, wy, wz = w.unbind(dim=-1) + zeros = torch.zeros_like(wx) + return torch.stack( + ( + torch.stack((zeros, -wz, wy), dim=-1), + torch.stack((wz, zeros, -wx), dim=-1), + torch.stack((-wy, wx, zeros), dim=-1), + ), + dim=-2, + ) + + +def _taylor_A(theta: torch.Tensor) -> torch.Tensor: + theta2 = theta * theta + return torch.where( + theta2 > 1e-12, + torch.sin(theta) / theta, + 1.0 - theta2 / 6.0 + theta2 * theta2 / 120.0, + ) + + +def _taylor_B(theta: torch.Tensor) -> torch.Tensor: + theta2 = theta * theta + return torch.where( + theta2 > 1e-12, + (1.0 - torch.cos(theta)) / theta2, + 0.5 - theta2 / 24.0 + theta2 * theta2 / 720.0, + ) + + +def _taylor_C(theta: torch.Tensor) -> torch.Tensor: + theta2 = theta * theta + return torch.where( + theta2 > 1e-12, + (theta - torch.sin(theta)) / (theta2 * theta), + 1.0 / 6.0 - theta2 / 120.0 + theta2 * theta2 / 5040.0, + ) + + +def se3_exp(xi: torch.Tensor) -> torch.Tensor: + v = xi[..., :3] + w = xi[..., 3:] + theta = torch.linalg.norm(w, dim=-1, keepdim=True) + batch_shape = xi.shape[:-1] + + theta_safe = torch.where(theta > 1e-9, theta, torch.ones_like(theta)) + w_hat = torch.where(theta > 1e-9, w / theta_safe, torch.zeros_like(w)) + W = _skew(w_hat) + + A = _taylor_A(theta)[..., None] + B = _taylor_B(theta)[..., None] + C = _taylor_C(theta)[..., None] + + eye = _batch_eye(batch_shape, dtype=xi.dtype, device=xi.device) + W2 = W @ W + R = eye + A * W + B * W2 + V = eye + B * W + C * W2 + t = (V @ v.unsqueeze(-1)).squeeze(-1) + + T = torch.zeros(*batch_shape, 4, 4, dtype=xi.dtype, device=xi.device) + T[..., :3, :3] = R + T[..., :3, 3] = t + T[..., 3, 3] = 1.0 + return T + + +def _rotation_log(R: torch.Tensor) -> torch.Tensor: + from longstream.utils.vendor.dust3r.utils.camera import matrix_to_quaternion + + q = matrix_to_quaternion(R) + qw = q[..., 0] + q_xyz = q[..., 1:] + sin_half = torch.linalg.norm(q_xyz, dim=-1, keepdim=True) + theta = 2.0 * torch.atan2(sin_half.squeeze(-1), qw.clamp(min=1e-9)).unsqueeze(-1) + + mask = sin_half > 1e-7 + axis = torch.zeros_like(q_xyz) + axis[mask] = q_xyz[mask] / sin_half[mask] + + w = torch.zeros_like(q_xyz) + w[mask] = axis[mask] * theta[mask] + w[~mask] = 2.0 * q_xyz[~mask] + return w + + +def se3_log(T: torch.Tensor) -> torch.Tensor: + R = T[..., :3, :3] + t = T[..., :3, 3] + w = _rotation_log(R) + theta = torch.linalg.norm(w, dim=-1, keepdim=True) + batch_shape = w.shape[:-1] + + theta_safe = torch.where(theta > 1e-9, theta, torch.ones_like(theta)) + w_hat = torch.where(theta > 1e-9, w / theta_safe, torch.zeros_like(w)) + W = _skew(w_hat) + + B = _taylor_B(theta)[..., None] + C = _taylor_C(theta)[..., None] + eye = _batch_eye(batch_shape, dtype=T.dtype, device=T.device) + V = eye + B * W + C * (W @ W) + + v = torch.linalg.solve(V, t.unsqueeze(-1)).squeeze(-1) + return torch.cat((v, w), dim=-1) + + +def compose(T1: torch.Tensor, T2: torch.Tensor) -> torch.Tensor: + return T1 @ T2 + + +def inverse(T: torch.Tensor) -> torch.Tensor: + R = T[..., :3, :3] + t = T[..., :3, 3] + Rt = R.transpose(-1, -2) + out = torch.zeros_like(T) + out[..., :3, :3] = Rt + out[..., :3, 3] = -(Rt @ t.unsqueeze(-1)).squeeze(-1) + out[..., 3, 3] = 1.0 + return out + + +def identity(batch_shape: Tuple[int, ...], device=None, dtype=None) -> torch.Tensor: + eye4 = torch.eye(4, dtype=dtype or torch.float32, device=device) + view_shape = (1,) * len(batch_shape) + (4, 4) + return eye4.view(view_shape).expand(batch_shape + (4, 4)).clone() diff --git a/longstream/utils/vendor/models/components/utils/visual_track.py b/longstream/utils/vendor/models/components/utils/visual_track.py new file mode 100644 index 0000000000000000000000000000000000000000..371bfb440de692fe29295d64891a88d8f2abf0ac --- /dev/null +++ b/longstream/utils/vendor/models/components/utils/visual_track.py @@ -0,0 +1,217 @@ +import cv2 +import torch +import numpy as np +import os + + +def color_from_xy(x, y, W, H, cmap_name="hsv"): + """ + Map (x, y) -> color in (R, G, B). + 1) Normalize x,y to [0,1]. + 2) Combine them into a single scalar c in [0,1]. + 3) Use matplotlib's colormap to convert c -> (R,G,B). + + You can customize step 2, e.g., c = (x + y)/2, or some function of (x, y). + """ + import matplotlib.cm + import matplotlib.colors + + x_norm = x / max(W - 1, 1) + y_norm = y / max(H - 1, 1) + + c = (x_norm + y_norm) / 2.0 + + cmap = matplotlib.cm.get_cmap(cmap_name) + + rgba = cmap(c) + r, g, b = rgba[0], rgba[1], rgba[2] + return (r, g, b) + + +def get_track_colors_by_position( + tracks_b, vis_mask_b=None, image_width=None, image_height=None, cmap_name="hsv" +): + """ + Given all tracks in one sample (b), compute a (N,3) array of RGB color values + in [0,255]. The color is determined by the (x,y) position in the first + visible frame for each track. + + Args: + tracks_b: Tensor of shape (S, N, 2). (x,y) for each track in each frame. + vis_mask_b: (S, N) boolean mask; if None, assume all are visible. + image_width, image_height: used for normalizing (x, y). + cmap_name: for matplotlib (e.g., 'hsv', 'rainbow', 'jet'). + + Returns: + track_colors: np.ndarray of shape (N, 3), each row is (R,G,B) in [0,255]. + """ + S, N, _ = tracks_b.shape + track_colors = np.zeros((N, 3), dtype=np.uint8) + + if vis_mask_b is None: + + vis_mask_b = torch.ones(S, N, dtype=torch.bool, device=tracks_b.device) + + for i in range(N): + + visible_frames = torch.where(vis_mask_b[:, i])[0] + if len(visible_frames) == 0: + + track_colors[i] = (0, 0, 0) + continue + + first_s = int(visible_frames[0].item()) + + x, y = tracks_b[first_s, i].tolist() + + r, g, b = color_from_xy( + x, y, W=image_width, H=image_height, cmap_name=cmap_name + ) + + r, g, b = int(r * 255), int(g * 255), int(b * 255) + track_colors[i] = (r, g, b) + + return track_colors + + +def visualize_tracks_on_images( + images, + tracks, + track_vis_mask=None, + out_dir="track_visuals_concat_by_xy", + image_format="CHW", + normalize_mode="[0,1]", + cmap_name="hsv", + frames_per_row=4, + save_grid=True, +): + """ + Visualizes frames in a grid layout with specified frames per row. + Each track's color is determined by its (x,y) position + in the first visible frame (or frame 0 if always visible). + Finally convert the BGR result to RGB before saving. + Also saves each individual frame as a separate PNG file. + + Args: + images: torch.Tensor (S, 3, H, W) if CHW or (S, H, W, 3) if HWC. + tracks: torch.Tensor (S, N, 2), last dim = (x, y). + track_vis_mask: torch.Tensor (S, N) or None. + out_dir: folder to save visualizations. + image_format: "CHW" or "HWC". + normalize_mode: "[0,1]", "[-1,1]", or None for direct raw -> 0..255 + cmap_name: a matplotlib colormap name for color_from_xy. + frames_per_row: number of frames to display in each row of the grid. + save_grid: whether to save all frames in one grid image. + + Returns: + None (saves images in out_dir). + """ + + if len(tracks.shape) == 4: + tracks = tracks.squeeze(0) + images = images.squeeze(0) + if track_vis_mask is not None: + track_vis_mask = track_vis_mask.squeeze(0) + + import matplotlib + + matplotlib.use("Agg") + + os.makedirs(out_dir, exist_ok=True) + + S = images.shape[0] + _, N, _ = tracks.shape + + images = images.cpu().clone() + tracks = tracks.cpu().clone() + if track_vis_mask is not None: + track_vis_mask = track_vis_mask.cpu().clone() + + if image_format == "CHW": + + H, W = images.shape[2], images.shape[3] + else: + + H, W = images.shape[1], images.shape[2] + + track_colors_rgb = get_track_colors_by_position( + tracks, + vis_mask_b=track_vis_mask if track_vis_mask is not None else None, + image_width=W, + image_height=H, + cmap_name=cmap_name, + ) + + frame_images = [] + + for s in range(S): + + img = images[s] + + if image_format == "CHW": + img = img.permute(1, 2, 0) + + img = img.numpy().astype(np.float32) + + if normalize_mode == "[0,1]": + img = np.clip(img, 0, 1) * 255.0 + elif normalize_mode == "[-1,1]": + img = (img + 1.0) * 0.5 * 255.0 + img = np.clip(img, 0, 255.0) + + img = img.astype(np.uint8) + + img_bgr = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) + + cur_tracks = tracks[s] + if track_vis_mask is not None: + valid_indices = torch.where(track_vis_mask[s])[0] + else: + valid_indices = range(N) + + cur_tracks_np = cur_tracks.numpy() + for i in valid_indices: + x, y = cur_tracks_np[i] + pt = (int(round(x)), int(round(y))) + + R, G, B = track_colors_rgb[i] + color_bgr = (int(B), int(G), int(R)) + cv2.circle(img_bgr, pt, radius=3, color=color_bgr, thickness=-1) + + img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB) + + frame_path = os.path.join(out_dir, f"frame_{s:04d}.png") + + frame_bgr = cv2.cvtColor(img_rgb, cv2.COLOR_RGB2BGR) + cv2.imwrite(frame_path, frame_bgr) + + frame_images.append(img_rgb) + + if save_grid: + + num_rows = (S + frames_per_row - 1) // frames_per_row + + grid_img = None + for row in range(num_rows): + start_idx = row * frames_per_row + end_idx = min(start_idx + frames_per_row, S) + + row_img = np.concatenate(frame_images[start_idx:end_idx], axis=1) + + if end_idx - start_idx < frames_per_row: + padding_width = (frames_per_row - (end_idx - start_idx)) * W + padding = np.zeros((H, padding_width, 3), dtype=np.uint8) + row_img = np.concatenate([row_img, padding], axis=1) + + if grid_img is None: + grid_img = row_img + else: + grid_img = np.concatenate([grid_img, row_img], axis=0) + + out_path = os.path.join(out_dir, "tracks_grid.png") + + grid_img_bgr = cv2.cvtColor(grid_img, cv2.COLOR_RGB2BGR) + cv2.imwrite(out_path, grid_img_bgr) + print(f"[INFO] Saved color-by-XY track visualization grid -> {out_path}") + + print(f"[INFO] Saved {S} individual frames to {out_dir}/frame_*.png") diff --git a/packages.txt b/packages.txt new file mode 100644 index 0000000000000000000000000000000000000000..20645e641240cb419f5fc66c14c1447e91daf669 --- /dev/null +++ b/packages.txt @@ -0,0 +1 @@ +ffmpeg diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..430a0cb3dfdb1c17d64f6f7a86d99b7fad100299 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,13 @@ +torch +torchvision +numpy +Pillow +PyYAML +huggingface_hub==0.34.4 +opencv-python-headless +onnxruntime +matplotlib +scipy +gradio==5.44.0 +plotly +trimesh