Spaces:
Sleeping
Sleeping
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # All rights reserved. | |
| # | |
| # This source code is licensed under the license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| # | |
| # Modified from https://github.com/facebookresearch/vggt | |
| import logging | |
| import warnings | |
| import torch | |
| import torch.nn.functional as F | |
| from lightglue import ALIKED, SIFT, SuperPoint | |
| from .vggsfm_tracker import TrackerPredictor | |
| # Suppress verbose logging from dependencies | |
| logging.getLogger("dinov2").setLevel(logging.WARNING) | |
| warnings.filterwarnings("ignore", message="xFormers is available") | |
| warnings.filterwarnings("ignore", message="dinov2") | |
| # Constants | |
| _RESNET_MEAN = [0.485, 0.456, 0.406] | |
| _RESNET_STD = [0.229, 0.224, 0.225] | |
| def build_vggsfm_tracker(model_path=None): | |
| """ | |
| Build and initialize the VGGSfM tracker. | |
| Args: | |
| model_path: Path to the model weights file. If None, weights are downloaded from HuggingFace. | |
| Returns: | |
| Initialized tracker model in eval mode. | |
| """ | |
| tracker = TrackerPredictor() | |
| if model_path is None: | |
| default_url = ( | |
| "https://huggingface.co/facebook/VGGSfM/resolve/main/vggsfm_v2_tracker.pt" | |
| ) | |
| tracker.load_state_dict(torch.hub.load_state_dict_from_url(default_url)) | |
| else: | |
| tracker.load_state_dict(torch.load(model_path)) | |
| tracker.eval() | |
| return tracker | |
| def generate_rank_by_dino( | |
| images, | |
| query_frame_num, | |
| image_size=336, | |
| model_name="dinov2_vitb14_reg", | |
| device="cuda", | |
| spatial_similarity=False, | |
| ): | |
| """ | |
| Generate a ranking of frames using DINO ViT features. | |
| Args: | |
| images: Tensor of shape (S, 3, H, W) with values in range [0, 1] | |
| query_frame_num: Number of frames to select | |
| image_size: Size to resize images to before processing | |
| model_name: Name of the DINO model to use | |
| device: Device to run the model on | |
| spatial_similarity: Whether to use spatial token similarity or CLS token similarity | |
| Returns: | |
| List of frame indices ranked by their representativeness | |
| """ | |
| # Resize images to the target size | |
| images = F.interpolate( | |
| images, (image_size, image_size), mode="bilinear", align_corners=False | |
| ) | |
| # Load DINO model | |
| dino_v2_model = torch.hub.load("facebookresearch/dinov2", model_name) | |
| dino_v2_model.eval() | |
| dino_v2_model = dino_v2_model.to(device) | |
| # Normalize images using ResNet normalization | |
| resnet_mean = torch.tensor(_RESNET_MEAN, device=device).view(1, 3, 1, 1) | |
| resnet_std = torch.tensor(_RESNET_STD, device=device).view(1, 3, 1, 1) | |
| images_resnet_norm = (images - resnet_mean) / resnet_std | |
| with torch.no_grad(): | |
| frame_feat = dino_v2_model(images_resnet_norm, is_training=True) | |
| # Process features based on similarity type | |
| if spatial_similarity: | |
| frame_feat = frame_feat["x_norm_patchtokens"] | |
| frame_feat_norm = F.normalize(frame_feat, p=2, dim=1) | |
| # Compute the similarity matrix | |
| frame_feat_norm = frame_feat_norm.permute(1, 0, 2) | |
| similarity_matrix = torch.bmm( | |
| frame_feat_norm, frame_feat_norm.transpose(-1, -2) | |
| ) | |
| similarity_matrix = similarity_matrix.mean(dim=0) | |
| else: | |
| frame_feat = frame_feat["x_norm_clstoken"] | |
| frame_feat_norm = F.normalize(frame_feat, p=2, dim=1) | |
| similarity_matrix = torch.mm(frame_feat_norm, frame_feat_norm.transpose(-1, -2)) | |
| distance_matrix = 100 - similarity_matrix.clone() | |
| # Ignore self-pairing | |
| similarity_matrix.fill_diagonal_(-100) | |
| similarity_sum = similarity_matrix.sum(dim=1) | |
| # Find the most common frame | |
| most_common_frame_index = torch.argmax(similarity_sum).item() | |
| # Conduct FPS sampling starting from the most common frame | |
| fps_idx = farthest_point_sampling( | |
| distance_matrix, query_frame_num, most_common_frame_index | |
| ) | |
| # Clean up all tensors and models to free memory | |
| del frame_feat, frame_feat_norm, similarity_matrix, distance_matrix | |
| del dino_v2_model | |
| torch.cuda.empty_cache() | |
| return fps_idx | |
| def farthest_point_sampling(distance_matrix, num_samples, most_common_frame_index=0): | |
| """ | |
| Farthest point sampling algorithm to select diverse frames. | |
| Args: | |
| distance_matrix: Matrix of distances between frames | |
| num_samples: Number of frames to select | |
| most_common_frame_index: Index of the first frame to select | |
| Returns: | |
| List of selected frame indices | |
| """ | |
| distance_matrix = distance_matrix.clamp(min=0) | |
| N = distance_matrix.size(0) | |
| # Initialize with the most common frame | |
| selected_indices = [most_common_frame_index] | |
| check_distances = distance_matrix[selected_indices] | |
| while len(selected_indices) < num_samples: | |
| # Find the farthest point from the current set of selected points | |
| farthest_point = torch.argmax(check_distances) | |
| selected_indices.append(farthest_point.item()) | |
| check_distances = distance_matrix[farthest_point] | |
| # Mark already selected points to avoid selecting them again | |
| check_distances[selected_indices] = 0 | |
| # Break if all points have been selected | |
| if len(selected_indices) == N: | |
| break | |
| return selected_indices | |
| def calculate_index_mappings(query_index, S, device=None): | |
| """ | |
| Construct an order that switches [query_index] and [0] | |
| so that the content of query_index would be placed at [0]. | |
| Args: | |
| query_index: Index to swap with 0 | |
| S: Total number of elements | |
| device: Device to place the tensor on | |
| Returns: | |
| Tensor of indices with the swapped order | |
| """ | |
| new_order = torch.arange(S) | |
| new_order[0] = query_index | |
| new_order[query_index] = 0 | |
| if device is not None: | |
| new_order = new_order.to(device) | |
| return new_order | |
| def switch_tensor_order(tensors, order, dim=1): | |
| """ | |
| Reorder tensors along a specific dimension according to the given order. | |
| Args: | |
| tensors: List of tensors to reorder | |
| order: Tensor of indices specifying the new order | |
| dim: Dimension along which to reorder | |
| Returns: | |
| List of reordered tensors | |
| """ | |
| return [ | |
| torch.index_select(tensor, dim, order) if tensor is not None else None | |
| for tensor in tensors | |
| ] | |
| def initialize_feature_extractors( | |
| max_query_num, det_thres=0.005, extractor_method="aliked", device="cuda" | |
| ): | |
| """ | |
| Initialize feature extractors that can be reused based on a method string. | |
| Args: | |
| max_query_num: Maximum number of keypoints to extract | |
| det_thres: Detection threshold for keypoint extraction | |
| extractor_method: String specifying which extractors to use (e.g., "aliked", "sp+sift", "aliked+sp+sift") | |
| device: Device to run extraction on | |
| Returns: | |
| Dictionary of initialized extractors | |
| """ | |
| extractors = {} | |
| methods = extractor_method.lower().split("+") | |
| for method in methods: | |
| method = method.strip() | |
| if method == "aliked": | |
| aliked_extractor = ALIKED( | |
| max_num_keypoints=max_query_num, detection_threshold=det_thres | |
| ) | |
| extractors["aliked"] = aliked_extractor.to(device).eval() | |
| elif method == "sp": | |
| sp_extractor = SuperPoint( | |
| max_num_keypoints=max_query_num, detection_threshold=det_thres | |
| ) | |
| extractors["sp"] = sp_extractor.to(device).eval() | |
| elif method == "sift": | |
| sift_extractor = SIFT(max_num_keypoints=max_query_num) | |
| extractors["sift"] = sift_extractor.to(device).eval() | |
| else: | |
| print(f"Warning: Unknown feature extractor '{method}', ignoring.") | |
| if not extractors: | |
| print( | |
| f"Warning: No valid extractors found in '{extractor_method}'. Using ALIKED by default." | |
| ) | |
| aliked_extractor = ALIKED( | |
| max_num_keypoints=max_query_num, detection_threshold=det_thres | |
| ) | |
| extractors["aliked"] = aliked_extractor.to(device).eval() | |
| return extractors | |
| def extract_keypoints(query_image, extractors, round_keypoints=True): | |
| """ | |
| Extract keypoints using pre-initialized feature extractors. | |
| Args: | |
| query_image: Input image tensor (3xHxW, range [0, 1]) | |
| extractors: Dictionary of initialized extractors | |
| Returns: | |
| Tensor of keypoint coordinates (1xNx2) | |
| """ | |
| query_points = None | |
| with torch.no_grad(): | |
| for extractor_name, extractor in extractors.items(): | |
| query_points_data = extractor.extract(query_image, invalid_mask=None) | |
| extractor_points = query_points_data["keypoints"] | |
| if round_keypoints: | |
| extractor_points = extractor_points.round() | |
| if query_points is not None: | |
| query_points = torch.cat([query_points, extractor_points], dim=1) | |
| else: | |
| query_points = extractor_points | |
| return query_points | |
| def predict_tracks_in_chunks( | |
| track_predictor, | |
| images_feed, | |
| query_points_list, | |
| fmaps_feed, | |
| fine_tracking, | |
| num_splits=None, | |
| fine_chunk=40960, | |
| ): | |
| """ | |
| Process a list of query points to avoid memory issues. | |
| Args: | |
| track_predictor (object): The track predictor object used for predicting tracks. | |
| images_feed (torch.Tensor): A tensor of shape (B, T, C, H, W) representing a batch of images. | |
| query_points_list (list or tuple): A list/tuple of tensors, each of shape (B, Ni, 2) representing chunks of query points. | |
| fmaps_feed (torch.Tensor): A tensor of feature maps for the tracker. | |
| fine_tracking (bool): Whether to perform fine tracking. | |
| num_splits (int, optional): Ignored when query_points_list is provided. Kept for backward compatibility. | |
| Returns: | |
| tuple: A tuple containing the concatenated predicted tracks, visibility, and scores. | |
| """ | |
| # If query_points_list is not a list or tuple but a single tensor, handle it like the old version for backward compatibility | |
| if not isinstance(query_points_list, (list, tuple)): | |
| query_points = query_points_list | |
| if num_splits is None: | |
| num_splits = 1 | |
| query_points_list = torch.chunk(query_points, num_splits, dim=1) | |
| # Ensure query_points_list is a list for iteration (as torch.chunk returns a tuple) | |
| if isinstance(query_points_list, tuple): | |
| query_points_list = list(query_points_list) | |
| fine_pred_track_list = [] | |
| pred_vis_list = [] | |
| pred_score_list = [] | |
| for split_points in query_points_list: | |
| # Feed into track predictor for each split | |
| fine_pred_track, _, pred_vis, pred_score = track_predictor( | |
| images_feed, | |
| split_points, | |
| fmaps=fmaps_feed, | |
| fine_tracking=fine_tracking, | |
| fine_chunk=fine_chunk, | |
| ) | |
| fine_pred_track_list.append(fine_pred_track) | |
| pred_vis_list.append(pred_vis) | |
| pred_score_list.append(pred_score) | |
| # Concatenate the results from all splits | |
| fine_pred_track = torch.cat(fine_pred_track_list, dim=2) | |
| pred_vis = torch.cat(pred_vis_list, dim=2) | |
| if pred_score is not None: | |
| pred_score = torch.cat(pred_score_list, dim=2) | |
| else: | |
| pred_score = None | |
| return fine_pred_track, pred_vis, pred_score | |