Spaces:
Sleeping
Sleeping
File size: 6,102 Bytes
11aa70b | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 | """
DEIMv2: Real-Time Object Detection Meets DINOv3
Copyright (c) 2025 The DEIMv2 Authors. All Rights Reserved.
---------------------------------------------------------------------------------
Modified from DINOv3 (https://github.com/facebookresearch/dinov3)
Copyright (c) Meta Platforms, Inc. and affiliates.
This software may be used and distributed in accordance with
the terms of the DINOv3 License Agreement.
"""
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint as cp
from functools import partial
from ..core import register
from .vit_tiny import VisionTransformer
from .dinov3 import DinoVisionTransformer
class SpatialPriorModulev2(nn.Module):
def __init__(self, inplanes=16):
super().__init__()
# 1/4
self.stem = nn.Sequential(
*[
nn.Conv2d(3, inplanes, kernel_size=3, stride=2, padding=1, bias=False),
nn.SyncBatchNorm(inplanes),
nn.GELU(),
nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
]
)
# 1/8
self.conv2 = nn.Sequential(
*[
nn.Conv2d(inplanes, 2 * inplanes, kernel_size=3, stride=2, padding=1, bias=False),
nn.SyncBatchNorm(2 * inplanes),
]
)
# 1/16
self.conv3 = nn.Sequential(
*[
nn.GELU(),
nn.Conv2d(2 * inplanes, 4 * inplanes, kernel_size=3, stride=2, padding=1, bias=False),
nn.SyncBatchNorm(4 * inplanes),
]
)
# 1/32
self.conv4 = nn.Sequential(
*[
nn.GELU(),
nn.Conv2d(4 * inplanes, 4 * inplanes, kernel_size=3, stride=2, padding=1, bias=False),
nn.SyncBatchNorm(4 * inplanes),
]
)
def forward(self, x):
c1 = self.stem(x)
c2 = self.conv2(c1) # 1/8
c3 = self.conv3(c2) # 1/16
c4 = self.conv4(c3) # 1/32
return c2, c3, c4
@register()
class DINOv3STAs(nn.Module):
def __init__(
self,
name=None,
weights_path=None,
interaction_indexes=[],
finetune=True,
embed_dim=192,
num_heads=3,
patch_size=16,
use_sta=True,
conv_inplane=16,
hidden_dim=None,
):
super(DINOv3STAs, self).__init__()
if 'dinov3' in name:
self.dinov3 = DinoVisionTransformer(name=name)
if weights_path is not None and os.path.exists(weights_path):
print(f'Loading ckpt from {weights_path}...')
self.dinov3.load_state_dict(torch.load(weights_path))
else:
print('Training DINOv3 from scratch...')
else:
self.dinov3 = VisionTransformer(embed_dim=embed_dim, num_heads=num_heads, return_layers=interaction_indexes)
if weights_path is not None and os.path.exists(weights_path):
print(f'Loading ckpt from {weights_path}...')
self.dinov3._model.load_state_dict(torch.load(weights_path))
else:
print('Training ViT-Tiny from scratch...')
embed_dim = self.dinov3.embed_dim
self.interaction_indexes = interaction_indexes
self.patch_size = patch_size
if not finetune:
self.dinov3.eval()
self.dinov3.requires_grad_(False)
# init the feature pyramid
self.use_sta = use_sta
if use_sta:
print(f"Using Lite Spatial Prior Module with inplanes={conv_inplane}")
self.sta = SpatialPriorModulev2(inplanes=conv_inplane)
else:
conv_inplane = 0
# linear projection
hidden_dim = hidden_dim if hidden_dim is not None else embed_dim
self.convs = nn.ModuleList([
nn.Conv2d(embed_dim + conv_inplane*2, hidden_dim, kernel_size=1, stride=1, padding=0, bias=False),
nn.Conv2d(embed_dim + conv_inplane*4, hidden_dim, kernel_size=1, stride=1, padding=0, bias=False),
nn.Conv2d(embed_dim + conv_inplane*4, hidden_dim, kernel_size=1, stride=1, padding=0, bias=False)
])
# norm
self.norms = nn.ModuleList([
nn.SyncBatchNorm(hidden_dim),
nn.SyncBatchNorm(hidden_dim),
nn.SyncBatchNorm(hidden_dim)
])
def forward(self, x):
# Code for matching with oss
H_c, W_c = x.shape[2] // 16, x.shape[3] // 16
H_toks, W_toks = x.shape[2] // self.patch_size, x.shape[3] // self.patch_size
bs, C, h, w = x.shape
if len(self.interaction_indexes) > 0 and not isinstance(self.dinov3, VisionTransformer):
all_layers = self.dinov3.get_intermediate_layers(
x, n=self.interaction_indexes, return_class_token=True
)
else:
all_layers = self.dinov3(x)
if len(all_layers) == 1: # repeat the same layer for all the three scales
all_layers = [all_layers[0], all_layers[0], all_layers[0]]
sem_feats = []
num_scales = len(all_layers) - 2
for i, sem_feat in enumerate(all_layers):
feat, _ = sem_feat
sem_feat = feat.transpose(1, 2).view(bs, -1, H_c, W_c).contiguous() # [B, D, H, W]
resize_H, resize_W = int(H_c * 2**(num_scales-i)), int(W_c * 2**(num_scales-i))
sem_feat = F.interpolate(sem_feat, size=[resize_H, resize_W], mode="bilinear", align_corners=False)
sem_feats.append(sem_feat)
# fusion
fused_feats = []
if self.use_sta:
detail_feats = self.sta(x)
for sem_feat, detail_feat in zip(sem_feats, detail_feats):
fused_feats.append(torch.cat([sem_feat, detail_feat], dim=1))
else:
fused_feats = sem_feats
c2 = self.norms[0](self.convs[0](fused_feats[0]))
c3 = self.norms[1](self.convs[1](fused_feats[1]))
c4 = self.norms[2](self.convs[2](fused_feats[2]))
return c2, c3, c4 |