| | |
| | |
| |
|
| | |
| | |
| |
|
| | import math |
| | from typing import Tuple |
| |
|
| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| |
|
| | from sam2.modeling.sam2_utils import DropPath, get_clones, LayerNorm2d |
| |
|
| |
|
| | class MaskDownSampler(nn.Module): |
| | """ |
| | Progressively downsample a mask by total_stride, each time by stride. |
| | Note that LayerNorm is applied per *token*, like in ViT. |
| | |
| | With each downsample (by a factor stride**2), channel capacity increases by the same factor. |
| | In the end, we linearly project to embed_dim channels. |
| | """ |
| |
|
| | def __init__( |
| | self, |
| | embed_dim=256, |
| | kernel_size=4, |
| | stride=4, |
| | padding=0, |
| | total_stride=16, |
| | activation=nn.GELU, |
| | ): |
| | super().__init__() |
| | num_layers = int(math.log2(total_stride) // math.log2(stride)) |
| | assert stride**num_layers == total_stride |
| | self.encoder = nn.Sequential() |
| | mask_in_chans, mask_out_chans = 1, 1 |
| | for _ in range(num_layers): |
| | mask_out_chans = mask_in_chans * (stride**2) |
| | self.encoder.append( |
| | nn.Conv2d( |
| | mask_in_chans, |
| | mask_out_chans, |
| | kernel_size=kernel_size, |
| | stride=stride, |
| | padding=padding, |
| | ) |
| | ) |
| | self.encoder.append(LayerNorm2d(mask_out_chans)) |
| | self.encoder.append(activation()) |
| | mask_in_chans = mask_out_chans |
| |
|
| | self.encoder.append(nn.Conv2d(mask_out_chans, embed_dim, kernel_size=1)) |
| |
|
| | def forward(self, x): |
| | return self.encoder(x) |
| |
|
| |
|
| | |
| | class CXBlock(nn.Module): |
| | r"""ConvNeXt Block. There are two equivalent implementations: |
| | (1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W) |
| | (2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back |
| | We use (2) as we find it slightly faster in PyTorch |
| | |
| | Args: |
| | dim (int): Number of input channels. |
| | drop_path (float): Stochastic depth rate. Default: 0.0 |
| | layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6. |
| | """ |
| |
|
| | def __init__( |
| | self, |
| | dim, |
| | kernel_size=7, |
| | padding=3, |
| | drop_path=0.0, |
| | layer_scale_init_value=1e-6, |
| | use_dwconv=True, |
| | ): |
| | super().__init__() |
| | self.dwconv = nn.Conv2d( |
| | dim, |
| | dim, |
| | kernel_size=kernel_size, |
| | padding=padding, |
| | groups=dim if use_dwconv else 1, |
| | ) |
| | self.norm = LayerNorm2d(dim, eps=1e-6) |
| | self.pwconv1 = nn.Linear( |
| | dim, 4 * dim |
| | ) |
| | self.act = nn.GELU() |
| | self.pwconv2 = nn.Linear(4 * dim, dim) |
| | self.gamma = ( |
| | nn.Parameter(layer_scale_init_value * torch.ones((dim)), requires_grad=True) |
| | if layer_scale_init_value > 0 |
| | else None |
| | ) |
| | self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() |
| |
|
| | def forward(self, x): |
| | input = x |
| | x = self.dwconv(x) |
| | x = self.norm(x) |
| | x = x.permute(0, 2, 3, 1) |
| | x = self.pwconv1(x) |
| | x = self.act(x) |
| | x = self.pwconv2(x) |
| | if self.gamma is not None: |
| | x = self.gamma * x |
| | x = x.permute(0, 3, 1, 2) |
| |
|
| | x = input + self.drop_path(x) |
| | return x |
| |
|
| |
|
| | class Fuser(nn.Module): |
| | def __init__(self, layer, num_layers, dim=None, input_projection=False): |
| | super().__init__() |
| | self.proj = nn.Identity() |
| | self.layers = get_clones(layer, num_layers) |
| |
|
| | if input_projection: |
| | assert dim is not None |
| | self.proj = nn.Conv2d(dim, dim, kernel_size=1) |
| |
|
| | def forward(self, x): |
| | |
| | x = self.proj(x) |
| | for layer in self.layers: |
| | x = layer(x) |
| | return x |
| |
|
| |
|
| | class MemoryEncoder(nn.Module): |
| | def __init__( |
| | self, |
| | out_dim, |
| | mask_downsampler, |
| | fuser, |
| | position_encoding, |
| | in_dim=256, |
| | ): |
| | super().__init__() |
| |
|
| | self.mask_downsampler = mask_downsampler |
| |
|
| | self.pix_feat_proj = nn.Conv2d(in_dim, in_dim, kernel_size=1) |
| | self.fuser = fuser |
| | import ipdb; ipdb.set_trace() |
| | self.position_encoding = position_encoding |
| | self.out_proj = nn.Identity() |
| | if out_dim != in_dim: |
| | self.out_proj = nn.Conv2d(in_dim, out_dim, kernel_size=1) |
| |
|
| | def forward( |
| | self, |
| | pix_feat: torch.Tensor, |
| | masks: torch.Tensor, |
| | skip_mask_sigmoid: bool = False, |
| | ) -> Tuple[torch.Tensor, torch.Tensor]: |
| | |
| | |
| | if not skip_mask_sigmoid: |
| | masks = F.sigmoid(masks) |
| | masks = self.mask_downsampler(masks.to(dtype = torch.bfloat16)) |
| |
|
| | |
| | |
| | pix_feat = pix_feat.to(masks.device) |
| |
|
| | x = self.pix_feat_proj(pix_feat.to(dtype = torch.bfloat16)) |
| | x = x + masks |
| | x = self.fuser(x) |
| | x = self.out_proj(x) |
| |
|
| | pos = self.position_encoding(x).to(x.dtype) |
| |
|
| | return {"vision_features": x, "vision_pos_enc": [pos]} |
| |
|