"""Action-Conditioned Predictor for Grid-JEPA.""" from typing import Optional import torch import torch.nn as nn import torch.nn.functional as F class DiscreteActionEmbed(nn.Module): def __init__(self, num_key_actions=6, grid_size=64, action_embed_dim=64): super().__init__() self.num_key_actions = num_key_actions self.grid_size = grid_size self.num_positions = grid_size * grid_size self.key_embed = nn.Embedding(num_key_actions, action_embed_dim) self.pos_embed = nn.Embedding(self.num_positions, action_embed_dim) def forward(self, key_action, cell_position=None): emb = self.key_embed(key_action) if cell_position is not None: emb = emb + self.pos_embed(cell_position) return emb class PredictorBlock(nn.Module): def __init__(self, dim, num_heads, mlp_ratio=4.0): super().__init__() self.norm1 = nn.LayerNorm(dim, eps=1e-6) self.num_heads = num_heads self.head_dim = dim // num_heads self.scale = self.head_dim ** -0.5 self.qkv = nn.Linear(dim, dim * 3) self.proj = nn.Linear(dim, dim) self.norm2 = nn.LayerNorm(dim, eps=1e-6) hidden = int(dim * mlp_ratio) self.mlp = nn.Sequential(nn.Linear(dim, hidden), nn.GELU(), nn.Linear(hidden, dim)) def forward(self, x): B, N, D = x.shape xn = self.norm1(x) qkv = self.qkv(xn).reshape(B, N, 3, self.num_heads, self.head_dim) qkv = qkv.permute(2, 0, 3, 1, 4) q, k, v = qkv[0], qkv[1], qkv[2] attn = F.softmax((q @ k.transpose(-2, -1)) * self.scale, dim=-1) out = (attn @ v).transpose(1, 2).reshape(B, N, D) x = x + self.proj(out) x = x + self.mlp(self.norm2(x)) return x class ActionConditionedPredictor(nn.Module): def __init__(self, num_patches=4096, embed_dim=384, action_embed_dim=64, predictor_depth=12, num_heads=6): super().__init__() self.num_patches = num_patches self.embed_dim = embed_dim self.mask_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim)) nn.init.trunc_normal_(self.mask_token, std=0.02) nn.init.trunc_normal_(self.pos_embed, std=0.02) self.action_mixer = nn.Linear(embed_dim + action_embed_dim, embed_dim) self.blocks = nn.ModuleList([PredictorBlock(embed_dim, num_heads) for _ in range(predictor_depth)]) self.norm = nn.LayerNorm(embed_dim, eps=1e-6) def forward(self, context_repr, action_emb, target_positions): B = context_repr.shape[0] if target_positions.dim() == 1: N_tgt = target_positions.shape[0] tgt_pos = target_positions.unsqueeze(0).expand(B, -1) else: N_tgt = target_positions.shape[1] tgt_pos = target_positions tgt_pos_emb = torch.gather( self.pos_embed.expand(B, -1, -1), 1, tgt_pos.unsqueeze(-1).expand(-1, -1, self.embed_dim) ) mask_tokens = self.mask_token.expand(B, N_tgt, -1) + tgt_pos_emb act_tokens = action_emb.unsqueeze(1).expand(-1, N_tgt, -1) tokens = self.action_mixer(torch.cat([mask_tokens, act_tokens], dim=-1)) x = torch.cat([context_repr, tokens], dim=1) for block in self.blocks: x = block(x) x = self.norm(x) return x[:, -N_tgt:, :] class GridWorldPredictor(nn.Module): def __init__(self, num_colors=16, embed_dim=384, predictor_depth=12, num_heads=6, num_key_actions=6, grid_size=64, action_embed_dim=64): super().__init__() self.num_patches = grid_size * grid_size self.action_embed = DiscreteActionEmbed(num_key_actions, grid_size, action_embed_dim) self.predictor = ActionConditionedPredictor(self.num_patches, embed_dim, action_embed_dim, predictor_depth, num_heads) self.decoder = nn.Sequential(nn.Linear(embed_dim, embed_dim), nn.GELU(), nn.Linear(embed_dim, num_colors)) def predict_next_state(self, current_repr, action_key, action_pos=None): B = current_repr.shape[0] action_emb = self.action_embed(action_key, action_pos) all_pos = torch.arange(self.num_patches, device=current_repr.device).unsqueeze(0).expand(B, -1) return self.predictor(current_repr, action_emb, all_pos) def decode_to_grid(self, repr): return self.decoder(repr) if __name__ == "__main__": B, N_ctx, D, N_tgt = 2, 60, 128, 40 context = torch.randn(B, N_ctx, D) action_key = torch.randint(0, 6, (B,)) action_pos = torch.randint(0, 100, (B,)) target_pos = torch.randint(0, 100, (N_tgt,)) embed = DiscreteActionEmbed(6, 10, 64) print(f"Action emb: {embed(action_key, action_pos).shape}") pred = ActionConditionedPredictor(100, D, 64, 4, 4) print(f"Pred: {pred(context, embed(action_key, action_pos), target_pos).shape}") wp = GridWorldPredictor(10, D, 4, 4, 6, 10, 64) print(f"Next state: {wp.predict_next_state(context, action_key, action_pos).shape}") print(f"Decode: {wp.decode_to_grid(wp.predict_next_state(context, action_key, action_pos)).shape}")