# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. # # Modified from https://github.com/facebookresearch/vggt import logging import warnings import torch import torch.nn.functional as F from lightglue import ALIKED, SIFT, SuperPoint from .vggsfm_tracker import TrackerPredictor # Suppress verbose logging from dependencies logging.getLogger("dinov2").setLevel(logging.WARNING) warnings.filterwarnings("ignore", message="xFormers is available") warnings.filterwarnings("ignore", message="dinov2") # Constants _RESNET_MEAN = [0.485, 0.456, 0.406] _RESNET_STD = [0.229, 0.224, 0.225] def build_vggsfm_tracker(model_path=None): """ Build and initialize the VGGSfM tracker. Args: model_path: Path to the model weights file. If None, weights are downloaded from HuggingFace. Returns: Initialized tracker model in eval mode. """ tracker = TrackerPredictor() if model_path is None: default_url = ( "https://huggingface.co/facebook/VGGSfM/resolve/main/vggsfm_v2_tracker.pt" ) tracker.load_state_dict(torch.hub.load_state_dict_from_url(default_url)) else: tracker.load_state_dict(torch.load(model_path)) tracker.eval() return tracker def generate_rank_by_dino( images, query_frame_num, image_size=336, model_name="dinov2_vitb14_reg", device="cuda", spatial_similarity=False, ): """ Generate a ranking of frames using DINO ViT features. Args: images: Tensor of shape (S, 3, H, W) with values in range [0, 1] query_frame_num: Number of frames to select image_size: Size to resize images to before processing model_name: Name of the DINO model to use device: Device to run the model on spatial_similarity: Whether to use spatial token similarity or CLS token similarity Returns: List of frame indices ranked by their representativeness """ # Resize images to the target size images = F.interpolate( images, (image_size, image_size), mode="bilinear", align_corners=False ) # Load DINO model dino_v2_model = torch.hub.load("facebookresearch/dinov2", model_name) dino_v2_model.eval() dino_v2_model = dino_v2_model.to(device) # Normalize images using ResNet normalization resnet_mean = torch.tensor(_RESNET_MEAN, device=device).view(1, 3, 1, 1) resnet_std = torch.tensor(_RESNET_STD, device=device).view(1, 3, 1, 1) images_resnet_norm = (images - resnet_mean) / resnet_std with torch.no_grad(): frame_feat = dino_v2_model(images_resnet_norm, is_training=True) # Process features based on similarity type if spatial_similarity: frame_feat = frame_feat["x_norm_patchtokens"] frame_feat_norm = F.normalize(frame_feat, p=2, dim=1) # Compute the similarity matrix frame_feat_norm = frame_feat_norm.permute(1, 0, 2) similarity_matrix = torch.bmm( frame_feat_norm, frame_feat_norm.transpose(-1, -2) ) similarity_matrix = similarity_matrix.mean(dim=0) else: frame_feat = frame_feat["x_norm_clstoken"] frame_feat_norm = F.normalize(frame_feat, p=2, dim=1) similarity_matrix = torch.mm(frame_feat_norm, frame_feat_norm.transpose(-1, -2)) distance_matrix = 100 - similarity_matrix.clone() # Ignore self-pairing similarity_matrix.fill_diagonal_(-100) similarity_sum = similarity_matrix.sum(dim=1) # Find the most common frame most_common_frame_index = torch.argmax(similarity_sum).item() # Conduct FPS sampling starting from the most common frame fps_idx = farthest_point_sampling( distance_matrix, query_frame_num, most_common_frame_index ) # Clean up all tensors and models to free memory del frame_feat, frame_feat_norm, similarity_matrix, distance_matrix del dino_v2_model torch.cuda.empty_cache() return fps_idx def farthest_point_sampling(distance_matrix, num_samples, most_common_frame_index=0): """ Farthest point sampling algorithm to select diverse frames. Args: distance_matrix: Matrix of distances between frames num_samples: Number of frames to select most_common_frame_index: Index of the first frame to select Returns: List of selected frame indices """ distance_matrix = distance_matrix.clamp(min=0) N = distance_matrix.size(0) # Initialize with the most common frame selected_indices = [most_common_frame_index] check_distances = distance_matrix[selected_indices] while len(selected_indices) < num_samples: # Find the farthest point from the current set of selected points farthest_point = torch.argmax(check_distances) selected_indices.append(farthest_point.item()) check_distances = distance_matrix[farthest_point] # Mark already selected points to avoid selecting them again check_distances[selected_indices] = 0 # Break if all points have been selected if len(selected_indices) == N: break return selected_indices def calculate_index_mappings(query_index, S, device=None): """ Construct an order that switches [query_index] and [0] so that the content of query_index would be placed at [0]. Args: query_index: Index to swap with 0 S: Total number of elements device: Device to place the tensor on Returns: Tensor of indices with the swapped order """ new_order = torch.arange(S) new_order[0] = query_index new_order[query_index] = 0 if device is not None: new_order = new_order.to(device) return new_order def switch_tensor_order(tensors, order, dim=1): """ Reorder tensors along a specific dimension according to the given order. Args: tensors: List of tensors to reorder order: Tensor of indices specifying the new order dim: Dimension along which to reorder Returns: List of reordered tensors """ return [ torch.index_select(tensor, dim, order) if tensor is not None else None for tensor in tensors ] def initialize_feature_extractors( max_query_num, det_thres=0.005, extractor_method="aliked", device="cuda" ): """ Initialize feature extractors that can be reused based on a method string. Args: max_query_num: Maximum number of keypoints to extract det_thres: Detection threshold for keypoint extraction extractor_method: String specifying which extractors to use (e.g., "aliked", "sp+sift", "aliked+sp+sift") device: Device to run extraction on Returns: Dictionary of initialized extractors """ extractors = {} methods = extractor_method.lower().split("+") for method in methods: method = method.strip() if method == "aliked": aliked_extractor = ALIKED( max_num_keypoints=max_query_num, detection_threshold=det_thres ) extractors["aliked"] = aliked_extractor.to(device).eval() elif method == "sp": sp_extractor = SuperPoint( max_num_keypoints=max_query_num, detection_threshold=det_thres ) extractors["sp"] = sp_extractor.to(device).eval() elif method == "sift": sift_extractor = SIFT(max_num_keypoints=max_query_num) extractors["sift"] = sift_extractor.to(device).eval() else: print(f"Warning: Unknown feature extractor '{method}', ignoring.") if not extractors: print( f"Warning: No valid extractors found in '{extractor_method}'. Using ALIKED by default." ) aliked_extractor = ALIKED( max_num_keypoints=max_query_num, detection_threshold=det_thres ) extractors["aliked"] = aliked_extractor.to(device).eval() return extractors def extract_keypoints(query_image, extractors, round_keypoints=True): """ Extract keypoints using pre-initialized feature extractors. Args: query_image: Input image tensor (3xHxW, range [0, 1]) extractors: Dictionary of initialized extractors Returns: Tensor of keypoint coordinates (1xNx2) """ query_points = None with torch.no_grad(): for extractor_name, extractor in extractors.items(): query_points_data = extractor.extract(query_image, invalid_mask=None) extractor_points = query_points_data["keypoints"] if round_keypoints: extractor_points = extractor_points.round() if query_points is not None: query_points = torch.cat([query_points, extractor_points], dim=1) else: query_points = extractor_points return query_points def predict_tracks_in_chunks( track_predictor, images_feed, query_points_list, fmaps_feed, fine_tracking, num_splits=None, fine_chunk=40960, ): """ Process a list of query points to avoid memory issues. Args: track_predictor (object): The track predictor object used for predicting tracks. images_feed (torch.Tensor): A tensor of shape (B, T, C, H, W) representing a batch of images. query_points_list (list or tuple): A list/tuple of tensors, each of shape (B, Ni, 2) representing chunks of query points. fmaps_feed (torch.Tensor): A tensor of feature maps for the tracker. fine_tracking (bool): Whether to perform fine tracking. num_splits (int, optional): Ignored when query_points_list is provided. Kept for backward compatibility. Returns: tuple: A tuple containing the concatenated predicted tracks, visibility, and scores. """ # If query_points_list is not a list or tuple but a single tensor, handle it like the old version for backward compatibility if not isinstance(query_points_list, (list, tuple)): query_points = query_points_list if num_splits is None: num_splits = 1 query_points_list = torch.chunk(query_points, num_splits, dim=1) # Ensure query_points_list is a list for iteration (as torch.chunk returns a tuple) if isinstance(query_points_list, tuple): query_points_list = list(query_points_list) fine_pred_track_list = [] pred_vis_list = [] pred_score_list = [] for split_points in query_points_list: # Feed into track predictor for each split fine_pred_track, _, pred_vis, pred_score = track_predictor( images_feed, split_points, fmaps=fmaps_feed, fine_tracking=fine_tracking, fine_chunk=fine_chunk, ) fine_pred_track_list.append(fine_pred_track) pred_vis_list.append(pred_vis) pred_score_list.append(pred_score) # Concatenate the results from all splits fine_pred_track = torch.cat(fine_pred_track_list, dim=2) pred_vis = torch.cat(pred_vis_list, dim=2) if pred_score is not None: pred_score = torch.cat(pred_score_list, dim=2) else: pred_score = None return fine_pred_track, pred_vis, pred_score