Jepa / jepa_model.py
Ananthusajeev190's picture
Upload 5 files
046e256 verified
import torch
import torch.nn as nn
class SimpleTransformer(nn.Module):
def __init__(self, dim, num_heads, depth):
super().__init__()
self.layers = nn.ModuleList([
nn.TransformerEncoderLayer(d_model=dim, nhead=num_heads)
for _ in range(depth)
])
def forward(self, x):
for layer in self.layers:
x = layer(x)
return x
class JEPA(nn.Module):
def __init__(self, patch_dim, embed_dim=256, num_heads=8, depth=4):
super().__init__()
self.ctx_encoder = SimpleTransformer(patch_dim, num_heads, depth)
self.tgt_encoder = SimpleTransformer(patch_dim, num_heads, depth)
self.pred_head = nn.Sequential(
nn.Linear(patch_dim, embed_dim),
nn.ReLU(),
nn.Linear(embed_dim, embed_dim)
)
def forward(self, context_patches, target_patches):
ctx_emb = self.ctx_encoder(context_patches)
tgt_emb = self.tgt_encoder(target_patches)
pred_emb = self.pred_head(ctx_emb)
return pred_emb, tgt_emb