|
|
|
|
|
|
|
|
"""Various utility models""" |
|
|
|
|
|
import copy |
|
|
import math |
|
|
import weakref |
|
|
from collections.abc import Iterator |
|
|
from contextlib import AbstractContextManager |
|
|
from enum import auto, Enum |
|
|
from typing import Dict, List, Optional, Union |
|
|
|
|
|
import numpy as np |
|
|
import torch |
|
|
import torch.nn.functional as F |
|
|
from torch import nn, Tensor |
|
|
from typing_extensions import override |
|
|
|
|
|
|
|
|
def inverse_sigmoid(x, eps=1e-3): |
|
|
""" |
|
|
The inverse function for sigmoid activation function. |
|
|
Note: It might face numberical issues with fp16 small eps. |
|
|
""" |
|
|
x = x.clamp(min=0, max=1) |
|
|
x1 = x.clamp(min=eps) |
|
|
x2 = (1 - x).clamp(min=eps) |
|
|
return torch.log(x1 / x2) |
|
|
|
|
|
|
|
|
class MultiheadAttentionWrapper(nn.MultiheadAttention): |
|
|
def forward(self, *args, **kwargs): |
|
|
kwargs["need_weights"] = False |
|
|
return super().forward(*args, **kwargs) |
|
|
|
|
|
|
|
|
class DotProductScoring(torch.nn.Module): |
|
|
def __init__( |
|
|
self, |
|
|
d_model, |
|
|
d_proj, |
|
|
prompt_mlp=None, |
|
|
clamp_logits=True, |
|
|
clamp_max_val=12.0, |
|
|
): |
|
|
super().__init__() |
|
|
self.d_proj = d_proj |
|
|
assert isinstance(prompt_mlp, torch.nn.Module) or prompt_mlp is None |
|
|
self.prompt_mlp = prompt_mlp |
|
|
self.prompt_proj = torch.nn.Linear(d_model, d_proj) |
|
|
self.hs_proj = torch.nn.Linear(d_model, d_proj) |
|
|
self.scale = float(1.0 / np.sqrt(d_proj)) |
|
|
self.clamp_logits = clamp_logits |
|
|
if self.clamp_logits: |
|
|
self.clamp_max_val = clamp_max_val |
|
|
|
|
|
def mean_pool_text(self, prompt, prompt_mask): |
|
|
|
|
|
is_valid = (~prompt_mask).float().permute(1, 0)[..., None] |
|
|
|
|
|
num_valid = torch.clamp(torch.sum(is_valid, dim=0), min=1.0) |
|
|
|
|
|
pooled_prompt = (prompt * is_valid).sum(dim=0) / num_valid |
|
|
return pooled_prompt |
|
|
|
|
|
def forward(self, hs, prompt, prompt_mask): |
|
|
|
|
|
|
|
|
|
|
|
assert hs.dim() == 4 and prompt.dim() == 3 and prompt_mask.dim() == 2 |
|
|
|
|
|
|
|
|
if self.prompt_mlp is not None: |
|
|
prompt = self.prompt_mlp(prompt) |
|
|
|
|
|
|
|
|
pooled_prompt = self.mean_pool_text(prompt, prompt_mask) |
|
|
|
|
|
|
|
|
proj_pooled_prompt = self.prompt_proj(pooled_prompt) |
|
|
proj_hs = self.hs_proj(hs) |
|
|
|
|
|
|
|
|
scores = torch.matmul(proj_hs, proj_pooled_prompt.unsqueeze(-1)) |
|
|
scores *= self.scale |
|
|
|
|
|
|
|
|
if self.clamp_logits: |
|
|
scores.clamp_(min=-self.clamp_max_val, max=self.clamp_max_val) |
|
|
|
|
|
return scores |
|
|
|
|
|
|
|
|
class LayerScale(nn.Module): |
|
|
def __init__( |
|
|
self, |
|
|
dim: int, |
|
|
init_values: Union[float, Tensor] = 1e-5, |
|
|
inplace: bool = False, |
|
|
) -> None: |
|
|
super().__init__() |
|
|
self.inplace = inplace |
|
|
self.gamma = nn.Parameter(init_values * torch.ones(dim)) |
|
|
|
|
|
def forward(self, x: Tensor) -> Tensor: |
|
|
return x.mul_(self.gamma) if self.inplace else x * self.gamma |
|
|
|
|
|
|
|
|
class LayerNorm2d(nn.Module): |
|
|
def __init__(self, num_channels: int, eps: float = 1e-6) -> None: |
|
|
super().__init__() |
|
|
self.weight = nn.Parameter(torch.ones(num_channels)) |
|
|
self.bias = nn.Parameter(torch.zeros(num_channels)) |
|
|
self.eps = eps |
|
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
|
u = x.mean(1, keepdim=True) |
|
|
s = (x - u).pow(2).mean(1, keepdim=True) |
|
|
x = (x - u) / torch.sqrt(s + self.eps) |
|
|
x = self.weight[:, None, None] * x + self.bias[:, None, None] |
|
|
return x |
|
|
|
|
|
|
|
|
class TransformerWrapper(nn.Module): |
|
|
def __init__( |
|
|
self, |
|
|
encoder, |
|
|
decoder, |
|
|
d_model: int, |
|
|
two_stage_type="none", |
|
|
pos_enc_at_input_dec=True, |
|
|
): |
|
|
super().__init__() |
|
|
|
|
|
self.encoder = encoder |
|
|
self.decoder = decoder |
|
|
self.num_queries = decoder.num_queries if decoder is not None else None |
|
|
self.pos_enc_at_input_dec = pos_enc_at_input_dec |
|
|
|
|
|
|
|
|
assert two_stage_type in ["none"], "unknown param {} of two_stage_type".format( |
|
|
two_stage_type |
|
|
) |
|
|
self.two_stage_type = two_stage_type |
|
|
|
|
|
self._reset_parameters() |
|
|
self.d_model = d_model |
|
|
|
|
|
def _reset_parameters(self): |
|
|
for n, p in self.named_parameters(): |
|
|
if p.dim() > 1: |
|
|
if ( |
|
|
"box_embed" not in n |
|
|
and "query_embed" not in n |
|
|
and "reference_points" not in n |
|
|
): |
|
|
nn.init.xavier_uniform_(p) |
|
|
|
|
|
|
|
|
class MLP(nn.Module): |
|
|
"""Very simple multi-layer perceptron (also called FFN)""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
input_dim: int, |
|
|
hidden_dim: int, |
|
|
output_dim: int, |
|
|
num_layers: int, |
|
|
dropout: float = 0.0, |
|
|
residual: bool = False, |
|
|
out_norm: Optional[nn.Module] = None, |
|
|
): |
|
|
super().__init__() |
|
|
self.num_layers = num_layers |
|
|
h = [hidden_dim] * (num_layers - 1) |
|
|
self.layers = nn.ModuleList( |
|
|
nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]) |
|
|
) |
|
|
self.drop = nn.Dropout(dropout) if dropout > 0 else nn.Identity() |
|
|
|
|
|
if residual and input_dim != output_dim: |
|
|
raise ValueError("residual is only supported if input_dim == output_dim") |
|
|
self.residual = residual |
|
|
|
|
|
assert isinstance(out_norm, nn.Module) or out_norm is None |
|
|
self.out_norm = out_norm or nn.Identity() |
|
|
|
|
|
def forward(self, x): |
|
|
orig_x = x |
|
|
for i, layer in enumerate(self.layers): |
|
|
x = self.drop(F.relu(layer(x))) if i < self.num_layers - 1 else layer(x) |
|
|
if self.residual: |
|
|
x = x + orig_x |
|
|
x = self.out_norm(x) |
|
|
return x |
|
|
|
|
|
|
|
|
def get_clones(module, N): |
|
|
return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) |
|
|
|
|
|
|
|
|
def get_clones_seq(module, N): |
|
|
return nn.Sequential(*[copy.deepcopy(module) for i in range(N)]) |
|
|
|
|
|
|
|
|
def get_activation_fn(activation): |
|
|
"""Return an activation function given a string""" |
|
|
if activation == "relu": |
|
|
return F.relu |
|
|
if activation == "gelu": |
|
|
return F.gelu |
|
|
if activation == "glu": |
|
|
return F.glu |
|
|
raise RuntimeError(f"activation should be relu/gelu, not {activation}.") |
|
|
|
|
|
|
|
|
def get_activation_module(activation): |
|
|
"""Return an activation function given a string""" |
|
|
if activation == "relu": |
|
|
return nn.ReLU |
|
|
if activation == "gelu": |
|
|
return nn.GELU |
|
|
if activation == "glu": |
|
|
return nn.GLU |
|
|
raise RuntimeError(f"activation should be relu/gelu, not {activation}.") |
|
|
|
|
|
|
|
|
def get_valid_ratio(mask): |
|
|
_, H, W = mask.shape |
|
|
valid_H = torch.sum(~mask[:, :, 0], 1) |
|
|
valid_W = torch.sum(~mask[:, 0, :], 1) |
|
|
valid_ratio_h = valid_H.float() / H |
|
|
valid_ratio_w = valid_W.float() / W |
|
|
valid_ratio = torch.stack([valid_ratio_w, valid_ratio_h], -1) |
|
|
return valid_ratio |
|
|
|
|
|
|
|
|
def gen_sineembed_for_position(pos_tensor, num_feats=256): |
|
|
assert num_feats % 2 == 0 |
|
|
num_feats = num_feats // 2 |
|
|
|
|
|
|
|
|
scale = 2 * math.pi |
|
|
dim_t = torch.arange(num_feats, dtype=torch.float32, device=pos_tensor.device) |
|
|
dim_t = 10000 ** (2 * (torch.div(dim_t, 2, rounding_mode="floor")) / num_feats) |
|
|
x_embed = pos_tensor[:, :, 0] * scale |
|
|
y_embed = pos_tensor[:, :, 1] * scale |
|
|
pos_x = x_embed[:, :, None] / dim_t |
|
|
pos_y = y_embed[:, :, None] / dim_t |
|
|
pos_x = torch.stack( |
|
|
(pos_x[:, :, 0::2].sin(), pos_x[:, :, 1::2].cos()), dim=3 |
|
|
).flatten(2) |
|
|
pos_y = torch.stack( |
|
|
(pos_y[:, :, 0::2].sin(), pos_y[:, :, 1::2].cos()), dim=3 |
|
|
).flatten(2) |
|
|
if pos_tensor.size(-1) == 2: |
|
|
pos = torch.cat((pos_y, pos_x), dim=2) |
|
|
elif pos_tensor.size(-1) == 4: |
|
|
w_embed = pos_tensor[:, :, 2] * scale |
|
|
pos_w = w_embed[:, :, None] / dim_t |
|
|
pos_w = torch.stack( |
|
|
(pos_w[:, :, 0::2].sin(), pos_w[:, :, 1::2].cos()), dim=3 |
|
|
).flatten(2) |
|
|
|
|
|
h_embed = pos_tensor[:, :, 3] * scale |
|
|
pos_h = h_embed[:, :, None] / dim_t |
|
|
pos_h = torch.stack( |
|
|
(pos_h[:, :, 0::2].sin(), pos_h[:, :, 1::2].cos()), dim=3 |
|
|
).flatten(2) |
|
|
|
|
|
pos = torch.cat((pos_y, pos_x, pos_w, pos_h), dim=2) |
|
|
else: |
|
|
raise ValueError("Unknown pos_tensor shape(-1):{}".format(pos_tensor.size(-1))) |
|
|
return pos |
|
|
|
|
|
|
|
|
class SAM3Output(list): |
|
|
""" |
|
|
A class representing the output of a SAM3 model. |
|
|
It provides an iterable interface that supports different iteration modes, including iterating over all steps per stage, |
|
|
last step per stage, and flattened output. |
|
|
Attributes: |
|
|
output: The output of the SAM3 model, represented as a list of lists. |
|
|
iter_mode: The current iteration mode. |
|
|
Example: |
|
|
>>> output = [[1, 2], [3, 4], [5, 6]] |
|
|
>>> sam3_output = SAM3Output(output) |
|
|
>>> for step in sam3_output: |
|
|
... print(step) |
|
|
[1, 2] |
|
|
[3, 4] |
|
|
[5, 6] |
|
|
>>> with SAM3Output.iteration_mode(SAM3Output.IterMode.LAST_STEP_PER_STAGE) as sam3_last_step_out: |
|
|
... for step in sam3_last_step_out: |
|
|
... print(step) |
|
|
[2] |
|
|
[4] |
|
|
[6] |
|
|
>>> with SAM3Output.iteration_mode(SAM3Output.IterMode.FLATTENED) as sam3_flattened_out: |
|
|
... for step in sam3_flattened_out: |
|
|
... print(step) |
|
|
1 |
|
|
2 |
|
|
3 |
|
|
4 |
|
|
5 |
|
|
6 |
|
|
""" |
|
|
|
|
|
class IterMode(Enum): |
|
|
|
|
|
ALL_STEPS_PER_STAGE = auto() |
|
|
LAST_STEP_PER_STAGE = auto() |
|
|
FLATTENED = auto() |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
output: List[List[Dict]] = None, |
|
|
iter_mode: IterMode = IterMode.ALL_STEPS_PER_STAGE, |
|
|
loss_stages: Optional[List[int]] = None, |
|
|
): |
|
|
if output is not None: |
|
|
assert ( |
|
|
isinstance(output, list) |
|
|
and len(output) > 0 |
|
|
and isinstance(output[0], list) |
|
|
), "Expected output to be a list of lists" |
|
|
self.output = output |
|
|
else: |
|
|
self.output = [] |
|
|
assert isinstance( |
|
|
iter_mode, SAM3Output.IterMode |
|
|
), f"iter_mode shoulf be of enum type 'SAM3Output.IterMode'. Got {type(iter_mode)}" |
|
|
|
|
|
self.iter_mode = iter_mode |
|
|
|
|
|
|
|
|
self_ref = weakref.ref(self) |
|
|
self._mode2iter = { |
|
|
SAM3Output.IterMode.ALL_STEPS_PER_STAGE: lambda: iter(self_ref().output), |
|
|
SAM3Output.IterMode.LAST_STEP_PER_STAGE: lambda: ( |
|
|
inner_list[-1] for inner_list in self_ref().output |
|
|
), |
|
|
SAM3Output.IterMode.FLATTENED: lambda: ( |
|
|
element for inner_list in self_ref().output for element in inner_list |
|
|
), |
|
|
} |
|
|
self.loss_stages = loss_stages |
|
|
|
|
|
@override |
|
|
def __iter__(self) -> Iterator: |
|
|
return self._mode2iter[self.iter_mode]() |
|
|
|
|
|
def __getitem__(self, index): |
|
|
""" |
|
|
Returns the item at the specified index. |
|
|
Args: |
|
|
index (int): The index of the item to return. |
|
|
Returns: |
|
|
list or element: The item at the specified index. |
|
|
""" |
|
|
assert isinstance(index, int), f"index should be an integer. Got {type(index)}" |
|
|
if self.iter_mode == SAM3Output.IterMode.ALL_STEPS_PER_STAGE: |
|
|
return self.output[index] |
|
|
elif self.iter_mode == SAM3Output.IterMode.LAST_STEP_PER_STAGE: |
|
|
return self.output[index][-1] |
|
|
elif self.iter_mode == SAM3Output.IterMode.FLATTENED: |
|
|
if index == -1: |
|
|
return self.self.output[-1][-1] |
|
|
else: |
|
|
flattened_output = sum(self.output, []) |
|
|
return flattened_output[index] |
|
|
|
|
|
class _IterationMode(AbstractContextManager): |
|
|
""" |
|
|
A context manager that temporarily changes the iteration mode of a SAM3Output object. |
|
|
This class is used internally by the SAM3Output.iteration_mode method. |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, model_output: "SAM3Output", iter_mode: "SAM3Output.IterMode" |
|
|
): |
|
|
self._model_output = model_output |
|
|
self._orig_iter_mode = model_output.iter_mode |
|
|
self._new_iter_mode = iter_mode |
|
|
|
|
|
@override |
|
|
def __enter__(self) -> "SAM3Output": |
|
|
self._model_output.iter_mode = self._new_iter_mode |
|
|
return self._model_output |
|
|
|
|
|
@override |
|
|
def __exit__(self, exc_type, exc_value, traceback): |
|
|
self._model_output.iter_mode = self._orig_iter_mode |
|
|
return super().__exit__(exc_type, exc_value, traceback) |
|
|
|
|
|
@staticmethod |
|
|
def iteration_mode( |
|
|
model_output: "SAM3Output", iter_mode: IterMode |
|
|
) -> _IterationMode: |
|
|
""" |
|
|
Returns a context manager that allows you to temporarily change the iteration mode of the SAM3Output object. |
|
|
Args: |
|
|
model_output: The SAM3Output object. |
|
|
iter_mode: The new iteration mode. |
|
|
Returns: |
|
|
SAM3Output._IterationMode: A context manager that changes the iteration mode of the SAM3Output object. |
|
|
""" |
|
|
return SAM3Output._IterationMode(model_output=model_output, iter_mode=iter_mode) |
|
|
|
|
|
def append(self, item: list): |
|
|
assert isinstance( |
|
|
item, list |
|
|
), f"Only list items are supported. Got {type(item)}" |
|
|
self.output.append(item) |
|
|
|
|
|
def __repr__(self): |
|
|
return self.output.__repr__() |
|
|
|
|
|
def __len__(self): |
|
|
if self.iter_mode in [ |
|
|
SAM3Output.IterMode.ALL_STEPS_PER_STAGE, |
|
|
SAM3Output.IterMode.LAST_STEP_PER_STAGE, |
|
|
]: |
|
|
return len(self.output) |
|
|
elif self.iter_mode == SAM3Output.IterMode.FLATTENED: |
|
|
flattened_output = sum(self.output, []) |
|
|
return len(flattened_output) |
|
|
|