ColabWan / preprocessing /sam3 /model /vl_combiner.py
1ripon1's picture
Upload folder using huggingface_hub
7344bef verified
Raw
History Blame Contribute Delete
14.9 kB
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
# pyre-unsafe
"""Provides utility to combine a vision backbone with a language backbone."""
from copy import copy
from typing import List, Optional
import torch
import torch.nn as nn
from torch.nn.attention import sdpa_kernel, SDPBackend
from .act_ckpt_utils import activation_ckpt_wrapper
from .data_misc import NestedTensor
from .necks import Sam3DualViTDetNeck, Sam3TriViTDetNeck
class SAM3VLBackbone(nn.Module):
"""This backbone combines a vision backbone and a language backbone without fusion.
As such it is more of a convenience wrapper to handle the two backbones together.
It adds support for activation checkpointing and compilation.
"""
def __init__(
self,
visual: Sam3DualViTDetNeck,
text,
compile_visual: bool = False,
act_ckpt_whole_vision_backbone: bool = False,
act_ckpt_whole_language_backbone: bool = False,
scalp=0,
):
"""Initialize the backbone combiner.
:param visual: The vision backbone to use
:param text: The text encoder to use
"""
super().__init__()
self.vision_backbone: Sam3DualViTDetNeck = (
torch.compile(visual) if compile_visual else visual
)
self.language_backbone = text
self.scalp = scalp
# allow running activation checkpointing on the entire vision and language backbones
self.act_ckpt_whole_vision_backbone = act_ckpt_whole_vision_backbone
self.act_ckpt_whole_language_backbone = act_ckpt_whole_language_backbone
def forward(
self,
samples: torch.Tensor,
captions: List[str],
input_boxes: Optional[torch.Tensor] = None,
additional_text: Optional[List[str]] = None,
):
"""Forward pass of the backbone combiner.
:param samples: The input images
:param captions: The input captions
:param input_boxes: If the text contains place-holders for boxes, this
parameter contains the tensor containing their spatial features
:param additional_text: This can be used to encode some additional text
(different from the captions) in the same forward of the backbone
:return: Output dictionary with the following keys:
- vision_features: The output of the vision backbone
- language_features: The output of the language backbone
- language_mask: The attention mask of the language backbone
- vision_pos_enc: The positional encoding of the vision backbone
- (optional) additional_text_features: The output of the language
backbone for the additional text
- (optional) additional_text_mask: The attention mask of the
language backbone for the additional text
"""
output = self.forward_image(samples)
device = output["vision_features"].device
output.update(self.forward_text(captions, input_boxes, additional_text, device))
return output
def forward_image(self, samples: torch.Tensor):
return activation_ckpt_wrapper(self._forward_image_no_act_ckpt)(
samples=samples,
act_ckpt_enable=self.act_ckpt_whole_vision_backbone and self.training,
)
def _forward_image_no_act_ckpt(self, samples):
# Forward through backbone
sam3_features, sam3_pos, sam2_features, sam2_pos = self.vision_backbone.forward(
samples
)
if self.scalp > 0:
# Discard the lowest resolution features
sam3_features, sam3_pos = (
sam3_features[: -self.scalp],
sam3_pos[: -self.scalp],
)
if sam2_features is not None and sam2_pos is not None:
sam2_features, sam2_pos = (
sam2_features[: -self.scalp],
sam2_pos[: -self.scalp],
)
sam2_output = None
if sam2_features is not None and sam2_pos is not None:
sam2_src = sam2_features[-1]
sam2_output = {
"vision_features": sam2_src,
"vision_pos_enc": sam2_pos,
"backbone_fpn": sam2_features,
}
sam3_src = sam3_features[-1]
output = {
"vision_features": sam3_src,
"vision_pos_enc": sam3_pos,
"backbone_fpn": sam3_features,
"sam2_backbone_out": sam2_output,
}
return output
def forward_text(
self, captions, input_boxes=None, additional_text=None, device="cuda"
):
return activation_ckpt_wrapper(self._forward_text_no_ack_ckpt)(
captions=captions,
input_boxes=input_boxes,
additional_text=additional_text,
device=device,
act_ckpt_enable=self.act_ckpt_whole_language_backbone and self.training,
)
def _forward_text_no_ack_ckpt(
self,
captions,
input_boxes=None,
additional_text=None,
device="cuda",
):
output = {}
# Forward through text_encoder
text_to_encode = copy(captions)
if additional_text is not None:
# if there are additional_text, we piggy-back them into this forward.
# They'll be used later for output alignment
text_to_encode += additional_text
sdpa_context = sdpa_kernel(
[
SDPBackend.MATH,
SDPBackend.EFFICIENT_ATTENTION,
SDPBackend.FLASH_ATTENTION,
]
)
with sdpa_context:
text_attention_mask, text_memory, text_embeds = self.language_backbone(
text_to_encode, input_boxes, device=device
)
if additional_text is not None:
output["additional_text_features"] = text_memory[:, -len(additional_text) :]
output["additional_text_mask"] = text_attention_mask[
-len(additional_text) :
]
text_memory = text_memory[:, : len(captions)]
text_attention_mask = text_attention_mask[: len(captions)]
text_embeds = text_embeds[:, : len(captions)]
output["language_features"] = text_memory
output["language_mask"] = text_attention_mask
output["language_embeds"] = (
text_embeds # Text embeddings before forward to the encoder
)
return output
class SAM3VLBackboneTri(SAM3VLBackbone):
"""VL backbone with triple-head vision (sam3, interactive, propagation) + text encoder."""
def __init__(self, visual, text, compile_visual=False, scalp=0):
super().__init__(
visual=visual, text=text, compile_visual=compile_visual, scalp=scalp
)
assert isinstance(self.vision_backbone, Sam3TriViTDetNeck), (
f"Expected vision backbone to be of type Sam3TriViTDetNeck, got {type(self.vision_backbone)}"
)
def forward_image(
self,
samples,
*,
need_sam3_out: bool = True,
need_interactive_out: bool = True,
need_propagation_out: bool = True,
):
return activation_ckpt_wrapper(self._forward_image_tri_no_act_ckpt)(
samples=samples,
need_sam3_out=need_sam3_out,
need_interactive_out=need_interactive_out,
need_propagation_out=need_propagation_out,
act_ckpt_enable=self.act_ckpt_whole_vision_backbone and self.training,
)
def _forward_image_tri_no_act_ckpt(
self,
samples,
need_sam3_out=True,
need_interactive_out=True,
need_propagation_out=True,
):
(
sam3_features,
sam3_pos,
interactive_features,
interactive_pos,
propagation_features,
propagation_pos,
) = self.vision_backbone.forward(
samples,
need_sam3_out=need_sam3_out,
need_interactive_out=need_interactive_out,
need_propagation_out=need_propagation_out,
)
if self.scalp > 0:
sam3_features, sam3_pos = (
sam3_features[: -self.scalp],
sam3_pos[: -self.scalp],
)
interactive_features, interactive_pos = (
interactive_features[: -self.scalp],
interactive_pos[: -self.scalp],
)
propagation_features, propagation_pos = (
propagation_features[: -self.scalp],
propagation_pos[: -self.scalp],
)
output = {}
if need_sam3_out:
sam3_last = sam3_features[-1]
output.update(
{
"vision_features": sam3_last.tensors,
"vision_mask": sam3_last.mask,
"vision_pos_enc": sam3_pos,
"backbone_fpn": sam3_features,
}
)
if need_interactive_out:
inte_last = interactive_features[-1]
output["interactive"] = {
"vision_features": inte_last.tensors,
"vision_mask": inte_last.mask,
"vision_pos_enc": interactive_pos,
"backbone_fpn": interactive_features,
}
if need_propagation_out:
prop_last = propagation_features[-1]
output["sam2_backbone_out"] = {
"vision_features": prop_last.tensors,
"vision_mask": prop_last.mask,
"vision_pos_enc": propagation_pos,
"backbone_fpn": propagation_features,
}
return output
class VisionOnly(nn.Module):
def __init__(
self,
visual,
n_features,
forward_in_chunk_for_eval=False,
eval_chunk_size=4,
eval_cast_to_cpu=False,
scalp=0,
compile_mode: str = None,
compile_extra_args: Optional[dict] = None,
):
super().__init__()
self.vision_backbone = visual
self.should_compile = compile_mode is not None or compile_extra_args is not None
self.compile_mode = compile_mode
self.compile_extra_args = compile_extra_args or {}
self.compiled = False
self.n_features = n_features
self.forward_in_chunk_for_eval = forward_in_chunk_for_eval
self.eval_chunk_size = eval_chunk_size
self.eval_cast_to_cpu = eval_cast_to_cpu
self.scalp = scalp
def _compile(self):
if self.should_compile and not self.compiled:
self.vision_backbone = torch.compile(
self.vision_backbone, mode=self.compile_mode, **self.compile_extra_args
)
self.compiled = True
def forward_image(self, samples):
self._compile()
# Forward through backbone
features, pos = self.vision_backbone(samples)
if self.scalp > 0:
features, pos = features[: -self.scalp], pos[: -self.scalp]
elif self.scalp < 0:
features.pop(self.scalp)
pos.pop(self.scalp)
src, mask = features[-1].decompose()
output = {
"vision_features": src,
"vision_mask": mask,
"vision_pos_enc": pos,
"backbone_fpn": features,
}
return output
def forward_text(
self,
captions,
input_boxes=None,
additional_text=None,
device="cuda",
):
bs = len(captions)
output = {
"language_features": torch.zeros((0, bs, self.n_features), device=device),
"language_mask": torch.zeros((bs, 0), device=device),
}
return output
class TriHeadVisionOnly(VisionOnly):
def __init__(
self,
visual,
n_features,
forward_in_chunk_for_eval=False,
eval_chunk_size=4,
eval_cast_to_cpu=False,
scalp=0,
compile_mode: str = None,
compile_extra_args: Optional[dict] = None,
):
super().__init__(
visual=visual,
n_features=n_features,
forward_in_chunk_for_eval=forward_in_chunk_for_eval,
eval_chunk_size=eval_chunk_size,
eval_cast_to_cpu=eval_cast_to_cpu,
scalp=scalp,
compile_mode=compile_mode,
compile_extra_args=compile_extra_args,
)
assert isinstance(self.vision_backbone, Sam3TriViTDetNeck), (
f"Expected vision backbone to be of type Sam3TriViTDetNeck, got {type(self.vision_backbone)}"
)
def forward_image(
self,
samples,
*,
need_sam3_out: bool = True,
need_interactive_out: bool = True,
need_propagation_out: bool = True,
):
self._compile()
# Forward through backbone
(
sam3_features,
sam3_pos,
interactive_features,
interactive_pos,
propagation_features,
propagation_pos,
) = self.vision_backbone(
samples,
need_sam3_out=need_sam3_out,
need_interactive_out=need_interactive_out,
need_propagation_out=need_propagation_out,
)
if self.scalp > 0:
sam3_features, sam3_pos = (
sam3_features[: -self.scalp],
sam3_pos[: -self.scalp],
)
interactive_features, interactive_pos = (
interactive_features[: -self.scalp],
interactive_pos[: -self.scalp],
)
propagation_features, propagation_pos = (
propagation_features[: -self.scalp],
propagation_pos[: -self.scalp],
)
output = {}
if need_sam3_out:
sam3_last = sam3_features[-1]
output.update(
{
"vision_features": sam3_last.tensors,
"vision_mask": sam3_last.mask,
"vision_pos_enc": sam3_pos,
"backbone_fpn": sam3_features,
}
)
if need_interactive_out:
inte_last = interactive_features[-1]
output["interactive"] = {
"vision_features": inte_last.tensors,
"vision_mask": inte_last.mask,
"vision_pos_enc": interactive_pos,
"backbone_fpn": interactive_features,
}
if need_propagation_out:
prop_last = propagation_features[-1]
output["sam2_backbone_out"] = {
"vision_features": prop_last.tensors,
"vision_mask": prop_last.mask,
"vision_pos_enc": propagation_pos,
"backbone_fpn": propagation_features,
}
return output