from torch import nn from .cotracker2_utils.predictor import CoTrackerPredictor class CoTracker2(nn.Module): def __init__(self, args): super().__init__() self.model = CoTrackerPredictor(args.patch_size, args.wind_size) def forward(self, video, queries, backward_tracking, cache_features=False): return self.model(video, queries=queries, backward_tracking=backward_tracking, cache_features=cache_features)