Spaces:
Sleeping
Sleeping
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # All rights reserved. | |
| # | |
| # This source code is licensed under the license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| # | |
| # Modified from https://github.com/facebookresearch/vggt | |
| import numpy as np | |
| import torch | |
| from .vggsfm_utils import ( | |
| build_vggsfm_tracker, | |
| calculate_index_mappings, | |
| extract_keypoints, | |
| generate_rank_by_dino, | |
| initialize_feature_extractors, | |
| predict_tracks_in_chunks, | |
| switch_tensor_order, | |
| ) | |
| def predict_tracks( | |
| images, | |
| conf=None, | |
| points_3d=None, | |
| max_query_pts=2048, | |
| query_frame_num=5, | |
| keypoint_extractor="aliked+sp", | |
| max_points_num=163840, | |
| fine_tracking=True, | |
| complete_non_vis=True, | |
| ): | |
| """ | |
| Predict tracks for the given images and masks. | |
| TODO: support non-square images | |
| TODO: support masks | |
| This function predicts the tracks for the given images and masks using the specified query method | |
| and track predictor. It finds query points, and predicts the tracks, visibility, and scores for the query frames. | |
| Args: | |
| images: Tensor of shape [S, 3, H, W] containing the input images. | |
| conf: Tensor of shape [S, 1, H, W] containing the confidence scores. Default is None. | |
| points_3d: Tensor containing 3D points. Default is None. | |
| max_query_pts: Maximum number of query points. Default is 2048. | |
| query_frame_num: Number of query frames to use. Default is 5. | |
| keypoint_extractor: Method for keypoint extraction. Default is "aliked+sp". | |
| max_points_num: Maximum number of points to process at once. Default is 163840. | |
| fine_tracking: Whether to use fine tracking. Default is True. | |
| complete_non_vis: Whether to augment non-visible frames. Default is True. | |
| Returns: | |
| pred_tracks: Numpy array containing the predicted tracks. | |
| pred_vis_scores: Numpy array containing the visibility scores for the tracks. | |
| pred_confs: Numpy array containing the confidence scores for the tracks. | |
| pred_points_3d: Numpy array containing the 3D points for the tracks. | |
| pred_colors: Numpy array containing the point colors for the tracks. (0, 255) | |
| """ | |
| device = images.device | |
| dtype = images.dtype | |
| tracker = build_vggsfm_tracker().to(device, dtype) | |
| # Find query frames | |
| query_frame_indexes = generate_rank_by_dino( | |
| images, query_frame_num=query_frame_num, device=device | |
| ) | |
| # Add the first image to the front if not already present | |
| if 0 in query_frame_indexes: | |
| query_frame_indexes.remove(0) | |
| query_frame_indexes = [0, *query_frame_indexes] | |
| # TODO: add the functionality to handle the masks | |
| keypoint_extractors = initialize_feature_extractors( | |
| max_query_pts, extractor_method=keypoint_extractor, device=device | |
| ) | |
| pred_tracks = [] | |
| pred_vis_scores = [] | |
| pred_confs = [] | |
| pred_points_3d = [] | |
| pred_colors = [] | |
| fmaps_for_tracker = tracker.process_images_to_fmaps(images) | |
| if fine_tracking: | |
| print("For faster inference, consider disabling fine_tracking") | |
| for query_index in query_frame_indexes: | |
| print(f"Predicting tracks for query frame {query_index}") | |
| pred_track, pred_vis, pred_conf, pred_point_3d, pred_color = _forward_on_query( | |
| query_index, | |
| images, | |
| conf, | |
| points_3d, | |
| fmaps_for_tracker, | |
| keypoint_extractors, | |
| tracker, | |
| max_points_num, | |
| fine_tracking, | |
| device, | |
| ) | |
| pred_tracks.append(pred_track) | |
| pred_vis_scores.append(pred_vis) | |
| pred_confs.append(pred_conf) | |
| pred_points_3d.append(pred_point_3d) | |
| pred_colors.append(pred_color) | |
| if complete_non_vis: | |
| pred_tracks, pred_vis_scores, pred_confs, pred_points_3d, pred_colors = ( | |
| _augment_non_visible_frames( | |
| pred_tracks, | |
| pred_vis_scores, | |
| pred_confs, | |
| pred_points_3d, | |
| pred_colors, | |
| images, | |
| conf, | |
| points_3d, | |
| fmaps_for_tracker, | |
| keypoint_extractors, | |
| tracker, | |
| max_points_num, | |
| fine_tracking, | |
| min_vis=500, | |
| non_vis_thresh=0.1, | |
| device=device, | |
| ) | |
| ) | |
| pred_tracks = np.concatenate(pred_tracks, axis=1) | |
| pred_vis_scores = np.concatenate(pred_vis_scores, axis=1) | |
| pred_confs = np.concatenate(pred_confs, axis=0) if pred_confs else None | |
| pred_points_3d = np.concatenate(pred_points_3d, axis=0) if pred_points_3d else None | |
| pred_colors = np.concatenate(pred_colors, axis=0) if pred_colors else None | |
| # from vggt.utils.visual_track import visualize_tracks_on_images | |
| # visualize_tracks_on_images(images[None], torch.from_numpy(pred_tracks[None]), torch.from_numpy(pred_vis_scores[None])>0.2, out_dir="track_visuals") | |
| return pred_tracks, pred_vis_scores, pred_confs, pred_points_3d, pred_colors | |
| def _forward_on_query( | |
| query_index, | |
| images, | |
| conf, | |
| points_3d, | |
| fmaps_for_tracker, | |
| keypoint_extractors, | |
| tracker, | |
| max_points_num, | |
| fine_tracking, | |
| device, | |
| ): | |
| """ | |
| Process a single query frame for track prediction. | |
| Args: | |
| query_index: Index of the query frame | |
| images: Tensor of shape [S, 3, H, W] containing the input images | |
| conf: Confidence tensor | |
| points_3d: 3D points tensor | |
| fmaps_for_tracker: Feature maps for the tracker | |
| keypoint_extractors: Initialized feature extractors | |
| tracker: VGG-SFM tracker | |
| max_points_num: Maximum number of points to process at once | |
| fine_tracking: Whether to use fine tracking | |
| device: Device to use for computation | |
| Returns: | |
| pred_track: Predicted tracks | |
| pred_vis: Visibility scores for the tracks | |
| pred_conf: Confidence scores for the tracks | |
| pred_point_3d: 3D points for the tracks | |
| pred_color: Point colors for the tracks (0, 255) | |
| """ | |
| frame_num, _, height, width = images.shape | |
| query_image = images[query_index] | |
| query_points = extract_keypoints( | |
| query_image, keypoint_extractors, round_keypoints=False | |
| ) | |
| query_points = query_points[:, torch.randperm(query_points.shape[1], device=device)] | |
| # Extract the color at the keypoint locations | |
| query_points_long = query_points.squeeze(0).round().long() | |
| pred_color = images[query_index][ | |
| :, query_points_long[:, 1], query_points_long[:, 0] | |
| ] | |
| pred_color = (pred_color.permute(1, 0).cpu().numpy() * 255).astype(np.uint8) | |
| # Query the confidence and points_3d at the keypoint locations | |
| if (conf is not None) and (points_3d is not None): | |
| assert height == width | |
| assert conf.shape[-2] == conf.shape[-1] | |
| assert conf.shape[:3] == points_3d.shape[:3] | |
| scale = conf.shape[-1] / width | |
| query_points_scaled = (query_points.squeeze(0) * scale).round().long() | |
| query_points_scaled = query_points_scaled.cpu().numpy() | |
| pred_conf = conf[query_index][ | |
| query_points_scaled[:, 1], query_points_scaled[:, 0] | |
| ] | |
| pred_point_3d = points_3d[query_index][ | |
| query_points_scaled[:, 1], query_points_scaled[:, 0] | |
| ] | |
| # heuristic to remove low confidence points | |
| # should I export this as an input parameter? | |
| valid_mask = pred_conf > 1.2 | |
| if valid_mask.sum() > 512: | |
| query_points = query_points[:, valid_mask] # Make sure shape is compatible | |
| pred_conf = pred_conf[valid_mask] | |
| pred_point_3d = pred_point_3d[valid_mask] | |
| pred_color = pred_color[valid_mask] | |
| else: | |
| pred_conf = None | |
| pred_point_3d = None | |
| reorder_index = calculate_index_mappings(query_index, frame_num, device=device) | |
| images_feed, fmaps_feed = switch_tensor_order( | |
| [images, fmaps_for_tracker], reorder_index, dim=0 | |
| ) | |
| images_feed = images_feed[None] # add batch dimension | |
| fmaps_feed = fmaps_feed[None] # add batch dimension | |
| all_points_num = images_feed.shape[1] * query_points.shape[1] | |
| # Don't need to be scared, this is just chunking to make GPU happy | |
| if all_points_num > max_points_num: | |
| num_splits = (all_points_num + max_points_num - 1) // max_points_num | |
| query_points = torch.chunk(query_points, num_splits, dim=1) | |
| else: | |
| query_points = [query_points] | |
| pred_track, pred_vis, _ = predict_tracks_in_chunks( | |
| tracker, images_feed, query_points, fmaps_feed, fine_tracking=fine_tracking | |
| ) | |
| pred_track, pred_vis = switch_tensor_order( | |
| [pred_track, pred_vis], reorder_index, dim=1 | |
| ) | |
| pred_track = pred_track.squeeze(0).float().cpu().numpy() | |
| pred_vis = pred_vis.squeeze(0).float().cpu().numpy() | |
| return pred_track, pred_vis, pred_conf, pred_point_3d, pred_color | |
| def _augment_non_visible_frames( | |
| pred_tracks: list, # β running list of np.ndarrays | |
| pred_vis_scores: list, # β running list of np.ndarrays | |
| pred_confs: list, # β running list of np.ndarrays for confidence scores | |
| pred_points_3d: list, # β running list of np.ndarrays for 3D points | |
| pred_colors: list, # β running list of np.ndarrays for colors | |
| images: torch.Tensor, | |
| conf, | |
| points_3d, | |
| fmaps_for_tracker, | |
| keypoint_extractors, | |
| tracker, | |
| max_points_num: int, | |
| fine_tracking: bool, | |
| *, | |
| min_vis: int = 500, | |
| non_vis_thresh: float = 0.1, | |
| device: torch.device = None, | |
| ): | |
| """ | |
| Augment tracking for frames with insufficient visibility. | |
| Args: | |
| pred_tracks: List of numpy arrays containing predicted tracks. | |
| pred_vis_scores: List of numpy arrays containing visibility scores. | |
| pred_confs: List of numpy arrays containing confidence scores. | |
| pred_points_3d: List of numpy arrays containing 3D points. | |
| pred_colors: List of numpy arrays containing point colors. | |
| images: Tensor of shape [S, 3, H, W] containing the input images. | |
| conf: Tensor of shape [S, 1, H, W] containing confidence scores | |
| points_3d: Tensor containing 3D points | |
| fmaps_for_tracker: Feature maps for the tracker | |
| keypoint_extractors: Initialized feature extractors | |
| tracker: VGG-SFM tracker | |
| max_points_num: Maximum number of points to process at once | |
| fine_tracking: Whether to use fine tracking | |
| min_vis: Minimum visibility threshold | |
| non_vis_thresh: Non-visibility threshold | |
| device: Device to use for computation | |
| Returns: | |
| Updated pred_tracks, pred_vis_scores, pred_confs, pred_points_3d, and pred_colors lists. | |
| """ | |
| last_query = -1 | |
| final_trial = False | |
| cur_extractors = keypoint_extractors # may be replaced on the final trial | |
| while True: | |
| # Visibility per frame | |
| vis_array = np.concatenate(pred_vis_scores, axis=1) | |
| # Count frames with sufficient visibility using numpy | |
| sufficient_vis_count = (vis_array > non_vis_thresh).sum(axis=-1) | |
| non_vis_frames = np.where(sufficient_vis_count < min_vis)[0].tolist() | |
| if len(non_vis_frames) == 0: | |
| break | |
| print("Processing non visible frames:", non_vis_frames) | |
| # Decide the frames & extractor for this round | |
| if non_vis_frames[0] == last_query: | |
| # Same frame failed twice - final "all-in" attempt | |
| final_trial = True | |
| cur_extractors = initialize_feature_extractors( | |
| 2048, extractor_method="sp+sift+aliked", device=device | |
| ) | |
| query_frame_list = non_vis_frames # blast them all at once | |
| else: | |
| query_frame_list = [non_vis_frames[0]] # Process one at a time | |
| last_query = non_vis_frames[0] | |
| # Run the tracker for every selected frame | |
| for query_index in query_frame_list: | |
| new_track, new_vis, new_conf, new_point_3d, new_color = _forward_on_query( | |
| query_index, | |
| images, | |
| conf, | |
| points_3d, | |
| fmaps_for_tracker, | |
| cur_extractors, | |
| tracker, | |
| max_points_num, | |
| fine_tracking, | |
| device, | |
| ) | |
| pred_tracks.append(new_track) | |
| pred_vis_scores.append(new_vis) | |
| pred_confs.append(new_conf) | |
| pred_points_3d.append(new_point_3d) | |
| pred_colors.append(new_color) | |
| if final_trial: | |
| break # Stop after final attempt | |
| return pred_tracks, pred_vis_scores, pred_confs, pred_points_3d, pred_colors | |