| """Action-Conditioned Predictor for Grid-JEPA.""" |
| from typing import Optional |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
| class DiscreteActionEmbed(nn.Module): |
| def __init__(self, num_key_actions=6, grid_size=64, action_embed_dim=64): |
| super().__init__() |
| self.num_key_actions = num_key_actions |
| self.grid_size = grid_size |
| self.num_positions = grid_size * grid_size |
| self.key_embed = nn.Embedding(num_key_actions, action_embed_dim) |
| self.pos_embed = nn.Embedding(self.num_positions, action_embed_dim) |
| def forward(self, key_action, cell_position=None): |
| emb = self.key_embed(key_action) |
| if cell_position is not None: |
| emb = emb + self.pos_embed(cell_position) |
| return emb |
|
|
| class PredictorBlock(nn.Module): |
| def __init__(self, dim, num_heads, mlp_ratio=4.0): |
| super().__init__() |
| self.norm1 = nn.LayerNorm(dim, eps=1e-6) |
| self.num_heads = num_heads |
| self.head_dim = dim // num_heads |
| self.scale = self.head_dim ** -0.5 |
| self.qkv = nn.Linear(dim, dim * 3) |
| self.proj = nn.Linear(dim, dim) |
| self.norm2 = nn.LayerNorm(dim, eps=1e-6) |
| hidden = int(dim * mlp_ratio) |
| self.mlp = nn.Sequential(nn.Linear(dim, hidden), nn.GELU(), nn.Linear(hidden, dim)) |
| def forward(self, x): |
| B, N, D = x.shape |
| xn = self.norm1(x) |
| qkv = self.qkv(xn).reshape(B, N, 3, self.num_heads, self.head_dim) |
| qkv = qkv.permute(2, 0, 3, 1, 4) |
| q, k, v = qkv[0], qkv[1], qkv[2] |
| attn = F.softmax((q @ k.transpose(-2, -1)) * self.scale, dim=-1) |
| out = (attn @ v).transpose(1, 2).reshape(B, N, D) |
| x = x + self.proj(out) |
| x = x + self.mlp(self.norm2(x)) |
| return x |
|
|
| class ActionConditionedPredictor(nn.Module): |
| def __init__(self, num_patches=4096, embed_dim=384, action_embed_dim=64, |
| predictor_depth=12, num_heads=6): |
| super().__init__() |
| self.num_patches = num_patches |
| self.embed_dim = embed_dim |
| self.mask_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) |
| self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim)) |
| nn.init.trunc_normal_(self.mask_token, std=0.02) |
| nn.init.trunc_normal_(self.pos_embed, std=0.02) |
| self.action_mixer = nn.Linear(embed_dim + action_embed_dim, embed_dim) |
| self.blocks = nn.ModuleList([PredictorBlock(embed_dim, num_heads) for _ in range(predictor_depth)]) |
| self.norm = nn.LayerNorm(embed_dim, eps=1e-6) |
| def forward(self, context_repr, action_emb, target_positions): |
| B = context_repr.shape[0] |
| if target_positions.dim() == 1: |
| N_tgt = target_positions.shape[0] |
| tgt_pos = target_positions.unsqueeze(0).expand(B, -1) |
| else: |
| N_tgt = target_positions.shape[1] |
| tgt_pos = target_positions |
| tgt_pos_emb = torch.gather( |
| self.pos_embed.expand(B, -1, -1), 1, |
| tgt_pos.unsqueeze(-1).expand(-1, -1, self.embed_dim) |
| ) |
| mask_tokens = self.mask_token.expand(B, N_tgt, -1) + tgt_pos_emb |
| act_tokens = action_emb.unsqueeze(1).expand(-1, N_tgt, -1) |
| tokens = self.action_mixer(torch.cat([mask_tokens, act_tokens], dim=-1)) |
| x = torch.cat([context_repr, tokens], dim=1) |
| for block in self.blocks: |
| x = block(x) |
| x = self.norm(x) |
| return x[:, -N_tgt:, :] |
|
|
| class GridWorldPredictor(nn.Module): |
| def __init__(self, num_colors=16, embed_dim=384, predictor_depth=12, num_heads=6, |
| num_key_actions=6, grid_size=64, action_embed_dim=64): |
| super().__init__() |
| self.num_patches = grid_size * grid_size |
| self.action_embed = DiscreteActionEmbed(num_key_actions, grid_size, action_embed_dim) |
| self.predictor = ActionConditionedPredictor(self.num_patches, embed_dim, action_embed_dim, predictor_depth, num_heads) |
| self.decoder = nn.Sequential(nn.Linear(embed_dim, embed_dim), nn.GELU(), nn.Linear(embed_dim, num_colors)) |
| def predict_next_state(self, current_repr, action_key, action_pos=None): |
| B = current_repr.shape[0] |
| action_emb = self.action_embed(action_key, action_pos) |
| all_pos = torch.arange(self.num_patches, device=current_repr.device).unsqueeze(0).expand(B, -1) |
| return self.predictor(current_repr, action_emb, all_pos) |
| def decode_to_grid(self, repr): |
| return self.decoder(repr) |
|
|
| if __name__ == "__main__": |
| B, N_ctx, D, N_tgt = 2, 60, 128, 40 |
| context = torch.randn(B, N_ctx, D) |
| action_key = torch.randint(0, 6, (B,)) |
| action_pos = torch.randint(0, 100, (B,)) |
| target_pos = torch.randint(0, 100, (N_tgt,)) |
| embed = DiscreteActionEmbed(6, 10, 64) |
| print(f"Action emb: {embed(action_key, action_pos).shape}") |
| pred = ActionConditionedPredictor(100, D, 64, 4, 4) |
| print(f"Pred: {pred(context, embed(action_key, action_pos), target_pos).shape}") |
| wp = GridWorldPredictor(10, D, 4, 4, 6, 10, 64) |
| print(f"Next state: {wp.predict_next_state(context, action_key, action_pos).shape}") |
| print(f"Decode: {wp.decode_to_grid(wp.predict_next_state(context, action_key, action_pos)).shape}") |
|
|