Prior2DSM / src /dinov3 /eval /detection /models /backbone.py
osherr's picture
Upload 222 files
bc90483 verified
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This software may be used and distributed in accordance with
# the terms of the DINOv3 License Agreement.
# ------------------------------------------------------------------------
# Plain-DETR
# Copyright (c) 2023 Xi'an Jiaotong University & Microsoft Research Asia.
# Licensed under The MIT License [see LICENSE for details]
# ------------------------------------------------------------------------
# Deformable DETR
# Copyright (c) 2020 SenseTime. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
# ------------------------------------------------------------------------
# Modified from DETR (https://github.com/facebookresearch/detr)
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
# ------------------------------------------------------------------------
"""
Backbone modules.
"""
import logging
from typing import List, Optional, Union
import torch
import torch.nn.functional as F
from torch import nn
from ..util.misc import NestedTensor
from .position_encoding import build_position_encoding
from .utils import LayerNorm2D
from .windows import WindowsWrapper
logger = logging.getLogger("dinov3")
class DINOBackbone(nn.Module):
def __init__(
self,
backbone_model: nn.Module,
train_backbone: bool,
blocks_to_train: Optional[List[str]] = None,
layers_to_use: Union[int, List] = 1,
use_layernorm: bool = True,
):
super().__init__()
self.backbone = backbone_model
self.blocks_to_train = blocks_to_train
self.patch_size = self.backbone.patch_size
self.use_layernorm = use_layernorm
for _, (name, parameter) in enumerate(self.backbone.named_parameters()):
train_condition = any(f".{b}." in name for b in self.blocks_to_train) if self.blocks_to_train else True
if (not train_backbone) or "mask_token" in name or (not train_condition):
parameter.requires_grad_(False)
self.strides = [self.backbone.patch_size]
# get embed_dim for each intermediate output
n_all_layers = self.backbone.n_blocks
blocks_to_take = (
range(n_all_layers - layers_to_use, n_all_layers) if isinstance(layers_to_use, int) else layers_to_use
)
# if models do not define embed_dims, repeat embed_dim n_blocks times
embed_dims = getattr(self.backbone, "embed_dims", [self.backbone.embed_dim] * self.backbone.n_blocks)
embed_dims = [embed_dims[i] for i in range(n_all_layers) if i in blocks_to_take]
if self.use_layernorm:
self.layer_norms = nn.ModuleList([LayerNorm2D(embed_dim) for embed_dim in embed_dims])
self.num_channels = [sum(embed_dims)]
self.layers_to_use = layers_to_use
def forward(self, tensor_list: NestedTensor):
xs = self.backbone.get_intermediate_layers(tensor_list.tensors, n=self.layers_to_use, reshape=True)
if self.use_layernorm:
xs = [ln(x).contiguous() for ln, x in zip(self.layer_norms, xs)]
xs = [torch.cat(xs, axis=1)]
out: list[NestedTensor] = []
for x in xs:
m = tensor_list.mask
assert m is not None
mask = F.interpolate(m[None].float(), size=x.shape[-2:]).to(torch.bool)[0]
out.append(NestedTensor(x, mask))
return out
class BackboneWithPositionEncoding(nn.Sequential):
def __init__(self, backbone, position_embedding):
super().__init__(backbone, position_embedding)
self.strides = backbone.strides
self.num_channels = backbone.num_channels
def forward(self, tensor_list: NestedTensor):
out: List[NestedTensor] = list(self[0](tensor_list))
pos = [self[1][idx](x).to(x.tensors.dtype) for idx, x in enumerate(out)]
return out, pos
def build_backbone(backbone_model, args):
position_embedding = build_position_encoding(args)
train_backbone = False
backbone = DINOBackbone(
backbone_model, train_backbone, args.blocks_to_train, args.layers_to_use, args.backbone_use_layernorm
)
if args.n_windows_sqrt > 0:
logger.info(f"Wrapping with {args.n_windows_sqrt} x {args.n_windows_sqrt} windows")
backbone = WindowsWrapper(
backbone, n_windows_w=args.n_windows_sqrt, n_windows_h=args.n_windows_sqrt, patch_size=backbone.patch_size
)
else:
logger.info("Not wrapping with windows")
return BackboneWithPositionEncoding(backbone, position_embedding)