3AM / engine.py
nycu-cplab's picture
sam2
8404eb0
import spaces
from dust3r.models.blocks import PositionGetter
from dust3r.post_process import estimate_focal_knowing_depth
from must3r.model.blocks.attention import has_xformers, toggle_memory_efficient_attention
toggle_memory_efficient_attention(enabled = has_xformers)
from hydra import compose
from hydra.utils import instantiate
from sam2.build_sam import build_sam2_video_predictor
from einops import rearrange, repeat
from collections import OrderedDict
import copy
import torch
from tqdm import tqdm
import json
from sam2.modeling.sam2_base import NO_OBJ_SCORE, SAM2Base
from sam2.utils.misc import concat_points, fill_holes_in_mask_scores, load_video_frames
import torch.nn.functional as F
from torchvision.transforms import functional as TF
import torchvision.transforms as T
import numpy as np
from torch import nn
from training_utils import load_checkpoint, BatchedVideoDatapoint, positional_encoding, postprocess_must3r_output
from sam2.modeling.sam2_utils import LayerNorm2d
import os
MUST3R_SIZE = 512
def get_views(pil_imgs):
## pil_imgs = a list of PIL Image
from data import load_images
views, resize_funcs = load_images(pil_imgs, size = MUST3R_SIZE, patch_size = 16)
return views, resize_funcs
def prepare_sam2_inputs(views, pil_imgs, resize_funcs):
image_transform = T.Compose([
T.Resize((1024, 1024), interpolation = T.InterpolationMode.BILINEAR),
T.Normalize(mean = (0.485, 0.456, 0.406), std = (0.229, 0.224, 0.225))
])
images = resize_funcs[0].transforms[0](torch.stack([TF.to_tensor(p) for p in pil_imgs], dim = 0).cpu())
sam2_input_images = image_transform(images) # normalize to [0, 1] range and then normalize with ImageNet stats
return sam2_input_images, images
@torch.no_grad()
def must3r_features_and_output(views, device = 'cuda'):
import functools
from must3r.model import load_model, get_pointmaps_activation
from must3r.demo.gradio import get_args_parser, main_demo, get_reconstructed_scene
from must3r.demo.inference import must3r_inference_video, slam_is_keyframe, slam_update_scene_state, must3r_inference
from must3r.slam.model import get_searcher
from must3r.model import ActivationType
from must3r.demo.inference import get_pointmaps_activation
from must3r.tools.geometry import apply_exp_to_norm
cmd_params = ["--weights", "./private/MUSt3R_512.pth", "--retrieval", "./private/MUSt3R_512_retrieval_trainingfree.pth", "--image_size", "512", "--amp", "bf16", "--viser", "--allow_local_files", "--device", device]
parser = get_args_parser()
args = parser.parse_args(cmd_params)
weights_path = args.weights
model = load_model(weights_path, encoder=args.encoder, decoder=args.decoder, device=args.device,
img_size=args.image_size, memory_mode=args.memory_mode, verbose=args.verbose)
model = [m.eval() for m in model]
assert model[0].patch_size == 16
assert get_pointmaps_activation(model[1]) == ActivationType.NORM_EXP
# model_224 = load_model("./private/MUSt3R_224.pth", encoder=args.encoder, decoder=args.decoder, device=args.device,
# img_size = 224, memory_mode=args.memory_mode, verbose=args.verbose)
# model_224 = [m.eval() for m in model_224]
# assert get_pointmaps_activation(model_224[1]) == ActivationType.NORM_EXP
retrieval = "./private/MUSt3R_512_retrieval_trainingfree.pth"
retrieval_224 = "./private/MUSt3R_224_retrieval_trainingfree.pth"
verbose = False
image_size = 512
image_size_224 = 224
amp = "bf16"
amp_224 = "fp16"
max_bs = 1
num_refinements_iterations = 0
execution_mode = "vidslam"
num_mem_images = 0
render_once = False
vidseq_local_context_size = 0
keyframe_interval = 0
slam_local_context_size = 0
subsample = 2
min_conf_keyframe = 1.5
keyframe_overlap_thr = 0.05
overlap_percentile = 85
min_conf_thr = 3
as_pointcloud = True
transparent_cams = False
local_pointmaps = False
cam_size = 0.05
camera_conf_thr = 1.5
local_context_size = slam_local_context_size
overlap_mode = "nn-norm"
assert MUST3R_SIZE == 512
model[1].recorded_feats = []
model[1].all_feats = []
is_keyframe_function = functools.partial(slam_is_keyframe, subsample, min_conf_keyframe, keyframe_overlap_thr, overlap_percentile, overlap_mode)
scene_state = get_searcher("kdtree-scipy-quadrant_x2")
scene_state_update_function = functools.partial(slam_update_scene_state, subsample, min_conf_keyframe)
must3r_inference_video((model), device, image_size, amp, filelist = None, max_bs = max_bs, init_num_images = 2, batch_num_views = 1,
viser_server = None, num_refinements_iterations = num_refinements_iterations,
local_context_size = local_context_size, is_keyframe_function = is_keyframe_function,
scene_state = scene_state, scene_state_update_function = scene_state_update_function,
verbose = True, views = views)
must3r_feats = torch.cat(model[1].recorded_feats, dim = 0).to(device)
must3r_outputs = model[1]._compute_prediction_head(
torch.stack([torch.from_numpy(view['true_shape']).squeeze() for view in views]).to(device)[:, None],
len(views),
1,
[must3r_feats],
norm = False
).squeeze()
must3r_feats = [[f[0], f[4], f[7], f[11]] for f in model[1].all_feats]
must3r_feats = [torch.cat(f, dim = 0).to(device) for f in zip(*must3r_feats)]
from einops import rearrange
must3r_feats = [
rearrange(f, 'b (h w) c -> b c h w', h = views[0]['true_shape'][0] // 16, w = views[0]['true_shape'][1] // 16).cpu()
for f in must3r_feats
]
from training_utils import load_checkpoint, BatchedVideoDatapoint, positional_encoding, postprocess_must3r_output
from must3r.model import ActivationType, apply_activation
# must3r_outputs = postprocess_must3r_output(must3r_outputs, pointmaps_activation = ActivationType.NORM_EXP, compute_cam = True)
must3r_output_all = []
for f in tqdm(must3r_outputs):
must3r_output_all.append(postprocess_must3r_output(f.cpu()[None], pointmaps_activation = ActivationType.NORM_EXP, compute_cam = True))
must3r_outputs = {'pts3d': torch.cat([c['pts3d'] for c in must3r_output_all], dim = 0).squeeze(),
'ray_plucker': torch.cat([c['ray_plucker'] for c in must3r_output_all], dim = 0).squeeze()}
must3r_outputs = {k: v.cpu() for k, v in must3r_outputs.items()}
return must3r_feats, must3r_outputs
class FeatureFusion(nn.Module):
def __init__(self, cross_attn_blocks_3d, in_channels_2d = 256, in_channels_3d = 768):
super().__init__()
from einops.layers.torch import Rearrange
import copy
self.freqs = 6
self.position_getter = PositionGetter()
self.feat_conv_3d_224 = nn.ModuleList([
copy.deepcopy(block) for block in cross_attn_blocks_3d
] + [nn.Linear(in_features = 1024, out_features = 768)])
self.feat_conv_3d_512 = nn.ModuleList([
copy.deepcopy(block) for block in cross_attn_blocks_3d
] + [nn.Linear(in_features = 1024, out_features = 768)])
self.out = nn.Conv2d(in_channels = 768, out_channels = in_channels_2d, kernel_size = 3, padding = 1)
self.merge = nn.Conv2d(in_channels = in_channels_2d * 2, out_channels = in_channels_2d, kernel_size = 1, padding = 0)
self.explicit_3d_embedding = nn.Conv2d(in_channels = 3 * (2 * self.freqs + 1) + 6, out_channels = 768, kernel_size = 16, padding = 0, stride = 16)
def forward(self, feat_2d, feat_3d, explicit_3d = None, must3r_size = 512):
refinenets_3d = self.feat_conv_3d_224 if must3r_size == 224 else self.feat_conv_3d_512
assert len(feat_3d) == 4, f'Expected 4 levels of 3D features, got {len(feat_3d)}'
explicit_3d = torch.cat((positional_encoding(explicit_3d[:, :3], self.freqs, dim = 1), explicit_3d[:, 3:]), dim = 1)
explicit_3d = self.explicit_3d_embedding(explicit_3d)
pe_3d = rearrange(explicit_3d, 'b c h w -> b (h w) c')
B = pe_3d.shape[0]
assert B == 1
pe_2d = self.position_getter(B, explicit_3d.shape[2], explicit_3d.shape[3], device = explicit_3d.device)
feat_3d = [rearrange(f, 'b c h w -> b (h w) c') for f in feat_3d]
N = feat_3d[0].shape[1]
ca_attn_mask = torch.ones((B, 1, N, N * B), dtype = torch.bool, device = feat_3d[0].device)
for i in range(B):
ca_attn_mask[i, :, :, :(i + 1) * N] = False
feat_3d_post = feat_3d[0]
for i in range(len(feat_3d)):
if i == 0:
feat_3d_post = refinenets_3d[-1](feat_3d_post) + pe_3d
feat_3d_post = refinenets_3d[i](x = feat_3d_post, y = feat_3d_post, xpos = pe_2d)
else:
feat_3d_post = refinenets_3d[i](x = feat_3d_post + pe_3d, y = repeat(feat_3d[i] + pe_3d, 'b n c -> k (b n) c', k = B), xpos = pe_2d, ca_attn_mask = ca_attn_mask)
feat_3d_post = self.out(F.interpolate(rearrange(feat_3d_post, 'b (h w) c -> b c h w', b = B, h = explicit_3d.shape[2], w = explicit_3d.shape[3]), size = feat_2d.shape[-2:], mode = 'bilinear', align_corners = False))
feat_merged = self.merge(torch.cat([feat_3d_post, feat_2d], dim = 1))
return feat_merged
def get_must3r_cross_attn_layers(device = 'cuda'):
from must3r.model import load_model
from must3r.demo.gradio import get_args_parser
cmd_params = ["--weights", "./private/MUSt3R_512.pth", "--retrieval", "./private/MUSt3R_512_retrieval_trainingfree.pth", "--image_size", "512", "--amp", "bf16", "--viser", "--allow_local_files", "--device", device]
parser = get_args_parser()
args = parser.parse_args(cmd_params)
model = load_model(args.weights, encoder=args.encoder, decoder=args.decoder, device=args.device,
img_size=args.image_size, memory_mode=args.memory_mode, verbose=args.verbose)
return model[1].blocks_dec
def get_predictors(device = 'cuda'):
predictor_original = build_sam2_video_predictor("configs/sam2.1/sam2.1_hiera_l.yaml", "./sam2-src/checkpoints/sam2.1_hiera_large.pt").to(device).eval()
predictor = build_sam2_video_predictor("configs/sam2.1/sam2.1_hiera_l_3d.yaml").to(device).eval()
cross_attn_blocks_3d = get_must3r_cross_attn_layers(device = device)
predictor.fusion_3d = FeatureFusion(cross_attn_blocks_3d = [copy.deepcopy(cross_attn_blocks_3d[i]) for i in [0, 4, 7, 11]])
predictor = load_checkpoint(predictor, torch.load('./private/sam2.1-must3r-fixed-vision-v1-decomp-standalone-regional-best-2.7851.pt', map_location = 'cpu'))
return predictor_original.cpu(), predictor.cpu()
@torch.no_grad()
def get_image_feature(
predictor: SAM2Base,
images: torch.Tensor,
device = 'cuda'
):
backbone_out = predictor.forward_image(images)
backbone_out = {
"backbone_fpn": backbone_out["backbone_fpn"].copy(),
"vision_pos_enc": backbone_out["vision_pos_enc"].copy(),
}
backbone_out, vision_feats, vision_pos_embeds, feat_sizes = predictor._prepare_backbone_features(backbone_out)
return backbone_out, vision_feats, vision_pos_embeds, feat_sizes
class Tracker(nn.Module):
def __init__(self, predictor, predictor_original = None, device = 'cuda'):
super().__init__()
self.predictor = predictor.to(device)
self.predictor_original = predictor_original.to(device) if predictor_original is not None else None
self.device = device
def init(self, images, processing_order, points = None, labels = None, mask_inputs = None, must3r_feats = None, explicit_3d = None, image_features = None):
self.images = images
self.point_inputs = {"point_coords": points.to(self.device), "point_labels": labels.to(self.device)} if points is not None and labels is not None else None
self.mask_inputs = mask_inputs
self.processing_order = processing_order
self.output_dict = {'cond_frame_outputs': {}, 'non_cond_frame_outputs': {}}
self.pred_maskses = []
self.num_frames = len(processing_order)
self.current_idx = 0
self.image_features = image_features
self.must3r_feats = must3r_feats
self.explicit_3d = explicit_3d
@torch.no_grad()
@torch.autocast(device_type = 'cuda', dtype = torch.bfloat16)
def step(self, mask_inputs = None, point_inputs = None):
assert (mask_inputs is None or self.current_idx > 0) and (point_inputs is None or self.current_idx > 0), f"mask_inputs: {mask_inputs}, point_inputs: {point_inputs}"
frame_idx = self.processing_order[self.current_idx]
(
_,
current_vision_feats,
current_vision_pos_embeds,
feat_sizes,
) = get_image_feature(self.predictor, images = self.images[:, frame_idx, :3].to(self.device))
if self.must3r_feats is not None:
feat_2d_original = rearrange(current_vision_feats[-1], '(x y) b c -> b c x y', x = 64, y = 64)
feat_2d = self.predictor.fusion_3d(
feat_2d = feat_2d_original,
feat_3d = [f[frame_idx].to(self.device).squeeze()[None] for f in self.must3r_feats],
explicit_3d = self.explicit_3d[frame_idx].to(self.device).squeeze()[None],
must3r_size = 224 if self.must3r_feats[0][frame_idx].shape[-1] == 14 and self.must3r_feats[0][frame_idx].shape[-2] == 14 else 512
)
current_vision_feats[-1] = rearrange(feat_2d, 'b c h w -> (h w) b c')
assert not torch.allclose(feat_2d.float(), feat_2d_original.float(), 1e-4), 'Feature fusion did not change features'
self.current_vision_feats = current_vision_feats
self.feat_sizes = feat_sizes
memory_dict = {
'cond_frame_outputs': self.output_dict['cond_frame_outputs'],
'non_cond_frame_outputs': {k: v for k, v in self.output_dict['non_cond_frame_outputs'].items() if (v['pred_masks'] > 0).any()} | ({self.current_idx - 1: d} if (d := self.output_dict['non_cond_frame_outputs'].get(self.current_idx - 1)) else {})
}
if len(memory_dict['non_cond_frame_outputs']) > 32:
memory_dict['non_cond_frame_outputs'] = {self.current_idx - i: v for i, (k, v) in enumerate(sorted(memory_dict['non_cond_frame_outputs'].items(), key = lambda x: abs(x[0] - self.current_idx))[:32])}
if len(memory_dict['cond_frame_outputs']) > 32:
memory_dict['cond_frame_outputs'] = {self.current_idx - i: v for i, (k, v) in enumerate(sorted(memory_dict['cond_frame_outputs'].items(), key = lambda x: abs(x[0] - self.current_idx))[:32])}
current_out = self.predictor.track_step(
frame_idx = self.current_idx,
is_init_cond_frame = self.current_idx == 0,
current_vision_feats = current_vision_feats,
current_vision_pos_embeds = current_vision_pos_embeds,
feat_sizes = feat_sizes,
point_inputs = self.point_inputs if self.current_idx == 0 else point_inputs,
mask_inputs = self.mask_inputs.to(self.device) if self.current_idx == 0 else mask_inputs,
output_dict = memory_dict,
num_frames = self.num_frames,
track_in_reverse = False,
run_mem_encoder = False,
prev_sam_mask_logits = None,
)
current_out["pred_masks"] = fill_holes_in_mask_scores(
current_out["pred_masks"], self.predictor.fill_hole_area
)
current_out["pred_masks_high_res"] = torch.nn.functional.interpolate(
current_out["pred_masks"],
size = (self.predictor.image_size, self.predictor.image_size),
mode = "bilinear",
align_corners = False,
)
# if self.predictor_original is not None and self.current_idx != 0:
# current_out['pred_masks_high_res_lq'] = current_out['pred_masks_high_res'].clone()
# self.predictor_original.use_mask_input_as_output_without_sam = False
# current_vision_feats_original = current_vision_feats.copy()
# current_vision_feats_original[-1] = rearrange(feat_2d_original, 'b c h w -> (h w) b c')
# current_out_original = self.predictor_original.track_step(
# frame_idx = 0,
# is_init_cond_frame = True,
# current_vision_feats = current_vision_feats_original,
# current_vision_pos_embeds = current_vision_pos_embeds,
# feat_sizes = feat_sizes,
# point_inputs = None,
# mask_inputs = current_out["pred_masks_high_res"].to(self.device).squeeze()[None, None],
# output_dict = {},
# num_frames = self.num_frames,
# track_in_reverse = False,
# run_mem_encoder = False,
# prev_sam_mask_logits = None,
# )
# # if (current_out['pred_masks_high_res'] > 0).sum() > 0: assert (current_out_original['pred_masks'] > 0).sum() > 0, 'Original predictor produced empty mask'
# current_out["pred_masks"] = fill_holes_in_mask_scores(
# current_out_original["pred_masks"], self.predictor.fill_hole_area
# )
# current_out["pred_masks_high_res"] = torch.nn.functional.interpolate(
# current_out["pred_masks"],
# size = (self.predictor.image_size, self.predictor.image_size),
# mode = "bilinear",
# align_corners = False,
# )
return current_out
@torch.no_grad()
@torch.autocast(device_type = 'cuda', dtype = torch.bfloat16)
def postprocess(self, current_out):
maskmem_features, maskmem_pos_enc = self.predictor._encode_new_memory(
current_vision_feats = self.current_vision_feats,
feat_sizes = self.feat_sizes,
pred_masks_high_res = current_out["pred_masks_high_res"],
object_score_logits = current_out['object_score_logits'],
is_mask_from_pts = False
)
current_out["maskmem_features"] = maskmem_features.to(torch.bfloat16)
current_out["maskmem_pos_enc"] = maskmem_pos_enc
self.pred_maskses.append(current_out['pred_masks_high_res'].cpu())
self.output_dict['cond_frame_outputs'if self.current_idx == 0 else 'non_cond_frame_outputs'][self.current_idx] = current_out
if len(self.output_dict['non_cond_frame_outputs']) > 256:
self.output_dict['non_cond_frame_outputs'] = {k: v for k, v in self.output_dict['non_cond_frame_outputs'].items() if k >= self.current_idx - 256}
if len(self.output_dict['cond_frame_outputs']) > 256:
self.output_dict['cond_frame_outputs'] = {k: v for k, v in self.output_dict['cond_frame_outputs'].items() if k >= self.current_idx - 256}
self.current_idx += 1
@torch.no_grad()
@torch.autocast(device_type = 'cuda', dtype = torch.bfloat16)
def forward_original(predictor: SAM2Base, images, points = None, labels = None, mask_inputs = None, processing_order = None, device = 'cuda'):
B, T, _, H, W = images.shape
point_inputs = {"point_coords": points, "point_labels": labels} if points is not None and labels is not None else None
assert (mask_inputs is None) ^ (point_inputs is None), f"mask_inputs: {mask_inputs}, point_inputs: {point_inputs}"
processing_order = list(range(images.shape[1])) if processing_order is None else processing_order
num_frames = len(processing_order)
pred_maskses = []
ious = []
output_dict = {'cond_frame_outputs': {}, 'non_cond_frame_outputs': {}}
for idx, frame_idx in enumerate(tqdm(processing_order)):
(
_,
current_vision_feats,
current_vision_pos_embeds,
feat_sizes,
) = get_image_feature(predictor, images = images[:, frame_idx, :3].to(device))
memory_dict = {'cond_frame_outputs': output_dict['cond_frame_outputs'], 'non_cond_frame_outputs': {k: v for k, v in output_dict['non_cond_frame_outputs'].items() if (v['pred_masks'] > 0).any()}}
if len(memory_dict['non_cond_frame_outputs']) > 0:
memory_dict['non_cond_frame_outputs'] = {idx - i: v for i, (k, v) in enumerate(sorted(memory_dict['non_cond_frame_outputs'].items(), key = lambda x: abs(x[0] - idx))[:24])}
current_out = predictor.track_step(
frame_idx = idx,
is_init_cond_frame = idx == 0,
current_vision_feats = current_vision_feats,
current_vision_pos_embeds = current_vision_pos_embeds,
feat_sizes = feat_sizes,
point_inputs = point_inputs if idx == 0 else None,
mask_inputs = mask_inputs if idx == 0 else None,
output_dict = memory_dict,
num_frames = num_frames,
track_in_reverse = False,
run_mem_encoder = False,
prev_sam_mask_logits = None,
)
current_out['ppred_masks_high_res_lq'] = current_out['pred_masks_high_res']
current_out["pred_masks"] = fill_holes_in_mask_scores(
current_out["pred_masks"], predictor.fill_hole_area
)
current_out["pred_masks_high_res"] = torch.nn.functional.interpolate(
current_out["pred_masks"],
size = (predictor.image_size, predictor.image_size),
mode = "bilinear",
align_corners = False,
)
maskmem_features, maskmem_pos_enc = predictor._encode_new_memory(
current_vision_feats = current_vision_feats,
feat_sizes = feat_sizes,
pred_masks_high_res = current_out["pred_masks_high_res"],
object_score_logits = current_out['object_score_logits'],
is_mask_from_pts = True
)
current_out["maskmem_features"] = maskmem_features.to(torch.bfloat16)
current_out["maskmem_pos_enc"] = maskmem_pos_enc
pred_maskses.append(current_out['pred_masks_high_res'].cpu())
output_dict['cond_frame_outputs'if idx == 0 else 'non_cond_frame_outputs'][idx] = current_out
if len(output_dict['non_cond_frame_outputs']) > 256:
output_dict['non_cond_frame_outputs'] = {k: v for k, v in output_dict['non_cond_frame_outputs'].items() if k >= idx - 256}
pred_maskses = torch.stack(pred_maskses, dim = 1).squeeze(2) # (B, T, H, W)
assert pred_maskses.shape == (B, len(processing_order), H, W)
return pred_maskses
@torch.no_grad()
def get_single_frame_mask(image: torch.Tensor, predictor_original, points, labels, device = 'cuda'):
'''
points: 1 x N x 2
labels: 1 x N (positive 1, negative 0, box (top left 2, low right 3))
'''
return forward_original(
predictor_original.to(device),
images = image.squeeze()[None, None],
points = points,
labels = labels,
processing_order = [0],
device = device
)
@torch.no_grad()
def get_tracked_masks(sam2_input_images, must3r_feats, must3r_outputs, start_idx, first_frame_mask, predictor, predictor_original, device = 'cuda'):
tracker = Tracker(predictor, predictor_original = predictor_original, device = device)
tracker.init(
images = sam2_input_images.squeeze()[None],
processing_order = range(start_idx, sam2_input_images.shape[0]),
mask_inputs = first_frame_mask.squeeze()[None, None] > 0,
must3r_feats = must3r_feats,
explicit_3d = torch.cat((must3r_outputs['pts3d'], must3r_outputs['ray_plucker']), dim = -1).permute(0, 3, 1, 2)
)
output_masks = {}
for idx, frame_idx in enumerate(tqdm(tracker.processing_order)):
current_out = tracker.step()
output_masks[frame_idx] = current_out['pred_masks_high_res'].squeeze().cpu().numpy() > 0
tracker.postprocess(current_out)
tracker.init(
images = sam2_input_images.squeeze()[None],
processing_order = range(start_idx, -1, -1),
mask_inputs = first_frame_mask.squeeze()[None, None] > 0,
must3r_feats = must3r_feats,
explicit_3d = torch.cat((must3r_outputs['pts3d'], must3r_outputs['ray_plucker']), dim = -1).permute(0, 3, 1, 2)
)
for idx, frame_idx in enumerate(tqdm(tracker.processing_order)):
current_out = tracker.step()
output_masks[frame_idx] = current_out['pred_masks_high_res'].squeeze().cpu().numpy() > 0
tracker.postprocess(current_out)
return output_masks