Spaces:
Sleeping
Sleeping
| import warnings | |
| from typing import Callable, Dict, List, Optional, Union | |
| from torch import nn, Tensor | |
| from torchvision.ops import misc as misc_nn_ops | |
| from torchvision.ops.feature_pyramid_network import ExtraFPNBlock, FeaturePyramidNetwork, LastLevelMaxPool | |
| from torchvision.models import resnet | |
| from torchvision.models._api import _get_enum_from_fn, WeightsEnum | |
| from torchvision.models._utils import handle_legacy_interface, IntermediateLayerGetter | |
| class BackboneWithFPN(nn.Module): | |
| def __init__( | |
| self, | |
| backbone: nn.Module, | |
| return_layers: Dict[str, str], | |
| in_channels_list: List[int], | |
| out_channels: int, | |
| extra_blocks: Optional[ExtraFPNBlock] = None, | |
| norm_layer: Optional[Callable[..., nn.Module]] = None, | |
| ) -> None: | |
| super().__init__() | |
| if extra_blocks is None: | |
| extra_blocks = LastLevelMaxPool() | |
| self.body = IntermediateLayerGetter(backbone, return_layers=return_layers) | |
| self.fpn = FeaturePyramidNetwork( | |
| in_channels_list=in_channels_list, | |
| out_channels=out_channels, | |
| extra_blocks=extra_blocks, | |
| norm_layer=norm_layer, | |
| ) | |
| self.out_channels = out_channels | |
| def forward(self, x: Tensor) -> Dict[str, Tensor]: | |
| x = self.body(x) | |
| x = self.fpn(x) | |
| return x | |
| def resnet_fpn_backbone( | |
| *, | |
| backbone_name: str, | |
| weights: Optional[WeightsEnum], | |
| norm_layer: Callable[..., nn.Module] = misc_nn_ops.FrozenBatchNorm2d, | |
| trainable_layers: int = 3, | |
| returned_layers: Optional[List[int]] = None, | |
| extra_blocks: Optional[ExtraFPNBlock] = None, | |
| ) -> BackboneWithFPN: | |
| backbone = resnet.__dict__[backbone_name](weights=weights, norm_layer=norm_layer) | |
| return _resnet_fpn_extractor(backbone, trainable_layers, returned_layers, extra_blocks) | |
| def _resnet_fpn_extractor( | |
| backbone: resnet.ResNet, | |
| trainable_layers: int, | |
| returned_layers: Optional[List[int]] = None, | |
| extra_blocks: Optional[ExtraFPNBlock] = None, | |
| norm_layer: Optional[Callable[..., nn.Module]] = None, | |
| ) -> BackboneWithFPN: | |
| # select layers that won't be frozen | |
| if trainable_layers < 0 or trainable_layers > 5: | |
| raise ValueError(f"Trainable layers should be in the range [0,5], got {trainable_layers}") | |
| layers_to_train = ["layer4", "layer3", "layer2", "layer1", "conv1"][:trainable_layers] | |
| if trainable_layers == 5: | |
| layers_to_train.append("bn1") | |
| for name, parameter in backbone.named_parameters(): | |
| if all([not name.startswith(layer) for layer in layers_to_train]): | |
| parameter.requires_grad_(False) | |
| if extra_blocks is None: | |
| extra_blocks = LastLevelMaxPool() | |
| if returned_layers is None: | |
| returned_layers = [1, 2, 3, 4] | |
| if min(returned_layers) <= 0 or max(returned_layers) >= 5: | |
| raise ValueError(f"Each returned layer should be in the range [1,4]. Got {returned_layers}") | |
| return_layers = {f"layer{k}": str(v) for v, k in enumerate(returned_layers)} | |
| in_channels_stage2 = backbone.inplanes // 8 | |
| in_channels_list = [in_channels_stage2 * 2 ** (i - 1) for i in returned_layers] | |
| out_channels = 256 | |
| return BackboneWithFPN( | |
| backbone, return_layers, in_channels_list, out_channels, extra_blocks=extra_blocks, norm_layer=norm_layer | |
| ) | |
| def _validate_trainable_layers( | |
| is_trained: bool, | |
| trainable_backbone_layers: Optional[int], | |
| max_value: int, | |
| default_value: int, | |
| ) -> int: | |
| # don't freeze any layers if pretrained model or backbone is not used | |
| if not is_trained: | |
| if trainable_backbone_layers is not None: | |
| warnings.warn( | |
| "Changing trainable_backbone_layers has no effect if " | |
| "neither pretrained nor pretrained_backbone have been set to True, " | |
| f"falling back to trainable_backbone_layers={max_value} so that all layers are trainable" | |
| ) | |
| trainable_backbone_layers = max_value | |
| # by default freeze first blocks | |
| if trainable_backbone_layers is None: | |
| trainable_backbone_layers = default_value | |
| if trainable_backbone_layers < 0 or trainable_backbone_layers > max_value: | |
| raise ValueError( | |
| f"Trainable backbone layers should be in the range [0,{max_value}], got {trainable_backbone_layers} " | |
| ) | |
| return trainable_backbone_layers |