| |
|
|
| import os |
| from typing import Optional |
|
|
| import torch |
| import torch.nn as nn |
| from huggingface_hub import hf_hub_download |
| from iopath.common.file_io import g_pathmgr |
| from sam3.model.decoder import ( |
| TransformerDecoder, |
| TransformerDecoderLayer, |
| TransformerDecoderLayerv2, |
| TransformerEncoderCrossAttention, |
| ) |
| from sam3.model.encoder import TransformerEncoderFusion, TransformerEncoderLayer |
| from sam3.model.geometry_encoders import SequenceGeometryEncoder |
| from sam3.model.maskformer_segmentation import PixelDecoder, UniversalSegmentationHead |
| from sam3.model.memory import ( |
| CXBlock, |
| SimpleFuser, |
| SimpleMaskDownSampler, |
| SimpleMaskEncoder, |
| ) |
| from sam3.model.model_misc import ( |
| DotProductScoring, |
| MLP, |
| MultiheadAttentionWrapper as MultiheadAttention, |
| TransformerWrapper, |
| ) |
| from sam3.model.necks import Sam3DualViTDetNeck |
| from sam3.model.position_encoding import PositionEmbeddingSine |
| from sam3.model.sam1_task_predictor import SAM3InteractiveImagePredictor |
| from sam3.model.sam3_image import Sam3Image, Sam3ImageOnVideoMultiGPU |
| from sam3.model.sam3_tracking_predictor import Sam3TrackerPredictor |
| from sam3.model.sam3_video_inference import Sam3VideoInferenceWithInstanceInteractivity |
| from sam3.model.sam3_video_predictor import Sam3VideoPredictorMultiGPU |
| from sam3.model.text_encoder_ve import VETextEncoder |
| from sam3.model.tokenizer_ve import SimpleTokenizer |
| from sam3.model.vitdet import ViT |
| from sam3.model.vl_combiner import SAM3VLBackbone |
| from sam3.sam.transformer import RoPEAttention |
|
|
|
|
| |
| def _setup_tf32() -> None: |
| """Enable TensorFloat-32 for Ampere GPUs if available.""" |
| if torch.cuda.is_available(): |
| device_props = torch.cuda.get_device_properties(0) |
| if device_props.major >= 8: |
| torch.backends.cuda.matmul.allow_tf32 = True |
| torch.backends.cudnn.allow_tf32 = True |
|
|
|
|
| _setup_tf32() |
|
|
|
|
| def _create_position_encoding(precompute_resolution=None): |
| """Create position encoding for visual backbone.""" |
| return PositionEmbeddingSine( |
| num_pos_feats=256, |
| normalize=True, |
| scale=None, |
| temperature=10000, |
| precompute_resolution=precompute_resolution, |
| ) |
|
|
|
|
| def _create_vit_backbone(compile_mode=None): |
| """Create ViT backbone for visual feature extraction.""" |
| return ViT( |
| img_size=1008, |
| pretrain_img_size=336, |
| patch_size=14, |
| embed_dim=1024, |
| depth=32, |
| num_heads=16, |
| mlp_ratio=4.625, |
| norm_layer="LayerNorm", |
| drop_path_rate=0.1, |
| qkv_bias=True, |
| use_abs_pos=True, |
| tile_abs_pos=True, |
| global_att_blocks=(7, 15, 23, 31), |
| rel_pos_blocks=(), |
| use_rope=True, |
| use_interp_rope=True, |
| window_size=24, |
| pretrain_use_cls_token=True, |
| retain_cls_token=False, |
| ln_pre=True, |
| ln_post=False, |
| return_interm_layers=False, |
| bias_patch_embed=False, |
| compile_mode=compile_mode, |
| ) |
|
|
|
|
| def _create_vit_neck(position_encoding, vit_backbone, enable_inst_interactivity=False): |
| """Create ViT neck for feature pyramid.""" |
| return Sam3DualViTDetNeck( |
| position_encoding=position_encoding, |
| d_model=256, |
| scale_factors=[4.0, 2.0, 1.0, 0.5], |
| trunk=vit_backbone, |
| add_sam2_neck=enable_inst_interactivity, |
| ) |
|
|
|
|
| def _create_vl_backbone(vit_neck, text_encoder): |
| """Create visual-language backbone.""" |
| return SAM3VLBackbone(visual=vit_neck, text=text_encoder, scalp=1) |
|
|
|
|
| def _create_transformer_encoder() -> TransformerEncoderFusion: |
| """Create transformer encoder with its layer.""" |
| encoder_layer = TransformerEncoderLayer( |
| activation="relu", |
| d_model=256, |
| dim_feedforward=2048, |
| dropout=0.1, |
| pos_enc_at_attn=True, |
| pos_enc_at_cross_attn_keys=False, |
| pos_enc_at_cross_attn_queries=False, |
| pre_norm=True, |
| self_attention=MultiheadAttention( |
| num_heads=8, |
| dropout=0.1, |
| embed_dim=256, |
| batch_first=True, |
| ), |
| cross_attention=MultiheadAttention( |
| num_heads=8, |
| dropout=0.1, |
| embed_dim=256, |
| batch_first=True, |
| ), |
| ) |
|
|
| encoder = TransformerEncoderFusion( |
| layer=encoder_layer, |
| num_layers=6, |
| d_model=256, |
| num_feature_levels=1, |
| frozen=False, |
| use_act_checkpoint=True, |
| add_pooled_text_to_img_feat=False, |
| pool_text_with_mask=True, |
| ) |
| return encoder |
|
|
|
|
| def _create_transformer_decoder() -> TransformerDecoder: |
| """Create transformer decoder with its layer.""" |
| decoder_layer = TransformerDecoderLayer( |
| activation="relu", |
| d_model=256, |
| dim_feedforward=2048, |
| dropout=0.1, |
| cross_attention=MultiheadAttention( |
| num_heads=8, |
| dropout=0.1, |
| embed_dim=256, |
| ), |
| n_heads=8, |
| use_text_cross_attention=True, |
| ) |
|
|
| decoder = TransformerDecoder( |
| layer=decoder_layer, |
| num_layers=6, |
| num_queries=200, |
| return_intermediate=True, |
| box_refine=True, |
| num_o2m_queries=0, |
| dac=True, |
| boxRPB="log", |
| d_model=256, |
| frozen=False, |
| interaction_layer=None, |
| dac_use_selfatt_ln=True, |
| resolution=1008, |
| stride=14, |
| use_act_checkpoint=True, |
| presence_token=True, |
| ) |
| return decoder |
|
|
|
|
| def _create_dot_product_scoring(): |
| """Create dot product scoring module.""" |
| prompt_mlp = MLP( |
| input_dim=256, |
| hidden_dim=2048, |
| output_dim=256, |
| num_layers=2, |
| dropout=0.1, |
| residual=True, |
| out_norm=nn.LayerNorm(256), |
| ) |
| return DotProductScoring(d_model=256, d_proj=256, prompt_mlp=prompt_mlp) |
|
|
|
|
| def _create_segmentation_head(compile_mode=None): |
| """Create segmentation head with pixel decoder.""" |
| pixel_decoder = PixelDecoder( |
| num_upsampling_stages=3, |
| interpolation_mode="nearest", |
| hidden_dim=256, |
| compile_mode=compile_mode, |
| ) |
|
|
| cross_attend_prompt = MultiheadAttention( |
| num_heads=8, |
| dropout=0, |
| embed_dim=256, |
| ) |
|
|
| segmentation_head = UniversalSegmentationHead( |
| hidden_dim=256, |
| upsampling_stages=3, |
| aux_masks=False, |
| presence_head=False, |
| dot_product_scorer=None, |
| act_ckpt=True, |
| cross_attend_prompt=cross_attend_prompt, |
| pixel_decoder=pixel_decoder, |
| ) |
| return segmentation_head |
|
|
|
|
| def _create_geometry_encoder(): |
| """Create geometry encoder with all its components.""" |
| |
| geo_pos_enc = _create_position_encoding() |
| |
| cx_block = CXBlock( |
| dim=256, |
| kernel_size=7, |
| padding=3, |
| layer_scale_init_value=1.0e-06, |
| use_dwconv=True, |
| ) |
| |
| geo_layer = TransformerEncoderLayer( |
| activation="relu", |
| d_model=256, |
| dim_feedforward=2048, |
| dropout=0.1, |
| pos_enc_at_attn=False, |
| pre_norm=True, |
| self_attention=MultiheadAttention( |
| num_heads=8, |
| dropout=0.1, |
| embed_dim=256, |
| batch_first=False, |
| ), |
| pos_enc_at_cross_attn_queries=False, |
| pos_enc_at_cross_attn_keys=True, |
| cross_attention=MultiheadAttention( |
| num_heads=8, |
| dropout=0.1, |
| embed_dim=256, |
| batch_first=False, |
| ), |
| ) |
|
|
| |
| input_geometry_encoder = SequenceGeometryEncoder( |
| pos_enc=geo_pos_enc, |
| encode_boxes_as_points=False, |
| points_direct_project=True, |
| points_pool=True, |
| points_pos_enc=True, |
| boxes_direct_project=True, |
| boxes_pool=True, |
| boxes_pos_enc=True, |
| d_model=256, |
| num_layers=3, |
| layer=geo_layer, |
| use_act_ckpt=True, |
| add_cls=True, |
| add_post_encode_proj=True, |
| ) |
| return input_geometry_encoder |
|
|
|
|
| def _create_sam3_model( |
| backbone, |
| transformer, |
| input_geometry_encoder, |
| segmentation_head, |
| dot_prod_scoring, |
| inst_interactive_predictor, |
| eval_mode, |
| ): |
| """Create the SAM3 image model.""" |
| common_params = { |
| "backbone": backbone, |
| "transformer": transformer, |
| "input_geometry_encoder": input_geometry_encoder, |
| "segmentation_head": segmentation_head, |
| "num_feature_levels": 1, |
| "o2m_mask_predict": True, |
| "dot_prod_scoring": dot_prod_scoring, |
| "use_instance_query": False, |
| "multimask_output": True, |
| "inst_interactive_predictor": inst_interactive_predictor, |
| } |
|
|
| matcher = None |
| if not eval_mode: |
| from sam3.train.matcher import BinaryHungarianMatcherV2 |
|
|
| matcher = BinaryHungarianMatcherV2( |
| focal=True, |
| cost_class=2.0, |
| cost_bbox=5.0, |
| cost_giou=2.0, |
| alpha=0.25, |
| gamma=2, |
| stable=False, |
| ) |
| common_params["matcher"] = matcher |
| model = Sam3Image(**common_params) |
|
|
| return model |
|
|
|
|
| def _create_tracker_maskmem_backbone(): |
| """Create the SAM3 Tracker memory encoder.""" |
| |
| position_encoding = PositionEmbeddingSine( |
| num_pos_feats=64, |
| normalize=True, |
| scale=None, |
| temperature=10000, |
| precompute_resolution=1008, |
| ) |
|
|
| |
| mask_downsampler = SimpleMaskDownSampler( |
| kernel_size=3, stride=2, padding=1, interpol_size=[1152, 1152] |
| ) |
|
|
| cx_block_layer = CXBlock( |
| dim=256, |
| kernel_size=7, |
| padding=3, |
| layer_scale_init_value=1.0e-06, |
| use_dwconv=True, |
| ) |
|
|
| fuser = SimpleFuser(layer=cx_block_layer, num_layers=2) |
|
|
| maskmem_backbone = SimpleMaskEncoder( |
| out_dim=64, |
| position_encoding=position_encoding, |
| mask_downsampler=mask_downsampler, |
| fuser=fuser, |
| ) |
|
|
| return maskmem_backbone |
|
|
|
|
| def _create_tracker_transformer(): |
| """Create the SAM3 Tracker transformer components.""" |
| |
| self_attention = RoPEAttention( |
| embedding_dim=256, |
| num_heads=1, |
| downsample_rate=1, |
| dropout=0.1, |
| rope_theta=10000.0, |
| feat_sizes=[72, 72], |
| use_fa3=False, |
| use_rope_real=False, |
| ) |
|
|
| |
| cross_attention = RoPEAttention( |
| embedding_dim=256, |
| num_heads=1, |
| downsample_rate=1, |
| dropout=0.1, |
| kv_in_dim=64, |
| rope_theta=10000.0, |
| feat_sizes=[72, 72], |
| rope_k_repeat=True, |
| use_fa3=False, |
| use_rope_real=False, |
| ) |
|
|
| |
| encoder_layer = TransformerDecoderLayerv2( |
| cross_attention_first=False, |
| activation="relu", |
| dim_feedforward=2048, |
| dropout=0.1, |
| pos_enc_at_attn=False, |
| pre_norm=True, |
| self_attention=self_attention, |
| d_model=256, |
| pos_enc_at_cross_attn_keys=True, |
| pos_enc_at_cross_attn_queries=False, |
| cross_attention=cross_attention, |
| ) |
|
|
| |
| encoder = TransformerEncoderCrossAttention( |
| remove_cross_attention_layers=[], |
| batch_first=True, |
| d_model=256, |
| frozen=False, |
| pos_enc_at_input=True, |
| layer=encoder_layer, |
| num_layers=4, |
| use_act_checkpoint=False, |
| ) |
|
|
| |
| transformer = TransformerWrapper( |
| encoder=encoder, |
| decoder=None, |
| d_model=256, |
| ) |
|
|
| return transformer |
|
|
|
|
| def build_tracker( |
| apply_temporal_disambiguation: bool, with_backbone: bool = False, compile_mode=None |
| ) -> Sam3TrackerPredictor: |
| """ |
| Build the SAM3 Tracker module for video tracking. |
| |
| Returns: |
| Sam3TrackerPredictor: Wrapped SAM3 Tracker module |
| """ |
|
|
| |
| maskmem_backbone = _create_tracker_maskmem_backbone() |
| transformer = _create_tracker_transformer() |
| backbone = None |
| if with_backbone: |
| vision_backbone = _create_vision_backbone(compile_mode=compile_mode) |
| backbone = SAM3VLBackbone(scalp=1, visual=vision_backbone, text=None) |
| |
| model = Sam3TrackerPredictor( |
| image_size=1008, |
| num_maskmem=7, |
| backbone=backbone, |
| backbone_stride=14, |
| transformer=transformer, |
| maskmem_backbone=maskmem_backbone, |
| |
| multimask_output_in_sam=True, |
| |
| forward_backbone_per_frame_for_eval=True, |
| trim_past_non_cond_mem_for_eval=False, |
| |
| multimask_output_for_tracking=True, |
| multimask_min_pt_num=0, |
| multimask_max_pt_num=1, |
| |
| always_start_from_first_ann_frame=False, |
| |
| non_overlap_masks_for_mem_enc=False, |
| non_overlap_masks_for_output=False, |
| max_cond_frames_in_attn=4, |
| offload_output_to_cpu_for_eval=False, |
| |
| sam_mask_decoder_extra_args={ |
| "dynamic_multimask_via_stability": True, |
| "dynamic_multimask_stability_delta": 0.05, |
| "dynamic_multimask_stability_thresh": 0.98, |
| }, |
| clear_non_cond_mem_around_input=True, |
| fill_hole_area=0, |
| use_memory_selection=apply_temporal_disambiguation, |
| ) |
|
|
| return model |
|
|
|
|
| def _create_text_encoder(bpe_path: str) -> VETextEncoder: |
| """Create SAM3 text encoder.""" |
| tokenizer = SimpleTokenizer(bpe_path=bpe_path) |
| return VETextEncoder( |
| tokenizer=tokenizer, |
| d_model=256, |
| width=1024, |
| heads=16, |
| layers=24, |
| ) |
|
|
|
|
| def _create_vision_backbone( |
| compile_mode=None, enable_inst_interactivity=True |
| ) -> Sam3DualViTDetNeck: |
| """Create SAM3 visual backbone with ViT and neck.""" |
| |
| position_encoding = _create_position_encoding(precompute_resolution=1008) |
| |
| vit_backbone: ViT = _create_vit_backbone(compile_mode=compile_mode) |
| vit_neck: Sam3DualViTDetNeck = _create_vit_neck( |
| position_encoding, |
| vit_backbone, |
| enable_inst_interactivity=enable_inst_interactivity, |
| ) |
| |
| return vit_neck |
|
|
|
|
| def _create_sam3_transformer(has_presence_token: bool = True) -> TransformerWrapper: |
| """Create SAM3 transformer encoder and decoder.""" |
| encoder: TransformerEncoderFusion = _create_transformer_encoder() |
| decoder: TransformerDecoder = _create_transformer_decoder() |
|
|
| return TransformerWrapper(encoder=encoder, decoder=decoder, d_model=256) |
|
|
|
|
| def _load_checkpoint(model, checkpoint_path): |
| """Load model checkpoint from file.""" |
| with g_pathmgr.open(checkpoint_path, "rb") as f: |
| ckpt = torch.load(f, map_location="cpu", weights_only=True) |
| if "model" in ckpt and isinstance(ckpt["model"], dict): |
| ckpt = ckpt["model"] |
| sam3_image_ckpt = { |
| k.replace("detector.", ""): v for k, v in ckpt.items() if "detector" in k |
| } |
| if model.inst_interactive_predictor is not None: |
| sam3_image_ckpt.update( |
| { |
| k.replace("tracker.", "inst_interactive_predictor.model."): v |
| for k, v in ckpt.items() |
| if "tracker" in k |
| } |
| ) |
| missing_keys, _ = model.load_state_dict(sam3_image_ckpt, strict=False) |
| if len(missing_keys) > 0: |
| print( |
| f"loaded {checkpoint_path} and found " |
| f"missing and/or unexpected keys:\n{missing_keys=}" |
| ) |
|
|
|
|
| def _setup_device_and_mode(model, device, eval_mode): |
| """Setup model device and evaluation mode.""" |
| if device == "cuda": |
| model = model.cuda() |
| if eval_mode: |
| model.eval() |
| return model |
|
|
|
|
| def build_sam3_image_model( |
| bpe_path=None, |
| device="cuda" if torch.cuda.is_available() else "cpu", |
| eval_mode=True, |
| checkpoint_path=None, |
| load_from_HF=True, |
| enable_segmentation=True, |
| enable_inst_interactivity=False, |
| compile=False, |
| ): |
| """ |
| Build SAM3 image model |
| |
| Args: |
| bpe_path: Path to the BPE tokenizer vocabulary |
| device: Device to load the model on ('cuda' or 'cpu') |
| eval_mode: Whether to set the model to evaluation mode |
| checkpoint_path: Optional path to model checkpoint |
| enable_segmentation: Whether to enable segmentation head |
| enable_inst_interactivity: Whether to enable instance interactivity (SAM 1 task) |
| compile_mode: To enable compilation, set to "default" |
| |
| Returns: |
| A SAM3 image model |
| """ |
| if bpe_path is None: |
| bpe_path = os.path.join( |
| os.path.dirname(__file__), "..", "assets", "bpe_simple_vocab_16e6.txt.gz" |
| ) |
| |
| compile_mode = "default" if compile else None |
| vision_encoder = _create_vision_backbone( |
| compile_mode=compile_mode, enable_inst_interactivity=enable_inst_interactivity |
| ) |
|
|
| |
| text_encoder = _create_text_encoder(bpe_path) |
|
|
| |
| backbone = _create_vl_backbone(vision_encoder, text_encoder) |
|
|
| |
| transformer = _create_sam3_transformer() |
|
|
| |
| dot_prod_scoring = _create_dot_product_scoring() |
|
|
| |
| segmentation_head = ( |
| _create_segmentation_head(compile_mode=compile_mode) |
| if enable_segmentation |
| else None |
| ) |
|
|
| |
| input_geometry_encoder = _create_geometry_encoder() |
| if enable_inst_interactivity: |
| sam3_pvs_base = build_tracker(apply_temporal_disambiguation=False) |
| inst_predictor = SAM3InteractiveImagePredictor(sam3_pvs_base) |
| else: |
| inst_predictor = None |
| |
| model = _create_sam3_model( |
| backbone, |
| transformer, |
| input_geometry_encoder, |
| segmentation_head, |
| dot_prod_scoring, |
| inst_predictor, |
| eval_mode, |
| ) |
| if load_from_HF and checkpoint_path is None: |
| checkpoint_path = download_ckpt_from_hf() |
| |
| if checkpoint_path is not None: |
| _load_checkpoint(model, checkpoint_path) |
|
|
| |
| model = _setup_device_and_mode(model, device, eval_mode) |
|
|
| return model |
|
|
|
|
| def download_ckpt_from_hf(): |
| SAM3_MODEL_ID = "facebook/sam3" |
| SAM3_CKPT_NAME = "sam3.pt" |
| SAM3_CFG_NAME = "config.json" |
| _ = hf_hub_download(repo_id=SAM3_MODEL_ID, filename=SAM3_CFG_NAME) |
| checkpoint_path = hf_hub_download(repo_id=SAM3_MODEL_ID, filename=SAM3_CKPT_NAME) |
| return checkpoint_path |
|
|
|
|
| def build_sam3_video_model( |
| checkpoint_path: Optional[str] = None, |
| load_from_HF=True, |
| bpe_path: Optional[str] = None, |
| has_presence_token: bool = True, |
| geo_encoder_use_img_cross_attn: bool = True, |
| strict_state_dict_loading: bool = True, |
| apply_temporal_disambiguation: bool = True, |
| device="cuda" if torch.cuda.is_available() else "cpu", |
| compile=False, |
| ) -> Sam3VideoInferenceWithInstanceInteractivity: |
| """ |
| Build SAM3 dense tracking model. |
| |
| Args: |
| checkpoint_path: Optional path to checkpoint file |
| bpe_path: Path to the BPE tokenizer file |
| |
| Returns: |
| Sam3VideoInferenceWithInstanceInteractivity: The instantiated dense tracking model |
| """ |
| if bpe_path is None: |
| bpe_path = os.path.join( |
| os.path.dirname(__file__), "..", "assets", "bpe_simple_vocab_16e6.txt.gz" |
| ) |
|
|
| |
| tracker = build_tracker(apply_temporal_disambiguation=apply_temporal_disambiguation) |
|
|
| |
| visual_neck = _create_vision_backbone() |
| text_encoder = _create_text_encoder(bpe_path) |
| backbone = SAM3VLBackbone(scalp=1, visual=visual_neck, text=text_encoder) |
| transformer = _create_sam3_transformer(has_presence_token=has_presence_token) |
| segmentation_head: UniversalSegmentationHead = _create_segmentation_head() |
| input_geometry_encoder = _create_geometry_encoder() |
|
|
| |
| main_dot_prod_mlp = MLP( |
| input_dim=256, |
| hidden_dim=2048, |
| output_dim=256, |
| num_layers=2, |
| dropout=0.1, |
| residual=True, |
| out_norm=nn.LayerNorm(256), |
| ) |
| main_dot_prod_scoring = DotProductScoring( |
| d_model=256, d_proj=256, prompt_mlp=main_dot_prod_mlp |
| ) |
|
|
| |
| detector = Sam3ImageOnVideoMultiGPU( |
| num_feature_levels=1, |
| backbone=backbone, |
| transformer=transformer, |
| segmentation_head=segmentation_head, |
| semantic_segmentation_head=None, |
| input_geometry_encoder=input_geometry_encoder, |
| use_early_fusion=True, |
| use_dot_prod_scoring=True, |
| dot_prod_scoring=main_dot_prod_scoring, |
| supervise_joint_box_scores=has_presence_token, |
| ) |
|
|
| |
| if apply_temporal_disambiguation: |
| model = Sam3VideoInferenceWithInstanceInteractivity( |
| detector=detector, |
| tracker=tracker, |
| score_threshold_detection=0.5, |
| assoc_iou_thresh=0.1, |
| det_nms_thresh=0.1, |
| new_det_thresh=0.7, |
| hotstart_delay=15, |
| hotstart_unmatch_thresh=8, |
| hotstart_dup_thresh=8, |
| suppress_unmatched_only_within_hotstart=True, |
| min_trk_keep_alive=-1, |
| max_trk_keep_alive=30, |
| init_trk_keep_alive=30, |
| suppress_overlapping_based_on_recent_occlusion_threshold=0.7, |
| suppress_det_close_to_boundary=False, |
| fill_hole_area=16, |
| recondition_every_nth_frame=16, |
| masklet_confirmation_enable=False, |
| decrease_trk_keep_alive_for_empty_masklets=False, |
| image_size=1008, |
| image_mean=(0.5, 0.5, 0.5), |
| image_std=(0.5, 0.5, 0.5), |
| compile_model=compile, |
| ) |
| else: |
| |
| model = Sam3VideoInferenceWithInstanceInteractivity( |
| detector=detector, |
| tracker=tracker, |
| score_threshold_detection=0.5, |
| assoc_iou_thresh=0.1, |
| det_nms_thresh=0.1, |
| new_det_thresh=0.7, |
| hotstart_delay=0, |
| hotstart_unmatch_thresh=0, |
| hotstart_dup_thresh=0, |
| suppress_unmatched_only_within_hotstart=True, |
| min_trk_keep_alive=-1, |
| max_trk_keep_alive=30, |
| init_trk_keep_alive=30, |
| suppress_overlapping_based_on_recent_occlusion_threshold=0.7, |
| suppress_det_close_to_boundary=False, |
| fill_hole_area=16, |
| recondition_every_nth_frame=0, |
| masklet_confirmation_enable=False, |
| decrease_trk_keep_alive_for_empty_masklets=False, |
| image_size=1008, |
| image_mean=(0.5, 0.5, 0.5), |
| image_std=(0.5, 0.5, 0.5), |
| compile_model=compile, |
| ) |
|
|
| |
| if load_from_HF and checkpoint_path is None: |
| checkpoint_path = download_ckpt_from_hf() |
| if checkpoint_path is not None: |
| with g_pathmgr.open(checkpoint_path, "rb") as f: |
| ckpt = torch.load(f, map_location="cpu", weights_only=True) |
| if "model" in ckpt and isinstance(ckpt["model"], dict): |
| ckpt = ckpt["model"] |
|
|
| missing_keys, unexpected_keys = model.load_state_dict( |
| ckpt, strict=strict_state_dict_loading |
| ) |
| if missing_keys: |
| print(f"Missing keys: {missing_keys}") |
| if unexpected_keys: |
| print(f"Unexpected keys: {unexpected_keys}") |
|
|
| model.to(device=device) |
| return model |
|
|
|
|
| def build_sam3_video_predictor(*model_args, gpus_to_use=None, **model_kwargs): |
| return Sam3VideoPredictorMultiGPU( |
| *model_args, gpus_to_use=gpus_to_use, **model_kwargs |
| ) |
|
|