# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os import pycolmap import cv2 as cv import numpy as np from PIL import Image from ...specs import Prediction from ...utils.logger import logger from .glb import _depths_to_world_points_with_colors def export_to_colmap( prediction: Prediction, export_dir: str, image_paths: list[str], conf_thresh_percentile: float = 40.0, process_res_method: str = "upper_bound_resize", ) -> None: # 1. Data preparation conf_thresh = np.percentile(prediction.conf, conf_thresh_percentile) points, colors = _depths_to_world_points_with_colors( prediction.depth, prediction.intrinsics, prediction.extrinsics, # w2c prediction.processed_images, prediction.conf, conf_thresh, ) num_points = len(points) logger.info(f"Exporting to COLMAP with {num_points} points") num_frames = len(prediction.processed_images) h, w = prediction.processed_images.shape[1:3] points_xyf = _create_xyf(num_frames, h, w) points_xyf = points_xyf[prediction.conf >= conf_thresh] # 2. Set Reconstruction reconstruction = pycolmap.Reconstruction() point3d_ids = [] for vidx in range(num_points): point3d_id = reconstruction.add_point3D(points[vidx], pycolmap.Track(), colors[vidx]) point3d_ids.append(point3d_id) for fidx in range(num_frames): orig_w, orig_h = Image.open(image_paths[fidx]).size intrinsic = prediction.intrinsics[fidx] if process_res_method.endswith("resize"): intrinsic[:1] *= orig_w / w intrinsic[1:2] *= orig_h / h elif process_res_method == "crop": raise NotImplementedError("COLMAP export for crop method is not implemented") else: raise ValueError(f"Unknown process_res_method: {process_res_method}") pycolmap_intri = np.array( [intrinsic[0, 0], intrinsic[1, 1], intrinsic[0, 2], intrinsic[1, 2]] ) extrinsic = prediction.extrinsics[fidx] cam_from_world = pycolmap.Rigid3d(pycolmap.Rotation3d(extrinsic[:3, :3]), extrinsic[:3, 3]) # set and add camera camera = pycolmap.Camera() camera.camera_id = fidx + 1 camera.model = pycolmap.CameraModelId.PINHOLE camera.width = orig_w camera.height = orig_h camera.params = pycolmap_intri reconstruction.add_camera(camera) # set and add rig (from camera) rig = pycolmap.Rig() rig.rig_id = camera.camera_id rig.add_ref_sensor(camera.sensor_id) reconstruction.add_rig(rig) # set image image = pycolmap.Image() image.image_id = fidx + 1 image.camera_id = camera.camera_id # set and add frame (from image) frame = pycolmap.Frame() frame.frame_id = image.image_id frame.rig_id = camera.camera_id frame.add_data_id(image.data_id) frame.rig_from_world = cam_from_world reconstruction.add_frame(frame) # set point2d and update track point2d_list = [] points_in_frame = points_xyf[:, 2].astype(np.int32) == fidx for vidx in np.where(points_in_frame)[0]: point2d = points_xyf[vidx][:2] point2d[0] *= orig_w / w point2d[1] *= orig_h / h point3d_id = point3d_ids[vidx] point2d_list.append(pycolmap.Point2D(point2d, point3d_id)) reconstruction.point3D(point3d_id).track.add_element( image.image_id, len(point2d_list) - 1 ) # set and add image image.frame_id = image.image_id image.name = os.path.basename(image_paths[fidx]) image.points2D = pycolmap.Point2DList(point2d_list) reconstruction.add_image(image) # 3. Export reconstruction.write(export_dir) def _create_xyf(num_frames, height, width): """ Creates a grid of pixel coordinates and frame indices (fidx) for all frames. """ # Create coordinate grids for a single frame y_grid, x_grid = np.indices((height, width), dtype=np.int32) x_grid = x_grid[np.newaxis, :, :] y_grid = y_grid[np.newaxis, :, :] # Broadcast to all frames x_coords = np.broadcast_to(x_grid, (num_frames, height, width)) y_coords = np.broadcast_to(y_grid, (num_frames, height, width)) # Create frame indices and broadcast f_idx = np.arange(num_frames, dtype=np.int32)[:, np.newaxis, np.newaxis] f_coords = np.broadcast_to(f_idx, (num_frames, height, width)) # Stack coordinates and frame indices points_xyf = np.stack((x_coords, y_coords, f_coords), axis=-1) return points_xyf