File size: 5,215 Bytes
e332759
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
"""Action-Conditioned Predictor for Grid-JEPA."""
from typing import Optional
import torch
import torch.nn as nn
import torch.nn.functional as F

class DiscreteActionEmbed(nn.Module):
    def __init__(self, num_key_actions=6, grid_size=64, action_embed_dim=64):
        super().__init__()
        self.num_key_actions = num_key_actions
        self.grid_size = grid_size
        self.num_positions = grid_size * grid_size
        self.key_embed = nn.Embedding(num_key_actions, action_embed_dim)
        self.pos_embed = nn.Embedding(self.num_positions, action_embed_dim)
    def forward(self, key_action, cell_position=None):
        emb = self.key_embed(key_action)
        if cell_position is not None:
            emb = emb + self.pos_embed(cell_position)
        return emb

class PredictorBlock(nn.Module):
    def __init__(self, dim, num_heads, mlp_ratio=4.0):
        super().__init__()
        self.norm1 = nn.LayerNorm(dim, eps=1e-6)
        self.num_heads = num_heads
        self.head_dim = dim // num_heads
        self.scale = self.head_dim ** -0.5
        self.qkv = nn.Linear(dim, dim * 3)
        self.proj = nn.Linear(dim, dim)
        self.norm2 = nn.LayerNorm(dim, eps=1e-6)
        hidden = int(dim * mlp_ratio)
        self.mlp = nn.Sequential(nn.Linear(dim, hidden), nn.GELU(), nn.Linear(hidden, dim))
    def forward(self, x):
        B, N, D = x.shape
        xn = self.norm1(x)
        qkv = self.qkv(xn).reshape(B, N, 3, self.num_heads, self.head_dim)
        qkv = qkv.permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]
        attn = F.softmax((q @ k.transpose(-2, -1)) * self.scale, dim=-1)
        out = (attn @ v).transpose(1, 2).reshape(B, N, D)
        x = x + self.proj(out)
        x = x + self.mlp(self.norm2(x))
        return x

class ActionConditionedPredictor(nn.Module):
    def __init__(self, num_patches=4096, embed_dim=384, action_embed_dim=64,
                 predictor_depth=12, num_heads=6):
        super().__init__()
        self.num_patches = num_patches
        self.embed_dim = embed_dim
        self.mask_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
        nn.init.trunc_normal_(self.mask_token, std=0.02)
        nn.init.trunc_normal_(self.pos_embed, std=0.02)
        self.action_mixer = nn.Linear(embed_dim + action_embed_dim, embed_dim)
        self.blocks = nn.ModuleList([PredictorBlock(embed_dim, num_heads) for _ in range(predictor_depth)])
        self.norm = nn.LayerNorm(embed_dim, eps=1e-6)
    def forward(self, context_repr, action_emb, target_positions):
        B = context_repr.shape[0]
        if target_positions.dim() == 1:
            N_tgt = target_positions.shape[0]
            tgt_pos = target_positions.unsqueeze(0).expand(B, -1)
        else:
            N_tgt = target_positions.shape[1]
            tgt_pos = target_positions
        tgt_pos_emb = torch.gather(
            self.pos_embed.expand(B, -1, -1), 1,
            tgt_pos.unsqueeze(-1).expand(-1, -1, self.embed_dim)
        )
        mask_tokens = self.mask_token.expand(B, N_tgt, -1) + tgt_pos_emb
        act_tokens = action_emb.unsqueeze(1).expand(-1, N_tgt, -1)
        tokens = self.action_mixer(torch.cat([mask_tokens, act_tokens], dim=-1))
        x = torch.cat([context_repr, tokens], dim=1)
        for block in self.blocks:
            x = block(x)
        x = self.norm(x)
        return x[:, -N_tgt:, :]

class GridWorldPredictor(nn.Module):
    def __init__(self, num_colors=16, embed_dim=384, predictor_depth=12, num_heads=6,
                 num_key_actions=6, grid_size=64, action_embed_dim=64):
        super().__init__()
        self.num_patches = grid_size * grid_size
        self.action_embed = DiscreteActionEmbed(num_key_actions, grid_size, action_embed_dim)
        self.predictor = ActionConditionedPredictor(self.num_patches, embed_dim, action_embed_dim, predictor_depth, num_heads)
        self.decoder = nn.Sequential(nn.Linear(embed_dim, embed_dim), nn.GELU(), nn.Linear(embed_dim, num_colors))
    def predict_next_state(self, current_repr, action_key, action_pos=None):
        B = current_repr.shape[0]
        action_emb = self.action_embed(action_key, action_pos)
        all_pos = torch.arange(self.num_patches, device=current_repr.device).unsqueeze(0).expand(B, -1)
        return self.predictor(current_repr, action_emb, all_pos)
    def decode_to_grid(self, repr):
        return self.decoder(repr)

if __name__ == "__main__":
    B, N_ctx, D, N_tgt = 2, 60, 128, 40
    context = torch.randn(B, N_ctx, D)
    action_key = torch.randint(0, 6, (B,))
    action_pos = torch.randint(0, 100, (B,))
    target_pos = torch.randint(0, 100, (N_tgt,))
    embed = DiscreteActionEmbed(6, 10, 64)
    print(f"Action emb: {embed(action_key, action_pos).shape}")
    pred = ActionConditionedPredictor(100, D, 64, 4, 4)
    print(f"Pred: {pred(context, embed(action_key, action_pos), target_pos).shape}")
    wp = GridWorldPredictor(10, D, 4, 4, 6, 10, 64)
    print(f"Next state: {wp.predict_next_state(context, action_key, action_pos).shape}")
    print(f"Decode: {wp.decode_to_grid(wp.predict_next_state(context, action_key, action_pos)).shape}")