import torch from torch import nn from mapminer import models from torchvision.ops import MLP class LeJEPA(nn.Module): def __init__(self,out_dims=1024): super().__init__() self.encoder = models.DINOv3(architecture="vit-l-sat", pretrained=True) self.mlp = MLP(1024, [2048, 2048, out_dims], norm_layer=nn.BatchNorm1d) def forward(self,x): """ x : shape (N, V, C, H, W) out : shape (N, V, D,H/16,W/16) """ N, V, C, H, W = x.shape x = x.reshape(N * V, C, H, W) x = self.encoder.model.forward_features(x)['x_norm_clstoken'] proj = self.mlp(x.reshape(-1,x.shape[1])) if len(x.shape)>2 : _,Dx,H,W = x.shape _,Dp = proj.shape proj = proj.reshape(N, V, Dp, H, W) x = x.reshape(N, V, Dx, H, W) else : proj = proj.reshape(N, V, -1) x = x.reshape(N, V, -1) return x,proj