Spaces:
Sleeping
Sleeping
| from .mlp import MLP | |
| from .positional_encoding import PositionalEncodingsFixed | |
| import torch | |
| from torch import nn | |
| from torchvision.ops import roi_align | |
| class OPEModule(nn.Module): | |
| def __init__( | |
| self, | |
| num_iterative_steps: int, | |
| emb_dim: int, | |
| kernel_dim: int, | |
| num_objects: int, | |
| num_heads: int, | |
| reduction: int, | |
| layer_norm_eps: float, | |
| mlp_factor: int, | |
| norm_first: bool, | |
| activation: nn.Module, | |
| norm: bool, | |
| zero_shot: bool, | |
| ): | |
| super(OPEModule, self).__init__() | |
| self.num_iterative_steps = num_iterative_steps | |
| self.zero_shot = zero_shot | |
| self.kernel_dim = kernel_dim | |
| self.num_objects = num_objects | |
| self.emb_dim = emb_dim | |
| self.reduction = reduction | |
| if num_iterative_steps > 0: | |
| self.iterative_adaptation = IterativeAdaptationModule( | |
| num_layers=num_iterative_steps, emb_dim=emb_dim, num_heads=num_heads, | |
| dropout=0, layer_norm_eps=layer_norm_eps, | |
| mlp_factor=mlp_factor, norm_first=norm_first, | |
| activation=activation, norm=norm, | |
| zero_shot=zero_shot | |
| ) | |
| if not self.zero_shot: | |
| self.shape_or_objectness = nn.Sequential( | |
| nn.Linear(2, 64), | |
| nn.ReLU(), | |
| nn.Linear(64, emb_dim), | |
| nn.ReLU(), | |
| nn.Linear(emb_dim, self.kernel_dim**2 * emb_dim) | |
| ) | |
| else: | |
| self.shape_or_objectness = nn.Parameter( | |
| torch.empty((self.num_objects, self.kernel_dim**2, emb_dim)) | |
| ) | |
| nn.init.normal_(self.shape_or_objectness) | |
| self.pos_emb = PositionalEncodingsFixed(emb_dim) | |
| def forward(self, f_e, pos_emb, bboxes): | |
| bs, _, h, w = f_e.size() | |
| # extract the shape features or objectness | |
| if not self.zero_shot: | |
| box_hw = torch.zeros(bboxes.size(0), bboxes.size(1), 2).to(bboxes.device) | |
| box_hw[:, :, 0] = bboxes[:, :, 2] - bboxes[:, :, 0] | |
| box_hw[:, :, 1] = bboxes[:, :, 3] - bboxes[:, :, 1] | |
| shape_or_objectness = self.shape_or_objectness(box_hw).reshape( | |
| bs, -1, self.kernel_dim ** 2, self.emb_dim | |
| ).flatten(1, 2).transpose(0, 1) | |
| else: | |
| shape_or_objectness = self.shape_or_objectness.expand( | |
| bs, -1, -1, -1 | |
| ).flatten(1, 2).transpose(0, 1) | |
| # if not zero shot add appearance | |
| if not self.zero_shot: | |
| # reshape bboxes into the format suitable for roi_align | |
| num_of_boxes = bboxes.size(1) | |
| bboxes = torch.cat([ | |
| torch.arange( | |
| bs, requires_grad=False | |
| ).to(bboxes.device).repeat_interleave(num_of_boxes).reshape(-1, 1), | |
| bboxes.flatten(0, 1), | |
| ], dim=1) | |
| appearance = roi_align( | |
| f_e, | |
| boxes=bboxes, output_size=self.kernel_dim, | |
| spatial_scale=1.0 / self.reduction, aligned=True | |
| ).permute(0, 2, 3, 1).reshape( | |
| bs, num_of_boxes * self.kernel_dim ** 2, -1 | |
| ).transpose(0, 1) | |
| else: | |
| num_of_boxes = self.num_objects | |
| appearance = None | |
| query_pos_emb = self.pos_emb( | |
| bs, self.kernel_dim, self.kernel_dim, f_e.device | |
| ).flatten(2).permute(2, 0, 1).repeat(num_of_boxes, 1, 1) | |
| if self.num_iterative_steps > 0: | |
| memory = f_e.flatten(2).permute(2, 0, 1) | |
| all_prototypes = self.iterative_adaptation( | |
| shape_or_objectness, appearance, memory, pos_emb, query_pos_emb | |
| ) | |
| else: | |
| if shape_or_objectness is not None and appearance is not None: | |
| all_prototypes = (shape_or_objectness + appearance).unsqueeze(0) | |
| else: | |
| all_prototypes = ( | |
| shape_or_objectness if shape_or_objectness is not None else appearance | |
| ).unsqueeze(0) | |
| return all_prototypes | |
| class IterativeAdaptationModule(nn.Module): | |
| def __init__( | |
| self, | |
| num_layers: int, | |
| emb_dim: int, | |
| num_heads: int, | |
| dropout: float, | |
| layer_norm_eps: float, | |
| mlp_factor: int, | |
| norm_first: bool, | |
| activation: nn.Module, | |
| norm: bool, | |
| zero_shot: bool | |
| ): | |
| super(IterativeAdaptationModule, self).__init__() | |
| self.layers = nn.ModuleList([ | |
| IterativeAdaptationLayer( | |
| emb_dim, num_heads, dropout, layer_norm_eps, | |
| mlp_factor, norm_first, activation, zero_shot | |
| ) for i in range(num_layers) | |
| ]) | |
| self.norm = nn.LayerNorm(emb_dim, layer_norm_eps) if norm else nn.Identity() | |
| def forward( | |
| self, tgt, appearance, memory, pos_emb, query_pos_emb, tgt_mask=None, memory_mask=None, | |
| tgt_key_padding_mask=None, memory_key_padding_mask=None | |
| ): | |
| output = tgt | |
| outputs = list() | |
| for i, layer in enumerate(self.layers): | |
| output = layer( | |
| output, appearance, memory, pos_emb, query_pos_emb, tgt_mask, memory_mask, | |
| tgt_key_padding_mask, memory_key_padding_mask | |
| ) | |
| outputs.append(self.norm(output)) | |
| return torch.stack(outputs) | |
| class IterativeAdaptationLayer(nn.Module): | |
| def __init__( | |
| self, | |
| emb_dim: int, | |
| num_heads: int, | |
| dropout: float, | |
| layer_norm_eps: float, | |
| mlp_factor: int, | |
| norm_first: bool, | |
| activation: nn.Module, | |
| zero_shot: bool | |
| ): | |
| super(IterativeAdaptationLayer, self).__init__() | |
| self.norm_first = norm_first | |
| self.zero_shot = zero_shot | |
| if not self.zero_shot: | |
| self.norm1 = nn.LayerNorm(emb_dim, layer_norm_eps) | |
| self.norm2 = nn.LayerNorm(emb_dim, layer_norm_eps) | |
| self.norm3 = nn.LayerNorm(emb_dim, layer_norm_eps) | |
| if not self.zero_shot: | |
| self.dropout1 = nn.Dropout(dropout) | |
| self.dropout2 = nn.Dropout(dropout) | |
| self.dropout3 = nn.Dropout(dropout) | |
| if not self.zero_shot: | |
| self.self_attn = nn.MultiheadAttention(emb_dim, num_heads, dropout) | |
| self.enc_dec_attn = nn.MultiheadAttention(emb_dim, num_heads, dropout) | |
| self.mlp = MLP(emb_dim, mlp_factor * emb_dim, dropout, activation) | |
| def with_emb(self, x, emb): | |
| return x if emb is None else x + emb | |
| def forward( | |
| self, tgt, appearance, memory, pos_emb, query_pos_emb, tgt_mask, memory_mask, | |
| tgt_key_padding_mask, memory_key_padding_mask | |
| ): | |
| if self.norm_first: | |
| if not self.zero_shot: | |
| tgt_norm = self.norm1(tgt) | |
| tgt = tgt + self.dropout1(self.self_attn( | |
| query=self.with_emb(tgt_norm, query_pos_emb), | |
| key=self.with_emb(appearance, query_pos_emb), | |
| value=appearance, | |
| attn_mask=tgt_mask, | |
| key_padding_mask=tgt_key_padding_mask | |
| )[0]) | |
| tgt_norm = self.norm2(tgt) | |
| tgt = tgt + self.dropout2(self.enc_dec_attn( | |
| query=self.with_emb(tgt_norm, query_pos_emb), | |
| key=memory+pos_emb, | |
| value=memory, | |
| attn_mask=memory_mask, | |
| key_padding_mask=memory_key_padding_mask | |
| )[0]) | |
| tgt_norm = self.norm3(tgt) | |
| tgt = tgt + self.dropout3(self.mlp(tgt_norm)) | |
| else: | |
| if not self.zero_shot: | |
| tgt = self.norm1(tgt + self.dropout1(self.self_attn( | |
| query=self.with_emb(tgt, query_pos_emb), | |
| key=self.with_emb(appearance), | |
| value=appearance, | |
| attn_mask=tgt_mask, | |
| key_padding_mask=tgt_key_padding_mask | |
| )[0])) | |
| tgt = self.norm2(tgt + self.dropout2(self.enc_dec_attn( | |
| query=self.with_emb(tgt, query_pos_emb), | |
| key=memory+pos_emb, | |
| value=memory, | |
| attn_mask=memory_mask, | |
| key_padding_mask=memory_key_padding_mask | |
| )[0])) | |
| tgt = self.norm3(tgt + self.dropout3(self.mlp(tgt))) | |
| return tgt | |