| |
| |
| import fvcore.nn.weight_init as weight_init |
| import torch |
| from torch import nn |
| from torch.nn import functional as F |
|
|
| from .position_encoding import PositionEmbeddingSine |
| from .transformer import Transformer |
|
|
|
|
| class StandardTransformerDecoder(nn.Module): |
| def __init__( |
| self, |
| in_channels, |
| num_classes, |
| mask_classification=True, |
| hidden_dim=256, |
| num_queries=100, |
| nheads=8, |
| dropout=0.0, |
| dim_feedforward=2048, |
| enc_layers=0, |
| dec_layers=10, |
| pre_norm=False, |
| deep_supervision=True, |
| mask_dim=256, |
| enforce_input_project=False |
| ): |
| super().__init__() |
| self.mask_classification = mask_classification |
| |
| N_steps = hidden_dim // 2 |
| self.pe_layer = PositionEmbeddingSine(N_steps, normalize=True) |
|
|
| transformer = Transformer( |
| d_model=hidden_dim, |
| dropout=dropout, |
| nhead=nheads, |
| dim_feedforward=dim_feedforward, |
| num_encoder_layers=enc_layers, |
| num_decoder_layers=dec_layers, |
| normalize_before=pre_norm, |
| return_intermediate_dec=deep_supervision, |
| ) |
|
|
| self.num_queries = num_queries |
| self.transformer = transformer |
| hidden_dim = transformer.d_model |
|
|
| self.query_embed = nn.Embedding(num_queries, hidden_dim) |
|
|
| if in_channels != hidden_dim or enforce_input_project: |
| self.input_proj = nn.Conv3d(in_channels, hidden_dim, kernel_size=1) |
| weight_init.c2_xavier_fill(self.input_proj) |
| else: |
| self.input_proj = nn.Sequential() |
| self.aux_loss = deep_supervision |
|
|
| |
| if self.mask_classification: |
| self.class_embed = nn.Linear(hidden_dim, num_classes + 1) |
| self.mask_embed = MLP(hidden_dim, hidden_dim, mask_dim, 3) |
|
|
| def forward(self, x, mask_features, mask=None): |
| if mask is not None: |
| mask = F.interpolate(mask[None].float(), size=x.shape[-2:]).to(torch.bool)[0] |
| pos = self.pe_layer(x, mask) |
|
|
| src = x |
| hs, memory = self.transformer(self.input_proj(src), mask, self.query_embed.weight, pos) |
|
|
| if self.mask_classification: |
| outputs_class = self.class_embed(hs) |
| out = {"pred_logits": outputs_class[-1]} |
| else: |
| out = {} |
|
|
| if self.aux_loss: |
| |
| mask_embed = self.mask_embed(hs) |
| outputs_seg_masks = torch.einsum("lbqc,bchw->lbqhw", mask_embed, mask_features) |
| out["pred_masks"] = outputs_seg_masks[-1] |
| out["aux_outputs"] = self._set_aux_loss( |
| outputs_class if self.mask_classification else None, outputs_seg_masks |
| ) |
| else: |
| |
| |
| mask_embed = self.mask_embed(hs[-1]) |
| outputs_seg_masks = torch.einsum("bqc,bchw->bqhw", mask_embed, mask_features) |
| out["pred_masks"] = outputs_seg_masks |
| return out |
|
|
| @torch.jit.unused |
| def _set_aux_loss(self, outputs_class, outputs_seg_masks): |
| |
| |
| |
| if self.mask_classification: |
| return [ |
| {"pred_logits": a, "pred_masks": b} |
| for a, b in zip(outputs_class[:-1], outputs_seg_masks[:-1]) |
| ] |
| else: |
| return [{"pred_masks": b} for b in outputs_seg_masks[:-1]] |
|
|
|
|
| class MLP(nn.Module): |
| """Very simple multi-layer perceptron (also called FFN)""" |
|
|
| def __init__(self, input_dim, hidden_dim, output_dim, num_layers): |
| super().__init__() |
| self.num_layers = num_layers |
| h = [hidden_dim] * (num_layers - 1) |
| self.layers = nn.ModuleList( |
| nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]) |
| ) |
|
|
| def forward(self, x): |
| for i, layer in enumerate(self.layers): |
| x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) |
| return x |
|
|