| |
|
|
| |
|
|
| """Provides utility to combine a vision backbone with a language backbone.""" |
|
|
| from copy import copy |
| from typing import List, Optional |
|
|
| import torch |
| import torch.nn as nn |
| from torch.nn.attention import sdpa_kernel, SDPBackend |
|
|
| from .act_ckpt_utils import activation_ckpt_wrapper |
| from .data_misc import NestedTensor |
| from .necks import Sam3DualViTDetNeck, Sam3TriViTDetNeck |
|
|
|
|
| class SAM3VLBackbone(nn.Module): |
| """This backbone combines a vision backbone and a language backbone without fusion. |
| As such it is more of a convenience wrapper to handle the two backbones together. |
| |
| It adds support for activation checkpointing and compilation. |
| """ |
|
|
| def __init__( |
| self, |
| visual: Sam3DualViTDetNeck, |
| text, |
| compile_visual: bool = False, |
| act_ckpt_whole_vision_backbone: bool = False, |
| act_ckpt_whole_language_backbone: bool = False, |
| scalp=0, |
| ): |
| """Initialize the backbone combiner. |
| |
| :param visual: The vision backbone to use |
| :param text: The text encoder to use |
| """ |
| super().__init__() |
| self.vision_backbone: Sam3DualViTDetNeck = ( |
| torch.compile(visual) if compile_visual else visual |
| ) |
| self.language_backbone = text |
| self.scalp = scalp |
| |
| self.act_ckpt_whole_vision_backbone = act_ckpt_whole_vision_backbone |
| self.act_ckpt_whole_language_backbone = act_ckpt_whole_language_backbone |
|
|
| def forward( |
| self, |
| samples: torch.Tensor, |
| captions: List[str], |
| input_boxes: Optional[torch.Tensor] = None, |
| additional_text: Optional[List[str]] = None, |
| ): |
| """Forward pass of the backbone combiner. |
| |
| :param samples: The input images |
| :param captions: The input captions |
| :param input_boxes: If the text contains place-holders for boxes, this |
| parameter contains the tensor containing their spatial features |
| :param additional_text: This can be used to encode some additional text |
| (different from the captions) in the same forward of the backbone |
| :return: Output dictionary with the following keys: |
| - vision_features: The output of the vision backbone |
| - language_features: The output of the language backbone |
| - language_mask: The attention mask of the language backbone |
| - vision_pos_enc: The positional encoding of the vision backbone |
| - (optional) additional_text_features: The output of the language |
| backbone for the additional text |
| - (optional) additional_text_mask: The attention mask of the |
| language backbone for the additional text |
| """ |
| output = self.forward_image(samples) |
| device = output["vision_features"].device |
| output.update(self.forward_text(captions, input_boxes, additional_text, device)) |
| return output |
|
|
| def forward_image(self, samples: torch.Tensor): |
| return activation_ckpt_wrapper(self._forward_image_no_act_ckpt)( |
| samples=samples, |
| act_ckpt_enable=self.act_ckpt_whole_vision_backbone and self.training, |
| ) |
|
|
| def _forward_image_no_act_ckpt(self, samples): |
| |
| sam3_features, sam3_pos, sam2_features, sam2_pos = self.vision_backbone.forward( |
| samples |
| ) |
| if self.scalp > 0: |
| |
| sam3_features, sam3_pos = ( |
| sam3_features[: -self.scalp], |
| sam3_pos[: -self.scalp], |
| ) |
| if sam2_features is not None and sam2_pos is not None: |
| sam2_features, sam2_pos = ( |
| sam2_features[: -self.scalp], |
| sam2_pos[: -self.scalp], |
| ) |
|
|
| sam2_output = None |
|
|
| if sam2_features is not None and sam2_pos is not None: |
| sam2_src = sam2_features[-1] |
| sam2_output = { |
| "vision_features": sam2_src, |
| "vision_pos_enc": sam2_pos, |
| "backbone_fpn": sam2_features, |
| } |
|
|
| sam3_src = sam3_features[-1] |
| output = { |
| "vision_features": sam3_src, |
| "vision_pos_enc": sam3_pos, |
| "backbone_fpn": sam3_features, |
| "sam2_backbone_out": sam2_output, |
| } |
|
|
| return output |
|
|
| def forward_text( |
| self, captions, input_boxes=None, additional_text=None, device="cuda" |
| ): |
| return activation_ckpt_wrapper(self._forward_text_no_ack_ckpt)( |
| captions=captions, |
| input_boxes=input_boxes, |
| additional_text=additional_text, |
| device=device, |
| act_ckpt_enable=self.act_ckpt_whole_language_backbone and self.training, |
| ) |
|
|
| def _forward_text_no_ack_ckpt( |
| self, |
| captions, |
| input_boxes=None, |
| additional_text=None, |
| device="cuda", |
| ): |
| output = {} |
|
|
| |
| text_to_encode = copy(captions) |
| if additional_text is not None: |
| |
| |
| text_to_encode += additional_text |
|
|
| sdpa_context = sdpa_kernel( |
| [ |
| SDPBackend.MATH, |
| SDPBackend.EFFICIENT_ATTENTION, |
| SDPBackend.FLASH_ATTENTION, |
| ] |
| ) |
|
|
| with sdpa_context: |
| text_attention_mask, text_memory, text_embeds = self.language_backbone( |
| text_to_encode, input_boxes, device=device |
| ) |
|
|
| if additional_text is not None: |
| output["additional_text_features"] = text_memory[:, -len(additional_text) :] |
| output["additional_text_mask"] = text_attention_mask[ |
| -len(additional_text) : |
| ] |
|
|
| text_memory = text_memory[:, : len(captions)] |
| text_attention_mask = text_attention_mask[: len(captions)] |
| text_embeds = text_embeds[:, : len(captions)] |
| output["language_features"] = text_memory |
| output["language_mask"] = text_attention_mask |
| output["language_embeds"] = ( |
| text_embeds |
| ) |
|
|
| return output |
|
|
|
|
| class SAM3VLBackboneTri(SAM3VLBackbone): |
| """VL backbone with triple-head vision (sam3, interactive, propagation) + text encoder.""" |
|
|
| def __init__(self, visual, text, compile_visual=False, scalp=0): |
| super().__init__( |
| visual=visual, text=text, compile_visual=compile_visual, scalp=scalp |
| ) |
| assert isinstance(self.vision_backbone, Sam3TriViTDetNeck), ( |
| f"Expected vision backbone to be of type Sam3TriViTDetNeck, got {type(self.vision_backbone)}" |
| ) |
|
|
| def forward_image( |
| self, |
| samples, |
| *, |
| need_sam3_out: bool = True, |
| need_interactive_out: bool = True, |
| need_propagation_out: bool = True, |
| ): |
| return activation_ckpt_wrapper(self._forward_image_tri_no_act_ckpt)( |
| samples=samples, |
| need_sam3_out=need_sam3_out, |
| need_interactive_out=need_interactive_out, |
| need_propagation_out=need_propagation_out, |
| act_ckpt_enable=self.act_ckpt_whole_vision_backbone and self.training, |
| ) |
|
|
| def _forward_image_tri_no_act_ckpt( |
| self, |
| samples, |
| need_sam3_out=True, |
| need_interactive_out=True, |
| need_propagation_out=True, |
| ): |
| ( |
| sam3_features, |
| sam3_pos, |
| interactive_features, |
| interactive_pos, |
| propagation_features, |
| propagation_pos, |
| ) = self.vision_backbone.forward( |
| samples, |
| need_sam3_out=need_sam3_out, |
| need_interactive_out=need_interactive_out, |
| need_propagation_out=need_propagation_out, |
| ) |
| if self.scalp > 0: |
| sam3_features, sam3_pos = ( |
| sam3_features[: -self.scalp], |
| sam3_pos[: -self.scalp], |
| ) |
| interactive_features, interactive_pos = ( |
| interactive_features[: -self.scalp], |
| interactive_pos[: -self.scalp], |
| ) |
| propagation_features, propagation_pos = ( |
| propagation_features[: -self.scalp], |
| propagation_pos[: -self.scalp], |
| ) |
|
|
| output = {} |
| if need_sam3_out: |
| sam3_last = sam3_features[-1] |
| output.update( |
| { |
| "vision_features": sam3_last.tensors, |
| "vision_mask": sam3_last.mask, |
| "vision_pos_enc": sam3_pos, |
| "backbone_fpn": sam3_features, |
| } |
| ) |
| if need_interactive_out: |
| inte_last = interactive_features[-1] |
| output["interactive"] = { |
| "vision_features": inte_last.tensors, |
| "vision_mask": inte_last.mask, |
| "vision_pos_enc": interactive_pos, |
| "backbone_fpn": interactive_features, |
| } |
| if need_propagation_out: |
| prop_last = propagation_features[-1] |
| output["sam2_backbone_out"] = { |
| "vision_features": prop_last.tensors, |
| "vision_mask": prop_last.mask, |
| "vision_pos_enc": propagation_pos, |
| "backbone_fpn": propagation_features, |
| } |
| return output |
|
|
|
|
| class VisionOnly(nn.Module): |
| def __init__( |
| self, |
| visual, |
| n_features, |
| forward_in_chunk_for_eval=False, |
| eval_chunk_size=4, |
| eval_cast_to_cpu=False, |
| scalp=0, |
| compile_mode: str = None, |
| compile_extra_args: Optional[dict] = None, |
| ): |
| super().__init__() |
| self.vision_backbone = visual |
| self.should_compile = compile_mode is not None or compile_extra_args is not None |
| self.compile_mode = compile_mode |
| self.compile_extra_args = compile_extra_args or {} |
| self.compiled = False |
| self.n_features = n_features |
| self.forward_in_chunk_for_eval = forward_in_chunk_for_eval |
| self.eval_chunk_size = eval_chunk_size |
| self.eval_cast_to_cpu = eval_cast_to_cpu |
| self.scalp = scalp |
|
|
| def _compile(self): |
| if self.should_compile and not self.compiled: |
| self.vision_backbone = torch.compile( |
| self.vision_backbone, mode=self.compile_mode, **self.compile_extra_args |
| ) |
| self.compiled = True |
|
|
| def forward_image(self, samples): |
| self._compile() |
| |
| features, pos = self.vision_backbone(samples) |
| if self.scalp > 0: |
| features, pos = features[: -self.scalp], pos[: -self.scalp] |
| elif self.scalp < 0: |
| features.pop(self.scalp) |
| pos.pop(self.scalp) |
|
|
| src, mask = features[-1].decompose() |
| output = { |
| "vision_features": src, |
| "vision_mask": mask, |
| "vision_pos_enc": pos, |
| "backbone_fpn": features, |
| } |
| return output |
|
|
| def forward_text( |
| self, |
| captions, |
| input_boxes=None, |
| additional_text=None, |
| device="cuda", |
| ): |
| bs = len(captions) |
| output = { |
| "language_features": torch.zeros((0, bs, self.n_features), device=device), |
| "language_mask": torch.zeros((bs, 0), device=device), |
| } |
| return output |
|
|
|
|
| class TriHeadVisionOnly(VisionOnly): |
| def __init__( |
| self, |
| visual, |
| n_features, |
| forward_in_chunk_for_eval=False, |
| eval_chunk_size=4, |
| eval_cast_to_cpu=False, |
| scalp=0, |
| compile_mode: str = None, |
| compile_extra_args: Optional[dict] = None, |
| ): |
| super().__init__( |
| visual=visual, |
| n_features=n_features, |
| forward_in_chunk_for_eval=forward_in_chunk_for_eval, |
| eval_chunk_size=eval_chunk_size, |
| eval_cast_to_cpu=eval_cast_to_cpu, |
| scalp=scalp, |
| compile_mode=compile_mode, |
| compile_extra_args=compile_extra_args, |
| ) |
| assert isinstance(self.vision_backbone, Sam3TriViTDetNeck), ( |
| f"Expected vision backbone to be of type Sam3TriViTDetNeck, got {type(self.vision_backbone)}" |
| ) |
|
|
| def forward_image( |
| self, |
| samples, |
| *, |
| need_sam3_out: bool = True, |
| need_interactive_out: bool = True, |
| need_propagation_out: bool = True, |
| ): |
| self._compile() |
| |
| ( |
| sam3_features, |
| sam3_pos, |
| interactive_features, |
| interactive_pos, |
| propagation_features, |
| propagation_pos, |
| ) = self.vision_backbone( |
| samples, |
| need_sam3_out=need_sam3_out, |
| need_interactive_out=need_interactive_out, |
| need_propagation_out=need_propagation_out, |
| ) |
|
|
| if self.scalp > 0: |
| sam3_features, sam3_pos = ( |
| sam3_features[: -self.scalp], |
| sam3_pos[: -self.scalp], |
| ) |
| interactive_features, interactive_pos = ( |
| interactive_features[: -self.scalp], |
| interactive_pos[: -self.scalp], |
| ) |
| propagation_features, propagation_pos = ( |
| propagation_features[: -self.scalp], |
| propagation_pos[: -self.scalp], |
| ) |
|
|
| output = {} |
|
|
| if need_sam3_out: |
| sam3_last = sam3_features[-1] |
| output.update( |
| { |
| "vision_features": sam3_last.tensors, |
| "vision_mask": sam3_last.mask, |
| "vision_pos_enc": sam3_pos, |
| "backbone_fpn": sam3_features, |
| } |
| ) |
| if need_interactive_out: |
| inte_last = interactive_features[-1] |
| output["interactive"] = { |
| "vision_features": inte_last.tensors, |
| "vision_mask": inte_last.mask, |
| "vision_pos_enc": interactive_pos, |
| "backbone_fpn": interactive_features, |
| } |
| if need_propagation_out: |
| prop_last = propagation_features[-1] |
| output["sam2_backbone_out"] = { |
| "vision_features": prop_last.tensors, |
| "vision_mask": prop_last.mask, |
| "vision_pos_enc": propagation_pos, |
| "backbone_fpn": propagation_features, |
| } |
|
|
| return output |
|
|