File size: 4,450 Bytes
b74998d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

import torch.nn as nn

from .dpt_head import DPTHead
from .track_modules.base_track_predictor import BaseTrackerPredictor


class TrackHead(nn.Module):
    """

    Track head that uses DPT head to process tokens and BaseTrackerPredictor for tracking.

    The tracking is performed iteratively, refining predictions over multiple iterations.

    """

    def __init__(

        self,

        dim_in,

        patch_size=14,

        features=128,

        iters=4,

        predict_conf=True,

        stride=2,

        corr_levels=7,

        corr_radius=4,

        hidden_size=384,

    ):
        """

        Initialize the TrackHead module.



        Args:

            dim_in (int): Input dimension of tokens from the backbone.

            patch_size (int): Size of image patches used in the vision transformer.

            features (int): Number of feature channels in the feature extractor output.

            iters (int): Number of refinement iterations for tracking predictions.

            predict_conf (bool): Whether to predict confidence scores for tracked points.

            stride (int): Stride value for the tracker predictor.

            corr_levels (int): Number of correlation pyramid levels

            corr_radius (int): Radius for correlation computation, controlling the search area.

            hidden_size (int): Size of hidden layers in the tracker network.

        """
        super().__init__()

        self.patch_size = patch_size

        # Feature extractor based on DPT architecture
        # Processes tokens into feature maps for tracking
        self.feature_extractor = DPTHead(
            dim_in=dim_in,
            patch_size=patch_size,
            features=features,
            feature_only=True,  # Only output features, no activation
            down_ratio=2,  # Reduces spatial dimensions by factor of 2
            pos_embed=False,
        )

        # Tracker module that predicts point trajectories
        # Takes feature maps and predicts coordinates and visibility
        self.tracker = BaseTrackerPredictor(
            latent_dim=features,  # Match the output_dim of feature extractor
            predict_conf=predict_conf,
            stride=stride,
            corr_levels=corr_levels,
            corr_radius=corr_radius,
            hidden_size=hidden_size,
        )

        self.iters = iters

    def forward(

        self,

        aggregated_tokens_list,

        images,

        patch_start_idx,

        query_points=None,

        iters=None,

    ):
        """

        Forward pass of the TrackHead.



        Args:

            aggregated_tokens_list (list): List of aggregated tokens from the backbone.

            images (torch.Tensor): Input images of shape (B, S, C, H, W) where:

                                   B = batch size, S = sequence length.

            patch_start_idx (int): Starting index for patch tokens.

            query_points (torch.Tensor, optional): Initial query points to track.

                                                  If None, points are initialized by the tracker.

            iters (int, optional): Number of refinement iterations. If None, uses self.iters.



        Returns:

            tuple:

                - coord_preds (torch.Tensor): Predicted coordinates for tracked points.

                - vis_scores (torch.Tensor): Visibility scores for tracked points.

                - conf_scores (torch.Tensor): Confidence scores for tracked points (if predict_conf=True).

        """
        B, S, _, H, W = images.shape

        # Extract features from tokens
        # feature_maps has shape (B, S, C, H//2, W//2) due to down_ratio=2
        feature_maps = self.feature_extractor(
            aggregated_tokens_list, images, patch_start_idx
        )

        # Use default iterations if not specified
        if iters is None:
            iters = self.iters

        # Perform tracking using the extracted features
        coord_preds, vis_scores, conf_scores = self.tracker(
            query_points=query_points,
            fmaps=feature_maps,
            iters=iters,
        )

        return coord_preds, vis_scores, conf_scores