import spaces from dust3r.models.blocks import PositionGetter from dust3r.post_process import estimate_focal_knowing_depth from must3r.model.blocks.attention import has_xformers, toggle_memory_efficient_attention toggle_memory_efficient_attention(enabled = has_xformers) from hydra import compose from hydra.utils import instantiate from sam2.build_sam import build_sam2_video_predictor from einops import rearrange, repeat from collections import OrderedDict import copy import torch from tqdm import tqdm import json from sam2.modeling.sam2_base import NO_OBJ_SCORE, SAM2Base from sam2.utils.misc import concat_points, fill_holes_in_mask_scores, load_video_frames import torch.nn.functional as F from torchvision.transforms import functional as TF import torchvision.transforms as T import numpy as np from torch import nn from training_utils import load_checkpoint, BatchedVideoDatapoint, positional_encoding, postprocess_must3r_output from sam2.modeling.sam2_utils import LayerNorm2d import os MUST3R_SIZE = 512 def get_views(pil_imgs): ## pil_imgs = a list of PIL Image from data import load_images views, resize_funcs = load_images(pil_imgs, size = MUST3R_SIZE, patch_size = 16) return views, resize_funcs def prepare_sam2_inputs(views, pil_imgs, resize_funcs): image_transform = T.Compose([ T.Resize((1024, 1024), interpolation = T.InterpolationMode.BILINEAR), T.Normalize(mean = (0.485, 0.456, 0.406), std = (0.229, 0.224, 0.225)) ]) images = resize_funcs[0].transforms[0](torch.stack([TF.to_tensor(p) for p in pil_imgs], dim = 0).cpu()) sam2_input_images = image_transform(images) # normalize to [0, 1] range and then normalize with ImageNet stats return sam2_input_images, images @torch.no_grad() def must3r_features_and_output(views, device = 'cuda'): import functools from must3r.model import load_model, get_pointmaps_activation from must3r.demo.gradio import get_args_parser, main_demo, get_reconstructed_scene from must3r.demo.inference import must3r_inference_video, slam_is_keyframe, slam_update_scene_state, must3r_inference from must3r.slam.model import get_searcher from must3r.model import ActivationType from must3r.demo.inference import get_pointmaps_activation from must3r.tools.geometry import apply_exp_to_norm cmd_params = ["--weights", "./private/MUSt3R_512.pth", "--retrieval", "./private/MUSt3R_512_retrieval_trainingfree.pth", "--image_size", "512", "--amp", "bf16", "--viser", "--allow_local_files", "--device", device] parser = get_args_parser() args = parser.parse_args(cmd_params) weights_path = args.weights model = load_model(weights_path, encoder=args.encoder, decoder=args.decoder, device=args.device, img_size=args.image_size, memory_mode=args.memory_mode, verbose=args.verbose) model = [m.eval() for m in model] assert model[0].patch_size == 16 assert get_pointmaps_activation(model[1]) == ActivationType.NORM_EXP # model_224 = load_model("./private/MUSt3R_224.pth", encoder=args.encoder, decoder=args.decoder, device=args.device, # img_size = 224, memory_mode=args.memory_mode, verbose=args.verbose) # model_224 = [m.eval() for m in model_224] # assert get_pointmaps_activation(model_224[1]) == ActivationType.NORM_EXP retrieval = "./private/MUSt3R_512_retrieval_trainingfree.pth" retrieval_224 = "./private/MUSt3R_224_retrieval_trainingfree.pth" verbose = False image_size = 512 image_size_224 = 224 amp = "bf16" amp_224 = "fp16" max_bs = 1 num_refinements_iterations = 0 execution_mode = "vidslam" num_mem_images = 0 render_once = False vidseq_local_context_size = 0 keyframe_interval = 0 slam_local_context_size = 0 subsample = 2 min_conf_keyframe = 1.5 keyframe_overlap_thr = 0.05 overlap_percentile = 85 min_conf_thr = 3 as_pointcloud = True transparent_cams = False local_pointmaps = False cam_size = 0.05 camera_conf_thr = 1.5 local_context_size = slam_local_context_size overlap_mode = "nn-norm" assert MUST3R_SIZE == 512 model[1].recorded_feats = [] model[1].all_feats = [] is_keyframe_function = functools.partial(slam_is_keyframe, subsample, min_conf_keyframe, keyframe_overlap_thr, overlap_percentile, overlap_mode) scene_state = get_searcher("kdtree-scipy-quadrant_x2") scene_state_update_function = functools.partial(slam_update_scene_state, subsample, min_conf_keyframe) must3r_inference_video((model), device, image_size, amp, filelist = None, max_bs = max_bs, init_num_images = 2, batch_num_views = 1, viser_server = None, num_refinements_iterations = num_refinements_iterations, local_context_size = local_context_size, is_keyframe_function = is_keyframe_function, scene_state = scene_state, scene_state_update_function = scene_state_update_function, verbose = True, views = views) must3r_feats = torch.cat(model[1].recorded_feats, dim = 0).to(device) must3r_outputs = model[1]._compute_prediction_head( torch.stack([torch.from_numpy(view['true_shape']).squeeze() for view in views]).to(device)[:, None], len(views), 1, [must3r_feats], norm = False ).squeeze() must3r_feats = [[f[0], f[4], f[7], f[11]] for f in model[1].all_feats] must3r_feats = [torch.cat(f, dim = 0).to(device) for f in zip(*must3r_feats)] from einops import rearrange must3r_feats = [ rearrange(f, 'b (h w) c -> b c h w', h = views[0]['true_shape'][0] // 16, w = views[0]['true_shape'][1] // 16).cpu() for f in must3r_feats ] from training_utils import load_checkpoint, BatchedVideoDatapoint, positional_encoding, postprocess_must3r_output from must3r.model import ActivationType, apply_activation # must3r_outputs = postprocess_must3r_output(must3r_outputs, pointmaps_activation = ActivationType.NORM_EXP, compute_cam = True) must3r_output_all = [] for f in tqdm(must3r_outputs): must3r_output_all.append(postprocess_must3r_output(f.cpu()[None], pointmaps_activation = ActivationType.NORM_EXP, compute_cam = True)) must3r_outputs = {'pts3d': torch.cat([c['pts3d'] for c in must3r_output_all], dim = 0).squeeze(), 'ray_plucker': torch.cat([c['ray_plucker'] for c in must3r_output_all], dim = 0).squeeze()} must3r_outputs = {k: v.cpu() for k, v in must3r_outputs.items()} return must3r_feats, must3r_outputs class FeatureFusion(nn.Module): def __init__(self, cross_attn_blocks_3d, in_channels_2d = 256, in_channels_3d = 768): super().__init__() from einops.layers.torch import Rearrange import copy self.freqs = 6 self.position_getter = PositionGetter() self.feat_conv_3d_224 = nn.ModuleList([ copy.deepcopy(block) for block in cross_attn_blocks_3d ] + [nn.Linear(in_features = 1024, out_features = 768)]) self.feat_conv_3d_512 = nn.ModuleList([ copy.deepcopy(block) for block in cross_attn_blocks_3d ] + [nn.Linear(in_features = 1024, out_features = 768)]) self.out = nn.Conv2d(in_channels = 768, out_channels = in_channels_2d, kernel_size = 3, padding = 1) self.merge = nn.Conv2d(in_channels = in_channels_2d * 2, out_channels = in_channels_2d, kernel_size = 1, padding = 0) self.explicit_3d_embedding = nn.Conv2d(in_channels = 3 * (2 * self.freqs + 1) + 6, out_channels = 768, kernel_size = 16, padding = 0, stride = 16) def forward(self, feat_2d, feat_3d, explicit_3d = None, must3r_size = 512): refinenets_3d = self.feat_conv_3d_224 if must3r_size == 224 else self.feat_conv_3d_512 assert len(feat_3d) == 4, f'Expected 4 levels of 3D features, got {len(feat_3d)}' explicit_3d = torch.cat((positional_encoding(explicit_3d[:, :3], self.freqs, dim = 1), explicit_3d[:, 3:]), dim = 1) explicit_3d = self.explicit_3d_embedding(explicit_3d) pe_3d = rearrange(explicit_3d, 'b c h w -> b (h w) c') B = pe_3d.shape[0] assert B == 1 pe_2d = self.position_getter(B, explicit_3d.shape[2], explicit_3d.shape[3], device = explicit_3d.device) feat_3d = [rearrange(f, 'b c h w -> b (h w) c') for f in feat_3d] N = feat_3d[0].shape[1] ca_attn_mask = torch.ones((B, 1, N, N * B), dtype = torch.bool, device = feat_3d[0].device) for i in range(B): ca_attn_mask[i, :, :, :(i + 1) * N] = False feat_3d_post = feat_3d[0] for i in range(len(feat_3d)): if i == 0: feat_3d_post = refinenets_3d[-1](feat_3d_post) + pe_3d feat_3d_post = refinenets_3d[i](x = feat_3d_post, y = feat_3d_post, xpos = pe_2d) else: feat_3d_post = refinenets_3d[i](x = feat_3d_post + pe_3d, y = repeat(feat_3d[i] + pe_3d, 'b n c -> k (b n) c', k = B), xpos = pe_2d, ca_attn_mask = ca_attn_mask) feat_3d_post = self.out(F.interpolate(rearrange(feat_3d_post, 'b (h w) c -> b c h w', b = B, h = explicit_3d.shape[2], w = explicit_3d.shape[3]), size = feat_2d.shape[-2:], mode = 'bilinear', align_corners = False)) feat_merged = self.merge(torch.cat([feat_3d_post, feat_2d], dim = 1)) return feat_merged def get_must3r_cross_attn_layers(device = 'cuda'): from must3r.model import load_model from must3r.demo.gradio import get_args_parser cmd_params = ["--weights", "./private/MUSt3R_512.pth", "--retrieval", "./private/MUSt3R_512_retrieval_trainingfree.pth", "--image_size", "512", "--amp", "bf16", "--viser", "--allow_local_files", "--device", device] parser = get_args_parser() args = parser.parse_args(cmd_params) model = load_model(args.weights, encoder=args.encoder, decoder=args.decoder, device=args.device, img_size=args.image_size, memory_mode=args.memory_mode, verbose=args.verbose) return model[1].blocks_dec def get_predictors(device = 'cuda'): predictor_original = build_sam2_video_predictor("configs/sam2.1/sam2.1_hiera_l.yaml", "./sam2-src/checkpoints/sam2.1_hiera_large.pt").to(device).eval() predictor = build_sam2_video_predictor("configs/sam2.1/sam2.1_hiera_l_3d.yaml").to(device).eval() cross_attn_blocks_3d = get_must3r_cross_attn_layers(device = device) predictor.fusion_3d = FeatureFusion(cross_attn_blocks_3d = [copy.deepcopy(cross_attn_blocks_3d[i]) for i in [0, 4, 7, 11]]) predictor = load_checkpoint(predictor, torch.load('./private/sam2.1-must3r-fixed-vision-v1-decomp-standalone-regional-best-2.7851.pt', map_location = 'cpu')) return predictor_original.cpu(), predictor.cpu() @torch.no_grad() def get_image_feature( predictor: SAM2Base, images: torch.Tensor, device = 'cuda' ): backbone_out = predictor.forward_image(images) backbone_out = { "backbone_fpn": backbone_out["backbone_fpn"].copy(), "vision_pos_enc": backbone_out["vision_pos_enc"].copy(), } backbone_out, vision_feats, vision_pos_embeds, feat_sizes = predictor._prepare_backbone_features(backbone_out) return backbone_out, vision_feats, vision_pos_embeds, feat_sizes class Tracker(nn.Module): def __init__(self, predictor, predictor_original = None, device = 'cuda'): super().__init__() self.predictor = predictor.to(device) self.predictor_original = predictor_original.to(device) if predictor_original is not None else None self.device = device def init(self, images, processing_order, points = None, labels = None, mask_inputs = None, must3r_feats = None, explicit_3d = None, image_features = None): self.images = images self.point_inputs = {"point_coords": points.to(self.device), "point_labels": labels.to(self.device)} if points is not None and labels is not None else None self.mask_inputs = mask_inputs self.processing_order = processing_order self.output_dict = {'cond_frame_outputs': {}, 'non_cond_frame_outputs': {}} self.pred_maskses = [] self.num_frames = len(processing_order) self.current_idx = 0 self.image_features = image_features self.must3r_feats = must3r_feats self.explicit_3d = explicit_3d @torch.no_grad() @torch.autocast(device_type = 'cuda', dtype = torch.bfloat16) def step(self, mask_inputs = None, point_inputs = None): assert (mask_inputs is None or self.current_idx > 0) and (point_inputs is None or self.current_idx > 0), f"mask_inputs: {mask_inputs}, point_inputs: {point_inputs}" frame_idx = self.processing_order[self.current_idx] ( _, current_vision_feats, current_vision_pos_embeds, feat_sizes, ) = get_image_feature(self.predictor, images = self.images[:, frame_idx, :3].to(self.device)) if self.must3r_feats is not None: feat_2d_original = rearrange(current_vision_feats[-1], '(x y) b c -> b c x y', x = 64, y = 64) feat_2d = self.predictor.fusion_3d( feat_2d = feat_2d_original, feat_3d = [f[frame_idx].to(self.device).squeeze()[None] for f in self.must3r_feats], explicit_3d = self.explicit_3d[frame_idx].to(self.device).squeeze()[None], must3r_size = 224 if self.must3r_feats[0][frame_idx].shape[-1] == 14 and self.must3r_feats[0][frame_idx].shape[-2] == 14 else 512 ) current_vision_feats[-1] = rearrange(feat_2d, 'b c h w -> (h w) b c') assert not torch.allclose(feat_2d.float(), feat_2d_original.float(), 1e-4), 'Feature fusion did not change features' self.current_vision_feats = current_vision_feats self.feat_sizes = feat_sizes memory_dict = { 'cond_frame_outputs': self.output_dict['cond_frame_outputs'], 'non_cond_frame_outputs': {k: v for k, v in self.output_dict['non_cond_frame_outputs'].items() if (v['pred_masks'] > 0).any()} | ({self.current_idx - 1: d} if (d := self.output_dict['non_cond_frame_outputs'].get(self.current_idx - 1)) else {}) } if len(memory_dict['non_cond_frame_outputs']) > 32: memory_dict['non_cond_frame_outputs'] = {self.current_idx - i: v for i, (k, v) in enumerate(sorted(memory_dict['non_cond_frame_outputs'].items(), key = lambda x: abs(x[0] - self.current_idx))[:32])} if len(memory_dict['cond_frame_outputs']) > 32: memory_dict['cond_frame_outputs'] = {self.current_idx - i: v for i, (k, v) in enumerate(sorted(memory_dict['cond_frame_outputs'].items(), key = lambda x: abs(x[0] - self.current_idx))[:32])} current_out = self.predictor.track_step( frame_idx = self.current_idx, is_init_cond_frame = self.current_idx == 0, current_vision_feats = current_vision_feats, current_vision_pos_embeds = current_vision_pos_embeds, feat_sizes = feat_sizes, point_inputs = self.point_inputs if self.current_idx == 0 else point_inputs, mask_inputs = self.mask_inputs.to(self.device) if self.current_idx == 0 else mask_inputs, output_dict = memory_dict, num_frames = self.num_frames, track_in_reverse = False, run_mem_encoder = False, prev_sam_mask_logits = None, ) current_out["pred_masks"] = fill_holes_in_mask_scores( current_out["pred_masks"], self.predictor.fill_hole_area ) current_out["pred_masks_high_res"] = torch.nn.functional.interpolate( current_out["pred_masks"], size = (self.predictor.image_size, self.predictor.image_size), mode = "bilinear", align_corners = False, ) # if self.predictor_original is not None and self.current_idx != 0: # current_out['pred_masks_high_res_lq'] = current_out['pred_masks_high_res'].clone() # self.predictor_original.use_mask_input_as_output_without_sam = False # current_vision_feats_original = current_vision_feats.copy() # current_vision_feats_original[-1] = rearrange(feat_2d_original, 'b c h w -> (h w) b c') # current_out_original = self.predictor_original.track_step( # frame_idx = 0, # is_init_cond_frame = True, # current_vision_feats = current_vision_feats_original, # current_vision_pos_embeds = current_vision_pos_embeds, # feat_sizes = feat_sizes, # point_inputs = None, # mask_inputs = current_out["pred_masks_high_res"].to(self.device).squeeze()[None, None], # output_dict = {}, # num_frames = self.num_frames, # track_in_reverse = False, # run_mem_encoder = False, # prev_sam_mask_logits = None, # ) # # if (current_out['pred_masks_high_res'] > 0).sum() > 0: assert (current_out_original['pred_masks'] > 0).sum() > 0, 'Original predictor produced empty mask' # current_out["pred_masks"] = fill_holes_in_mask_scores( # current_out_original["pred_masks"], self.predictor.fill_hole_area # ) # current_out["pred_masks_high_res"] = torch.nn.functional.interpolate( # current_out["pred_masks"], # size = (self.predictor.image_size, self.predictor.image_size), # mode = "bilinear", # align_corners = False, # ) return current_out @torch.no_grad() @torch.autocast(device_type = 'cuda', dtype = torch.bfloat16) def postprocess(self, current_out): maskmem_features, maskmem_pos_enc = self.predictor._encode_new_memory( current_vision_feats = self.current_vision_feats, feat_sizes = self.feat_sizes, pred_masks_high_res = current_out["pred_masks_high_res"], object_score_logits = current_out['object_score_logits'], is_mask_from_pts = False ) current_out["maskmem_features"] = maskmem_features.to(torch.bfloat16) current_out["maskmem_pos_enc"] = maskmem_pos_enc self.pred_maskses.append(current_out['pred_masks_high_res'].cpu()) self.output_dict['cond_frame_outputs'if self.current_idx == 0 else 'non_cond_frame_outputs'][self.current_idx] = current_out if len(self.output_dict['non_cond_frame_outputs']) > 256: self.output_dict['non_cond_frame_outputs'] = {k: v for k, v in self.output_dict['non_cond_frame_outputs'].items() if k >= self.current_idx - 256} if len(self.output_dict['cond_frame_outputs']) > 256: self.output_dict['cond_frame_outputs'] = {k: v for k, v in self.output_dict['cond_frame_outputs'].items() if k >= self.current_idx - 256} self.current_idx += 1 @torch.no_grad() @torch.autocast(device_type = 'cuda', dtype = torch.bfloat16) def forward_original(predictor: SAM2Base, images, points = None, labels = None, mask_inputs = None, processing_order = None, device = 'cuda'): B, T, _, H, W = images.shape point_inputs = {"point_coords": points, "point_labels": labels} if points is not None and labels is not None else None assert (mask_inputs is None) ^ (point_inputs is None), f"mask_inputs: {mask_inputs}, point_inputs: {point_inputs}" processing_order = list(range(images.shape[1])) if processing_order is None else processing_order num_frames = len(processing_order) pred_maskses = [] ious = [] output_dict = {'cond_frame_outputs': {}, 'non_cond_frame_outputs': {}} for idx, frame_idx in enumerate(tqdm(processing_order)): ( _, current_vision_feats, current_vision_pos_embeds, feat_sizes, ) = get_image_feature(predictor, images = images[:, frame_idx, :3].to(device)) memory_dict = {'cond_frame_outputs': output_dict['cond_frame_outputs'], 'non_cond_frame_outputs': {k: v for k, v in output_dict['non_cond_frame_outputs'].items() if (v['pred_masks'] > 0).any()}} if len(memory_dict['non_cond_frame_outputs']) > 0: memory_dict['non_cond_frame_outputs'] = {idx - i: v for i, (k, v) in enumerate(sorted(memory_dict['non_cond_frame_outputs'].items(), key = lambda x: abs(x[0] - idx))[:24])} current_out = predictor.track_step( frame_idx = idx, is_init_cond_frame = idx == 0, current_vision_feats = current_vision_feats, current_vision_pos_embeds = current_vision_pos_embeds, feat_sizes = feat_sizes, point_inputs = point_inputs if idx == 0 else None, mask_inputs = mask_inputs if idx == 0 else None, output_dict = memory_dict, num_frames = num_frames, track_in_reverse = False, run_mem_encoder = False, prev_sam_mask_logits = None, ) current_out['ppred_masks_high_res_lq'] = current_out['pred_masks_high_res'] current_out["pred_masks"] = fill_holes_in_mask_scores( current_out["pred_masks"], predictor.fill_hole_area ) current_out["pred_masks_high_res"] = torch.nn.functional.interpolate( current_out["pred_masks"], size = (predictor.image_size, predictor.image_size), mode = "bilinear", align_corners = False, ) maskmem_features, maskmem_pos_enc = predictor._encode_new_memory( current_vision_feats = current_vision_feats, feat_sizes = feat_sizes, pred_masks_high_res = current_out["pred_masks_high_res"], object_score_logits = current_out['object_score_logits'], is_mask_from_pts = True ) current_out["maskmem_features"] = maskmem_features.to(torch.bfloat16) current_out["maskmem_pos_enc"] = maskmem_pos_enc pred_maskses.append(current_out['pred_masks_high_res'].cpu()) output_dict['cond_frame_outputs'if idx == 0 else 'non_cond_frame_outputs'][idx] = current_out if len(output_dict['non_cond_frame_outputs']) > 256: output_dict['non_cond_frame_outputs'] = {k: v for k, v in output_dict['non_cond_frame_outputs'].items() if k >= idx - 256} pred_maskses = torch.stack(pred_maskses, dim = 1).squeeze(2) # (B, T, H, W) assert pred_maskses.shape == (B, len(processing_order), H, W) return pred_maskses @torch.no_grad() def get_single_frame_mask(image: torch.Tensor, predictor_original, points, labels, device = 'cuda'): ''' points: 1 x N x 2 labels: 1 x N (positive 1, negative 0, box (top left 2, low right 3)) ''' return forward_original( predictor_original.to(device), images = image.squeeze()[None, None], points = points, labels = labels, processing_order = [0], device = device ) @torch.no_grad() def get_tracked_masks(sam2_input_images, must3r_feats, must3r_outputs, start_idx, first_frame_mask, predictor, predictor_original, device = 'cuda'): tracker = Tracker(predictor, predictor_original = predictor_original, device = device) tracker.init( images = sam2_input_images.squeeze()[None], processing_order = range(start_idx, sam2_input_images.shape[0]), mask_inputs = first_frame_mask.squeeze()[None, None] > 0, must3r_feats = must3r_feats, explicit_3d = torch.cat((must3r_outputs['pts3d'], must3r_outputs['ray_plucker']), dim = -1).permute(0, 3, 1, 2) ) output_masks = {} for idx, frame_idx in enumerate(tqdm(tracker.processing_order)): current_out = tracker.step() output_masks[frame_idx] = current_out['pred_masks_high_res'].squeeze().cpu().numpy() > 0 tracker.postprocess(current_out) tracker.init( images = sam2_input_images.squeeze()[None], processing_order = range(start_idx, -1, -1), mask_inputs = first_frame_mask.squeeze()[None, None] > 0, must3r_feats = must3r_feats, explicit_3d = torch.cat((must3r_outputs['pts3d'], must3r_outputs['ray_plucker']), dim = -1).permute(0, 3, 1, 2) ) for idx, frame_idx in enumerate(tqdm(tracker.processing_order)): current_out = tracker.step() output_masks[frame_idx] = current_out['pred_masks_high_res'].squeeze().cpu().numpy() > 0 tracker.postprocess(current_out) return output_masks