| import torch | |
| from torch import nn | |
| from mapminer import models | |
| from torchvision.ops import MLP | |
| class LeJEPA(nn.Module): | |
| def __init__(self,out_dims=1024): | |
| super().__init__() | |
| self.encoder = models.DINOv3(architecture="vit-l-sat", pretrained=True) | |
| self.mlp = MLP(1024, [2048, 2048, out_dims], norm_layer=nn.BatchNorm1d) | |
| def forward(self,x): | |
| """ | |
| x : shape (N, V, C, H, W) | |
| out : shape (N, V, D,H/16,W/16) | |
| """ | |
| N, V, C, H, W = x.shape | |
| x = x.reshape(N * V, C, H, W) | |
| x = self.encoder.model.forward_features(x)['x_norm_clstoken'] | |
| proj = self.mlp(x.reshape(-1,x.shape[1])) | |
| if len(x.shape)>2 : | |
| _,Dx,H,W = x.shape | |
| _,Dp = proj.shape | |
| proj = proj.reshape(N, V, Dp, H, W) | |
| x = x.reshape(N, V, Dx, H, W) | |
| else : | |
| proj = proj.reshape(N, V, -1) | |
| x = x.reshape(N, V, -1) | |
| return x,proj |