| | |
| | |
| |
|
| | |
| | |
| |
|
| | import torch |
| |
|
| | from cotracker.models.core.cotracker.cotracker import CoTracker2 |
| |
|
| |
|
| | def build_cotracker( |
| | checkpoint: str, |
| | ): |
| | if checkpoint is None: |
| | return build_cotracker() |
| | model_name = checkpoint.split("/")[-1].split(".")[0] |
| | if model_name == "cotracker": |
| | return build_cotracker(checkpoint=checkpoint) |
| | else: |
| | raise ValueError(f"Unknown model name {model_name}") |
| |
|
| |
|
| | def build_cotracker(checkpoint=None): |
| | cotracker = CoTracker2(stride=4, window_len=8, add_space_attn=True) |
| |
|
| | if checkpoint is not None: |
| | with open(checkpoint, "rb") as f: |
| | state_dict = torch.load(f, map_location="cpu") |
| | if "model" in state_dict: |
| | state_dict = state_dict["model"] |
| | cotracker.load_state_dict(state_dict) |
| | return cotracker |
| |
|