Spaces:
Running on Zero
Running on Zero
| import os | |
| from typing import List, Optional, Tuple | |
| import numpy as np | |
| from .common import ( | |
| branch_key, | |
| c2w_in_view_space, | |
| load_metadata, | |
| selected_frame_indices, | |
| session_file, | |
| world_to_view, | |
| ) | |
| def _origin_shift(w2c_all) -> np.ndarray: | |
| first = c2w_in_view_space(w2c_all[0]) | |
| return first[:3, 3].copy() | |
| def _sample_flat_indices( | |
| valid_indices: np.ndarray, budget: Optional[int], rng: np.random.Generator | |
| ) -> np.ndarray: | |
| if budget is None or budget <= 0 or valid_indices.size <= budget: | |
| return valid_indices | |
| keep = rng.choice(valid_indices.size, size=int(budget), replace=False) | |
| return valid_indices[keep] | |
| def _depth_points_from_flat(depth, intri, w2c, flat_indices): | |
| h, w = depth.shape | |
| ys = flat_indices // w | |
| xs = flat_indices % w | |
| z = depth.reshape(-1)[flat_indices].astype(np.float64) | |
| fx = float(intri[0, 0]) | |
| fy = float(intri[1, 1]) | |
| cx = float(intri[0, 2]) | |
| cy = float(intri[1, 2]) | |
| x = (xs.astype(np.float64) - cx) * z / max(fx, 1e-12) | |
| y = (ys.astype(np.float64) - cy) * z / max(fy, 1e-12) | |
| pts_cam = np.stack([x, y, z], axis=1) | |
| R = w2c[:3, :3].astype(np.float64) | |
| t = w2c[:3, 3].astype(np.float64) | |
| return (R.T @ (pts_cam.T - t[:, None])).T.astype(np.float32, copy=False) | |
| def _camera_points_to_world(points, w2c): | |
| pts = np.asarray(points, dtype=np.float64).reshape(-1, 3) | |
| R = w2c[:3, :3].astype(np.float64) | |
| t = w2c[:3, 3].astype(np.float64) | |
| return (R.T @ (pts.T - t[:, None])).T.astype(np.float32, copy=False) | |
| def collect_points( | |
| session_dir: str, | |
| branch: str, | |
| display_mode: str, | |
| frame_index: int, | |
| mask_sky: bool, | |
| max_points: Optional[int], | |
| seed: int = 0, | |
| ): | |
| branch = branch_key(branch) | |
| meta = load_metadata(session_dir) | |
| frame_ids = selected_frame_indices(meta["num_frames"], frame_index, display_mode) | |
| if not frame_ids: | |
| return ( | |
| np.empty((0, 3), dtype=np.float32), | |
| np.empty((0, 3), dtype=np.uint8), | |
| np.zeros(3, dtype=np.float64), | |
| ) | |
| images = np.load(session_file(session_dir, "images.npy"), mmap_mode="r") | |
| w2c = np.load(session_file(session_dir, "w2c.npy"), mmap_mode="r") | |
| origin_shift = _origin_shift(w2c) | |
| sky = None | |
| if mask_sky and os.path.exists(session_file(session_dir, "sky_masks.npy")): | |
| sky = np.load(session_file(session_dir, "sky_masks.npy"), mmap_mode="r") | |
| if branch == "point_head": | |
| point_head = np.load(session_file(session_dir, "point_head.npy"), mmap_mode="r") | |
| source = point_head | |
| depth = None | |
| intri = None | |
| else: | |
| source = None | |
| depth = np.load(session_file(session_dir, "depth.npy"), mmap_mode="r") | |
| intri = np.load(session_file(session_dir, "intri.npy"), mmap_mode="r") | |
| per_frame_budget = None | |
| if max_points is not None and max_points > 0: | |
| per_frame_budget = max(int(max_points) // max(len(frame_ids), 1), 1) | |
| rng = np.random.default_rng(seed) | |
| points = [] | |
| colors = [] | |
| for idx in frame_ids: | |
| rgb_flat = images[idx].reshape(-1, 3) | |
| if branch == "point_head": | |
| pts_map = source[idx] | |
| valid = np.isfinite(pts_map).all(axis=-1).reshape(-1) | |
| if sky is not None: | |
| valid &= sky[idx].reshape(-1) > 0 | |
| flat = np.flatnonzero(valid) | |
| if flat.size == 0: | |
| continue | |
| flat = _sample_flat_indices(flat, per_frame_budget, rng) | |
| pts_cam = pts_map.reshape(-1, 3)[flat] | |
| pts_world = _camera_points_to_world(pts_cam, w2c[idx]) | |
| else: | |
| depth_i = depth[idx] | |
| valid = (np.isfinite(depth_i) & (depth_i > 0)).reshape(-1) | |
| if sky is not None: | |
| valid &= sky[idx].reshape(-1) > 0 | |
| flat = np.flatnonzero(valid) | |
| if flat.size == 0: | |
| continue | |
| flat = _sample_flat_indices(flat, per_frame_budget, rng) | |
| pts_world = _depth_points_from_flat(depth_i, intri[idx], w2c[idx], flat) | |
| pts_view = world_to_view(pts_world) - origin_shift[None] | |
| points.append(pts_view.astype(np.float32, copy=False)) | |
| colors.append(rgb_flat[flat].astype(np.uint8, copy=False)) | |
| if not points: | |
| return ( | |
| np.empty((0, 3), dtype=np.float32), | |
| np.empty((0, 3), dtype=np.uint8), | |
| origin_shift, | |
| ) | |
| return np.concatenate(points, axis=0), np.concatenate(colors, axis=0), origin_shift | |
| def _frustum_corners_camera(intri, image_hw, depth_scale): | |
| h, w = image_hw | |
| fx = float(intri[0, 0]) | |
| fy = float(intri[1, 1]) | |
| cx = float(intri[0, 2]) | |
| cy = float(intri[1, 2]) | |
| corners = np.array( | |
| [ | |
| [ | |
| (0.0 - cx) * depth_scale / max(fx, 1e-12), | |
| (0.0 - cy) * depth_scale / max(fy, 1e-12), | |
| depth_scale, | |
| ], | |
| [ | |
| ((w - 1.0) - cx) * depth_scale / max(fx, 1e-12), | |
| (0.0 - cy) * depth_scale / max(fy, 1e-12), | |
| depth_scale, | |
| ], | |
| [ | |
| ((w - 1.0) - cx) * depth_scale / max(fx, 1e-12), | |
| ((h - 1.0) - cy) * depth_scale / max(fy, 1e-12), | |
| depth_scale, | |
| ], | |
| [ | |
| (0.0 - cx) * depth_scale / max(fx, 1e-12), | |
| ((h - 1.0) - cy) * depth_scale / max(fy, 1e-12), | |
| depth_scale, | |
| ], | |
| ], | |
| dtype=np.float64, | |
| ) | |
| return corners | |
| def camera_geometry( | |
| session_dir: str, | |
| display_mode: str, | |
| frame_index: int, | |
| camera_scale_ratio: float, | |
| points_hint=None, | |
| ): | |
| meta = load_metadata(session_dir) | |
| frame_ids = selected_frame_indices(meta["num_frames"], frame_index, display_mode) | |
| w2c = np.load(session_file(session_dir, "w2c.npy"), mmap_mode="r") | |
| intri = np.load(session_file(session_dir, "intri.npy"), mmap_mode="r") | |
| origin_shift = _origin_shift(w2c) | |
| center_points = np.array( | |
| [c2w_in_view_space(w2c[idx], origin_shift)[:3, 3] for idx in frame_ids], | |
| dtype=np.float64, | |
| ) | |
| center_extent = 1.0 | |
| if len(center_points) > 1: | |
| center_extent = float( | |
| np.linalg.norm(center_points.max(axis=0) - center_points.min(axis=0)) | |
| ) | |
| point_extent = 0.0 | |
| if points_hint is not None and len(points_hint) > 0: | |
| lo = np.percentile(points_hint, 5, axis=0) | |
| hi = np.percentile(points_hint, 95, axis=0) | |
| point_extent = float(np.linalg.norm(hi - lo)) | |
| extent = max(center_extent, point_extent, 1.0) | |
| depth_scale = extent * float(camera_scale_ratio) | |
| centers = [] | |
| frustums = [] | |
| for idx in frame_ids: | |
| c2w_view = c2w_in_view_space(w2c[idx], origin_shift) | |
| center = c2w_view[:3, 3] | |
| corners_cam = _frustum_corners_camera( | |
| intri[idx], (meta["height"], meta["width"]), depth_scale | |
| ) | |
| corners_world = (c2w_view[:3, :3] @ corners_cam.T).T + center[None] | |
| centers.append(center) | |
| frustums.append((center, corners_world)) | |
| return np.asarray(centers, dtype=np.float64), frustums, origin_shift | |