import torch from mmengine.model import BaseModule from torch import nn from mmseg.registry import MODELS import os _DINOV2_MMSEG_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..')) _DINOV2_TORCHHUB_DIR = os.path.join(_DINOV2_MMSEG_ROOT, 'torchhub', 'facebookresearch_dinov2_main') @MODELS.register_module() class DINOv2(nn.Module): """Use DINOv2 pre-trained models """ def __init__(self, version='large', freeze=False, load_from=None): super().__init__() if version == 'large': self.dinov2 = torch.hub.load(_DINOV2_TORCHHUB_DIR, 'dinov2_vitl14', source='local', pretrained=False) else: raise NotImplementedError if load_from is not None: if load_from.split('/')[-1] == 'depth_anything_vitl14.pth': print(load_from) d = torch.load(load_from, map_location='cpu') new_d = {} for key, value in d.items(): if 'pretrained' in key: new_d[key.replace('pretrained.', '')] = value self.dinov2.load_state_dict(new_d) else: print(load_from) all_d = torch.load(load_from, map_location='cpu') d = all_d['state_dict'] new_d = {} for key, value in d.items(): if 'backbone.dinov2' in key: new_d[key.replace('backbone.dinov2.', '')] = value self.dinov2.load_state_dict(new_d) self.freeze = freeze def forward(self, inputs): B, _, h, w = inputs.shape if self.freeze: with torch.no_grad(): features = self.dinov2.get_intermediate_layers(inputs, 4) else: features = self.dinov2.get_intermediate_layers(inputs, 4) outs = [] for feature in features: C = feature.shape[-1] feature = feature.permute(0, 2, 1).reshape(B, C, h // 14, w // 14).contiguous() outs.append(feature) return outs