| import torch.nn as nn | |
| from .cnn_2d import build_2d_cnn | |
| class Backbone2D(nn.Module): | |
| def __init__(self, cfg, pretrained=False): | |
| super().__init__() | |
| self.cfg = cfg | |
| self.backbone, self.feat_dims = build_2d_cnn(cfg, pretrained) | |
| def forward(self, x): | |
| """ | |
| Input: | |
| x: (Tensor) -> [B, C, H, W] | |
| Output: | |
| y: (List) -> [ | |
| (Tensor) -> [B, C1, H1, W1], | |
| (Tensor) -> [B, C2, H2, W2], | |
| (Tensor) -> [B, C3, H3, W3] | |
| ] | |
| """ | |
| feat = self.backbone(x) | |
| return feat | |