File size: 4,967 Bytes
b74998d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#
# Modified from https://github.com/facebookresearch/vggt

import torch
import torch.nn as nn
import torch.nn.functional as F

from .track_modules.base_track_predictor import BaseTrackerPredictor
from .track_modules.blocks import BasicEncoder, ShallowEncoder
from .track_modules.track_refine import refine_track


class TrackerPredictor(nn.Module):
    def __init__(self, **extra_args):
        super(TrackerPredictor, self).__init__()
        """

        Initializes the tracker predictor.



        Both coarse_predictor and fine_predictor are constructed as a BaseTrackerPredictor,

        check track_modules/base_track_predictor.py



        Both coarse_fnet and fine_fnet are constructed as a 2D CNN network

        check track_modules/blocks.py for BasicEncoder and ShallowEncoder

        """
        # Define coarse predictor configuration
        coarse_stride = 4
        self.coarse_down_ratio = 2

        # Create networks directly instead of using instantiate
        self.coarse_fnet = BasicEncoder(stride=coarse_stride)
        self.coarse_predictor = BaseTrackerPredictor(stride=coarse_stride)

        # Create fine predictor with stride = 1
        self.fine_fnet = ShallowEncoder(stride=1)
        self.fine_predictor = BaseTrackerPredictor(
            stride=1,
            depth=4,
            corr_levels=3,
            corr_radius=3,
            latent_dim=32,
            hidden_size=256,
            fine=True,
            use_spaceatt=False,
        )

    def forward(

        self,

        images,

        query_points,

        fmaps=None,

        coarse_iters=6,

        inference=True,

        fine_tracking=True,

        fine_chunk=40960,

    ):
        """

        Args:

            images (torch.Tensor): Images as RGB, in the range of [0, 1], with a shape of B x S x 3 x H x W.

            query_points (torch.Tensor): 2D xy of query points, relative to top left, with a shape of B x N x 2.

            fmaps (torch.Tensor, optional): Precomputed feature maps. Defaults to None.

            coarse_iters (int, optional): Number of iterations for coarse prediction. Defaults to 6.

            inference (bool, optional): Whether to perform inference. Defaults to True.

            fine_tracking (bool, optional): Whether to perform fine tracking. Defaults to True.



        Returns:

            tuple: A tuple containing fine_pred_track, coarse_pred_track, pred_vis, and pred_score.

        """

        if fmaps is None:
            batch_num, frame_num, image_dim, height, width = images.shape
            reshaped_image = images.reshape(
                batch_num * frame_num, image_dim, height, width
            )
            fmaps = self.process_images_to_fmaps(reshaped_image)
            fmaps = fmaps.reshape(
                batch_num, frame_num, -1, fmaps.shape[-2], fmaps.shape[-1]
            )

            if inference:
                torch.cuda.empty_cache()

        # Coarse prediction
        coarse_pred_track_lists, pred_vis = self.coarse_predictor(
            query_points=query_points,
            fmaps=fmaps,
            iters=coarse_iters,
            down_ratio=self.coarse_down_ratio,
        )
        coarse_pred_track = coarse_pred_track_lists[-1]

        if inference:
            torch.cuda.empty_cache()

        if fine_tracking:
            # Refine the coarse prediction
            fine_pred_track, pred_score = refine_track(
                images,
                self.fine_fnet,
                self.fine_predictor,
                coarse_pred_track,
                compute_score=False,
                chunk=fine_chunk,
            )

            if inference:
                torch.cuda.empty_cache()
        else:
            fine_pred_track = coarse_pred_track
            pred_score = torch.ones_like(pred_vis)

        return fine_pred_track, coarse_pred_track, pred_vis, pred_score

    def process_images_to_fmaps(self, images):
        """

        This function processes images for inference.



        Args:

            images (torch.Tensor): The images to be processed with shape S x 3 x H x W.



        Returns:

            torch.Tensor: The processed feature maps.

        """
        if self.coarse_down_ratio > 1:
            # whether or not scale down the input images to save memory
            fmaps = self.coarse_fnet(
                F.interpolate(
                    images,
                    scale_factor=1 / self.coarse_down_ratio,
                    mode="bilinear",
                    align_corners=True,
                )
            )
        else:
            fmaps = self.coarse_fnet(images)

        return fmaps