# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved # pyre-unsafe """Provides utility to combine a vision backbone with a language backbone.""" from copy import copy from typing import List, Optional import torch import torch.nn as nn from torch.nn.attention import sdpa_kernel, SDPBackend from .act_ckpt_utils import activation_ckpt_wrapper from .data_misc import NestedTensor from .necks import Sam3DualViTDetNeck, Sam3TriViTDetNeck class SAM3VLBackbone(nn.Module): """This backbone combines a vision backbone and a language backbone without fusion. As such it is more of a convenience wrapper to handle the two backbones together. It adds support for activation checkpointing and compilation. """ def __init__( self, visual: Sam3DualViTDetNeck, text, compile_visual: bool = False, act_ckpt_whole_vision_backbone: bool = False, act_ckpt_whole_language_backbone: bool = False, scalp=0, ): """Initialize the backbone combiner. :param visual: The vision backbone to use :param text: The text encoder to use """ super().__init__() self.vision_backbone: Sam3DualViTDetNeck = ( torch.compile(visual) if compile_visual else visual ) self.language_backbone = text self.scalp = scalp # allow running activation checkpointing on the entire vision and language backbones self.act_ckpt_whole_vision_backbone = act_ckpt_whole_vision_backbone self.act_ckpt_whole_language_backbone = act_ckpt_whole_language_backbone def forward( self, samples: torch.Tensor, captions: List[str], input_boxes: Optional[torch.Tensor] = None, additional_text: Optional[List[str]] = None, ): """Forward pass of the backbone combiner. :param samples: The input images :param captions: The input captions :param input_boxes: If the text contains place-holders for boxes, this parameter contains the tensor containing their spatial features :param additional_text: This can be used to encode some additional text (different from the captions) in the same forward of the backbone :return: Output dictionary with the following keys: - vision_features: The output of the vision backbone - language_features: The output of the language backbone - language_mask: The attention mask of the language backbone - vision_pos_enc: The positional encoding of the vision backbone - (optional) additional_text_features: The output of the language backbone for the additional text - (optional) additional_text_mask: The attention mask of the language backbone for the additional text """ output = self.forward_image(samples) device = output["vision_features"].device output.update(self.forward_text(captions, input_boxes, additional_text, device)) return output def forward_image(self, samples: torch.Tensor): return activation_ckpt_wrapper(self._forward_image_no_act_ckpt)( samples=samples, act_ckpt_enable=self.act_ckpt_whole_vision_backbone and self.training, ) def _forward_image_no_act_ckpt(self, samples): # Forward through backbone sam3_features, sam3_pos, sam2_features, sam2_pos = self.vision_backbone.forward( samples ) if self.scalp > 0: # Discard the lowest resolution features sam3_features, sam3_pos = ( sam3_features[: -self.scalp], sam3_pos[: -self.scalp], ) if sam2_features is not None and sam2_pos is not None: sam2_features, sam2_pos = ( sam2_features[: -self.scalp], sam2_pos[: -self.scalp], ) sam2_output = None if sam2_features is not None and sam2_pos is not None: sam2_src = sam2_features[-1] sam2_output = { "vision_features": sam2_src, "vision_pos_enc": sam2_pos, "backbone_fpn": sam2_features, } sam3_src = sam3_features[-1] output = { "vision_features": sam3_src, "vision_pos_enc": sam3_pos, "backbone_fpn": sam3_features, "sam2_backbone_out": sam2_output, } return output def forward_text( self, captions, input_boxes=None, additional_text=None, device="cuda" ): return activation_ckpt_wrapper(self._forward_text_no_ack_ckpt)( captions=captions, input_boxes=input_boxes, additional_text=additional_text, device=device, act_ckpt_enable=self.act_ckpt_whole_language_backbone and self.training, ) def _forward_text_no_ack_ckpt( self, captions, input_boxes=None, additional_text=None, device="cuda", ): output = {} # Forward through text_encoder text_to_encode = copy(captions) if additional_text is not None: # if there are additional_text, we piggy-back them into this forward. # They'll be used later for output alignment text_to_encode += additional_text sdpa_context = sdpa_kernel( [ SDPBackend.MATH, SDPBackend.EFFICIENT_ATTENTION, SDPBackend.FLASH_ATTENTION, ] ) with sdpa_context: text_attention_mask, text_memory, text_embeds = self.language_backbone( text_to_encode, input_boxes, device=device ) if additional_text is not None: output["additional_text_features"] = text_memory[:, -len(additional_text) :] output["additional_text_mask"] = text_attention_mask[ -len(additional_text) : ] text_memory = text_memory[:, : len(captions)] text_attention_mask = text_attention_mask[: len(captions)] text_embeds = text_embeds[:, : len(captions)] output["language_features"] = text_memory output["language_mask"] = text_attention_mask output["language_embeds"] = ( text_embeds # Text embeddings before forward to the encoder ) return output class SAM3VLBackboneTri(SAM3VLBackbone): """VL backbone with triple-head vision (sam3, interactive, propagation) + text encoder.""" def __init__(self, visual, text, compile_visual=False, scalp=0): super().__init__( visual=visual, text=text, compile_visual=compile_visual, scalp=scalp ) assert isinstance(self.vision_backbone, Sam3TriViTDetNeck), ( f"Expected vision backbone to be of type Sam3TriViTDetNeck, got {type(self.vision_backbone)}" ) def forward_image( self, samples, *, need_sam3_out: bool = True, need_interactive_out: bool = True, need_propagation_out: bool = True, ): return activation_ckpt_wrapper(self._forward_image_tri_no_act_ckpt)( samples=samples, need_sam3_out=need_sam3_out, need_interactive_out=need_interactive_out, need_propagation_out=need_propagation_out, act_ckpt_enable=self.act_ckpt_whole_vision_backbone and self.training, ) def _forward_image_tri_no_act_ckpt( self, samples, need_sam3_out=True, need_interactive_out=True, need_propagation_out=True, ): ( sam3_features, sam3_pos, interactive_features, interactive_pos, propagation_features, propagation_pos, ) = self.vision_backbone.forward( samples, need_sam3_out=need_sam3_out, need_interactive_out=need_interactive_out, need_propagation_out=need_propagation_out, ) if self.scalp > 0: sam3_features, sam3_pos = ( sam3_features[: -self.scalp], sam3_pos[: -self.scalp], ) interactive_features, interactive_pos = ( interactive_features[: -self.scalp], interactive_pos[: -self.scalp], ) propagation_features, propagation_pos = ( propagation_features[: -self.scalp], propagation_pos[: -self.scalp], ) output = {} if need_sam3_out: sam3_last = sam3_features[-1] output.update( { "vision_features": sam3_last.tensors, "vision_mask": sam3_last.mask, "vision_pos_enc": sam3_pos, "backbone_fpn": sam3_features, } ) if need_interactive_out: inte_last = interactive_features[-1] output["interactive"] = { "vision_features": inte_last.tensors, "vision_mask": inte_last.mask, "vision_pos_enc": interactive_pos, "backbone_fpn": interactive_features, } if need_propagation_out: prop_last = propagation_features[-1] output["sam2_backbone_out"] = { "vision_features": prop_last.tensors, "vision_mask": prop_last.mask, "vision_pos_enc": propagation_pos, "backbone_fpn": propagation_features, } return output class VisionOnly(nn.Module): def __init__( self, visual, n_features, forward_in_chunk_for_eval=False, eval_chunk_size=4, eval_cast_to_cpu=False, scalp=0, compile_mode: str = None, compile_extra_args: Optional[dict] = None, ): super().__init__() self.vision_backbone = visual self.should_compile = compile_mode is not None or compile_extra_args is not None self.compile_mode = compile_mode self.compile_extra_args = compile_extra_args or {} self.compiled = False self.n_features = n_features self.forward_in_chunk_for_eval = forward_in_chunk_for_eval self.eval_chunk_size = eval_chunk_size self.eval_cast_to_cpu = eval_cast_to_cpu self.scalp = scalp def _compile(self): if self.should_compile and not self.compiled: self.vision_backbone = torch.compile( self.vision_backbone, mode=self.compile_mode, **self.compile_extra_args ) self.compiled = True def forward_image(self, samples): self._compile() # Forward through backbone features, pos = self.vision_backbone(samples) if self.scalp > 0: features, pos = features[: -self.scalp], pos[: -self.scalp] elif self.scalp < 0: features.pop(self.scalp) pos.pop(self.scalp) src, mask = features[-1].decompose() output = { "vision_features": src, "vision_mask": mask, "vision_pos_enc": pos, "backbone_fpn": features, } return output def forward_text( self, captions, input_boxes=None, additional_text=None, device="cuda", ): bs = len(captions) output = { "language_features": torch.zeros((0, bs, self.n_features), device=device), "language_mask": torch.zeros((bs, 0), device=device), } return output class TriHeadVisionOnly(VisionOnly): def __init__( self, visual, n_features, forward_in_chunk_for_eval=False, eval_chunk_size=4, eval_cast_to_cpu=False, scalp=0, compile_mode: str = None, compile_extra_args: Optional[dict] = None, ): super().__init__( visual=visual, n_features=n_features, forward_in_chunk_for_eval=forward_in_chunk_for_eval, eval_chunk_size=eval_chunk_size, eval_cast_to_cpu=eval_cast_to_cpu, scalp=scalp, compile_mode=compile_mode, compile_extra_args=compile_extra_args, ) assert isinstance(self.vision_backbone, Sam3TriViTDetNeck), ( f"Expected vision backbone to be of type Sam3TriViTDetNeck, got {type(self.vision_backbone)}" ) def forward_image( self, samples, *, need_sam3_out: bool = True, need_interactive_out: bool = True, need_propagation_out: bool = True, ): self._compile() # Forward through backbone ( sam3_features, sam3_pos, interactive_features, interactive_pos, propagation_features, propagation_pos, ) = self.vision_backbone( samples, need_sam3_out=need_sam3_out, need_interactive_out=need_interactive_out, need_propagation_out=need_propagation_out, ) if self.scalp > 0: sam3_features, sam3_pos = ( sam3_features[: -self.scalp], sam3_pos[: -self.scalp], ) interactive_features, interactive_pos = ( interactive_features[: -self.scalp], interactive_pos[: -self.scalp], ) propagation_features, propagation_pos = ( propagation_features[: -self.scalp], propagation_pos[: -self.scalp], ) output = {} if need_sam3_out: sam3_last = sam3_features[-1] output.update( { "vision_features": sam3_last.tensors, "vision_mask": sam3_last.mask, "vision_pos_enc": sam3_pos, "backbone_fpn": sam3_features, } ) if need_interactive_out: inte_last = interactive_features[-1] output["interactive"] = { "vision_features": inte_last.tensors, "vision_mask": inte_last.mask, "vision_pos_enc": interactive_pos, "backbone_fpn": interactive_features, } if need_propagation_out: prop_last = propagation_features[-1] output["sam2_backbone_out"] = { "vision_features": prop_last.tensors, "vision_mask": prop_last.mask, "vision_pos_enc": propagation_pos, "backbone_fpn": propagation_features, } return output