| | import pycolmap |
| | import numpy as np |
| |
|
| | import torch |
| | import torch.nn.functional as F |
| | from src.model.encoder.vggt.utils.pose_enc import pose_encoding_to_extri_intri |
| | from src.model.encoder.vggt.utils.geometry import closed_form_inverse_se3, unproject_depth_map_to_point_map |
| | from lightglue import ALIKED, SuperPoint, SIFT |
| | from src.utils.tensor_to_pycolmap import batch_matrix_to_pycolmap, pycolmap_to_batch_matrix |
| |
|
| |
|
| | _RESNET_MEAN = [0.485, 0.456, 0.406] |
| | _RESNET_STD = [0.229, 0.224, 0.225] |
| |
|
| |
|
| | def generate_rank_by_dino( |
| | images, query_frame_num, image_size=518, model_name="dinov2_vitb14_reg", device="cuda", spatial_similarity=True |
| | ): |
| | """ |
| | Generate a ranking of frames using DINO ViT features. |
| | |
| | Args: |
| | images: Tensor of shape (S, 3, H, W) with values in range [0, 1] |
| | query_frame_num: Number of frames to select |
| | image_size: Size to resize images to before processing |
| | model_name: Name of the DINO model to use |
| | device: Device to run the model on |
| | spatial_similarity: Whether to use spatial token similarity or CLS token similarity |
| | |
| | Returns: |
| | List of frame indices ranked by their representativeness |
| | """ |
| | dino_v2_model = torch.hub.load('facebookresearch/dinov2', model_name) |
| | dino_v2_model.eval() |
| | dino_v2_model = dino_v2_model.to(device) |
| | |
| | resnet_mean = torch.tensor(_RESNET_MEAN, device=device).view(1, 3, 1, 1) |
| | resnet_std = torch.tensor(_RESNET_STD, device=device).view(1, 3, 1, 1) |
| | images_resnet_norm = (images - resnet_mean) / resnet_std |
| | |
| | with torch.no_grad(): |
| | frame_feat = dino_v2_model(images_resnet_norm, is_training=True) |
| | |
| | if spatial_similarity: |
| | frame_feat = frame_feat["x_norm_patchtokens"] |
| | frame_feat_norm = F.normalize(frame_feat, p=2, dim=1) |
| |
|
| | |
| | frame_feat_norm = frame_feat_norm.permute(1, 0, 2) |
| | similarity_matrix = torch.bmm( |
| | frame_feat_norm, frame_feat_norm.transpose(-1, -2) |
| | ) |
| | similarity_matrix = similarity_matrix.mean(dim=0) |
| | else: |
| | frame_feat = frame_feat["x_norm_clstoken"] |
| | frame_feat_norm = F.normalize(frame_feat, p=2, dim=1) |
| | similarity_matrix = torch.mm( |
| | frame_feat_norm, frame_feat_norm.transpose(-1, -2) |
| | ) |
| | |
| | distance_matrix = 100 - similarity_matrix.clone() |
| |
|
| | |
| | similarity_matrix.fill_diagonal_(-100) |
| | similarity_sum = similarity_matrix.sum(dim=1) |
| |
|
| | |
| | most_common_frame_index = torch.argmax(similarity_sum).item() |
| |
|
| | |
| | fps_idx = farthest_point_sampling( |
| | distance_matrix, query_frame_num, most_common_frame_index |
| | ) |
| |
|
| | return fps_idx |
| |
|
| |
|
| | def farthest_point_sampling( |
| | distance_matrix, num_samples, most_common_frame_index=0 |
| | ): |
| | """ |
| | Farthest point sampling algorithm to select diverse frames. |
| | |
| | Args: |
| | distance_matrix: Matrix of distances between frames |
| | num_samples: Number of frames to select |
| | most_common_frame_index: Index of the first frame to select |
| | |
| | Returns: |
| | List of selected frame indices |
| | """ |
| | distance_matrix = distance_matrix.clamp(min=0) |
| | N = distance_matrix.size(0) |
| |
|
| | |
| | selected_indices = [most_common_frame_index] |
| | check_distances = distance_matrix[selected_indices] |
| |
|
| | while len(selected_indices) < num_samples: |
| | |
| | farthest_point = torch.argmax(check_distances) |
| | selected_indices.append(farthest_point.item()) |
| |
|
| | check_distances = distance_matrix[farthest_point] |
| | |
| | check_distances[selected_indices] = 0 |
| |
|
| | |
| | if len(selected_indices) == N: |
| | break |
| |
|
| | return selected_indices |
| |
|
| |
|
| | def calculate_index_mappings(query_index, S, device=None): |
| | """ |
| | Construct an order that switches [query_index] and [0] |
| | so that the content of query_index would be placed at [0]. |
| | |
| | Args: |
| | query_index: Index to swap with 0 |
| | S: Total number of elements |
| | device: Device to place the tensor on |
| | |
| | Returns: |
| | Tensor of indices with the swapped order |
| | """ |
| | new_order = torch.arange(S) |
| | new_order[0] = query_index |
| | new_order[query_index] = 0 |
| | if device is not None: |
| | new_order = new_order.to(device) |
| | return new_order |
| |
|
| |
|
| | def switch_tensor_order(tensors, order, dim=1): |
| | """ |
| | Reorder tensors along a specific dimension according to the given order. |
| | |
| | Args: |
| | tensors: List of tensors to reorder |
| | order: Tensor of indices specifying the new order |
| | dim: Dimension along which to reorder |
| | |
| | Returns: |
| | List of reordered tensors |
| | """ |
| | return [ |
| | torch.index_select(tensor, dim, order) if tensor is not None else None |
| | for tensor in tensors |
| | ] |
| |
|
| |
|
| | def predict_track(model, images, query_points, dtype=torch.bfloat16, use_tf32_for_track=True, iters=4): |
| | """ |
| | Predict tracks for query points across frames. |
| | |
| | Args: |
| | model: VGGT model |
| | images: Tensor of images of shape (S, 3, H, W) |
| | query_points: Query points to track |
| | dtype: Data type for computation |
| | use_tf32_for_track: Whether to use TF32 precision for tracking |
| | iters: Number of iterations for tracking |
| | |
| | Returns: |
| | Predicted tracks, visibility scores, and confidence scores |
| | """ |
| | with torch.no_grad(): |
| | with torch.cuda.amp.autocast(dtype=dtype): |
| | images = images[None] |
| | aggregated_tokens_list, ps_idx = model.aggregator(images) |
| | |
| | if not use_tf32_for_track: |
| | track_list, vis_score, conf_score = model.track_head( |
| | aggregated_tokens_list, images, ps_idx, query_points=query_points, iters=iters |
| | ) |
| |
|
| | if use_tf32_for_track: |
| | with torch.cuda.amp.autocast(enabled=False): |
| | track_list, vis_score, conf_score = model.track_head( |
| | aggregated_tokens_list, images, ps_idx, query_points=query_points, iters=iters |
| | ) |
| | |
| | pred_track = track_list[-1] |
| | |
| | return pred_track.squeeze(0), vis_score.squeeze(0), conf_score.squeeze(0) |
| | |
| |
|
| | def initialize_feature_extractors(max_query_num, det_thres, extractor_method="aliked", device="cuda"): |
| | """ |
| | Initialize feature extractors that can be reused based on a method string. |
| | |
| | Args: |
| | max_query_num: Maximum number of keypoints to extract |
| | det_thres: Detection threshold for keypoint extraction |
| | extractor_method: String specifying which extractors to use (e.g., "aliked", "sp+sift", "aliked+sp+sift") |
| | device: Device to run extraction on |
| | |
| | Returns: |
| | Dictionary of initialized extractors |
| | """ |
| | extractors = {} |
| | methods = extractor_method.lower().split('+') |
| | active_extractors = len(methods) |
| | |
| | for method in methods: |
| | method = method.strip() |
| | if method == "aliked": |
| | aliked_max_points = max_query_num // active_extractors |
| | aliked_extractor = ALIKED(max_num_keypoints=aliked_max_points, detection_threshold=det_thres) |
| | extractors['aliked'] = aliked_extractor.to(device).eval() |
| | elif method == "sp": |
| | sp_max_points = max_query_num // active_extractors |
| | sp_extractor = SuperPoint(max_num_keypoints=sp_max_points, detection_threshold=det_thres) |
| | extractors['sp'] = sp_extractor.to(device).eval() |
| | elif method == "sift": |
| | sift_max_points = max_query_num // active_extractors |
| | sift_extractor = SIFT(max_num_keypoints=sift_max_points) |
| | extractors['sift'] = sift_extractor.to(device).eval() |
| | else: |
| | print(f"Warning: Unknown feature extractor '{method}', ignoring.") |
| | |
| | if not extractors: |
| | print(f"Warning: No valid extractors found in '{extractor_method}'. Using ALIKED by default.") |
| | aliked_extractor = ALIKED(max_num_keypoints=max_query_num, detection_threshold=det_thres) |
| | extractors['aliked'] = aliked_extractor.to(device).eval() |
| | |
| | return extractors |
| |
|
| |
|
| | def extract_keypoints(query_image, extractors): |
| | """ |
| | Extract keypoints using pre-initialized feature extractors. |
| | |
| | Args: |
| | query_image: Input image tensor (3xHxW, range [0, 1]) |
| | extractors: Dictionary of initialized extractors |
| | |
| | Returns: |
| | Tensor of keypoint coordinates (1xNx2) |
| | """ |
| | query_points_round = None |
| | |
| | with torch.no_grad(): |
| | for extractor_name, extractor in extractors.items(): |
| | query_points_data = extractor.extract(query_image) |
| | extractor_points = query_points_data["keypoints"].round() |
| | |
| | if query_points_round is not None: |
| | query_points_round = torch.cat([query_points_round, extractor_points], dim=1) |
| | else: |
| | query_points_round = extractor_points |
| |
|
| | return query_points_round |
| |
|
| |
|
| | def run_vggt_with_ba(model, images, image_names=None, dtype=torch.bfloat16, |
| | max_query_num=2048, det_thres=0.005, query_frame_num=3, |
| | extractor_method="aliked+sp+sift", |
| | max_reproj_error=4, |
| | shared_camera=True, |
| | ): |
| | """ |
| | Run VGGT with bundle adjustment for pose estimation. |
| | |
| | Args: |
| | model: VGGT model |
| | images: Tensor of images of shape (S, 3, H, W) |
| | image_names: Optional list of image names |
| | dtype: Data type for computation |
| | |
| | Returns: |
| | Predicted extrinsic camera parameters |
| | |
| | TODO: |
| | - [ ] Use VGGT's vit instead of dinov2 for rank generation |
| | """ |
| | device = images.device |
| | frame_num = images.shape[0] |
| |
|
| | |
| | |
| | query_frame_indexes = generate_rank_by_dino( |
| | images, query_frame_num, image_size=518, |
| | model_name="dinov2_vitb14_reg", device=device, |
| | spatial_similarity=False |
| | ) |
| | |
| | |
| | if 0 in query_frame_indexes: |
| | query_frame_indexes.remove(0) |
| | query_frame_indexes = [0, *query_frame_indexes] |
| |
|
| | |
| |
|
| | with torch.no_grad(), torch.cuda.amp.autocast(enabled=False, dtype=torch.bfloat16): |
| | aggregated_tokens_list, patch_start_idx = model.aggregator(images, intermediate_layer_idx=model.cfg.intermediate_layer_idx) |
| | with torch.cuda.amp.autocast(enabled=False): |
| | fp32_tokens = [token.float() for token in aggregated_tokens_list] |
| | pred_all_pose_enc = model.camera_head(fp32_tokens)[-1] |
| | pred_all_extrinsic, pred_all_intrinsic = pose_encoding_to_extri_intri(pred_all_pose_enc, images.shape[-2:]) |
| | pred_extrinsic = pred_all_extrinsic[0] |
| | pred_intrinsic = pred_all_intrinsic[0] |
| | depth_map, depth_conf = model.depth_head(aggregated_tokens_list, images=images, patch_start_idx=patch_start_idx) |
| |
|
| | world_points = unproject_depth_map_to_point_map(depth_map, pred_extrinsic, pred_intrinsic) |
| | world_points = torch.from_numpy(world_points).to(device) |
| | world_points_conf = depth_conf.to(device) |
| | |
| | torch.cuda.empty_cache() |
| |
|
| | |
| | pred_tracks = [] |
| | pred_vis_scores = [] |
| | pred_conf_scores = [] |
| | pred_world_points = [] |
| | pred_world_points_conf = [] |
| | |
| | |
| | extractors = initialize_feature_extractors(max_query_num, det_thres, extractor_method, device) |
| |
|
| | |
| | for query_index in query_frame_indexes: |
| | query_image = images[query_index] |
| | query_points_round = extract_keypoints(query_image, extractors) |
| |
|
| | |
| | reorder_index = calculate_index_mappings(query_index, frame_num, device=device) |
| | reorder_images = switch_tensor_order([images], reorder_index, dim=0)[0] |
| | |
| | |
| | reorder_tracks, reorder_vis_score, reorder_conf_score = predict_track( |
| | model, reorder_images, query_points_round, dtype=dtype, use_tf32_for_track=True, iters=4 |
| | ) |
| | |
| | |
| | pred_track, pred_vis, pred_score = switch_tensor_order( |
| | [reorder_tracks, reorder_vis_score, reorder_conf_score], reorder_index, dim=0 |
| | ) |
| |
|
| | pred_tracks.append(pred_track) |
| | pred_vis_scores.append(pred_vis) |
| | pred_conf_scores.append(pred_score) |
| |
|
| | |
| | query_points_round_long = query_points_round.squeeze(0).long() |
| | query_world_points = world_points[query_index][ |
| | query_points_round_long[:, 1], query_points_round_long[:, 0] |
| | ] |
| | query_world_points_conf = world_points_conf[query_index][ |
| | query_points_round_long[:, 1], query_points_round_long[:, 0] |
| | ] |
| | |
| | pred_world_points.append(query_world_points) |
| | pred_world_points_conf.append(query_world_points_conf) |
| | |
| | |
| | pred_tracks = torch.cat(pred_tracks, dim=1) |
| | pred_vis_scores = torch.cat(pred_vis_scores, dim=1) |
| | pred_conf_scores = torch.cat(pred_conf_scores, dim=1) |
| | pred_world_points = torch.cat(pred_world_points, dim=0) |
| | pred_world_points_conf = torch.cat(pred_world_points_conf, dim=0) |
| | |
| | |
| | filtered_flag = pred_world_points_conf > 1.5 |
| | |
| | if filtered_flag.sum() > 1024: |
| | |
| | pred_world_points = pred_world_points[filtered_flag] |
| | pred_world_points_conf = pred_world_points_conf[filtered_flag] |
| | |
| | pred_tracks = pred_tracks[:, filtered_flag] |
| | pred_vis_scores = pred_vis_scores[:, filtered_flag] |
| | pred_conf_scores = pred_conf_scores[:, filtered_flag] |
| | |
| | torch.cuda.empty_cache() |
| |
|
| | |
| | S, _, H, W = images.shape |
| | image_size = torch.tensor([W, H], dtype=pred_tracks.dtype, device=device) |
| |
|
| | |
| | reconstruction = batch_matrix_to_pycolmap( |
| | pred_world_points, |
| | pred_extrinsic, |
| | pred_intrinsic, |
| | pred_tracks, |
| | image_size, |
| | max_reproj_error=max_reproj_error, |
| | shared_camera=shared_camera |
| | ) |
| | |
| | ba_options = pycolmap.BundleAdjustmentOptions() |
| | pycolmap.bundle_adjustment(reconstruction, ba_options) |
| | _, updated_extrinsic, _, _ = pycolmap_to_batch_matrix( |
| | reconstruction, device=device, camera_type="SIMPLE_PINHOLE" |
| | ) |
| | |
| | return updated_extrinsic |
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|