import warnings from typing import Callable, Dict, List, Optional, Union from torch import nn, Tensor from torchvision.ops import misc as misc_nn_ops from torchvision.ops.feature_pyramid_network import ExtraFPNBlock, FeaturePyramidNetwork, LastLevelMaxPool from torchvision.models import resnet from torchvision.models._api import _get_enum_from_fn, WeightsEnum from torchvision.models._utils import handle_legacy_interface, IntermediateLayerGetter class BackboneWithFPN(nn.Module): def __init__( self, backbone: nn.Module, return_layers: Dict[str, str], in_channels_list: List[int], out_channels: int, extra_blocks: Optional[ExtraFPNBlock] = None, norm_layer: Optional[Callable[..., nn.Module]] = None, ) -> None: super().__init__() if extra_blocks is None: extra_blocks = LastLevelMaxPool() self.body = IntermediateLayerGetter(backbone, return_layers=return_layers) self.fpn = FeaturePyramidNetwork( in_channels_list=in_channels_list, out_channels=out_channels, extra_blocks=extra_blocks, norm_layer=norm_layer, ) self.out_channels = out_channels def forward(self, x: Tensor) -> Dict[str, Tensor]: x = self.body(x) x = self.fpn(x) return x @handle_legacy_interface( weights=( "pretrained", lambda kwargs: _get_enum_from_fn(resnet.__dict__[kwargs["backbone_name"]])["IMAGENET1K_V1"], ), ) def resnet_fpn_backbone( *, backbone_name: str, weights: Optional[WeightsEnum], norm_layer: Callable[..., nn.Module] = misc_nn_ops.FrozenBatchNorm2d, trainable_layers: int = 3, returned_layers: Optional[List[int]] = None, extra_blocks: Optional[ExtraFPNBlock] = None, ) -> BackboneWithFPN: backbone = resnet.__dict__[backbone_name](weights=weights, norm_layer=norm_layer) return _resnet_fpn_extractor(backbone, trainable_layers, returned_layers, extra_blocks) def _resnet_fpn_extractor( backbone: resnet.ResNet, trainable_layers: int, returned_layers: Optional[List[int]] = None, extra_blocks: Optional[ExtraFPNBlock] = None, norm_layer: Optional[Callable[..., nn.Module]] = None, ) -> BackboneWithFPN: # select layers that won't be frozen if trainable_layers < 0 or trainable_layers > 5: raise ValueError(f"Trainable layers should be in the range [0,5], got {trainable_layers}") layers_to_train = ["layer4", "layer3", "layer2", "layer1", "conv1"][:trainable_layers] if trainable_layers == 5: layers_to_train.append("bn1") for name, parameter in backbone.named_parameters(): if all([not name.startswith(layer) for layer in layers_to_train]): parameter.requires_grad_(False) if extra_blocks is None: extra_blocks = LastLevelMaxPool() if returned_layers is None: returned_layers = [1, 2, 3, 4] if min(returned_layers) <= 0 or max(returned_layers) >= 5: raise ValueError(f"Each returned layer should be in the range [1,4]. Got {returned_layers}") return_layers = {f"layer{k}": str(v) for v, k in enumerate(returned_layers)} in_channels_stage2 = backbone.inplanes // 8 in_channels_list = [in_channels_stage2 * 2 ** (i - 1) for i in returned_layers] out_channels = 256 return BackboneWithFPN( backbone, return_layers, in_channels_list, out_channels, extra_blocks=extra_blocks, norm_layer=norm_layer ) def _validate_trainable_layers( is_trained: bool, trainable_backbone_layers: Optional[int], max_value: int, default_value: int, ) -> int: # don't freeze any layers if pretrained model or backbone is not used if not is_trained: if trainable_backbone_layers is not None: warnings.warn( "Changing trainable_backbone_layers has no effect if " "neither pretrained nor pretrained_backbone have been set to True, " f"falling back to trainable_backbone_layers={max_value} so that all layers are trainable" ) trainable_backbone_layers = max_value # by default freeze first blocks if trainable_backbone_layers is None: trainable_backbone_layers = default_value if trainable_backbone_layers < 0 or trainable_backbone_layers > max_value: raise ValueError( f"Trainable backbone layers should be in the range [0,{max_value}], got {trainable_backbone_layers} " ) return trainable_backbone_layers