# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. # # Modified from https://github.com/facebookresearch/vggt import torch import torch.nn as nn import torch.nn.functional as F from .track_modules.base_track_predictor import BaseTrackerPredictor from .track_modules.blocks import BasicEncoder, ShallowEncoder from .track_modules.track_refine import refine_track class TrackerPredictor(nn.Module): def __init__(self, **extra_args): super(TrackerPredictor, self).__init__() """ Initializes the tracker predictor. Both coarse_predictor and fine_predictor are constructed as a BaseTrackerPredictor, check track_modules/base_track_predictor.py Both coarse_fnet and fine_fnet are constructed as a 2D CNN network check track_modules/blocks.py for BasicEncoder and ShallowEncoder """ # Define coarse predictor configuration coarse_stride = 4 self.coarse_down_ratio = 2 # Create networks directly instead of using instantiate self.coarse_fnet = BasicEncoder(stride=coarse_stride) self.coarse_predictor = BaseTrackerPredictor(stride=coarse_stride) # Create fine predictor with stride = 1 self.fine_fnet = ShallowEncoder(stride=1) self.fine_predictor = BaseTrackerPredictor( stride=1, depth=4, corr_levels=3, corr_radius=3, latent_dim=32, hidden_size=256, fine=True, use_spaceatt=False, ) def forward( self, images, query_points, fmaps=None, coarse_iters=6, inference=True, fine_tracking=True, fine_chunk=40960, ): """ Args: images (torch.Tensor): Images as RGB, in the range of [0, 1], with a shape of B x S x 3 x H x W. query_points (torch.Tensor): 2D xy of query points, relative to top left, with a shape of B x N x 2. fmaps (torch.Tensor, optional): Precomputed feature maps. Defaults to None. coarse_iters (int, optional): Number of iterations for coarse prediction. Defaults to 6. inference (bool, optional): Whether to perform inference. Defaults to True. fine_tracking (bool, optional): Whether to perform fine tracking. Defaults to True. Returns: tuple: A tuple containing fine_pred_track, coarse_pred_track, pred_vis, and pred_score. """ if fmaps is None: batch_num, frame_num, image_dim, height, width = images.shape reshaped_image = images.reshape( batch_num * frame_num, image_dim, height, width ) fmaps = self.process_images_to_fmaps(reshaped_image) fmaps = fmaps.reshape( batch_num, frame_num, -1, fmaps.shape[-2], fmaps.shape[-1] ) if inference: torch.cuda.empty_cache() # Coarse prediction coarse_pred_track_lists, pred_vis = self.coarse_predictor( query_points=query_points, fmaps=fmaps, iters=coarse_iters, down_ratio=self.coarse_down_ratio, ) coarse_pred_track = coarse_pred_track_lists[-1] if inference: torch.cuda.empty_cache() if fine_tracking: # Refine the coarse prediction fine_pred_track, pred_score = refine_track( images, self.fine_fnet, self.fine_predictor, coarse_pred_track, compute_score=False, chunk=fine_chunk, ) if inference: torch.cuda.empty_cache() else: fine_pred_track = coarse_pred_track pred_score = torch.ones_like(pred_vis) return fine_pred_track, coarse_pred_track, pred_vis, pred_score def process_images_to_fmaps(self, images): """ This function processes images for inference. Args: images (torch.Tensor): The images to be processed with shape S x 3 x H x W. Returns: torch.Tensor: The processed feature maps. """ if self.coarse_down_ratio > 1: # whether or not scale down the input images to save memory fmaps = self.coarse_fnet( F.interpolate( images, scale_factor=1 / self.coarse_down_ratio, mode="bilinear", align_corners=True, ) ) else: fmaps = self.coarse_fnet(images) return fmaps