File size: 649 Bytes
beb7843 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 | import torch.nn as nn
from .cnn_2d import build_2d_cnn
class Backbone2D(nn.Module):
def __init__(self, cfg, pretrained=False):
super().__init__()
self.cfg = cfg
self.backbone, self.feat_dims = build_2d_cnn(cfg, pretrained)
def forward(self, x):
"""
Input:
x: (Tensor) -> [B, C, H, W]
Output:
y: (List) -> [
(Tensor) -> [B, C1, H1, W1],
(Tensor) -> [B, C2, H2, W2],
(Tensor) -> [B, C3, H3, W3]
]
"""
feat = self.backbone(x)
return feat
|