"""Complete Grid-JEPA system.""" import random import torch import torch.nn as nn import torch.optim as optim from torch.utils.data import DataLoader try: from .encoder import GridPatchEmbed, ViTEncoder from .predictor import DiscreteActionEmbed, ActionConditionedPredictor except ImportError: from encoder import GridPatchEmbed, ViTEncoder from predictor import DiscreteActionEmbed, ActionConditionedPredictor class GridJEPA(nn.Module): def __init__(self, num_colors=16, embed_dim=384, encoder_depth=12, predictor_depth=12, num_heads=6, mlp_ratio=4.0, max_grid_size=64, ema_decay=0.996, num_key_actions=6, action_embed_dim=64, use_action_conditioning=True): super().__init__() self.num_colors = num_colors self.embed_dim = embed_dim self.max_grid_size = max_grid_size self.num_patches = max_grid_size * max_grid_size self.use_action_conditioning = use_action_conditioning self.ema_decay = ema_decay self.patch_embed = GridPatchEmbed(num_colors, embed_dim, max_grid_size) self.context_encoder = ViTEncoder(embed_dim, encoder_depth, num_heads, mlp_ratio) self.target_encoder = ViTEncoder(embed_dim, encoder_depth, num_heads, mlp_ratio) self.target_encoder.load_state_dict(self.context_encoder.state_dict()) for p in self.target_encoder.parameters(): p.requires_grad = False self.action_embed = DiscreteActionEmbed(num_key_actions, max_grid_size, action_embed_dim) self.predictor = ActionConditionedPredictor(self.num_patches, embed_dim, action_embed_dim, predictor_depth, num_heads) def forward(self, grid, context_mask, target_mask, action_key=None, action_pos=None): B = grid.shape[0] N = self.num_patches x = self.patch_embed(grid) ctx_x = x.clone() ctx_x[~context_mask] = 0 context_repr = self.context_encoder(ctx_x) with torch.no_grad(): target_repr = self.target_encoder(x) target_indices = [target_mask[b].nonzero(as_tuple=True)[0] for b in range(B)] if self.use_action_conditioning and action_key is not None: action_emb = self.action_embed(action_key, action_pos) else: action_emb = torch.zeros(B, 64, device=grid.device) losses = [] for b in range(B): tgt_idx = target_indices[b] if len(tgt_idx) == 0: continue pred = self.predictor(context_repr[b:b+1], action_emb[b:b+1], tgt_idx.unsqueeze(0)) tgt = target_repr[b, tgt_idx].unsqueeze(0) losses.append((pred - tgt).pow(2).sum(dim=-1).mean()) if len(losses) == 0: return torch.tensor(0.0, device=grid.device), {} total = torch.stack(losses).mean() metrics = {"loss": total.item(), "num_targets": sum(len(t) for t in target_indices) / B} return total, metrics def update_ema(self): with torch.no_grad(): for pt, pc in zip(self.target_encoder.parameters(), self.context_encoder.parameters()): pt.data.mul_(self.ema_decay).add_(pc.data, alpha=1 - self.ema_decay) def encode(self, grid): return self.context_encoder(self.patch_embed(grid)) @torch.no_grad() def encode_target(self, grid): return self.target_encoder(self.patch_embed(grid)) if __name__ == "__main__": B, G, C = 2, 10, 10 device = "cuda" if torch.cuda.is_available() else "cpu" model = GridJEPA(C, 128, 4, 4, 4, max_grid_size=G).to(device) grid = torch.randint(0, C, (B, G, G), device=device) N = G * G ctx_mask = torch.zeros(B, N, dtype=torch.bool, device=device) tgt_mask = torch.zeros(B, N, dtype=torch.bool, device=device) for b in range(B): idx = list(range(N)) random.shuffle(idx) split = N // 2 ctx_mask[b, idx[:split]] = True tgt_mask[b, idx[split:]] = True action_key = torch.randint(0, 6, (B,), device=device) action_pos = torch.randint(0, N, (B,), device=device) loss, metrics = model(grid, ctx_mask, tgt_mask, action_key, action_pos) print(f"Loss: {loss.item():.2f}, Metrics: {metrics}") model.update_ema() print("EMA OK") print(f"Encode: {model.encode(grid).shape}")