| |
| |
|
|
| |
| |
|
|
| from typing import List, Optional |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
|
|
| class ImageEncoder(nn.Module): |
|
|
| def __init__( |
| self, |
| trunk: nn.Module, |
| neck: nn.Module, |
| scalp: int = 0, |
| ): |
| super().__init__() |
| self.trunk = trunk |
| self.neck = neck |
| self.scalp = scalp |
| assert ( |
| self.trunk.channel_list == self.neck.backbone_channel_list |
| ), f"Channel dims of trunk and neck do not match. Trunk: {self.trunk.channel_list}, neck: {self.neck.backbone_channel_list}" |
|
|
| def forward(self, sample: torch.Tensor): |
| |
| |
|
|
| |
| features, pos, chunk_size = [], [], 16 |
| for base_idx in range(0, sample.size(0), chunk_size): |
| chunk_features, chunk_pos = self.neck(self.trunk(sample[base_idx:base_idx + chunk_size])) |
| features.append(chunk_features) |
| pos.append(chunk_pos) |
| features = [torch.cat([e[i] for e in features]) for i in range(len(features[0]))] |
| pos = [torch.cat([e[i] for e in pos]) for i in range(len(pos[0]))] |
| assert features[0].size(0) == pos[0].size(0) == sample.size(0) |
| |
|
|
| if self.scalp > 0: |
| |
| features, pos = features[:-self.scalp], pos[:-self.scalp] |
|
|
| src = features[-1] |
| output = { |
| "vision_features": src, |
| "vision_pos_enc": pos, |
| "backbone_fpn": features, |
| } |
| return output |
|
|
|
|
| class FpnNeck(nn.Module): |
| """ |
| A modified variant of Feature Pyramid Network (FPN) neck |
| (we remove output conv and also do bicubic interpolation similar to ViT |
| pos embed interpolation) |
| """ |
|
|
| def __init__( |
| self, |
| position_encoding: nn.Module, |
| d_model: int, |
| backbone_channel_list: List[int], |
| kernel_size: int = 1, |
| stride: int = 1, |
| padding: int = 0, |
| fpn_interp_model: str = "bilinear", |
| fuse_type: str = "sum", |
| fpn_top_down_levels: Optional[List[int]] = None, |
| ): |
| """Initialize the neck |
| :param trunk: the backbone |
| :param position_encoding: the positional encoding to use |
| :param d_model: the dimension of the model |
| :param neck_norm: the normalization to use |
| """ |
| super().__init__() |
| self.position_encoding = position_encoding |
| self.convs = nn.ModuleList() |
| self.backbone_channel_list = backbone_channel_list |
| self.d_model = d_model |
| for dim in backbone_channel_list: |
| current = nn.Sequential() |
| current.add_module( |
| "conv", |
| nn.Conv2d( |
| in_channels=dim, |
| out_channels=d_model, |
| kernel_size=kernel_size, |
| stride=stride, |
| padding=padding, |
| ), |
| ) |
|
|
| self.convs.append(current) |
| self.fpn_interp_model = fpn_interp_model |
| assert fuse_type in ["sum", "avg"] |
| self.fuse_type = fuse_type |
|
|
| |
| |
| |
| |
| if fpn_top_down_levels is None: |
| |
| fpn_top_down_levels = range(len(self.convs)) |
| self.fpn_top_down_levels = list(fpn_top_down_levels) |
|
|
| def forward(self, xs: List[torch.Tensor]): |
|
|
| out = [None] * len(self.convs) |
| pos = [None] * len(self.convs) |
| assert len(xs) == len(self.convs) |
| |
| |
| prev_features = None |
| |
| n = len(self.convs) - 1 |
| for i in range(n, -1, -1): |
| x = xs[i] |
| lateral_features = self.convs[n - i](x) |
| if i in self.fpn_top_down_levels and prev_features is not None: |
| top_down_features = F.interpolate( |
| prev_features.float(), |
| scale_factor=2.0, |
| mode=self.fpn_interp_model, |
| align_corners=(None if self.fpn_interp_model == "nearest" else False), |
| antialias=False, |
| ).to(prev_features.dtype) |
| prev_features = lateral_features + top_down_features |
| if self.fuse_type == "avg": |
| prev_features /= 2 |
| else: |
| prev_features = lateral_features |
| x_out = prev_features |
| out[i] = x_out |
| pos[i] = self.position_encoding(x_out).to(x_out.dtype) |
|
|
| return out, pos |
|
|