|
|
|
|
|
|
|
|
|
|
|
|
| import logging
|
|
|
| import numpy as np
|
| import torch
|
| import torch.distributed
|
| from sam2.modeling.sam2_base import SAM2Base
|
| from sam2.modeling.sam2_utils import (
|
| get_1d_sine_pe,
|
| get_next_point,
|
| sample_box_points,
|
| select_closest_cond_frames,
|
| )
|
|
|
| from sam2.utils.misc import concat_points
|
|
|
| from training.utils.data_utils import BatchedVideoDatapoint
|
|
|
|
|
| class SAM2Train(SAM2Base):
|
| def __init__(
|
| self,
|
| image_encoder,
|
| memory_attention=None,
|
| memory_encoder=None,
|
| prob_to_use_pt_input_for_train=0.0,
|
| prob_to_use_pt_input_for_eval=0.0,
|
| prob_to_use_box_input_for_train=0.0,
|
| prob_to_use_box_input_for_eval=0.0,
|
|
|
| num_frames_to_correct_for_train=1,
|
| num_frames_to_correct_for_eval=1,
|
| rand_frames_to_correct_for_train=False,
|
| rand_frames_to_correct_for_eval=False,
|
|
|
|
|
|
|
|
|
|
|
|
|
| num_init_cond_frames_for_train=1,
|
| num_init_cond_frames_for_eval=1,
|
| rand_init_cond_frames_for_train=True,
|
| rand_init_cond_frames_for_eval=False,
|
|
|
|
|
| add_all_frames_to_correct_as_cond=False,
|
|
|
|
|
| num_correction_pt_per_frame=7,
|
|
|
|
|
|
|
| pt_sampling_for_eval="center",
|
|
|
|
|
|
|
| prob_to_sample_from_gt_for_train=0.0,
|
| use_act_ckpt_iterative_pt_sampling=False,
|
|
|
|
|
| forward_backbone_per_frame_for_eval=False,
|
| freeze_image_encoder=False,
|
| **kwargs,
|
| ):
|
| super().__init__(image_encoder, memory_attention, memory_encoder, **kwargs)
|
| self.use_act_ckpt_iterative_pt_sampling = use_act_ckpt_iterative_pt_sampling
|
| self.forward_backbone_per_frame_for_eval = forward_backbone_per_frame_for_eval
|
|
|
|
|
| self.prob_to_use_pt_input_for_train = prob_to_use_pt_input_for_train
|
| self.prob_to_use_box_input_for_train = prob_to_use_box_input_for_train
|
| self.prob_to_use_pt_input_for_eval = prob_to_use_pt_input_for_eval
|
| self.prob_to_use_box_input_for_eval = prob_to_use_box_input_for_eval
|
| if prob_to_use_pt_input_for_train > 0 or prob_to_use_pt_input_for_eval > 0:
|
| logging.info(
|
| f"Training with points (sampled from masks) as inputs with p={prob_to_use_pt_input_for_train}"
|
| )
|
| assert num_frames_to_correct_for_train >= num_init_cond_frames_for_train
|
| assert num_frames_to_correct_for_eval >= num_init_cond_frames_for_eval
|
|
|
| self.num_frames_to_correct_for_train = num_frames_to_correct_for_train
|
| self.num_frames_to_correct_for_eval = num_frames_to_correct_for_eval
|
| self.rand_frames_to_correct_for_train = rand_frames_to_correct_for_train
|
| self.rand_frames_to_correct_for_eval = rand_frames_to_correct_for_eval
|
|
|
| self.num_init_cond_frames_for_train = num_init_cond_frames_for_train
|
| self.num_init_cond_frames_for_eval = num_init_cond_frames_for_eval
|
| self.rand_init_cond_frames_for_train = rand_init_cond_frames_for_train
|
| self.rand_init_cond_frames_for_eval = rand_init_cond_frames_for_eval
|
| self.add_all_frames_to_correct_as_cond = add_all_frames_to_correct_as_cond
|
| self.num_correction_pt_per_frame = num_correction_pt_per_frame
|
| self.pt_sampling_for_eval = pt_sampling_for_eval
|
| self.prob_to_sample_from_gt_for_train = prob_to_sample_from_gt_for_train
|
|
|
| self.rng = np.random.default_rng(seed=42)
|
|
|
| if freeze_image_encoder:
|
| for p in self.image_encoder.parameters():
|
| p.requires_grad = False
|
|
|
| def forward(self, input: BatchedVideoDatapoint):
|
| if self.training or not self.forward_backbone_per_frame_for_eval:
|
|
|
| backbone_out = self.forward_image(input.flat_img_batch)
|
| else:
|
|
|
| backbone_out = {"backbone_fpn": None, "vision_pos_enc": None}
|
| backbone_out = self.prepare_prompt_inputs(backbone_out, input)
|
| previous_stages_out = self.forward_tracking(backbone_out, input)
|
|
|
| return previous_stages_out
|
|
|
| def _prepare_backbone_features_per_frame(self, img_batch, img_ids):
|
| """Compute the image backbone features on the fly for the given img_ids."""
|
|
|
|
|
| if img_ids.numel() > 1:
|
| unique_img_ids, inv_ids = torch.unique(img_ids, return_inverse=True)
|
| else:
|
| unique_img_ids, inv_ids = img_ids, None
|
|
|
|
|
| image = img_batch[unique_img_ids]
|
| backbone_out = self.forward_image(image)
|
| (
|
| _,
|
| vision_feats,
|
| vision_pos_embeds,
|
| feat_sizes,
|
| ) = self._prepare_backbone_features(backbone_out)
|
|
|
|
|
| if inv_ids is not None:
|
| image = image[inv_ids]
|
| vision_feats = [x[:, inv_ids] for x in vision_feats]
|
| vision_pos_embeds = [x[:, inv_ids] for x in vision_pos_embeds]
|
|
|
| return image, vision_feats, vision_pos_embeds, feat_sizes
|
|
|
| def prepare_prompt_inputs(self, backbone_out, input, start_frame_idx=0):
|
| """
|
| Prepare input mask, point or box prompts. Optionally, we allow tracking from
|
| a custom `start_frame_idx` to the end of the video (for evaluation purposes).
|
| """
|
|
|
|
|
|
|
|
|
|
|
|
|
| gt_masks_per_frame = {
|
| stage_id: masks.unsqueeze(1)
|
| for stage_id, masks in enumerate(input.masks)
|
| }
|
|
|
| backbone_out["gt_masks_per_frame"] = gt_masks_per_frame
|
| num_frames = input.num_frames
|
| backbone_out["num_frames"] = num_frames
|
|
|
|
|
| if self.training:
|
| prob_to_use_pt_input = self.prob_to_use_pt_input_for_train
|
| prob_to_use_box_input = self.prob_to_use_box_input_for_train
|
| num_frames_to_correct = self.num_frames_to_correct_for_train
|
| rand_frames_to_correct = self.rand_frames_to_correct_for_train
|
| num_init_cond_frames = self.num_init_cond_frames_for_train
|
| rand_init_cond_frames = self.rand_init_cond_frames_for_train
|
| else:
|
| prob_to_use_pt_input = self.prob_to_use_pt_input_for_eval
|
| prob_to_use_box_input = self.prob_to_use_box_input_for_eval
|
| num_frames_to_correct = self.num_frames_to_correct_for_eval
|
| rand_frames_to_correct = self.rand_frames_to_correct_for_eval
|
| num_init_cond_frames = self.num_init_cond_frames_for_eval
|
| rand_init_cond_frames = self.rand_init_cond_frames_for_eval
|
| if num_frames == 1:
|
|
|
|
|
| prob_to_use_pt_input = 1.0
|
| num_frames_to_correct = 1
|
| num_init_cond_frames = 1
|
| assert num_init_cond_frames >= 1
|
|
|
| use_pt_input = self.rng.random() < prob_to_use_pt_input
|
| if rand_init_cond_frames and num_init_cond_frames > 1:
|
|
|
| num_init_cond_frames = self.rng.integers(
|
| 1, num_init_cond_frames, endpoint=True
|
| )
|
| if (
|
| use_pt_input
|
| and rand_frames_to_correct
|
| and num_frames_to_correct > num_init_cond_frames
|
| ):
|
|
|
|
|
| num_frames_to_correct = self.rng.integers(
|
| num_init_cond_frames, num_frames_to_correct, endpoint=True
|
| )
|
| backbone_out["use_pt_input"] = use_pt_input
|
|
|
|
|
| if num_init_cond_frames == 1:
|
| init_cond_frames = [start_frame_idx]
|
| else:
|
|
|
| init_cond_frames = [start_frame_idx] + self.rng.choice(
|
| range(start_frame_idx + 1, num_frames),
|
| num_init_cond_frames - 1,
|
| replace=False,
|
| ).tolist()
|
| backbone_out["init_cond_frames"] = init_cond_frames
|
| backbone_out["frames_not_in_init_cond"] = [
|
| t for t in range(start_frame_idx, num_frames) if t not in init_cond_frames
|
| ]
|
|
|
| backbone_out["mask_inputs_per_frame"] = {}
|
| backbone_out["point_inputs_per_frame"] = {}
|
| for t in init_cond_frames:
|
| if not use_pt_input:
|
| backbone_out["mask_inputs_per_frame"][t] = gt_masks_per_frame[t]
|
| else:
|
|
|
| use_box_input = self.rng.random() < prob_to_use_box_input
|
| if use_box_input:
|
| points, labels = sample_box_points(
|
| gt_masks_per_frame[t],
|
| )
|
| else:
|
|
|
|
|
| points, labels = get_next_point(
|
| gt_masks=gt_masks_per_frame[t],
|
| pred_masks=None,
|
| method=(
|
| "uniform" if self.training else self.pt_sampling_for_eval
|
| ),
|
| )
|
|
|
| point_inputs = {"point_coords": points, "point_labels": labels}
|
| backbone_out["point_inputs_per_frame"][t] = point_inputs
|
|
|
|
|
|
|
| if not use_pt_input:
|
|
|
| frames_to_add_correction_pt = []
|
| elif num_frames_to_correct == num_init_cond_frames:
|
| frames_to_add_correction_pt = init_cond_frames
|
| else:
|
| assert num_frames_to_correct > num_init_cond_frames
|
|
|
| extra_num = num_frames_to_correct - num_init_cond_frames
|
| frames_to_add_correction_pt = (
|
| init_cond_frames
|
| + self.rng.choice(
|
| backbone_out["frames_not_in_init_cond"], extra_num, replace=False
|
| ).tolist()
|
| )
|
| backbone_out["frames_to_add_correction_pt"] = frames_to_add_correction_pt
|
|
|
| return backbone_out
|
|
|
| def forward_tracking(
|
| self, backbone_out, input: BatchedVideoDatapoint, return_dict=False
|
| ):
|
| """Forward video tracking on each frame (and sample correction clicks)."""
|
| img_feats_already_computed = backbone_out["backbone_fpn"] is not None
|
| if img_feats_already_computed:
|
|
|
|
|
| (
|
| _,
|
| vision_feats,
|
| vision_pos_embeds,
|
| feat_sizes,
|
| ) = self._prepare_backbone_features(backbone_out)
|
|
|
|
|
| num_frames = backbone_out["num_frames"]
|
| init_cond_frames = backbone_out["init_cond_frames"]
|
| frames_to_add_correction_pt = backbone_out["frames_to_add_correction_pt"]
|
|
|
|
|
| processing_order = init_cond_frames + backbone_out["frames_not_in_init_cond"]
|
| output_dict = {
|
| "cond_frame_outputs": {},
|
| "non_cond_frame_outputs": {},
|
| }
|
| for stage_id in processing_order:
|
|
|
|
|
| img_ids = input.flat_obj_to_img_idx[stage_id]
|
| if img_feats_already_computed:
|
|
|
| current_vision_feats = [x[:, img_ids] for x in vision_feats]
|
| current_vision_pos_embeds = [x[:, img_ids] for x in vision_pos_embeds]
|
| else:
|
|
|
|
|
| (
|
| _,
|
| current_vision_feats,
|
| current_vision_pos_embeds,
|
| feat_sizes,
|
| ) = self._prepare_backbone_features_per_frame(
|
| input.flat_img_batch, img_ids
|
| )
|
|
|
|
|
| current_out = self.track_step(
|
| frame_idx=stage_id,
|
| is_init_cond_frame=stage_id in init_cond_frames,
|
| current_vision_feats=current_vision_feats,
|
| current_vision_pos_embeds=current_vision_pos_embeds,
|
| feat_sizes=feat_sizes,
|
| point_inputs=backbone_out["point_inputs_per_frame"].get(stage_id, None),
|
| mask_inputs=backbone_out["mask_inputs_per_frame"].get(stage_id, None),
|
| gt_masks=backbone_out["gt_masks_per_frame"].get(stage_id, None),
|
| frames_to_add_correction_pt=frames_to_add_correction_pt,
|
| output_dict=output_dict,
|
| num_frames=num_frames,
|
| )
|
|
|
| add_output_as_cond_frame = stage_id in init_cond_frames or (
|
| self.add_all_frames_to_correct_as_cond
|
| and stage_id in frames_to_add_correction_pt
|
| )
|
| if add_output_as_cond_frame:
|
| output_dict["cond_frame_outputs"][stage_id] = current_out
|
| else:
|
| output_dict["non_cond_frame_outputs"][stage_id] = current_out
|
|
|
| if return_dict:
|
| return output_dict
|
|
|
| all_frame_outputs = {}
|
| all_frame_outputs.update(output_dict["cond_frame_outputs"])
|
| all_frame_outputs.update(output_dict["non_cond_frame_outputs"])
|
| all_frame_outputs = [all_frame_outputs[t] for t in range(num_frames)]
|
|
|
| all_frame_outputs = [
|
| {k: v for k, v in d.items() if k != "obj_ptr"} for d in all_frame_outputs
|
| ]
|
|
|
| return all_frame_outputs
|
|
|
| def track_step(
|
| self,
|
| frame_idx,
|
| is_init_cond_frame,
|
| current_vision_feats,
|
| current_vision_pos_embeds,
|
| feat_sizes,
|
| point_inputs,
|
| mask_inputs,
|
| output_dict,
|
| num_frames,
|
| track_in_reverse=False,
|
| run_mem_encoder=True,
|
| prev_sam_mask_logits=None,
|
| frames_to_add_correction_pt=None,
|
| gt_masks=None,
|
| ):
|
| if frames_to_add_correction_pt is None:
|
| frames_to_add_correction_pt = []
|
| current_out, sam_outputs, high_res_features, pix_feat = self._track_step(
|
| frame_idx,
|
| is_init_cond_frame,
|
| current_vision_feats,
|
| current_vision_pos_embeds,
|
| feat_sizes,
|
| point_inputs,
|
| mask_inputs,
|
| output_dict,
|
| num_frames,
|
| track_in_reverse,
|
| prev_sam_mask_logits,
|
| )
|
|
|
| (
|
| low_res_multimasks,
|
| high_res_multimasks,
|
| ious,
|
| low_res_masks,
|
| high_res_masks,
|
| obj_ptr,
|
| object_score_logits,
|
| ) = sam_outputs
|
|
|
| current_out["multistep_pred_masks"] = low_res_masks
|
| current_out["multistep_pred_masks_high_res"] = high_res_masks
|
| current_out["multistep_pred_multimasks"] = [low_res_multimasks]
|
| current_out["multistep_pred_multimasks_high_res"] = [high_res_multimasks]
|
| current_out["multistep_pred_ious"] = [ious]
|
| current_out["multistep_point_inputs"] = [point_inputs]
|
| current_out["multistep_object_score_logits"] = [object_score_logits]
|
|
|
|
|
| if frame_idx in frames_to_add_correction_pt:
|
| point_inputs, final_sam_outputs = self._iter_correct_pt_sampling(
|
| is_init_cond_frame,
|
| point_inputs,
|
| gt_masks,
|
| high_res_features,
|
| pix_feat,
|
| low_res_multimasks,
|
| high_res_multimasks,
|
| ious,
|
| low_res_masks,
|
| high_res_masks,
|
| object_score_logits,
|
| current_out,
|
| )
|
| (
|
| _,
|
| _,
|
| _,
|
| low_res_masks,
|
| high_res_masks,
|
| obj_ptr,
|
| object_score_logits,
|
| ) = final_sam_outputs
|
|
|
|
|
| current_out["pred_masks"] = low_res_masks
|
| current_out["pred_masks_high_res"] = high_res_masks
|
| current_out["obj_ptr"] = obj_ptr
|
|
|
|
|
|
|
| self._encode_memory_in_output(
|
| current_vision_feats,
|
| feat_sizes,
|
| point_inputs,
|
| run_mem_encoder,
|
| high_res_masks,
|
| object_score_logits,
|
| current_out,
|
| )
|
| return current_out
|
|
|
| def _iter_correct_pt_sampling(
|
| self,
|
| is_init_cond_frame,
|
| point_inputs,
|
| gt_masks,
|
| high_res_features,
|
| pix_feat_with_mem,
|
| low_res_multimasks,
|
| high_res_multimasks,
|
| ious,
|
| low_res_masks,
|
| high_res_masks,
|
| object_score_logits,
|
| current_out,
|
| ):
|
|
|
| assert gt_masks is not None
|
| all_pred_masks = [low_res_masks]
|
| all_pred_high_res_masks = [high_res_masks]
|
| all_pred_multimasks = [low_res_multimasks]
|
| all_pred_high_res_multimasks = [high_res_multimasks]
|
| all_pred_ious = [ious]
|
| all_point_inputs = [point_inputs]
|
| all_object_score_logits = [object_score_logits]
|
| for _ in range(self.num_correction_pt_per_frame):
|
|
|
|
|
| if self.training and self.prob_to_sample_from_gt_for_train > 0:
|
| sample_from_gt = (
|
| self.rng.random() < self.prob_to_sample_from_gt_for_train
|
| )
|
| else:
|
| sample_from_gt = False
|
|
|
| pred_for_new_pt = None if sample_from_gt else (high_res_masks > 0)
|
| new_points, new_labels = get_next_point(
|
| gt_masks=gt_masks,
|
| pred_masks=pred_for_new_pt,
|
| method="uniform" if self.training else self.pt_sampling_for_eval,
|
| )
|
| point_inputs = concat_points(point_inputs, new_points, new_labels)
|
|
|
|
|
|
|
| mask_inputs = low_res_masks
|
| multimask_output = self._use_multimask(is_init_cond_frame, point_inputs)
|
| if self.use_act_ckpt_iterative_pt_sampling and not multimask_output:
|
| sam_outputs = torch.utils.checkpoint.checkpoint(
|
| self._forward_sam_heads,
|
| backbone_features=pix_feat_with_mem,
|
| point_inputs=point_inputs,
|
| mask_inputs=mask_inputs,
|
| high_res_features=high_res_features,
|
| multimask_output=multimask_output,
|
| use_reentrant=False,
|
| )
|
| else:
|
| sam_outputs = self._forward_sam_heads(
|
| backbone_features=pix_feat_with_mem,
|
| point_inputs=point_inputs,
|
| mask_inputs=mask_inputs,
|
| high_res_features=high_res_features,
|
| multimask_output=multimask_output,
|
| )
|
| (
|
| low_res_multimasks,
|
| high_res_multimasks,
|
| ious,
|
| low_res_masks,
|
| high_res_masks,
|
| _,
|
| object_score_logits,
|
| ) = sam_outputs
|
| all_pred_masks.append(low_res_masks)
|
| all_pred_high_res_masks.append(high_res_masks)
|
| all_pred_multimasks.append(low_res_multimasks)
|
| all_pred_high_res_multimasks.append(high_res_multimasks)
|
| all_pred_ious.append(ious)
|
| all_point_inputs.append(point_inputs)
|
| all_object_score_logits.append(object_score_logits)
|
|
|
|
|
|
|
| current_out["multistep_pred_masks"] = torch.cat(all_pred_masks, dim=1)
|
| current_out["multistep_pred_masks_high_res"] = torch.cat(
|
| all_pred_high_res_masks, dim=1
|
| )
|
| current_out["multistep_pred_multimasks"] = all_pred_multimasks
|
| current_out["multistep_pred_multimasks_high_res"] = all_pred_high_res_multimasks
|
| current_out["multistep_pred_ious"] = all_pred_ious
|
| current_out["multistep_point_inputs"] = all_point_inputs
|
| current_out["multistep_object_score_logits"] = all_object_score_logits
|
|
|
| return point_inputs, sam_outputs
|
|
|