| """Complete Grid-JEPA system.""" |
| import random |
| import torch |
| import torch.nn as nn |
| import torch.optim as optim |
| from torch.utils.data import DataLoader |
|
|
| try: |
| from .encoder import GridPatchEmbed, ViTEncoder |
| from .predictor import DiscreteActionEmbed, ActionConditionedPredictor |
| except ImportError: |
| from encoder import GridPatchEmbed, ViTEncoder |
| from predictor import DiscreteActionEmbed, ActionConditionedPredictor |
|
|
| class GridJEPA(nn.Module): |
| def __init__(self, num_colors=16, embed_dim=384, encoder_depth=12, predictor_depth=12, |
| num_heads=6, mlp_ratio=4.0, max_grid_size=64, ema_decay=0.996, |
| num_key_actions=6, action_embed_dim=64, use_action_conditioning=True): |
| super().__init__() |
| self.num_colors = num_colors |
| self.embed_dim = embed_dim |
| self.max_grid_size = max_grid_size |
| self.num_patches = max_grid_size * max_grid_size |
| self.use_action_conditioning = use_action_conditioning |
| self.ema_decay = ema_decay |
|
|
| self.patch_embed = GridPatchEmbed(num_colors, embed_dim, max_grid_size) |
| self.context_encoder = ViTEncoder(embed_dim, encoder_depth, num_heads, mlp_ratio) |
| self.target_encoder = ViTEncoder(embed_dim, encoder_depth, num_heads, mlp_ratio) |
| self.target_encoder.load_state_dict(self.context_encoder.state_dict()) |
| for p in self.target_encoder.parameters(): |
| p.requires_grad = False |
|
|
| self.action_embed = DiscreteActionEmbed(num_key_actions, max_grid_size, action_embed_dim) |
| self.predictor = ActionConditionedPredictor(self.num_patches, embed_dim, action_embed_dim, predictor_depth, num_heads) |
|
|
| def forward(self, grid, context_mask, target_mask, action_key=None, action_pos=None): |
| B = grid.shape[0] |
| N = self.num_patches |
| x = self.patch_embed(grid) |
| ctx_x = x.clone() |
| ctx_x[~context_mask] = 0 |
| context_repr = self.context_encoder(ctx_x) |
| with torch.no_grad(): |
| target_repr = self.target_encoder(x) |
|
|
| target_indices = [target_mask[b].nonzero(as_tuple=True)[0] for b in range(B)] |
|
|
| if self.use_action_conditioning and action_key is not None: |
| action_emb = self.action_embed(action_key, action_pos) |
| else: |
| action_emb = torch.zeros(B, 64, device=grid.device) |
|
|
| losses = [] |
| for b in range(B): |
| tgt_idx = target_indices[b] |
| if len(tgt_idx) == 0: |
| continue |
| pred = self.predictor(context_repr[b:b+1], action_emb[b:b+1], tgt_idx.unsqueeze(0)) |
| tgt = target_repr[b, tgt_idx].unsqueeze(0) |
| losses.append((pred - tgt).pow(2).sum(dim=-1).mean()) |
|
|
| if len(losses) == 0: |
| return torch.tensor(0.0, device=grid.device), {} |
|
|
| total = torch.stack(losses).mean() |
| metrics = {"loss": total.item(), "num_targets": sum(len(t) for t in target_indices) / B} |
| return total, metrics |
|
|
| def update_ema(self): |
| with torch.no_grad(): |
| for pt, pc in zip(self.target_encoder.parameters(), self.context_encoder.parameters()): |
| pt.data.mul_(self.ema_decay).add_(pc.data, alpha=1 - self.ema_decay) |
|
|
| def encode(self, grid): |
| return self.context_encoder(self.patch_embed(grid)) |
|
|
| @torch.no_grad() |
| def encode_target(self, grid): |
| return self.target_encoder(self.patch_embed(grid)) |
|
|
| if __name__ == "__main__": |
| B, G, C = 2, 10, 10 |
| device = "cuda" if torch.cuda.is_available() else "cpu" |
| model = GridJEPA(C, 128, 4, 4, 4, max_grid_size=G).to(device) |
| grid = torch.randint(0, C, (B, G, G), device=device) |
| N = G * G |
| ctx_mask = torch.zeros(B, N, dtype=torch.bool, device=device) |
| tgt_mask = torch.zeros(B, N, dtype=torch.bool, device=device) |
| for b in range(B): |
| idx = list(range(N)) |
| random.shuffle(idx) |
| split = N // 2 |
| ctx_mask[b, idx[:split]] = True |
| tgt_mask[b, idx[split:]] = True |
| action_key = torch.randint(0, 6, (B,), device=device) |
| action_pos = torch.randint(0, N, (B,), device=device) |
| loss, metrics = model(grid, ctx_mask, tgt_mask, action_key, action_pos) |
| print(f"Loss: {loss.item():.2f}, Metrics: {metrics}") |
| model.update_ema() |
| print("EMA OK") |
| print(f"Encode: {model.encode(grid).shape}") |
|
|