Spaces:
Running
on
Zero
Running
on
Zero
| import spaces | |
| from dust3r.models.blocks import PositionGetter | |
| from dust3r.post_process import estimate_focal_knowing_depth | |
| from must3r.model.blocks.attention import has_xformers, toggle_memory_efficient_attention | |
| toggle_memory_efficient_attention(enabled = has_xformers) | |
| from hydra import compose | |
| from hydra.utils import instantiate | |
| from sam2.build_sam import build_sam2_video_predictor | |
| from einops import rearrange, repeat | |
| from collections import OrderedDict | |
| import copy | |
| import torch | |
| from tqdm import tqdm | |
| import json | |
| from sam2.modeling.sam2_base import NO_OBJ_SCORE, SAM2Base | |
| from sam2.utils.misc import concat_points, fill_holes_in_mask_scores, load_video_frames | |
| import torch.nn.functional as F | |
| from torchvision.transforms import functional as TF | |
| import torchvision.transforms as T | |
| import numpy as np | |
| from torch import nn | |
| from training_utils import load_checkpoint, BatchedVideoDatapoint, positional_encoding, postprocess_must3r_output | |
| from sam2.modeling.sam2_utils import LayerNorm2d | |
| import os | |
| MUST3R_SIZE = 512 | |
| def get_views(pil_imgs): | |
| ## pil_imgs = a list of PIL Image | |
| from data import load_images | |
| views, resize_funcs = load_images(pil_imgs, size = MUST3R_SIZE, patch_size = 16) | |
| return views, resize_funcs | |
| def prepare_sam2_inputs(views, pil_imgs, resize_funcs): | |
| image_transform = T.Compose([ | |
| T.Resize((1024, 1024), interpolation = T.InterpolationMode.BILINEAR), | |
| T.Normalize(mean = (0.485, 0.456, 0.406), std = (0.229, 0.224, 0.225)) | |
| ]) | |
| images = resize_funcs[0].transforms[0](torch.stack([TF.to_tensor(p) for p in pil_imgs], dim = 0).cpu()) | |
| sam2_input_images = image_transform(images) # normalize to [0, 1] range and then normalize with ImageNet stats | |
| return sam2_input_images, images | |
| def must3r_features_and_output(views, device = 'cuda'): | |
| import functools | |
| from must3r.model import load_model, get_pointmaps_activation | |
| from must3r.demo.gradio import get_args_parser, main_demo, get_reconstructed_scene | |
| from must3r.demo.inference import must3r_inference_video, slam_is_keyframe, slam_update_scene_state, must3r_inference | |
| from must3r.slam.model import get_searcher | |
| from must3r.model import ActivationType | |
| from must3r.demo.inference import get_pointmaps_activation | |
| from must3r.tools.geometry import apply_exp_to_norm | |
| cmd_params = ["--weights", "./private/MUSt3R_512.pth", "--retrieval", "./private/MUSt3R_512_retrieval_trainingfree.pth", "--image_size", "512", "--amp", "bf16", "--viser", "--allow_local_files", "--device", device] | |
| parser = get_args_parser() | |
| args = parser.parse_args(cmd_params) | |
| weights_path = args.weights | |
| model = load_model(weights_path, encoder=args.encoder, decoder=args.decoder, device=args.device, | |
| img_size=args.image_size, memory_mode=args.memory_mode, verbose=args.verbose) | |
| model = [m.eval() for m in model] | |
| assert model[0].patch_size == 16 | |
| assert get_pointmaps_activation(model[1]) == ActivationType.NORM_EXP | |
| # model_224 = load_model("./private/MUSt3R_224.pth", encoder=args.encoder, decoder=args.decoder, device=args.device, | |
| # img_size = 224, memory_mode=args.memory_mode, verbose=args.verbose) | |
| # model_224 = [m.eval() for m in model_224] | |
| # assert get_pointmaps_activation(model_224[1]) == ActivationType.NORM_EXP | |
| retrieval = "./private/MUSt3R_512_retrieval_trainingfree.pth" | |
| retrieval_224 = "./private/MUSt3R_224_retrieval_trainingfree.pth" | |
| verbose = False | |
| image_size = 512 | |
| image_size_224 = 224 | |
| amp = "bf16" | |
| amp_224 = "fp16" | |
| max_bs = 1 | |
| num_refinements_iterations = 0 | |
| execution_mode = "vidslam" | |
| num_mem_images = 0 | |
| render_once = False | |
| vidseq_local_context_size = 0 | |
| keyframe_interval = 0 | |
| slam_local_context_size = 0 | |
| subsample = 2 | |
| min_conf_keyframe = 1.5 | |
| keyframe_overlap_thr = 0.05 | |
| overlap_percentile = 85 | |
| min_conf_thr = 3 | |
| as_pointcloud = True | |
| transparent_cams = False | |
| local_pointmaps = False | |
| cam_size = 0.05 | |
| camera_conf_thr = 1.5 | |
| local_context_size = slam_local_context_size | |
| overlap_mode = "nn-norm" | |
| assert MUST3R_SIZE == 512 | |
| model[1].recorded_feats = [] | |
| model[1].all_feats = [] | |
| is_keyframe_function = functools.partial(slam_is_keyframe, subsample, min_conf_keyframe, keyframe_overlap_thr, overlap_percentile, overlap_mode) | |
| scene_state = get_searcher("kdtree-scipy-quadrant_x2") | |
| scene_state_update_function = functools.partial(slam_update_scene_state, subsample, min_conf_keyframe) | |
| must3r_inference_video((model), device, image_size, amp, filelist = None, max_bs = max_bs, init_num_images = 2, batch_num_views = 1, | |
| viser_server = None, num_refinements_iterations = num_refinements_iterations, | |
| local_context_size = local_context_size, is_keyframe_function = is_keyframe_function, | |
| scene_state = scene_state, scene_state_update_function = scene_state_update_function, | |
| verbose = True, views = views) | |
| must3r_feats = torch.cat(model[1].recorded_feats, dim = 0).to(device) | |
| must3r_outputs = model[1]._compute_prediction_head( | |
| torch.stack([torch.from_numpy(view['true_shape']).squeeze() for view in views]).to(device)[:, None], | |
| len(views), | |
| 1, | |
| [must3r_feats], | |
| norm = False | |
| ).squeeze() | |
| must3r_feats = [[f[0], f[4], f[7], f[11]] for f in model[1].all_feats] | |
| must3r_feats = [torch.cat(f, dim = 0).to(device) for f in zip(*must3r_feats)] | |
| from einops import rearrange | |
| must3r_feats = [ | |
| rearrange(f, 'b (h w) c -> b c h w', h = views[0]['true_shape'][0] // 16, w = views[0]['true_shape'][1] // 16).cpu() | |
| for f in must3r_feats | |
| ] | |
| from training_utils import load_checkpoint, BatchedVideoDatapoint, positional_encoding, postprocess_must3r_output | |
| from must3r.model import ActivationType, apply_activation | |
| # must3r_outputs = postprocess_must3r_output(must3r_outputs, pointmaps_activation = ActivationType.NORM_EXP, compute_cam = True) | |
| must3r_output_all = [] | |
| for f in tqdm(must3r_outputs): | |
| must3r_output_all.append(postprocess_must3r_output(f.cpu()[None], pointmaps_activation = ActivationType.NORM_EXP, compute_cam = True)) | |
| must3r_outputs = {'pts3d': torch.cat([c['pts3d'] for c in must3r_output_all], dim = 0).squeeze(), | |
| 'ray_plucker': torch.cat([c['ray_plucker'] for c in must3r_output_all], dim = 0).squeeze()} | |
| must3r_outputs = {k: v.cpu() for k, v in must3r_outputs.items()} | |
| return must3r_feats, must3r_outputs | |
| class FeatureFusion(nn.Module): | |
| def __init__(self, cross_attn_blocks_3d, in_channels_2d = 256, in_channels_3d = 768): | |
| super().__init__() | |
| from einops.layers.torch import Rearrange | |
| import copy | |
| self.freqs = 6 | |
| self.position_getter = PositionGetter() | |
| self.feat_conv_3d_224 = nn.ModuleList([ | |
| copy.deepcopy(block) for block in cross_attn_blocks_3d | |
| ] + [nn.Linear(in_features = 1024, out_features = 768)]) | |
| self.feat_conv_3d_512 = nn.ModuleList([ | |
| copy.deepcopy(block) for block in cross_attn_blocks_3d | |
| ] + [nn.Linear(in_features = 1024, out_features = 768)]) | |
| self.out = nn.Conv2d(in_channels = 768, out_channels = in_channels_2d, kernel_size = 3, padding = 1) | |
| self.merge = nn.Conv2d(in_channels = in_channels_2d * 2, out_channels = in_channels_2d, kernel_size = 1, padding = 0) | |
| self.explicit_3d_embedding = nn.Conv2d(in_channels = 3 * (2 * self.freqs + 1) + 6, out_channels = 768, kernel_size = 16, padding = 0, stride = 16) | |
| def forward(self, feat_2d, feat_3d, explicit_3d = None, must3r_size = 512): | |
| refinenets_3d = self.feat_conv_3d_224 if must3r_size == 224 else self.feat_conv_3d_512 | |
| assert len(feat_3d) == 4, f'Expected 4 levels of 3D features, got {len(feat_3d)}' | |
| explicit_3d = torch.cat((positional_encoding(explicit_3d[:, :3], self.freqs, dim = 1), explicit_3d[:, 3:]), dim = 1) | |
| explicit_3d = self.explicit_3d_embedding(explicit_3d) | |
| pe_3d = rearrange(explicit_3d, 'b c h w -> b (h w) c') | |
| B = pe_3d.shape[0] | |
| assert B == 1 | |
| pe_2d = self.position_getter(B, explicit_3d.shape[2], explicit_3d.shape[3], device = explicit_3d.device) | |
| feat_3d = [rearrange(f, 'b c h w -> b (h w) c') for f in feat_3d] | |
| N = feat_3d[0].shape[1] | |
| ca_attn_mask = torch.ones((B, 1, N, N * B), dtype = torch.bool, device = feat_3d[0].device) | |
| for i in range(B): | |
| ca_attn_mask[i, :, :, :(i + 1) * N] = False | |
| feat_3d_post = feat_3d[0] | |
| for i in range(len(feat_3d)): | |
| if i == 0: | |
| feat_3d_post = refinenets_3d[-1](feat_3d_post) + pe_3d | |
| feat_3d_post = refinenets_3d[i](x = feat_3d_post, y = feat_3d_post, xpos = pe_2d) | |
| else: | |
| feat_3d_post = refinenets_3d[i](x = feat_3d_post + pe_3d, y = repeat(feat_3d[i] + pe_3d, 'b n c -> k (b n) c', k = B), xpos = pe_2d, ca_attn_mask = ca_attn_mask) | |
| feat_3d_post = self.out(F.interpolate(rearrange(feat_3d_post, 'b (h w) c -> b c h w', b = B, h = explicit_3d.shape[2], w = explicit_3d.shape[3]), size = feat_2d.shape[-2:], mode = 'bilinear', align_corners = False)) | |
| feat_merged = self.merge(torch.cat([feat_3d_post, feat_2d], dim = 1)) | |
| return feat_merged | |
| def get_must3r_cross_attn_layers(device = 'cuda'): | |
| from must3r.model import load_model | |
| from must3r.demo.gradio import get_args_parser | |
| cmd_params = ["--weights", "./private/MUSt3R_512.pth", "--retrieval", "./private/MUSt3R_512_retrieval_trainingfree.pth", "--image_size", "512", "--amp", "bf16", "--viser", "--allow_local_files", "--device", device] | |
| parser = get_args_parser() | |
| args = parser.parse_args(cmd_params) | |
| model = load_model(args.weights, encoder=args.encoder, decoder=args.decoder, device=args.device, | |
| img_size=args.image_size, memory_mode=args.memory_mode, verbose=args.verbose) | |
| return model[1].blocks_dec | |
| def get_predictors(device = 'cuda'): | |
| predictor_original = build_sam2_video_predictor("configs/sam2.1/sam2.1_hiera_l.yaml", "./sam2-src/checkpoints/sam2.1_hiera_large.pt").to(device).eval() | |
| predictor = build_sam2_video_predictor("configs/sam2.1/sam2.1_hiera_l_3d.yaml").to(device).eval() | |
| cross_attn_blocks_3d = get_must3r_cross_attn_layers(device = device) | |
| predictor.fusion_3d = FeatureFusion(cross_attn_blocks_3d = [copy.deepcopy(cross_attn_blocks_3d[i]) for i in [0, 4, 7, 11]]) | |
| predictor = load_checkpoint(predictor, torch.load('./private/sam2.1-must3r-fixed-vision-v1-decomp-standalone-regional-best-2.7851.pt', map_location = 'cpu')) | |
| return predictor_original.cpu(), predictor.cpu() | |
| def get_image_feature( | |
| predictor: SAM2Base, | |
| images: torch.Tensor, | |
| device = 'cuda' | |
| ): | |
| backbone_out = predictor.forward_image(images) | |
| backbone_out = { | |
| "backbone_fpn": backbone_out["backbone_fpn"].copy(), | |
| "vision_pos_enc": backbone_out["vision_pos_enc"].copy(), | |
| } | |
| backbone_out, vision_feats, vision_pos_embeds, feat_sizes = predictor._prepare_backbone_features(backbone_out) | |
| return backbone_out, vision_feats, vision_pos_embeds, feat_sizes | |
| class Tracker(nn.Module): | |
| def __init__(self, predictor, predictor_original = None, device = 'cuda'): | |
| super().__init__() | |
| self.predictor = predictor.to(device) | |
| self.predictor_original = predictor_original.to(device) if predictor_original is not None else None | |
| self.device = device | |
| def init(self, images, processing_order, points = None, labels = None, mask_inputs = None, must3r_feats = None, explicit_3d = None, image_features = None): | |
| self.images = images | |
| self.point_inputs = {"point_coords": points.to(self.device), "point_labels": labels.to(self.device)} if points is not None and labels is not None else None | |
| self.mask_inputs = mask_inputs | |
| self.processing_order = processing_order | |
| self.output_dict = {'cond_frame_outputs': {}, 'non_cond_frame_outputs': {}} | |
| self.pred_maskses = [] | |
| self.num_frames = len(processing_order) | |
| self.current_idx = 0 | |
| self.image_features = image_features | |
| self.must3r_feats = must3r_feats | |
| self.explicit_3d = explicit_3d | |
| def step(self, mask_inputs = None, point_inputs = None): | |
| assert (mask_inputs is None or self.current_idx > 0) and (point_inputs is None or self.current_idx > 0), f"mask_inputs: {mask_inputs}, point_inputs: {point_inputs}" | |
| frame_idx = self.processing_order[self.current_idx] | |
| ( | |
| _, | |
| current_vision_feats, | |
| current_vision_pos_embeds, | |
| feat_sizes, | |
| ) = get_image_feature(self.predictor, images = self.images[:, frame_idx, :3].to(self.device)) | |
| if self.must3r_feats is not None: | |
| feat_2d_original = rearrange(current_vision_feats[-1], '(x y) b c -> b c x y', x = 64, y = 64) | |
| feat_2d = self.predictor.fusion_3d( | |
| feat_2d = feat_2d_original, | |
| feat_3d = [f[frame_idx].to(self.device).squeeze()[None] for f in self.must3r_feats], | |
| explicit_3d = self.explicit_3d[frame_idx].to(self.device).squeeze()[None], | |
| must3r_size = 224 if self.must3r_feats[0][frame_idx].shape[-1] == 14 and self.must3r_feats[0][frame_idx].shape[-2] == 14 else 512 | |
| ) | |
| current_vision_feats[-1] = rearrange(feat_2d, 'b c h w -> (h w) b c') | |
| assert not torch.allclose(feat_2d.float(), feat_2d_original.float(), 1e-4), 'Feature fusion did not change features' | |
| self.current_vision_feats = current_vision_feats | |
| self.feat_sizes = feat_sizes | |
| memory_dict = { | |
| 'cond_frame_outputs': self.output_dict['cond_frame_outputs'], | |
| 'non_cond_frame_outputs': {k: v for k, v in self.output_dict['non_cond_frame_outputs'].items() if (v['pred_masks'] > 0).any()} | ({self.current_idx - 1: d} if (d := self.output_dict['non_cond_frame_outputs'].get(self.current_idx - 1)) else {}) | |
| } | |
| if len(memory_dict['non_cond_frame_outputs']) > 32: | |
| memory_dict['non_cond_frame_outputs'] = {self.current_idx - i: v for i, (k, v) in enumerate(sorted(memory_dict['non_cond_frame_outputs'].items(), key = lambda x: abs(x[0] - self.current_idx))[:32])} | |
| if len(memory_dict['cond_frame_outputs']) > 32: | |
| memory_dict['cond_frame_outputs'] = {self.current_idx - i: v for i, (k, v) in enumerate(sorted(memory_dict['cond_frame_outputs'].items(), key = lambda x: abs(x[0] - self.current_idx))[:32])} | |
| current_out = self.predictor.track_step( | |
| frame_idx = self.current_idx, | |
| is_init_cond_frame = self.current_idx == 0, | |
| current_vision_feats = current_vision_feats, | |
| current_vision_pos_embeds = current_vision_pos_embeds, | |
| feat_sizes = feat_sizes, | |
| point_inputs = self.point_inputs if self.current_idx == 0 else point_inputs, | |
| mask_inputs = self.mask_inputs.to(self.device) if self.current_idx == 0 else mask_inputs, | |
| output_dict = memory_dict, | |
| num_frames = self.num_frames, | |
| track_in_reverse = False, | |
| run_mem_encoder = False, | |
| prev_sam_mask_logits = None, | |
| ) | |
| current_out["pred_masks"] = fill_holes_in_mask_scores( | |
| current_out["pred_masks"], self.predictor.fill_hole_area | |
| ) | |
| current_out["pred_masks_high_res"] = torch.nn.functional.interpolate( | |
| current_out["pred_masks"], | |
| size = (self.predictor.image_size, self.predictor.image_size), | |
| mode = "bilinear", | |
| align_corners = False, | |
| ) | |
| # if self.predictor_original is not None and self.current_idx != 0: | |
| # current_out['pred_masks_high_res_lq'] = current_out['pred_masks_high_res'].clone() | |
| # self.predictor_original.use_mask_input_as_output_without_sam = False | |
| # current_vision_feats_original = current_vision_feats.copy() | |
| # current_vision_feats_original[-1] = rearrange(feat_2d_original, 'b c h w -> (h w) b c') | |
| # current_out_original = self.predictor_original.track_step( | |
| # frame_idx = 0, | |
| # is_init_cond_frame = True, | |
| # current_vision_feats = current_vision_feats_original, | |
| # current_vision_pos_embeds = current_vision_pos_embeds, | |
| # feat_sizes = feat_sizes, | |
| # point_inputs = None, | |
| # mask_inputs = current_out["pred_masks_high_res"].to(self.device).squeeze()[None, None], | |
| # output_dict = {}, | |
| # num_frames = self.num_frames, | |
| # track_in_reverse = False, | |
| # run_mem_encoder = False, | |
| # prev_sam_mask_logits = None, | |
| # ) | |
| # # if (current_out['pred_masks_high_res'] > 0).sum() > 0: assert (current_out_original['pred_masks'] > 0).sum() > 0, 'Original predictor produced empty mask' | |
| # current_out["pred_masks"] = fill_holes_in_mask_scores( | |
| # current_out_original["pred_masks"], self.predictor.fill_hole_area | |
| # ) | |
| # current_out["pred_masks_high_res"] = torch.nn.functional.interpolate( | |
| # current_out["pred_masks"], | |
| # size = (self.predictor.image_size, self.predictor.image_size), | |
| # mode = "bilinear", | |
| # align_corners = False, | |
| # ) | |
| return current_out | |
| def postprocess(self, current_out): | |
| maskmem_features, maskmem_pos_enc = self.predictor._encode_new_memory( | |
| current_vision_feats = self.current_vision_feats, | |
| feat_sizes = self.feat_sizes, | |
| pred_masks_high_res = current_out["pred_masks_high_res"], | |
| object_score_logits = current_out['object_score_logits'], | |
| is_mask_from_pts = False | |
| ) | |
| current_out["maskmem_features"] = maskmem_features.to(torch.bfloat16) | |
| current_out["maskmem_pos_enc"] = maskmem_pos_enc | |
| self.pred_maskses.append(current_out['pred_masks_high_res'].cpu()) | |
| self.output_dict['cond_frame_outputs'if self.current_idx == 0 else 'non_cond_frame_outputs'][self.current_idx] = current_out | |
| if len(self.output_dict['non_cond_frame_outputs']) > 256: | |
| self.output_dict['non_cond_frame_outputs'] = {k: v for k, v in self.output_dict['non_cond_frame_outputs'].items() if k >= self.current_idx - 256} | |
| if len(self.output_dict['cond_frame_outputs']) > 256: | |
| self.output_dict['cond_frame_outputs'] = {k: v for k, v in self.output_dict['cond_frame_outputs'].items() if k >= self.current_idx - 256} | |
| self.current_idx += 1 | |
| def forward_original(predictor: SAM2Base, images, points = None, labels = None, mask_inputs = None, processing_order = None, device = 'cuda'): | |
| B, T, _, H, W = images.shape | |
| point_inputs = {"point_coords": points, "point_labels": labels} if points is not None and labels is not None else None | |
| assert (mask_inputs is None) ^ (point_inputs is None), f"mask_inputs: {mask_inputs}, point_inputs: {point_inputs}" | |
| processing_order = list(range(images.shape[1])) if processing_order is None else processing_order | |
| num_frames = len(processing_order) | |
| pred_maskses = [] | |
| ious = [] | |
| output_dict = {'cond_frame_outputs': {}, 'non_cond_frame_outputs': {}} | |
| for idx, frame_idx in enumerate(tqdm(processing_order)): | |
| ( | |
| _, | |
| current_vision_feats, | |
| current_vision_pos_embeds, | |
| feat_sizes, | |
| ) = get_image_feature(predictor, images = images[:, frame_idx, :3].to(device)) | |
| memory_dict = {'cond_frame_outputs': output_dict['cond_frame_outputs'], 'non_cond_frame_outputs': {k: v for k, v in output_dict['non_cond_frame_outputs'].items() if (v['pred_masks'] > 0).any()}} | |
| if len(memory_dict['non_cond_frame_outputs']) > 0: | |
| memory_dict['non_cond_frame_outputs'] = {idx - i: v for i, (k, v) in enumerate(sorted(memory_dict['non_cond_frame_outputs'].items(), key = lambda x: abs(x[0] - idx))[:24])} | |
| current_out = predictor.track_step( | |
| frame_idx = idx, | |
| is_init_cond_frame = idx == 0, | |
| current_vision_feats = current_vision_feats, | |
| current_vision_pos_embeds = current_vision_pos_embeds, | |
| feat_sizes = feat_sizes, | |
| point_inputs = point_inputs if idx == 0 else None, | |
| mask_inputs = mask_inputs if idx == 0 else None, | |
| output_dict = memory_dict, | |
| num_frames = num_frames, | |
| track_in_reverse = False, | |
| run_mem_encoder = False, | |
| prev_sam_mask_logits = None, | |
| ) | |
| current_out['ppred_masks_high_res_lq'] = current_out['pred_masks_high_res'] | |
| current_out["pred_masks"] = fill_holes_in_mask_scores( | |
| current_out["pred_masks"], predictor.fill_hole_area | |
| ) | |
| current_out["pred_masks_high_res"] = torch.nn.functional.interpolate( | |
| current_out["pred_masks"], | |
| size = (predictor.image_size, predictor.image_size), | |
| mode = "bilinear", | |
| align_corners = False, | |
| ) | |
| maskmem_features, maskmem_pos_enc = predictor._encode_new_memory( | |
| current_vision_feats = current_vision_feats, | |
| feat_sizes = feat_sizes, | |
| pred_masks_high_res = current_out["pred_masks_high_res"], | |
| object_score_logits = current_out['object_score_logits'], | |
| is_mask_from_pts = True | |
| ) | |
| current_out["maskmem_features"] = maskmem_features.to(torch.bfloat16) | |
| current_out["maskmem_pos_enc"] = maskmem_pos_enc | |
| pred_maskses.append(current_out['pred_masks_high_res'].cpu()) | |
| output_dict['cond_frame_outputs'if idx == 0 else 'non_cond_frame_outputs'][idx] = current_out | |
| if len(output_dict['non_cond_frame_outputs']) > 256: | |
| output_dict['non_cond_frame_outputs'] = {k: v for k, v in output_dict['non_cond_frame_outputs'].items() if k >= idx - 256} | |
| pred_maskses = torch.stack(pred_maskses, dim = 1).squeeze(2) # (B, T, H, W) | |
| assert pred_maskses.shape == (B, len(processing_order), H, W) | |
| return pred_maskses | |
| def get_single_frame_mask(image: torch.Tensor, predictor_original, points, labels, device = 'cuda'): | |
| ''' | |
| points: 1 x N x 2 | |
| labels: 1 x N (positive 1, negative 0, box (top left 2, low right 3)) | |
| ''' | |
| return forward_original( | |
| predictor_original.to(device), | |
| images = image.squeeze()[None, None], | |
| points = points, | |
| labels = labels, | |
| processing_order = [0], | |
| device = device | |
| ) | |
| def get_tracked_masks(sam2_input_images, must3r_feats, must3r_outputs, start_idx, first_frame_mask, predictor, predictor_original, device = 'cuda'): | |
| tracker = Tracker(predictor, predictor_original = predictor_original, device = device) | |
| tracker.init( | |
| images = sam2_input_images.squeeze()[None], | |
| processing_order = range(start_idx, sam2_input_images.shape[0]), | |
| mask_inputs = first_frame_mask.squeeze()[None, None] > 0, | |
| must3r_feats = must3r_feats, | |
| explicit_3d = torch.cat((must3r_outputs['pts3d'], must3r_outputs['ray_plucker']), dim = -1).permute(0, 3, 1, 2) | |
| ) | |
| output_masks = {} | |
| for idx, frame_idx in enumerate(tqdm(tracker.processing_order)): | |
| current_out = tracker.step() | |
| output_masks[frame_idx] = current_out['pred_masks_high_res'].squeeze().cpu().numpy() > 0 | |
| tracker.postprocess(current_out) | |
| tracker.init( | |
| images = sam2_input_images.squeeze()[None], | |
| processing_order = range(start_idx, -1, -1), | |
| mask_inputs = first_frame_mask.squeeze()[None, None] > 0, | |
| must3r_feats = must3r_feats, | |
| explicit_3d = torch.cat((must3r_outputs['pts3d'], must3r_outputs['ray_plucker']), dim = -1).permute(0, 3, 1, 2) | |
| ) | |
| for idx, frame_idx in enumerate(tqdm(tracker.processing_order)): | |
| current_out = tracker.step() | |
| output_masks[frame_idx] = current_out['pred_masks_high_res'].squeeze().cpu().numpy() > 0 | |
| tracker.postprocess(current_out) | |
| return output_masks |