""" DEIMv2: Real-Time Object Detection Meets DINOv3 Copyright (c) 2025 The DEIMv2 Authors. All Rights Reserved. --------------------------------------------------------------------------------- Modified from DINOv3 (https://github.com/facebookresearch/dinov3) Copyright (c) Meta Platforms, Inc. and affiliates. This software may be used and distributed in accordance with the terms of the DINOv3 License Agreement. """ import os import torch import torch.nn as nn import torch.nn.functional as F import torch.utils.checkpoint as cp from functools import partial from ..core import register from .vit_tiny import VisionTransformer from .dinov3 import DinoVisionTransformer class SpatialPriorModulev2(nn.Module): def __init__(self, inplanes=16): super().__init__() # 1/4 self.stem = nn.Sequential( *[ nn.Conv2d(3, inplanes, kernel_size=3, stride=2, padding=1, bias=False), nn.SyncBatchNorm(inplanes), nn.GELU(), nn.MaxPool2d(kernel_size=3, stride=2, padding=1), ] ) # 1/8 self.conv2 = nn.Sequential( *[ nn.Conv2d(inplanes, 2 * inplanes, kernel_size=3, stride=2, padding=1, bias=False), nn.SyncBatchNorm(2 * inplanes), ] ) # 1/16 self.conv3 = nn.Sequential( *[ nn.GELU(), nn.Conv2d(2 * inplanes, 4 * inplanes, kernel_size=3, stride=2, padding=1, bias=False), nn.SyncBatchNorm(4 * inplanes), ] ) # 1/32 self.conv4 = nn.Sequential( *[ nn.GELU(), nn.Conv2d(4 * inplanes, 4 * inplanes, kernel_size=3, stride=2, padding=1, bias=False), nn.SyncBatchNorm(4 * inplanes), ] ) def forward(self, x): c1 = self.stem(x) c2 = self.conv2(c1) # 1/8 c3 = self.conv3(c2) # 1/16 c4 = self.conv4(c3) # 1/32 return c2, c3, c4 @register() class DINOv3STAs(nn.Module): def __init__( self, name=None, weights_path=None, interaction_indexes=[], finetune=True, embed_dim=192, num_heads=3, patch_size=16, use_sta=True, conv_inplane=16, hidden_dim=None, ): super(DINOv3STAs, self).__init__() if 'dinov3' in name: self.dinov3 = DinoVisionTransformer(name=name) if weights_path is not None and os.path.exists(weights_path): print(f'Loading ckpt from {weights_path}...') self.dinov3.load_state_dict(torch.load(weights_path)) else: print('Training DINOv3 from scratch...') else: self.dinov3 = VisionTransformer(embed_dim=embed_dim, num_heads=num_heads, return_layers=interaction_indexes) if weights_path is not None and os.path.exists(weights_path): print(f'Loading ckpt from {weights_path}...') self.dinov3._model.load_state_dict(torch.load(weights_path)) else: print('Training ViT-Tiny from scratch...') embed_dim = self.dinov3.embed_dim self.interaction_indexes = interaction_indexes self.patch_size = patch_size if not finetune: self.dinov3.eval() self.dinov3.requires_grad_(False) # init the feature pyramid self.use_sta = use_sta if use_sta: print(f"Using Lite Spatial Prior Module with inplanes={conv_inplane}") self.sta = SpatialPriorModulev2(inplanes=conv_inplane) else: conv_inplane = 0 # linear projection hidden_dim = hidden_dim if hidden_dim is not None else embed_dim self.convs = nn.ModuleList([ nn.Conv2d(embed_dim + conv_inplane*2, hidden_dim, kernel_size=1, stride=1, padding=0, bias=False), nn.Conv2d(embed_dim + conv_inplane*4, hidden_dim, kernel_size=1, stride=1, padding=0, bias=False), nn.Conv2d(embed_dim + conv_inplane*4, hidden_dim, kernel_size=1, stride=1, padding=0, bias=False) ]) # norm self.norms = nn.ModuleList([ nn.SyncBatchNorm(hidden_dim), nn.SyncBatchNorm(hidden_dim), nn.SyncBatchNorm(hidden_dim) ]) def forward(self, x): # Code for matching with oss H_c, W_c = x.shape[2] // 16, x.shape[3] // 16 H_toks, W_toks = x.shape[2] // self.patch_size, x.shape[3] // self.patch_size bs, C, h, w = x.shape if len(self.interaction_indexes) > 0 and not isinstance(self.dinov3, VisionTransformer): all_layers = self.dinov3.get_intermediate_layers( x, n=self.interaction_indexes, return_class_token=True ) else: all_layers = self.dinov3(x) if len(all_layers) == 1: # repeat the same layer for all the three scales all_layers = [all_layers[0], all_layers[0], all_layers[0]] sem_feats = [] num_scales = len(all_layers) - 2 for i, sem_feat in enumerate(all_layers): feat, _ = sem_feat sem_feat = feat.transpose(1, 2).view(bs, -1, H_c, W_c).contiguous() # [B, D, H, W] resize_H, resize_W = int(H_c * 2**(num_scales-i)), int(W_c * 2**(num_scales-i)) sem_feat = F.interpolate(sem_feat, size=[resize_H, resize_W], mode="bilinear", align_corners=False) sem_feats.append(sem_feat) # fusion fused_feats = [] if self.use_sta: detail_feats = self.sta(x) for sem_feat, detail_feat in zip(sem_feats, detail_feats): fused_feats.append(torch.cat([sem_feat, detail_feat], dim=1)) else: fused_feats = sem_feats c2 = self.norms[0](self.convs[0](fused_feats[0])) c3 = self.norms[1](self.convs[1](fused_feats[1])) c4 = self.norms[2](self.convs[2](fused_feats[2])) return c2, c3, c4