|
|
|
|
|
|
|
|
|
|
|
|
| from typing import Optional
|
|
|
| import torch
|
| from torch import nn, Tensor
|
|
|
| from sam2.modeling.sam.transformer import RoPEAttention
|
|
|
| from sam2.modeling.sam2_utils import get_activation_fn, get_clones
|
|
|
|
|
| class MemoryAttentionLayer(nn.Module):
|
|
|
| def __init__(
|
| self,
|
| activation: str,
|
| cross_attention: nn.Module,
|
| d_model: int,
|
| dim_feedforward: int,
|
| dropout: float,
|
| pos_enc_at_attn: bool,
|
| pos_enc_at_cross_attn_keys: bool,
|
| pos_enc_at_cross_attn_queries: bool,
|
| self_attention: nn.Module,
|
| ):
|
| super().__init__()
|
| self.d_model = d_model
|
| self.dim_feedforward = dim_feedforward
|
| self.dropout_value = dropout
|
| self.self_attn = self_attention
|
| self.cross_attn_image = cross_attention
|
|
|
|
|
| self.linear1 = nn.Linear(d_model, dim_feedforward)
|
| self.dropout = nn.Dropout(dropout)
|
| self.linear2 = nn.Linear(dim_feedforward, d_model)
|
|
|
| self.norm1 = nn.LayerNorm(d_model)
|
| self.norm2 = nn.LayerNorm(d_model)
|
| self.norm3 = nn.LayerNorm(d_model)
|
| self.dropout1 = nn.Dropout(dropout)
|
| self.dropout2 = nn.Dropout(dropout)
|
| self.dropout3 = nn.Dropout(dropout)
|
|
|
| self.activation_str = activation
|
| self.activation = get_activation_fn(activation)
|
|
|
|
|
| self.pos_enc_at_attn = pos_enc_at_attn
|
| self.pos_enc_at_cross_attn_queries = pos_enc_at_cross_attn_queries
|
| self.pos_enc_at_cross_attn_keys = pos_enc_at_cross_attn_keys
|
|
|
| def _forward_sa(self, tgt, query_pos):
|
|
|
| tgt2 = self.norm1(tgt)
|
| q = k = tgt2 + query_pos if self.pos_enc_at_attn else tgt2
|
| tgt2 = self.self_attn(q, k, v=tgt2)
|
| tgt = tgt + self.dropout1(tgt2)
|
| return tgt
|
|
|
| def _forward_ca(self, tgt, memory, query_pos, pos, num_k_exclude_rope=0):
|
| kwds = {}
|
| if num_k_exclude_rope > 0:
|
| assert isinstance(self.cross_attn_image, RoPEAttention)
|
| kwds = {"num_k_exclude_rope": num_k_exclude_rope}
|
|
|
|
|
| tgt2 = self.norm2(tgt)
|
| tgt2 = self.cross_attn_image(
|
| q=tgt2 + query_pos if self.pos_enc_at_cross_attn_queries else tgt2,
|
| k=memory + pos if self.pos_enc_at_cross_attn_keys else memory,
|
| v=memory,
|
| **kwds,
|
| )
|
| tgt = tgt + self.dropout2(tgt2)
|
| return tgt
|
|
|
| def forward(
|
| self,
|
| tgt,
|
| memory,
|
| pos: Optional[Tensor] = None,
|
| query_pos: Optional[Tensor] = None,
|
| num_k_exclude_rope: int = 0,
|
| ) -> torch.Tensor:
|
|
|
|
|
| tgt = self._forward_sa(tgt, query_pos)
|
| tgt = self._forward_ca(tgt, memory, query_pos, pos, num_k_exclude_rope)
|
|
|
| tgt2 = self.norm3(tgt)
|
| tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
|
| tgt = tgt + self.dropout3(tgt2)
|
| return tgt
|
|
|
|
|
| class MemoryAttention(nn.Module):
|
| def __init__(
|
| self,
|
| d_model: int,
|
| pos_enc_at_input: bool,
|
| layer: nn.Module,
|
| num_layers: int,
|
| batch_first: bool = True,
|
| ):
|
| super().__init__()
|
| self.d_model = d_model
|
| self.layers = get_clones(layer, num_layers)
|
| self.num_layers = num_layers
|
| self.norm = nn.LayerNorm(d_model)
|
| self.pos_enc_at_input = pos_enc_at_input
|
| self.batch_first = batch_first
|
|
|
| def forward(
|
| self,
|
| curr: torch.Tensor,
|
| memory: torch.Tensor,
|
| curr_pos: Optional[Tensor] = None,
|
| memory_pos: Optional[Tensor] = None,
|
| num_obj_ptr_tokens: int = 0,
|
| ):
|
| if isinstance(curr, list):
|
| assert isinstance(curr_pos, list)
|
| assert len(curr) == len(curr_pos) == 1
|
| curr, curr_pos = (
|
| curr[0],
|
| curr_pos[0],
|
| )
|
|
|
| assert (
|
| curr.shape[1] == memory.shape[1]
|
| ), "Batch size must be the same for curr and memory"
|
|
|
| output = curr
|
| if self.pos_enc_at_input and curr_pos is not None:
|
| output = output + 0.1 * curr_pos
|
|
|
| if self.batch_first:
|
|
|
| output = output.transpose(0, 1)
|
| curr_pos = curr_pos.transpose(0, 1)
|
| memory = memory.transpose(0, 1)
|
| memory_pos = memory_pos.transpose(0, 1)
|
|
|
| for layer in self.layers:
|
| kwds = {}
|
| if isinstance(layer.cross_attn_image, RoPEAttention):
|
| kwds = {"num_k_exclude_rope": num_obj_ptr_tokens}
|
|
|
| output = layer(
|
| tgt=output,
|
| memory=memory,
|
| pos=memory_pos,
|
| query_pos=curr_pos,
|
| **kwds,
|
| )
|
| normed_output = self.norm(output)
|
|
|
| if self.batch_first:
|
|
|
| normed_output = normed_output.transpose(0, 1)
|
| curr_pos = curr_pos.transpose(0, 1)
|
|
|
| return normed_output
|
|
|