| | |
| | |
| | |
| | import os |
| | current_dir_path = os.path.dirname(__file__) |
| | import torch |
| | from torch import nn |
| |
|
| | class Dinov2Backbone(nn.Module): |
| | def __init__(self, name='dinov2_vitb14', pretrained=False, *args, **kwargs): |
| | super().__init__() |
| | self.name = name |
| | self.encoder = torch.hub.load(current_dir_path+'/../dinov2', self.name, pretrained=pretrained, source='local') |
| | self.patch_size = self.encoder.patch_size |
| | self.embed_dim = self.encoder.embed_dim |
| |
|
| | def forward(self, x): |
| | """ |
| | Encode a RGB image using a ViT-backbone |
| | Args: |
| | - x: torch.Tensor of shape [bs,3,w,h] |
| | Return: |
| | - y: torch.Tensor of shape [bs,k,d] - image in patchified mode |
| | """ |
| | assert len(x.shape) == 4 |
| | y = self.encoder.get_intermediate_layers(x)[0] |
| | return y |
| |
|
| |
|