File size: 457 Bytes
ef296aa | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 | from torch import nn
from .cotracker2_utils.predictor import CoTrackerPredictor
class CoTracker2(nn.Module):
def __init__(self, args):
super().__init__()
self.model = CoTrackerPredictor(args.patch_size, args.wind_size)
def forward(self, video, queries, backward_tracking, cache_features=False):
return self.model(video, queries=queries, backward_tracking=backward_tracking, cache_features=cache_features)
|