Shengxiao0709's picture
Upload 78 files
8f72b1f verified
from .mlp import MLP
from .positional_encoding import PositionalEncodingsFixed
import torch
from torch import nn
from torchvision.ops import roi_align
class OPEModule(nn.Module):
def __init__(
self,
num_iterative_steps: int,
emb_dim: int,
kernel_dim: int,
num_objects: int,
num_heads: int,
reduction: int,
layer_norm_eps: float,
mlp_factor: int,
norm_first: bool,
activation: nn.Module,
norm: bool,
zero_shot: bool,
):
super(OPEModule, self).__init__()
self.num_iterative_steps = num_iterative_steps
self.zero_shot = zero_shot
self.kernel_dim = kernel_dim
self.num_objects = num_objects
self.emb_dim = emb_dim
self.reduction = reduction
if num_iterative_steps > 0:
self.iterative_adaptation = IterativeAdaptationModule(
num_layers=num_iterative_steps, emb_dim=emb_dim, num_heads=num_heads,
dropout=0, layer_norm_eps=layer_norm_eps,
mlp_factor=mlp_factor, norm_first=norm_first,
activation=activation, norm=norm,
zero_shot=zero_shot
)
if not self.zero_shot:
self.shape_or_objectness = nn.Sequential(
nn.Linear(2, 64),
nn.ReLU(),
nn.Linear(64, emb_dim),
nn.ReLU(),
nn.Linear(emb_dim, self.kernel_dim**2 * emb_dim)
)
else:
self.shape_or_objectness = nn.Parameter(
torch.empty((self.num_objects, self.kernel_dim**2, emb_dim))
)
nn.init.normal_(self.shape_or_objectness)
self.pos_emb = PositionalEncodingsFixed(emb_dim)
def forward(self, f_e, pos_emb, bboxes):
bs, _, h, w = f_e.size()
# extract the shape features or objectness
if not self.zero_shot:
box_hw = torch.zeros(bboxes.size(0), bboxes.size(1), 2).to(bboxes.device)
box_hw[:, :, 0] = bboxes[:, :, 2] - bboxes[:, :, 0]
box_hw[:, :, 1] = bboxes[:, :, 3] - bboxes[:, :, 1]
shape_or_objectness = self.shape_or_objectness(box_hw).reshape(
bs, -1, self.kernel_dim ** 2, self.emb_dim
).flatten(1, 2).transpose(0, 1)
else:
shape_or_objectness = self.shape_or_objectness.expand(
bs, -1, -1, -1
).flatten(1, 2).transpose(0, 1)
# if not zero shot add appearance
if not self.zero_shot:
# reshape bboxes into the format suitable for roi_align
num_of_boxes = bboxes.size(1)
bboxes = torch.cat([
torch.arange(
bs, requires_grad=False
).to(bboxes.device).repeat_interleave(num_of_boxes).reshape(-1, 1),
bboxes.flatten(0, 1),
], dim=1)
appearance = roi_align(
f_e,
boxes=bboxes, output_size=self.kernel_dim,
spatial_scale=1.0 / self.reduction, aligned=True
).permute(0, 2, 3, 1).reshape(
bs, num_of_boxes * self.kernel_dim ** 2, -1
).transpose(0, 1)
else:
num_of_boxes = self.num_objects
appearance = None
query_pos_emb = self.pos_emb(
bs, self.kernel_dim, self.kernel_dim, f_e.device
).flatten(2).permute(2, 0, 1).repeat(num_of_boxes, 1, 1)
if self.num_iterative_steps > 0:
memory = f_e.flatten(2).permute(2, 0, 1)
all_prototypes = self.iterative_adaptation(
shape_or_objectness, appearance, memory, pos_emb, query_pos_emb
)
else:
if shape_or_objectness is not None and appearance is not None:
all_prototypes = (shape_or_objectness + appearance).unsqueeze(0)
else:
all_prototypes = (
shape_or_objectness if shape_or_objectness is not None else appearance
).unsqueeze(0)
return all_prototypes
class IterativeAdaptationModule(nn.Module):
def __init__(
self,
num_layers: int,
emb_dim: int,
num_heads: int,
dropout: float,
layer_norm_eps: float,
mlp_factor: int,
norm_first: bool,
activation: nn.Module,
norm: bool,
zero_shot: bool
):
super(IterativeAdaptationModule, self).__init__()
self.layers = nn.ModuleList([
IterativeAdaptationLayer(
emb_dim, num_heads, dropout, layer_norm_eps,
mlp_factor, norm_first, activation, zero_shot
) for i in range(num_layers)
])
self.norm = nn.LayerNorm(emb_dim, layer_norm_eps) if norm else nn.Identity()
def forward(
self, tgt, appearance, memory, pos_emb, query_pos_emb, tgt_mask=None, memory_mask=None,
tgt_key_padding_mask=None, memory_key_padding_mask=None
):
output = tgt
outputs = list()
for i, layer in enumerate(self.layers):
output = layer(
output, appearance, memory, pos_emb, query_pos_emb, tgt_mask, memory_mask,
tgt_key_padding_mask, memory_key_padding_mask
)
outputs.append(self.norm(output))
return torch.stack(outputs)
class IterativeAdaptationLayer(nn.Module):
def __init__(
self,
emb_dim: int,
num_heads: int,
dropout: float,
layer_norm_eps: float,
mlp_factor: int,
norm_first: bool,
activation: nn.Module,
zero_shot: bool
):
super(IterativeAdaptationLayer, self).__init__()
self.norm_first = norm_first
self.zero_shot = zero_shot
if not self.zero_shot:
self.norm1 = nn.LayerNorm(emb_dim, layer_norm_eps)
self.norm2 = nn.LayerNorm(emb_dim, layer_norm_eps)
self.norm3 = nn.LayerNorm(emb_dim, layer_norm_eps)
if not self.zero_shot:
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(dropout)
self.dropout3 = nn.Dropout(dropout)
if not self.zero_shot:
self.self_attn = nn.MultiheadAttention(emb_dim, num_heads, dropout)
self.enc_dec_attn = nn.MultiheadAttention(emb_dim, num_heads, dropout)
self.mlp = MLP(emb_dim, mlp_factor * emb_dim, dropout, activation)
def with_emb(self, x, emb):
return x if emb is None else x + emb
def forward(
self, tgt, appearance, memory, pos_emb, query_pos_emb, tgt_mask, memory_mask,
tgt_key_padding_mask, memory_key_padding_mask
):
if self.norm_first:
if not self.zero_shot:
tgt_norm = self.norm1(tgt)
tgt = tgt + self.dropout1(self.self_attn(
query=self.with_emb(tgt_norm, query_pos_emb),
key=self.with_emb(appearance, query_pos_emb),
value=appearance,
attn_mask=tgt_mask,
key_padding_mask=tgt_key_padding_mask
)[0])
tgt_norm = self.norm2(tgt)
tgt = tgt + self.dropout2(self.enc_dec_attn(
query=self.with_emb(tgt_norm, query_pos_emb),
key=memory+pos_emb,
value=memory,
attn_mask=memory_mask,
key_padding_mask=memory_key_padding_mask
)[0])
tgt_norm = self.norm3(tgt)
tgt = tgt + self.dropout3(self.mlp(tgt_norm))
else:
if not self.zero_shot:
tgt = self.norm1(tgt + self.dropout1(self.self_attn(
query=self.with_emb(tgt, query_pos_emb),
key=self.with_emb(appearance),
value=appearance,
attn_mask=tgt_mask,
key_padding_mask=tgt_key_padding_mask
)[0]))
tgt = self.norm2(tgt + self.dropout2(self.enc_dec_attn(
query=self.with_emb(tgt, query_pos_emb),
key=memory+pos_emb,
value=memory,
attn_mask=memory_mask,
key_padding_mask=memory_key_padding_mask
)[0]))
tgt = self.norm3(tgt + self.dropout3(self.mlp(tgt)))
return tgt