# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved """ Backbone modules. """ from collections import OrderedDict import os import torch import torch.nn.functional as F import torchvision from torch import nn from torchvision.models._utils import IntermediateLayerGetter from typing import Dict, List from util.misc import NestedTensor, is_main_process, clean_state_dict from .position_encoding import build_position_encoding from segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor from .swin_transformer import build_swin_transformer class FrozenBatchNorm2d(torch.nn.Module): """ BatchNorm2d where the batch statistics and the affine parameters are fixed. Copy-paste from torchvision.misc.ops with added eps before rqsrt, without which any other models than torchvision.models.resnet[18,34,50,101] produce nans. """ def __init__(self, n): super(FrozenBatchNorm2d, self).__init__() self.register_buffer("weight", torch.ones(n)) self.register_buffer("bias", torch.zeros(n)) self.register_buffer("running_mean", torch.zeros(n)) self.register_buffer("running_var", torch.ones(n)) def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs): num_batches_tracked_key = prefix + 'num_batches_tracked' if num_batches_tracked_key in state_dict: del state_dict[num_batches_tracked_key] super(FrozenBatchNorm2d, self)._load_from_state_dict( state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) def forward(self, x): # move reshapes to the beginning # to make it fuser-friendly w = self.weight.reshape(1, -1, 1, 1) b = self.bias.reshape(1, -1, 1, 1) rv = self.running_var.reshape(1, -1, 1, 1) rm = self.running_mean.reshape(1, -1, 1, 1) eps = 1e-5 scale = w * (rv + eps).rsqrt() bias = b - rm * scale return x * scale + bias class BackboneBase(nn.Module): def __init__(self, backbone: nn.Module, train_backbone: bool, num_channels: int, return_interm_layers: bool): super().__init__() for name, parameter in backbone.named_parameters(): if not train_backbone or 'layer2' not in name and 'layer3' not in name and 'layer4' not in name: parameter.requires_grad_(False) # if return_interm_layers: # return_layers = {"layer1": "0", "layer2": "1", "layer3": "2", "layer4": "3"} # else: # return_layers = {'layer4': "0"} # swin return_interm_indices = [1, 2, 3] return_layers = {} for idx, layer_index in enumerate(return_interm_indices): return_layers.update({"layer{}".format(5 - len(return_interm_indices) + idx): "{}".format(layer_index)}) self.body = IntermediateLayerGetter(backbone, return_layers=return_layers) self.num_channels = num_channels def forward(self, tensor_list: NestedTensor): xs = self.body(tensor_list.tensors) # xs = OrderedDict() # xs['0'] = self.sam_encoder(tensor_list.tensors) out: Dict[str, NestedTensor] = {} for name, x in xs.items(): m = tensor_list.mask assert m is not None mask = F.interpolate(m[None].float(), size=x.shape[-2:]).to(torch.bool)[0] out[name] = NestedTensor(x, mask) return out class Backbone(BackboneBase): """ResNet backbone with frozen BatchNorm.""" def __init__(self, name: str, train_backbone: bool, return_interm_layers: bool, dilation: bool): if name in ['resnet18', 'resnet34', 'resnet50', 'resnet101']: backbone = getattr(torchvision.models, name)( replace_stride_with_dilation=[False, False, dilation], pretrained=is_main_process(), norm_layer=FrozenBatchNorm2d) num_channels = 512 if name in ('resnet18', 'resnet34') else 2048 super().__init__(backbone, train_backbone, num_channels, return_interm_layers) class Joiner(nn.Sequential): def __init__(self, backbone, position_embedding): super().__init__(backbone, position_embedding) def forward(self, tensor_list: NestedTensor): xs = self[0](tensor_list) out: List[NestedTensor] = [] pos = [] for name, x in xs.items(): out.append(x) # position encoding pos.append(self[1](x).to(x.tensors.dtype)) return out, pos # def build_backbone(args): # position_embedding = build_position_encoding(args) # # sam = sam_model_registry["vit_b"](checkpoint="sam_vit_b_01ec64.pth") # # position_embedding = sam.prompt_encoder.get_dense_pe() # (bs, 256, 64, 64) # train_backbone = args.lr_backbone > 0 # return_interm_layers = args.masks # backbone = Backbone(args.backbone, train_backbone, return_interm_layers, args.dilation) # model = Joiner(backbone, position_embedding) # model.num_channels = backbone.num_channels # return model def build_backbone(args): """ Useful args: - backbone: backbone name - lr_backbone: - dilation - return_interm_indices: available: [0,1,2,3], [1,2,3], [3] - backbone_freeze_keywords: - use_checkpoint: for swin only for now """ backbone_freeze_keywords = None use_checkpoint = True position_embedding = build_position_encoding(args) train_backbone = args.lr_backbone > 0 if not train_backbone: raise ValueError("Please set lr_backbone > 0") return_interm_indices = [1, 2, 3] assert return_interm_indices in [[0,1,2,3], [1,2,3], [3]] if args.backbone in ['resnet50', 'resnet101']: backbone = Backbone(args.backbone, train_backbone, args.dilation, return_interm_indices, batch_norm=FrozenBatchNorm2d) bb_num_channels = backbone.num_channels elif args.backbone in ['swin_T_224_1k', 'swin_B_224_22k', 'swin_B_384_22k', 'swin_L_224_22k', 'swin_L_384_22k']: pretrain_img_size = int(args.backbone.split('_')[-2]) backbone = build_swin_transformer(args.backbone, \ pretrain_img_size=pretrain_img_size, \ out_indices=tuple(return_interm_indices), \ dilation=args.dilation, use_checkpoint=use_checkpoint) # freeze some layers if backbone_freeze_keywords is not None: for name, parameter in backbone.named_parameters(): for keyword in backbone_freeze_keywords: if keyword in name: parameter.requires_grad_(False) break if hasattr(args, "backbone_dir"): pretrained_dir = args.backbone_dir PTDICT = { 'swin_T_224_1k': 'swin_tiny_patch4_window7_224.pth', 'swin_B_224_22k': 'swin_base_patch4_window7_224_22kto1k.pth', 'swin_B_384_22k': 'swin_base_patch4_window12_384.pth', 'swin_L_384_22k': 'swin_large_patch4_window12_384_22k.pth', } pretrainedpath = os.path.join(pretrained_dir, PTDICT[args.backbone]) checkpoint = torch.load(pretrainedpath, map_location='cpu', weights_only=False)['model'] from collections import OrderedDict def key_select_function(keyname): if 'head' in keyname: return False if args.dilation and 'layers.3' in keyname: return False return True _tmp_st = OrderedDict({k:v for k, v in clean_state_dict(checkpoint).items() if key_select_function(k)}) _tmp_st_output = backbone.load_state_dict(_tmp_st, strict=False) print(str(_tmp_st_output)) bb_num_channels = backbone.num_features[4 - len(return_interm_indices):] # elif args.backbone in ['convnext_xlarge_22k']: # backbone = build_convnext(modelname=args.backbone, pretrained=True, out_indices=tuple(return_interm_indices),backbone_dir=args.backbone_dir) # bb_num_channels = backbone.dims[4 - len(return_interm_indices):] else: raise NotImplementedError("Unknown backbone {}".format(args.backbone)) assert len(bb_num_channels) == len(return_interm_indices), f"len(bb_num_channels) {len(bb_num_channels)} != len(return_interm_indices) {len(return_interm_indices)}" model = Joiner(backbone, position_embedding) model.num_channels = bb_num_channels assert isinstance(bb_num_channels, List), "bb_num_channels is expected to be a List but {}".format(type(bb_num_channels)) return model