File size: 956 Bytes
c71037b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
import torch
from torch import nn
from mapminer import models
from torchvision.ops import MLP

class LeJEPA(nn.Module):
    def __init__(self,out_dims=1024):
        super().__init__()
        self.encoder = models.DINOv3(architecture="vit-l-sat", pretrained=True)
        self.mlp = MLP(1024, [2048, 2048, out_dims], norm_layer=nn.BatchNorm1d)

    def forward(self,x):
        """
        x : shape (N, V, C, H, W)
        out : shape (N, V, D,H/16,W/16)
        """
        N, V, C, H, W = x.shape
        x = x.reshape(N * V, C, H, W)
        x = self.encoder.model.forward_features(x)['x_norm_clstoken']

        proj = self.mlp(x.reshape(-1,x.shape[1]))

        if len(x.shape)>2 :
            _,Dx,H,W = x.shape
            _,Dp = proj.shape
            proj = proj.reshape(N, V, Dp, H, W)
            x = x.reshape(N, V, Dx, H, W)

        else :
            proj = proj.reshape(N, V, -1)
            x = x.reshape(N, V, -1)

        return x,proj