| |
|
|
| """Various utility models""" |
|
|
| import copy |
| import math |
| import weakref |
| from collections.abc import Iterator |
| from contextlib import AbstractContextManager |
| from enum import auto, Enum |
| from typing import Dict, List, Optional, Union |
|
|
| import numpy as np |
| import torch |
| import torch.nn.functional as F |
| from torch import nn, Tensor |
| from typing_extensions import override |
|
|
|
|
| def inverse_sigmoid(x, eps=1e-3): |
| """ |
| The inverse function for sigmoid activation function. |
| Note: It might face numberical issues with fp16 small eps. |
| """ |
| x = x.clamp(min=0, max=1) |
| x1 = x.clamp(min=eps) |
| x2 = (1 - x).clamp(min=eps) |
| return torch.log(x1 / x2) |
|
|
|
|
| class MultiheadAttentionWrapper(nn.MultiheadAttention): |
| def forward(self, *args, **kwargs): |
| kwargs["need_weights"] = False |
| return super().forward(*args, **kwargs) |
|
|
|
|
| class DotProductScoring(torch.nn.Module): |
| def __init__( |
| self, |
| d_model, |
| d_proj, |
| prompt_mlp=None, |
| clamp_logits=True, |
| clamp_max_val=12.0, |
| ): |
| super().__init__() |
| self.d_proj = d_proj |
| assert isinstance(prompt_mlp, torch.nn.Module) or prompt_mlp is None |
| self.prompt_mlp = prompt_mlp |
| self.prompt_proj = torch.nn.Linear(d_model, d_proj) |
| self.hs_proj = torch.nn.Linear(d_model, d_proj) |
| self.scale = float(1.0 / np.sqrt(d_proj)) |
| self.clamp_logits = clamp_logits |
| if self.clamp_logits: |
| self.clamp_max_val = clamp_max_val |
|
|
| def mean_pool_text(self, prompt, prompt_mask): |
| |
| is_valid = (~prompt_mask).float().permute(1, 0)[..., None] |
| |
| num_valid = torch.clamp(torch.sum(is_valid, dim=0), min=1.0) |
| |
| pooled_prompt = (prompt * is_valid).sum(dim=0) / num_valid |
| return pooled_prompt |
|
|
| def forward(self, hs, prompt, prompt_mask): |
| |
| |
| |
| assert hs.dim() == 4 and prompt.dim() == 3 and prompt_mask.dim() == 2 |
|
|
| |
| if self.prompt_mlp is not None: |
| prompt = self.prompt_mlp(prompt) |
|
|
| |
| pooled_prompt = self.mean_pool_text(prompt, prompt_mask) |
|
|
| |
| proj_pooled_prompt = self.prompt_proj(pooled_prompt) |
| proj_hs = self.hs_proj(hs) |
|
|
| |
| scores = torch.matmul(proj_hs, proj_pooled_prompt.unsqueeze(-1)) |
| scores *= self.scale |
|
|
| |
| if self.clamp_logits: |
| scores.clamp_(min=-self.clamp_max_val, max=self.clamp_max_val) |
|
|
| return scores |
|
|
|
|
| class LayerScale(nn.Module): |
| def __init__( |
| self, |
| dim: int, |
| init_values: Union[float, Tensor] = 1e-5, |
| inplace: bool = False, |
| ) -> None: |
| super().__init__() |
| self.inplace = inplace |
| self.gamma = nn.Parameter(init_values * torch.ones(dim)) |
|
|
| def forward(self, x: Tensor) -> Tensor: |
| return x.mul_(self.gamma) if self.inplace else x * self.gamma |
|
|
|
|
| class LayerNorm2d(nn.Module): |
| def __init__(self, num_channels: int, eps: float = 1e-6) -> None: |
| super().__init__() |
| self.weight = nn.Parameter(torch.ones(num_channels)) |
| self.bias = nn.Parameter(torch.zeros(num_channels)) |
| self.eps = eps |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| u = x.mean(1, keepdim=True) |
| s = (x - u).pow(2).mean(1, keepdim=True) |
| x = (x - u) / torch.sqrt(s + self.eps) |
| x = self.weight[:, None, None] * x + self.bias[:, None, None] |
| return x |
|
|
|
|
| class TransformerWrapper(nn.Module): |
| def __init__( |
| self, |
| encoder, |
| decoder, |
| d_model: int, |
| two_stage_type="none", |
| pos_enc_at_input_dec=True, |
| ): |
| super().__init__() |
|
|
| self.encoder = encoder |
| self.decoder = decoder |
| self.num_queries = decoder.num_queries if decoder is not None else None |
| self.pos_enc_at_input_dec = pos_enc_at_input_dec |
|
|
| |
| assert two_stage_type in ["none"], "unknown param {} of two_stage_type".format( |
| two_stage_type |
| ) |
| self.two_stage_type = two_stage_type |
|
|
| self._reset_parameters() |
| self.d_model = d_model |
|
|
| def _reset_parameters(self): |
| for n, p in self.named_parameters(): |
| if p.dim() > 1: |
| if ( |
| "box_embed" not in n |
| and "query_embed" not in n |
| and "reference_points" not in n |
| ): |
| nn.init.xavier_uniform_(p) |
|
|
|
|
| class MLP(nn.Module): |
| """Very simple multi-layer perceptron (also called FFN)""" |
|
|
| def __init__( |
| self, |
| input_dim: int, |
| hidden_dim: int, |
| output_dim: int, |
| num_layers: int, |
| dropout: float = 0.0, |
| residual: bool = False, |
| out_norm: Optional[nn.Module] = None, |
| ): |
| super().__init__() |
| self.num_layers = num_layers |
| h = [hidden_dim] * (num_layers - 1) |
| self.layers = nn.ModuleList( |
| nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]) |
| ) |
| self.drop = nn.Dropout(dropout) if dropout > 0 else nn.Identity() |
| |
| if residual and input_dim != output_dim: |
| raise ValueError("residual is only supported if input_dim == output_dim") |
| self.residual = residual |
| |
| assert isinstance(out_norm, nn.Module) or out_norm is None |
| self.out_norm = out_norm or nn.Identity() |
|
|
| def forward(self, x): |
| orig_x = x |
| for i, layer in enumerate(self.layers): |
| x = self.drop(F.relu(layer(x))) if i < self.num_layers - 1 else layer(x) |
| if self.residual: |
| x = x + orig_x |
| x = self.out_norm(x) |
| return x |
|
|
|
|
| def get_clones(module, N): |
| return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) |
|
|
|
|
| def get_clones_seq(module, N): |
| return nn.Sequential(*[copy.deepcopy(module) for i in range(N)]) |
|
|
|
|
| def get_activation_fn(activation): |
| """Return an activation function given a string""" |
| if activation == "relu": |
| return F.relu |
| if activation == "gelu": |
| return F.gelu |
| if activation == "glu": |
| return F.glu |
| raise RuntimeError(f"activation should be relu/gelu, not {activation}.") |
|
|
|
|
| def get_activation_module(activation): |
| """Return an activation function given a string""" |
| if activation == "relu": |
| return nn.ReLU |
| if activation == "gelu": |
| return nn.GELU |
| if activation == "glu": |
| return nn.GLU |
| raise RuntimeError(f"activation should be relu/gelu, not {activation}.") |
|
|
|
|
| def get_valid_ratio(mask): |
| _, H, W = mask.shape |
| valid_H = torch.sum(~mask[:, :, 0], 1) |
| valid_W = torch.sum(~mask[:, 0, :], 1) |
| valid_ratio_h = valid_H.float() / H |
| valid_ratio_w = valid_W.float() / W |
| valid_ratio = torch.stack([valid_ratio_w, valid_ratio_h], -1) |
| return valid_ratio |
|
|
|
|
| def gen_sineembed_for_position(pos_tensor, num_feats=256): |
| assert num_feats % 2 == 0 |
| num_feats = num_feats // 2 |
| |
| |
| scale = 2 * math.pi |
| dim_t = torch.arange(num_feats, dtype=torch.float32, device=pos_tensor.device) |
| dim_t = 10000 ** (2 * (torch.div(dim_t, 2, rounding_mode="floor")) / num_feats) |
| x_embed = pos_tensor[:, :, 0] * scale |
| y_embed = pos_tensor[:, :, 1] * scale |
| pos_x = x_embed[:, :, None] / dim_t |
| pos_y = y_embed[:, :, None] / dim_t |
| pos_x = torch.stack( |
| (pos_x[:, :, 0::2].sin(), pos_x[:, :, 1::2].cos()), dim=3 |
| ).flatten(2) |
| pos_y = torch.stack( |
| (pos_y[:, :, 0::2].sin(), pos_y[:, :, 1::2].cos()), dim=3 |
| ).flatten(2) |
| if pos_tensor.size(-1) == 2: |
| pos = torch.cat((pos_y, pos_x), dim=2) |
| elif pos_tensor.size(-1) == 4: |
| w_embed = pos_tensor[:, :, 2] * scale |
| pos_w = w_embed[:, :, None] / dim_t |
| pos_w = torch.stack( |
| (pos_w[:, :, 0::2].sin(), pos_w[:, :, 1::2].cos()), dim=3 |
| ).flatten(2) |
|
|
| h_embed = pos_tensor[:, :, 3] * scale |
| pos_h = h_embed[:, :, None] / dim_t |
| pos_h = torch.stack( |
| (pos_h[:, :, 0::2].sin(), pos_h[:, :, 1::2].cos()), dim=3 |
| ).flatten(2) |
|
|
| pos = torch.cat((pos_y, pos_x, pos_w, pos_h), dim=2) |
| else: |
| raise ValueError("Unknown pos_tensor shape(-1):{}".format(pos_tensor.size(-1))) |
| return pos |
|
|
|
|
| class SAM3Output(list): |
| """ |
| A class representing the output of a SAM3 model. |
| It provides an iterable interface that supports different iteration modes, including iterating over all steps per stage, |
| last step per stage, and flattened output. |
| Attributes: |
| output: The output of the SAM3 model, represented as a list of lists. |
| iter_mode: The current iteration mode. |
| Example: |
| >>> output = [[1, 2], [3, 4], [5, 6]] |
| >>> sam3_output = SAM3Output(output) |
| >>> for step in sam3_output: |
| ... print(step) |
| [1, 2] |
| [3, 4] |
| [5, 6] |
| >>> with SAM3Output.iteration_mode(SAM3Output.IterMode.LAST_STEP_PER_STAGE) as sam3_last_step_out: |
| ... for step in sam3_last_step_out: |
| ... print(step) |
| [2] |
| [4] |
| [6] |
| >>> with SAM3Output.iteration_mode(SAM3Output.IterMode.FLATTENED) as sam3_flattened_out: |
| ... for step in sam3_flattened_out: |
| ... print(step) |
| 1 |
| 2 |
| 3 |
| 4 |
| 5 |
| 6 |
| """ |
|
|
| class IterMode(Enum): |
| |
| ALL_STEPS_PER_STAGE = auto() |
| LAST_STEP_PER_STAGE = auto() |
| FLATTENED = auto() |
|
|
| def __init__( |
| self, |
| output: List[List[Dict]] = None, |
| iter_mode: IterMode = IterMode.ALL_STEPS_PER_STAGE, |
| loss_stages: Optional[List[int]] = None, |
| ): |
| if output is not None: |
| assert ( |
| isinstance(output, list) |
| and len(output) > 0 |
| and isinstance(output[0], list) |
| ), "Expected output to be a list of lists" |
| self.output = output |
| else: |
| self.output = [] |
| assert isinstance( |
| iter_mode, SAM3Output.IterMode |
| ), f"iter_mode shoulf be of enum type 'SAM3Output.IterMode'. Got {type(iter_mode)}" |
|
|
| self.iter_mode = iter_mode |
| |
| |
| self_ref = weakref.ref(self) |
| self._mode2iter = { |
| SAM3Output.IterMode.ALL_STEPS_PER_STAGE: lambda: iter(self_ref().output), |
| SAM3Output.IterMode.LAST_STEP_PER_STAGE: lambda: ( |
| inner_list[-1] for inner_list in self_ref().output |
| ), |
| SAM3Output.IterMode.FLATTENED: lambda: ( |
| element for inner_list in self_ref().output for element in inner_list |
| ), |
| } |
| self.loss_stages = loss_stages |
|
|
| @override |
| def __iter__(self) -> Iterator: |
| return self._mode2iter[self.iter_mode]() |
|
|
| def __getitem__(self, index): |
| """ |
| Returns the item at the specified index. |
| Args: |
| index (int): The index of the item to return. |
| Returns: |
| list or element: The item at the specified index. |
| """ |
| assert isinstance(index, int), f"index should be an integer. Got {type(index)}" |
| if self.iter_mode == SAM3Output.IterMode.ALL_STEPS_PER_STAGE: |
| return self.output[index] |
| elif self.iter_mode == SAM3Output.IterMode.LAST_STEP_PER_STAGE: |
| return self.output[index][-1] |
| elif self.iter_mode == SAM3Output.IterMode.FLATTENED: |
| if index == -1: |
| return self.self.output[-1][-1] |
| else: |
| flattened_output = sum(self.output, []) |
| return flattened_output[index] |
|
|
| class _IterationMode(AbstractContextManager): |
| """ |
| A context manager that temporarily changes the iteration mode of a SAM3Output object. |
| This class is used internally by the SAM3Output.iteration_mode method. |
| """ |
|
|
| def __init__( |
| self, model_output: "SAM3Output", iter_mode: "SAM3Output.IterMode" |
| ): |
| self._model_output = model_output |
| self._orig_iter_mode = model_output.iter_mode |
| self._new_iter_mode = iter_mode |
|
|
| @override |
| def __enter__(self) -> "SAM3Output": |
| self._model_output.iter_mode = self._new_iter_mode |
| return self._model_output |
|
|
| @override |
| def __exit__(self, exc_type, exc_value, traceback): |
| self._model_output.iter_mode = self._orig_iter_mode |
| return super().__exit__(exc_type, exc_value, traceback) |
|
|
| @staticmethod |
| def iteration_mode( |
| model_output: "SAM3Output", iter_mode: IterMode |
| ) -> _IterationMode: |
| """ |
| Returns a context manager that allows you to temporarily change the iteration mode of the SAM3Output object. |
| Args: |
| model_output: The SAM3Output object. |
| iter_mode: The new iteration mode. |
| Returns: |
| SAM3Output._IterationMode: A context manager that changes the iteration mode of the SAM3Output object. |
| """ |
| return SAM3Output._IterationMode(model_output=model_output, iter_mode=iter_mode) |
|
|
| def append(self, item: list): |
| assert isinstance( |
| item, list |
| ), f"Only list items are supported. Got {type(item)}" |
| self.output.append(item) |
|
|
| def __repr__(self): |
| return self.output.__repr__() |
|
|
| def __len__(self): |
| if self.iter_mode in [ |
| SAM3Output.IterMode.ALL_STEPS_PER_STAGE, |
| SAM3Output.IterMode.LAST_STEP_PER_STAGE, |
| ]: |
| return len(self.output) |
| elif self.iter_mode == SAM3Output.IterMode.FLATTENED: |
| flattened_output = sum(self.output, []) |
| return len(flattened_output) |
|
|