| |
| |
| |
| import os |
| current_dir_path = os.path.dirname(__file__) |
| import torch |
| from torch import nn |
|
|
| class Dinov2Backbone(nn.Module): |
| def __init__(self, name='dinov2_vitb14', pretrained=False, *args, **kwargs): |
| super().__init__() |
| self.name = name |
| self.encoder = torch.hub.load(current_dir_path+'/../dinov2', self.name, pretrained=pretrained, source='local') |
| self.patch_size = self.encoder.patch_size |
| self.embed_dim = self.encoder.embed_dim |
|
|
| def forward(self, x): |
| """ |
| Encode a RGB image using a ViT-backbone |
| Args: |
| - x: torch.Tensor of shape [bs,3,w,h] |
| Return: |
| - y: torch.Tensor of shape [bs,k,d] - image in patchified mode |
| """ |
| assert len(x.shape) == 4 |
| y = self.encoder.get_intermediate_layers(x)[0] |
| return y |
|
|
|
|