|
|
| import torch |
| import torch.nn as nn |
|
|
| class SimpleTransformer(nn.Module): |
| def __init__(self, dim, num_heads, depth): |
| super().__init__() |
| self.layers = nn.ModuleList([ |
| nn.TransformerEncoderLayer(d_model=dim, nhead=num_heads) |
| for _ in range(depth) |
| ]) |
| |
| def forward(self, x): |
| for layer in self.layers: |
| x = layer(x) |
| return x |
|
|
| class JEPA(nn.Module): |
| def __init__(self, patch_dim, embed_dim=256, num_heads=8, depth=4): |
| super().__init__() |
| self.ctx_encoder = SimpleTransformer(patch_dim, num_heads, depth) |
| self.tgt_encoder = SimpleTransformer(patch_dim, num_heads, depth) |
| self.pred_head = nn.Sequential( |
| nn.Linear(patch_dim, embed_dim), |
| nn.ReLU(), |
| nn.Linear(embed_dim, embed_dim) |
| ) |
|
|
| def forward(self, context_patches, target_patches): |
| ctx_emb = self.ctx_encoder(context_patches) |
| tgt_emb = self.tgt_encoder(target_patches) |
| pred_emb = self.pred_head(ctx_emb) |
| return pred_emb, tgt_emb |
|
|