# ------------------------------------------------------------------------ # RF-DETR # Copyright (c) 2025 Roboflow. All Rights Reserved. # Licensed under the Apache License, Version 2.0 [see LICENSE for details] # ------------------------------------------------------------------------ # Modified from LW-DETR (https://github.com/Atten4Vis/LW-DETR) # Copyright (c) 2024 Baidu. All Rights Reserved. # ------------------------------------------------------------------------ # Modified from Conditional DETR (https://github.com/Atten4Vis/ConditionalDETR) # Copyright (c) 2021 Microsoft. All Rights Reserved. # ------------------------------------------------------------------------ # Copied from DETR (https://github.com/facebookresearch/detr) # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. # ------------------------------------------------------------------------ """ Backbone modules. """ from functools import partial import torch import torch.nn.functional as F from torch import nn from transformers import AutoModel, AutoProcessor, AutoModelForCausalLM, AutoConfig, AutoBackbone from peft import LoraConfig, get_peft_model, PeftModel from rfdetr.util.misc import NestedTensor, is_main_process from rfdetr.models.backbone.base import BackboneBase from rfdetr.models.backbone.projector import MultiScaleProjector from rfdetr.models.backbone.dinov2 import DinoV2 __all__ = ["Backbone"] class Backbone(BackboneBase): """backbone.""" def __init__(self, name: str, pretrained_encoder: str=None, window_block_indexes: list=None, drop_path=0.0, out_channels=256, out_feature_indexes: list=None, projector_scale: list=None, use_cls_token: bool = False, freeze_encoder: bool = False, layer_norm: bool = False, target_shape: tuple[int, int] = (640, 640), rms_norm: bool = False, backbone_lora: bool = False, gradient_checkpointing: bool = False, load_dinov2_weights: bool = True, patch_size: int = 14, num_windows: int = 4, positional_encoding_size: bool = False, ): super().__init__() # an example name here would be "dinov2_base" or "dinov2_registers_windowed_base" # if "registers" is in the name, then use_registers is set to True, otherwise it is set to False # similarly, if "windowed" is in the name, then use_windowed_attn is set to True, otherwise it is set to False # the last part of the name should be the size # and the start should be dinov2 name_parts = name.split("_") assert name_parts[0] == "dinov2" size = name_parts[-1] use_registers = False if "registers" in name_parts: use_registers = True name_parts.remove("registers") use_windowed_attn = False if "windowed" in name_parts: use_windowed_attn = True name_parts.remove("windowed") assert len(name_parts) == 2, "name should be dinov2, then either registers, windowed, both, or none, then the size" self.encoder = DinoV2( size=name_parts[-1], out_feature_indexes=out_feature_indexes, shape=target_shape, use_registers=use_registers, use_windowed_attn=use_windowed_attn, gradient_checkpointing=gradient_checkpointing, load_dinov2_weights=load_dinov2_weights, patch_size=patch_size, num_windows=num_windows, positional_encoding_size=positional_encoding_size, ) # build encoder + projector as backbone module if freeze_encoder: for param in self.encoder.parameters(): param.requires_grad = False self.projector_scale = projector_scale assert len(self.projector_scale) > 0 # x[0] assert ( sorted(self.projector_scale) == self.projector_scale ), "only support projector scale P3/P4/P5/P6 in ascending order." level2scalefactor = dict(P3=2.0, P4=1.0, P5=0.5, P6=0.25) scale_factors = [level2scalefactor[lvl] for lvl in self.projector_scale] self.projector = MultiScaleProjector( in_channels=self.encoder._out_feature_channels, out_channels=out_channels, scale_factors=scale_factors, layer_norm=layer_norm, rms_norm=rms_norm, ) self._export = False def export(self): self._export = True self._forward_origin = self.forward self.forward = self.forward_export if isinstance(self.encoder, PeftModel): print("Merging and unloading LoRA weights") self.encoder.merge_and_unload() def forward(self, tensor_list: NestedTensor): """ """ # (H, W, B, C) feats = self.encoder(tensor_list.tensors) feats = self.projector(feats) # x: [(B, C, H, W)] out = [] for feat in feats: m = tensor_list.mask assert m is not None mask = F.interpolate(m[None].float(), size=feat.shape[-2:]).to(torch.bool)[ 0 ] out.append(NestedTensor(feat, mask)) return out def forward_export(self, tensors: torch.Tensor): feats = self.encoder(tensors) feats = self.projector(feats) out_feats = [] out_masks = [] for feat in feats: # x: [(B, C, H, W)] b, _, h, w = feat.shape out_masks.append( torch.zeros((b, h, w), dtype=torch.bool, device=feat.device) ) out_feats.append(feat) return out_feats, out_masks def get_named_param_lr_pairs(self, args, prefix: str = "backbone.0"): num_layers = args.out_feature_indexes[-1] + 1 backbone_key = "backbone.0.encoder" named_param_lr_pairs = {} for n, p in self.named_parameters(): n = prefix + "." + n if backbone_key in n and p.requires_grad: lr = ( args.lr_encoder * get_dinov2_lr_decay_rate( n, lr_decay_rate=args.lr_vit_layer_decay, num_layers=num_layers, ) * args.lr_component_decay**2 ) wd = args.weight_decay * get_dinov2_weight_decay_rate(n) named_param_lr_pairs[n] = { "params": p, "lr": lr, "weight_decay": wd, } return named_param_lr_pairs def get_dinov2_lr_decay_rate(name, lr_decay_rate=1.0, num_layers=12): """ Calculate lr decay rate for different ViT blocks. Args: name (string): parameter name. lr_decay_rate (float): base lr decay rate. num_layers (int): number of ViT blocks. Returns: lr decay rate for the given parameter. """ layer_id = num_layers + 1 if name.startswith("backbone"): if "embeddings" in name: layer_id = 0 elif ".layer." in name and ".residual." not in name: layer_id = int(name[name.find(".layer.") :].split(".")[2]) + 1 return lr_decay_rate ** (num_layers + 1 - layer_id) def get_dinov2_weight_decay_rate(name, weight_decay_rate=1.0): if ( ("gamma" in name) or ("pos_embed" in name) or ("rel_pos" in name) or ("bias" in name) or ("norm" in name) or ("embeddings" in name) ): weight_decay_rate = 0.0 return weight_decay_rate