Spaces:
Sleeping
Sleeping
File size: 4,967 Bytes
b74998d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 |
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#
# Modified from https://github.com/facebookresearch/vggt
import torch
import torch.nn as nn
import torch.nn.functional as F
from .track_modules.base_track_predictor import BaseTrackerPredictor
from .track_modules.blocks import BasicEncoder, ShallowEncoder
from .track_modules.track_refine import refine_track
class TrackerPredictor(nn.Module):
def __init__(self, **extra_args):
super(TrackerPredictor, self).__init__()
"""
Initializes the tracker predictor.
Both coarse_predictor and fine_predictor are constructed as a BaseTrackerPredictor,
check track_modules/base_track_predictor.py
Both coarse_fnet and fine_fnet are constructed as a 2D CNN network
check track_modules/blocks.py for BasicEncoder and ShallowEncoder
"""
# Define coarse predictor configuration
coarse_stride = 4
self.coarse_down_ratio = 2
# Create networks directly instead of using instantiate
self.coarse_fnet = BasicEncoder(stride=coarse_stride)
self.coarse_predictor = BaseTrackerPredictor(stride=coarse_stride)
# Create fine predictor with stride = 1
self.fine_fnet = ShallowEncoder(stride=1)
self.fine_predictor = BaseTrackerPredictor(
stride=1,
depth=4,
corr_levels=3,
corr_radius=3,
latent_dim=32,
hidden_size=256,
fine=True,
use_spaceatt=False,
)
def forward(
self,
images,
query_points,
fmaps=None,
coarse_iters=6,
inference=True,
fine_tracking=True,
fine_chunk=40960,
):
"""
Args:
images (torch.Tensor): Images as RGB, in the range of [0, 1], with a shape of B x S x 3 x H x W.
query_points (torch.Tensor): 2D xy of query points, relative to top left, with a shape of B x N x 2.
fmaps (torch.Tensor, optional): Precomputed feature maps. Defaults to None.
coarse_iters (int, optional): Number of iterations for coarse prediction. Defaults to 6.
inference (bool, optional): Whether to perform inference. Defaults to True.
fine_tracking (bool, optional): Whether to perform fine tracking. Defaults to True.
Returns:
tuple: A tuple containing fine_pred_track, coarse_pred_track, pred_vis, and pred_score.
"""
if fmaps is None:
batch_num, frame_num, image_dim, height, width = images.shape
reshaped_image = images.reshape(
batch_num * frame_num, image_dim, height, width
)
fmaps = self.process_images_to_fmaps(reshaped_image)
fmaps = fmaps.reshape(
batch_num, frame_num, -1, fmaps.shape[-2], fmaps.shape[-1]
)
if inference:
torch.cuda.empty_cache()
# Coarse prediction
coarse_pred_track_lists, pred_vis = self.coarse_predictor(
query_points=query_points,
fmaps=fmaps,
iters=coarse_iters,
down_ratio=self.coarse_down_ratio,
)
coarse_pred_track = coarse_pred_track_lists[-1]
if inference:
torch.cuda.empty_cache()
if fine_tracking:
# Refine the coarse prediction
fine_pred_track, pred_score = refine_track(
images,
self.fine_fnet,
self.fine_predictor,
coarse_pred_track,
compute_score=False,
chunk=fine_chunk,
)
if inference:
torch.cuda.empty_cache()
else:
fine_pred_track = coarse_pred_track
pred_score = torch.ones_like(pred_vis)
return fine_pred_track, coarse_pred_track, pred_vis, pred_score
def process_images_to_fmaps(self, images):
"""
This function processes images for inference.
Args:
images (torch.Tensor): The images to be processed with shape S x 3 x H x W.
Returns:
torch.Tensor: The processed feature maps.
"""
if self.coarse_down_ratio > 1:
# whether or not scale down the input images to save memory
fmaps = self.coarse_fnet(
F.interpolate(
images,
scale_factor=1 / self.coarse_down_ratio,
mode="bilinear",
align_corners=True,
)
)
else:
fmaps = self.coarse_fnet(images)
return fmaps
|