Spaces:
Running on Zero
Running on Zero
| import argparse | |
| import os | |
| import yaml | |
| import cv2 | |
| import numpy as np | |
| import torch | |
| from PIL import Image | |
| from longstream.core.model import LongStreamModel | |
| from longstream.data.dataloader import LongStreamDataLoader | |
| from longstream.streaming.keyframe_selector import KeyframeSelector | |
| from longstream.streaming.refresh import run_batch_refresh, run_streaming_refresh | |
| from longstream.utils.vendor.models.components.utils.pose_enc import ( | |
| pose_encoding_to_extri_intri, | |
| ) | |
| from longstream.utils.camera import compose_abs_from_rel | |
| from longstream.utils.depth import colorize_depth, unproject_depth_to_points | |
| from longstream.utils.sky_mask import compute_sky_mask | |
| from longstream.io.save_points import save_pointcloud | |
| from longstream.io.save_poses_txt import save_w2c_txt, save_intri_txt, save_rel_pose_txt | |
| from longstream.io.save_images import save_image_sequence, save_video | |
| def _to_uint8_rgb(images): | |
| imgs = images.detach().cpu().numpy() | |
| imgs = np.clip(imgs, 0.0, 1.0) | |
| imgs = (imgs * 255.0).astype(np.uint8) | |
| return imgs | |
| def _ensure_dir(path): | |
| os.makedirs(path, exist_ok=True) | |
| def _apply_sky_mask(depth, mask): | |
| if mask is None: | |
| return depth | |
| m = (mask > 0).astype(np.float32) | |
| return depth * m | |
| def _camera_points_to_world(points, extri): | |
| pts = np.asarray(points, dtype=np.float64).reshape(-1, 3) | |
| R = np.asarray(extri[:3, :3], dtype=np.float64) | |
| t = np.asarray(extri[:3, 3], dtype=np.float64) | |
| world = (R.T @ (pts.T - t[:, None])).T | |
| return world.astype(np.float32, copy=False) | |
| def _mask_points_and_colors(points, colors, mask): | |
| pts = points.reshape(-1, 3) | |
| cols = None if colors is None else colors.reshape(-1, 3) | |
| if mask is None: | |
| return pts, cols | |
| valid = mask.reshape(-1) > 0 | |
| pts = pts[valid] | |
| if cols is not None: | |
| cols = cols[valid] | |
| return pts, cols | |
| def _resize_long_edge(arr, long_edge_size, interpolation): | |
| h, w = arr.shape[:2] | |
| scale = float(long_edge_size) / float(max(h, w)) | |
| new_w = int(round(w * scale)) | |
| new_h = int(round(h * scale)) | |
| return cv2.resize(arr, (new_w, new_h), interpolation=interpolation) | |
| def _prepare_mask_for_model( | |
| mask, size, crop, patch_size, target_shape, square_ok=False | |
| ): | |
| if mask is None: | |
| return None | |
| long_edge = ( | |
| round(size * max(mask.shape[1] / mask.shape[0], mask.shape[0] / mask.shape[1])) | |
| if size == 224 | |
| else size | |
| ) | |
| mask = _resize_long_edge(mask, long_edge, cv2.INTER_NEAREST) | |
| h, w = mask.shape[:2] | |
| cx, cy = w // 2, h // 2 | |
| if size == 224: | |
| half = min(cx, cy) | |
| target_w = 2 * half | |
| target_h = 2 * half | |
| if crop: | |
| mask = mask[cy - half : cy + half, cx - half : cx + half] | |
| else: | |
| mask = cv2.resize( | |
| mask, (target_w, target_h), interpolation=cv2.INTER_NEAREST | |
| ) | |
| else: | |
| halfw = ((2 * cx) // patch_size) * (patch_size // 2) | |
| halfh = ((2 * cy) // patch_size) * (patch_size // 2) | |
| if not square_ok and w == h: | |
| halfh = int(3 * halfw / 4) | |
| target_w = 2 * halfw | |
| target_h = 2 * halfh | |
| if crop: | |
| mask = mask[cy - halfh : cy + halfh, cx - halfw : cx + halfw] | |
| else: | |
| mask = cv2.resize( | |
| mask, (target_w, target_h), interpolation=cv2.INTER_NEAREST | |
| ) | |
| if mask.shape[:2] != tuple(target_shape): | |
| mask = cv2.resize( | |
| mask, (target_shape[1], target_shape[0]), interpolation=cv2.INTER_NEAREST | |
| ) | |
| return mask | |
| def _save_full_pointcloud(path, point_chunks, color_chunks, max_points=None, seed=0): | |
| if not point_chunks: | |
| return | |
| points = np.concatenate(point_chunks, axis=0) | |
| colors = None | |
| if color_chunks and len(color_chunks) == len(point_chunks): | |
| colors = np.concatenate(color_chunks, axis=0) | |
| if max_points is not None and len(points) > max_points: | |
| rng = np.random.default_rng(seed) | |
| keep = rng.choice(len(points), size=max_points, replace=False) | |
| points = points[keep] | |
| if colors is not None: | |
| colors = colors[keep] | |
| np.save(os.path.splitext(path)[0] + ".npy", points.astype(np.float32, copy=False)) | |
| save_pointcloud(path, points, colors=colors, max_points=None, seed=seed) | |
| def run_inference_cfg(cfg: dict): | |
| device = cfg.get("device", "cuda" if torch.cuda.is_available() else "cpu") | |
| device_type = torch.device(device).type | |
| model_cfg = cfg.get("model", {}) | |
| data_cfg = cfg.get("data", {}) | |
| infer_cfg = cfg.get("inference", {}) | |
| output_cfg = cfg.get("output", {}) | |
| print(f"[longstream] device={device}", flush=True) | |
| model = LongStreamModel(model_cfg).to(device) | |
| model.eval() | |
| print("[longstream] model ready", flush=True) | |
| loader = LongStreamDataLoader(data_cfg) | |
| keyframe_stride = int(infer_cfg.get("keyframe_stride", 8)) | |
| keyframe_mode = infer_cfg.get("keyframe_mode", "fixed") | |
| refresh = int( | |
| infer_cfg.get("refresh", int(infer_cfg.get("keyframes_per_batch", 3)) + 1) | |
| ) | |
| if refresh < 2: | |
| raise ValueError( | |
| "refresh must be >= 2 because it counts both keyframe endpoints" | |
| ) | |
| mode = infer_cfg.get("mode", "streaming_refresh") | |
| if mode == "streaming": | |
| mode = "streaming_refresh" | |
| streaming_mode = infer_cfg.get("streaming_mode", "causal") | |
| window_size = int(infer_cfg.get("window_size", 5)) | |
| selector = KeyframeSelector( | |
| min_interval=keyframe_stride, | |
| max_interval=keyframe_stride, | |
| force_first=True, | |
| mode="random" if keyframe_mode == "random" else "fixed", | |
| ) | |
| out_root = output_cfg.get("root", "outputs") | |
| _ensure_dir(out_root) | |
| save_videos = bool(output_cfg.get("save_videos", True)) | |
| save_points = bool(output_cfg.get("save_points", True)) | |
| save_frame_points = bool(output_cfg.get("save_frame_points", True)) | |
| save_depth = bool(output_cfg.get("save_depth", True)) | |
| save_images = bool(output_cfg.get("save_images", True)) | |
| mask_sky = bool(output_cfg.get("mask_sky", True)) | |
| max_full_pointcloud_points = output_cfg.get("max_full_pointcloud_points", None) | |
| if max_full_pointcloud_points is not None: | |
| max_full_pointcloud_points = int(max_full_pointcloud_points) | |
| max_frame_pointcloud_points = output_cfg.get("max_frame_pointcloud_points", None) | |
| if max_frame_pointcloud_points is not None: | |
| max_frame_pointcloud_points = int(max_frame_pointcloud_points) | |
| skyseg_path = output_cfg.get( | |
| "skyseg_path", | |
| os.path.join(os.path.dirname(__file__), "..", "..", "skyseg.onnx"), | |
| ) | |
| with torch.no_grad(): | |
| for seq in loader: | |
| images = seq.images | |
| B, S, C, H, W = images.shape | |
| print( | |
| f"[longstream] sequence {seq.name}: inference start ({S} frames)", | |
| flush=True, | |
| ) | |
| is_keyframe, keyframe_indices = selector.select_keyframes( | |
| S, B, images.device | |
| ) | |
| rel_pose_cfg = infer_cfg.get("rel_pose_head_cfg", {"num_iterations": 4}) | |
| if mode == "batch_refresh": | |
| outputs = run_batch_refresh( | |
| model, | |
| images, | |
| is_keyframe, | |
| keyframe_indices, | |
| streaming_mode, | |
| keyframe_stride, | |
| refresh, | |
| rel_pose_cfg, | |
| ) | |
| elif mode == "streaming_refresh": | |
| outputs = run_streaming_refresh( | |
| model, | |
| images, | |
| is_keyframe, | |
| keyframe_indices, | |
| streaming_mode, | |
| window_size, | |
| refresh, | |
| rel_pose_cfg, | |
| ) | |
| else: | |
| raise ValueError(f"Unsupported inference mode: {mode}") | |
| print(f"[longstream] sequence {seq.name}: inference done", flush=True) | |
| if device_type == "cuda": | |
| torch.cuda.empty_cache() | |
| seq_dir = os.path.join(out_root, seq.name) | |
| _ensure_dir(seq_dir) | |
| frame_ids = list(range(S)) | |
| rgb = _to_uint8_rgb(images[0].permute(0, 2, 3, 1)) | |
| if "rel_pose_enc" in outputs: | |
| rel_pose_enc = outputs["rel_pose_enc"][0] | |
| abs_pose_enc = compose_abs_from_rel(rel_pose_enc, keyframe_indices[0]) | |
| extri, intri = pose_encoding_to_extri_intri( | |
| abs_pose_enc[None], image_size_hw=(H, W) | |
| ) | |
| extri_np = extri[0].detach().cpu().numpy() | |
| intri_np = intri[0].detach().cpu().numpy() | |
| pose_dir = os.path.join(seq_dir, "poses") | |
| _ensure_dir(pose_dir) | |
| save_w2c_txt( | |
| os.path.join(pose_dir, "abs_pose.txt"), extri_np, frame_ids | |
| ) | |
| save_intri_txt(os.path.join(pose_dir, "intri.txt"), intri_np, frame_ids) | |
| save_rel_pose_txt( | |
| os.path.join(pose_dir, "rel_pose.txt"), rel_pose_enc, frame_ids | |
| ) | |
| elif "pose_enc" in outputs: | |
| pose_enc = outputs["pose_enc"][0] | |
| extri, intri = pose_encoding_to_extri_intri( | |
| pose_enc[None], image_size_hw=(H, W) | |
| ) | |
| extri_np = extri[0].detach().cpu().numpy() | |
| intri_np = intri[0].detach().cpu().numpy() | |
| pose_dir = os.path.join(seq_dir, "poses") | |
| _ensure_dir(pose_dir) | |
| save_w2c_txt( | |
| os.path.join(pose_dir, "abs_pose.txt"), extri_np, frame_ids | |
| ) | |
| save_intri_txt(os.path.join(pose_dir, "intri.txt"), intri_np, frame_ids) | |
| if save_images: | |
| print(f"[longstream] sequence {seq.name}: saving rgb", flush=True) | |
| rgb_dir = os.path.join(seq_dir, "images", "rgb") | |
| save_image_sequence(rgb_dir, list(rgb)) | |
| if save_videos: | |
| save_video( | |
| os.path.join(seq_dir, "images", "rgb.mp4"), | |
| os.path.join(rgb_dir, "frame_*.png"), | |
| ) | |
| sky_masks = None | |
| if mask_sky: | |
| raw_sky_masks = compute_sky_mask( | |
| seq.image_paths, skyseg_path, os.path.join(seq_dir, "sky_masks") | |
| ) | |
| if raw_sky_masks is not None: | |
| sky_masks = [ | |
| _prepare_mask_for_model( | |
| mask, | |
| size=int(data_cfg.get("size", 518)), | |
| crop=bool(data_cfg.get("crop", False)), | |
| patch_size=int(data_cfg.get("patch_size", 14)), | |
| target_shape=(H, W), | |
| ) | |
| for mask in raw_sky_masks | |
| ] | |
| if save_depth and "depth" in outputs: | |
| print(f"[longstream] sequence {seq.name}: saving depth", flush=True) | |
| depth = outputs["depth"][0, :, :, :, 0].detach().cpu().numpy() | |
| depth_dir = os.path.join(seq_dir, "depth", "dpt") | |
| _ensure_dir(depth_dir) | |
| color_dir = os.path.join(seq_dir, "depth", "dpt_plasma") | |
| _ensure_dir(color_dir) | |
| color_frames = [] | |
| for i in range(S): | |
| d = depth[i] | |
| if sky_masks is not None and sky_masks[i] is not None: | |
| d = _apply_sky_mask(d, sky_masks[i]) | |
| np.save(os.path.join(depth_dir, f"frame_{i:06d}.npy"), d) | |
| colored = colorize_depth(d, cmap="plasma") | |
| Image.fromarray(colored).save( | |
| os.path.join(color_dir, f"frame_{i:06d}.png") | |
| ) | |
| color_frames.append(colored) | |
| if save_videos: | |
| save_video( | |
| os.path.join(seq_dir, "depth", "dpt_plasma.mp4"), | |
| os.path.join(color_dir, "frame_*.png"), | |
| ) | |
| if save_points: | |
| print( | |
| f"[longstream] sequence {seq.name}: saving point clouds", flush=True | |
| ) | |
| if "world_points" in outputs: | |
| if "rel_pose_enc" in outputs: | |
| abs_pose_enc = compose_abs_from_rel( | |
| outputs["rel_pose_enc"][0], keyframe_indices[0] | |
| ) | |
| extri, intri = pose_encoding_to_extri_intri( | |
| abs_pose_enc[None], image_size_hw=(H, W) | |
| ) | |
| else: | |
| extri, intri = pose_encoding_to_extri_intri( | |
| outputs["pose_enc"][0][None], image_size_hw=(H, W) | |
| ) | |
| extri = extri[0] | |
| intri = intri[0] | |
| pts_dir = os.path.join(seq_dir, "points", "point_head") | |
| _ensure_dir(pts_dir) | |
| pts = outputs["world_points"][0].detach().cpu().numpy() | |
| full_pts = [] | |
| full_cols = [] | |
| for i in range(S): | |
| pts_world = _camera_points_to_world( | |
| pts[i], extri[i].detach().cpu().numpy() | |
| ) | |
| pts_world = pts_world.reshape(pts[i].shape) | |
| pts_i, cols_i = _mask_points_and_colors( | |
| pts_world, | |
| rgb[i], | |
| None if sky_masks is None else sky_masks[i], | |
| ) | |
| if save_frame_points: | |
| save_pointcloud( | |
| os.path.join(pts_dir, f"frame_{i:06d}.ply"), | |
| pts_i, | |
| colors=cols_i, | |
| max_points=max_frame_pointcloud_points, | |
| seed=i, | |
| ) | |
| if len(pts_i): | |
| full_pts.append(pts_i) | |
| full_cols.append(cols_i) | |
| _save_full_pointcloud( | |
| os.path.join(seq_dir, "points", "point_head_full.ply"), | |
| full_pts, | |
| full_cols, | |
| max_points=max_full_pointcloud_points, | |
| seed=0, | |
| ) | |
| if "depth" in outputs and ( | |
| "rel_pose_enc" in outputs or "pose_enc" in outputs | |
| ): | |
| depth = outputs["depth"][0, :, :, :, 0] | |
| if "rel_pose_enc" in outputs: | |
| abs_pose_enc = compose_abs_from_rel( | |
| outputs["rel_pose_enc"][0], keyframe_indices[0] | |
| ) | |
| extri, intri = pose_encoding_to_extri_intri( | |
| abs_pose_enc[None], image_size_hw=(H, W) | |
| ) | |
| else: | |
| extri, intri = pose_encoding_to_extri_intri( | |
| outputs["pose_enc"][0][None], image_size_hw=(H, W) | |
| ) | |
| extri = extri[0] | |
| intri = intri[0] | |
| dpt_pts_dir = os.path.join(seq_dir, "points", "dpt_unproj") | |
| _ensure_dir(dpt_pts_dir) | |
| full_pts = [] | |
| full_cols = [] | |
| for i in range(S): | |
| d = depth[i] | |
| pts_cam = unproject_depth_to_points(d[None], intri[i : i + 1])[ | |
| 0 | |
| ] | |
| R = extri[i, :3, :3] | |
| t = extri[i, :3, 3] | |
| pts_world = ( | |
| R.t() @ (pts_cam.reshape(-1, 3).t() - t[:, None]) | |
| ).t() | |
| pts_world = pts_world.cpu().numpy().reshape(-1, 3) | |
| pts_i, cols_i = _mask_points_and_colors( | |
| pts_world, | |
| rgb[i], | |
| None if sky_masks is None else sky_masks[i], | |
| ) | |
| if save_frame_points: | |
| save_pointcloud( | |
| os.path.join(dpt_pts_dir, f"frame_{i:06d}.ply"), | |
| pts_i, | |
| colors=cols_i, | |
| max_points=max_frame_pointcloud_points, | |
| seed=i, | |
| ) | |
| if len(pts_i): | |
| full_pts.append(pts_i) | |
| full_cols.append(cols_i) | |
| _save_full_pointcloud( | |
| os.path.join(seq_dir, "points", "dpt_unproj_full.ply"), | |
| full_pts, | |
| full_cols, | |
| max_points=max_full_pointcloud_points, | |
| seed=1, | |
| ) | |
| del outputs | |
| if device_type == "cuda": | |
| torch.cuda.empty_cache() | |
| def run_inference(config_path: str): | |
| with open(config_path, "r") as f: | |
| cfg = yaml.safe_load(f) | |
| run_inference_cfg(cfg) | |
| def main(): | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--config", required=True) | |
| args = parser.parse_args() | |
| run_inference(args.config) | |
| if __name__ == "__main__": | |
| main() | |