import torch.nn as nn from .cnn_2d import build_2d_cnn class Backbone2D(nn.Module): def __init__(self, cfg, pretrained=False): super().__init__() self.cfg = cfg self.backbone, self.feat_dims = build_2d_cnn(cfg, pretrained) def forward(self, x): """ Input: x: (Tensor) -> [B, C, H, W] Output: y: (List) -> [ (Tensor) -> [B, C1, H1, W1], (Tensor) -> [B, C2, H2, W2], (Tensor) -> [B, C3, H3, W3] ] """ feat = self.backbone(x) return feat