core-jepa / src /backbone.py
Gajesh Ladhar
initial src and benchmark added
c71037b
raw
history blame contribute delete
956 Bytes
import torch
from torch import nn
from mapminer import models
from torchvision.ops import MLP
class LeJEPA(nn.Module):
def __init__(self,out_dims=1024):
super().__init__()
self.encoder = models.DINOv3(architecture="vit-l-sat", pretrained=True)
self.mlp = MLP(1024, [2048, 2048, out_dims], norm_layer=nn.BatchNorm1d)
def forward(self,x):
"""
x : shape (N, V, C, H, W)
out : shape (N, V, D,H/16,W/16)
"""
N, V, C, H, W = x.shape
x = x.reshape(N * V, C, H, W)
x = self.encoder.model.forward_features(x)['x_norm_clstoken']
proj = self.mlp(x.reshape(-1,x.shape[1]))
if len(x.shape)>2 :
_,Dx,H,W = x.shape
_,Dp = proj.shape
proj = proj.reshape(N, V, Dp, H, W)
x = x.reshape(N, V, Dx, H, W)
else :
proj = proj.reshape(N, V, -1)
x = x.reshape(N, V, -1)
return x,proj