File size: 16,309 Bytes
b7b35a5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
import numpy as np
import torch
from torch import nn
from torch.nn import functional as F


def get_patches_fast(images, cfg):
    from einops import rearrange
    batch_size, height, width, channels = images.shape
    patch_size = cfg.patch_size ## n_patches = 8

    patches = rearrange(images[:,:,:,:3], 'b (h p1) (w p2) c -> b (h w) (p1 p2 c)', p1 = patch_size, p2 = patch_size)
    if channels > 3:
        ## History stacking in the channel dimension for observations only, not goal images.
        patches = rearrange(images, 'b (h p1) (w p2) (c hs) -> b (h w hs) (p1 p2 c)', p1 = patch_size, p2 = patch_size, hs=cfg.policy.obs_stacking) ## Stack the history in the channel dimension
    return patches


def calc_positional_embeddings(sequence_length, d):
    result = torch.ones(sequence_length, d)
    for i in range(sequence_length):
        for j in range(d):
            result[i][j] = np.sin(i / (10000 ** (j / d))) if j % 2 == 0 else np.cos(i / (10000 ** ((j - 1) / d)))
    return result


class Head(nn.Module):
    """ one head of self-attention """

    def __init__(self, head_size, n_embd, dropout):
        super().__init__()
        self.key = nn.Linear(n_embd, head_size, bias=False)
        self.query = nn.Linear(n_embd, head_size, bias=False)
        self.value = nn.Linear(n_embd, head_size, bias=False)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, mask=None):
        B,T,C = x.shape
        # TODO: 
        ## Provide the block masking logic for the attention head
        k = self.key(x)
        q = self.query(x)
        wei = q @ k.transpose(-2,-1) * C**-0.5
        wei = wei.masked_fill(mask == 0, float('-inf'))
        wei = F.softmax(wei, dim=-1)
        wei = self.dropout(wei)
        v = self.value(x)
        out = wei @ v
        return out


class MultiHeadAttention(nn.Module):
    def __init__(self, num_heads, head_size, n_embd, dropout):
        super().__init__()
        self.heads = nn.ModuleList([Head(head_size, n_embd=n_embd, dropout=dropout) for _ in range(num_heads)])
        self.proj = nn.Linear(n_embd, n_embd)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, mask=None):
        with torch.profiler.record_function("Self-Attention"):
            out = torch.cat([h(x, mask) for h in self.heads], dim=-1)
            out = self.dropout(self.proj(out))
        return out


class FeedFoward(nn.Module):
    def __init__(self, n_embd, dropout):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(n_embd, 4 * n_embd),
            nn.ReLU(),
            nn.Linear(4 * n_embd, n_embd),
            nn.Dropout(dropout),
        )

    def forward(self, x):
        return self.net(x)


class Block(nn.Module):
    def __init__(self, n_embd, n_head, dropout):
        super().__init__()
        head_size = n_embd // n_head
        self.sa = MultiHeadAttention(n_head, head_size, n_embd=n_embd, dropout=dropout)
        self.ffwd = FeedFoward(n_embd, dropout)
        self.ln1 = nn.LayerNorm(n_embd)
        self.ln2 = nn.LayerNorm(n_embd)

    def forward(self, x, mask=None):
        x = x + self.sa(self.ln1(x), mask)
        x = x + self.ffwd(self.ln2(x))
        return x


class GRP(nn.Module):
    def __init__(self, cfg, mlp_ratio=4):
        super(GRP, self).__init__()
        self._cfg = cfg
        chars = cfg.dataset.chars_list
        cfg.vocab_size = len(chars)
        # TODO: 
        ## Provide the logic for the GRP network
        # 1) Patch embedding layer
        self.patch_embedding = nn.Linear(cfg.patch_size * cfg.patch_size * 3, cfg.n_embd)
        # 2) Learnable token embeddings for classification and goal image tokens
        self.class_token = nn.Parameter(torch.zeros(1, 1, cfg.n_embd))
        self.goal_token = nn.Parameter(torch.zeros(1, 1, cfg.n_embd))
        # 3) Token embedding table for text goals (if not using T5)
        self.token_embedding_table = nn.Embedding(cfg.vocab_size, cfg.n_embd)   
        self.dropout = nn.Dropout(cfg.dropout)

        # 4) Transformer encoder blocks
        self.blocks = nn.ModuleList([Block(cfg.n_embd, cfg.n_head, cfg.dropout) for _ in range(cfg.n_blocks)])
        self.ln_f = nn.LayerNorm(cfg.n_embd)

        # 5) Classification MLP head
        if cfg.action_space == "continuous":
            self.action_head = nn.Sequential(
                nn.Linear(cfg.n_embd, cfg.n_embd * mlp_ratio),
                nn.ReLU(),
                nn.Linear(cfg.n_embd * mlp_ratio, cfg.action_dim * cfg.policy.action_stacking)
            )
        elif cfg.action_space == "discrete":
            self.action_head = nn.Sequential(
                nn.Linear(cfg.n_embd, cfg.n_embd * mlp_ratio),
                nn.ReLU(),
                nn.Linear(cfg.n_embd * mlp_ratio, cfg.action_dim * cfg.policy.action_stacking * 14)
            )
        # Weight initialization
        self.apply(self._init_weights)
        # Initialize learnable tokens with small random values instead of zeros

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)

    def forward(self, images, goals_txt, goal_imgs, targets=None, pose=None, mask_=True):
        n, c, h, w = images.shape
        obs_patches = get_patches_fast(images, self._cfg)
        patches_g = get_patches_fast(goal_imgs, self._cfg)
        if self._cfg.dataset.encode_with_t5:
            goals_e = goals_txt
            B, T, E = goals_txt.shape
        else:
            goals_e = self.token_embedding_table(goals_txt)
            B, E = goals_txt.shape
            T = self._cfg.max_block_size

        # TODO: 
        ## Provide the logic to produce the output and loss for the GRP
        
        # Map the vector corresponding to each patch to the hidden size dimension
        obs_tokens = self.patch_embedding(obs_patches)  # (n, n_patches, n_embd)
        goal_img_tokens = self.patch_embedding(patches_g)  # (n, n_patches, n_embd)

        # Adding classification and goal_img tokens to the tokens
        cls_token = self.class_token.expand(B, -1, -1)  # (batch, 1, n_embd)
        goal_token = self.goal_token.expand(B, -1, -1)  # (batch, 1, n_embd)
        x = torch.cat([cls_token, obs_tokens, goal_token, goal_img_tokens, goals_e], dim=1)  # (batch, total_tokens, n_embd)
        # Adding positional embedding
        pos_emb = calc_positional_embeddings(x.shape[1], self._cfg.n_embd).to(x.device)
        x = x + pos_emb.unsqueeze(0)[:, :x.shape[1], :]
        x = self.dropout(x)

        # Compute blocked masks
        att_mask = torch.ones((B, x.shape[1]), device=x.device)
        if mask_:
            total_patches = (self._cfg.image_shape[0] // self._cfg.patch_size) * (self._cfg.image_shape[1] // self._cfg.patch_size)
            obs_start = 1
            obs_end = obs_start + total_patches * self._cfg.policy.obs_stacking
            goal_img_start = obs_end + 1
            goal_img_end = goal_img_start + total_patches
            goal_text_start = goal_img_end
            goal_text_end = goal_text_start + T

            assert x.shape[1] == goal_text_end

            # Create attention mask
            # Randomly mask Text or Image goal
            rand_val = torch.rand(n, device=x.device)
            mask_text = (rand_val < 0.33).unsqueeze(1)  # (B, 1)
            mask_image = (rand_val > 0.66).unsqueeze(1)  # (B, 1)

            # Apply masking (0 = ignore, 1 = attend)
            att_mask[:, goal_text_start:goal_text_end].masked_fill_(mask_text, 0)
            att_mask[:, goal_img_start:goal_img_end].masked_fill_(mask_image, 0)

        block_mask = att_mask.unsqueeze(1)  # (B, 1, T)

        # New masking logic
        # att_mask = torch.ones((B, x.shape[1], self._cfg.n_embd), device=x.device)
        # if mask_:
        #     total_patches = (self._cfg.image_shape[0] // self._cfg.patch_size) * (self._cfg.image_shape[1] // self._cfg.patch_size)
        #     obs_start = 1
        #     obs_end = obs_start + total_patches * self._cfg.policy.obs_stacking
        #     goal_img_start = obs_end + 1
        #     goal_img_end = goal_img_start + total_patches
        #     goal_text_start = goal_img_end
        #     goal_text_end = goal_text_start + T

        #     assert x.shape[1] == goal_text_end

        #     # Create attention mask
        #     # Randomly mask Text or Image goal
        #     rand_val = torch.rand(n, device=x.device)
        #     mask_text = (rand_val < 0.33).unsqueeze(1).unsqueeze(2)  # (B, 1, 1)
        #     mask_image = (rand_val > 0.66).unsqueeze(1).unsqueeze(2)  # (B, 1, 1)

        #     # Apply masking (0 = ignore, 1 = attend)
        #     att_mask[:, goal_text_start:goal_text_end, :].masked_fill_(mask_text, 0)
        #     att_mask[:, goal_img_start:goal_img_end, :].masked_fill_(mask_image, 0)
            
        # # Apply mask to the embedding dimension as well
        # x = x * att_mask
        # Pass the mask to transformer blocks
        for block in self.blocks:
            x = block(x, mask=block_mask) 

        x = self.ln_f(x)
        
        if targets is not None:
            if self._cfg.action_space == "continuous":
                out = self.action_head(x[:, 0, :])  # (batch, action_dim * action_stacking)
                loss = F.mse_loss(out, targets)
            elif self._cfg.action_space == "discrete":
                logits = self.action_head(x[:, 0, :]).view(B, -1, 14)  # (B, action_dim * action_stacking, 14)
                targets_clamped = torch.clamp(targets, -1, 1)  # Ensure targets are within valid range
                targets_bins = ((targets_clamped + 1) / 2 * 13).long()  # Map targets from [-1, 1] to [0, 13]
                loss = F.cross_entropy(logits.permute(0, 2, 1), targets_bins)
                out = logits.argmax(dim=-1)

        else:
            if self._cfg.action_space == "continuous":
                out = self.action_head(x[:, 0, :])  # (batch, action_dim * action_stacking)
                loss = torch.tensor(0.0, device=out.device)
            elif self._cfg.action_space == "discrete":
                logits = self.action_head(x[:, 0, :]).view(B, -1, 14)  # (B, action_dim * action_stacking, 14)
                bin_idxs = logits.argmax(dim=-1).float()
                # Convert bin indices back to continuous values in [-1, 1]
                out = (bin_idxs / 13.0) * 2.0 - 1.0
                loss = torch.tensor(0.0, device=out.device)
            # print("No targets provided, loss set to 0.0")

        # Compute output and loss
        return (out, loss)
    
    def resize_image(self, image):
        """
        Docstring for resize_image
        
        :param self: Description
        :param image: Description
        self._resize_state = lambda sf:   cv2.resize(np.array(sf, dtype=np.float32), (cfg.image_shape[0], cfg.image_shape[1]))  # resize state
        """
        import cv2
        import numpy as _np
        img = _np.array(image, dtype=_np.float32)
        img = cv2.resize(img, (self._cfg.image_shape[0], self._cfg.image_shape[1]))
        return img

    def normalize_state(self, image):
        """
        Docstring for preprocess_state
        
        :param self: Description
        :param image: Description
        self._encode_state = lambda af:   ((af/(255.0)*2.0)-1.0) # encoder: take a float, output an integer
        self._resize_state = lambda sf:   cv2.resize(np.array(sf, dtype=np.float32), (cfg.image_shape[0], cfg.image_shape[1]))  # resize state
        """
        # img = _np.array(image, dtype=_np.float32)
        # img = cv2.resize(img, (self._cfg.image_shape[0], self._cfg.image_shape[1]))
        enc = ((image / 255.0) * 2.0) - 1.0
        # t = _torch.tensor(enc, dtype=_torch.float32, device=self._cfg.device)
        return enc
    
    def preprocess_state(self, image):
        img = self.resize_image(image)
        img = self.normalize_state(img)
        return img

    def preprocess_goal_image(self, image):
        return self.preprocess_state(image)

    def encode_text_goal(self, goal, tokenizer=None, text_model=None):
        import numpy as _np
        import torch as _torch
        if self._cfg.dataset.encode_with_t5:
            if tokenizer is None or text_model is None:
                raise ValueError("tokenizer and text_model must be provided when using T5 encoding")
            # TODO:    
            ## Provide the logic converting text goal to T5 embedding tensor
            device = text_model.device
            tokens = tokenizer(goal, return_tensors="pt").input_ids.to(device)

            with _torch.no_grad():
                embedding = text_model.encoder(tokens).last_hidden_state
            
            return embedding.to(self._cfg.device)
        else:
            pad = " " * self._cfg.max_block_size
            goal_ = goal[:self._cfg.max_block_size] + pad[len(goal):self._cfg.max_block_size]
            try:
                stoi = {c: i for i, c in enumerate(self._cfg.dataset.chars_list)}
                ids = [stoi.get(c, 0) for c in goal_]
            except Exception:
                ids = [0] * self._cfg.max_block_size
            return _torch.tensor(_np.expand_dims(_np.array(ids, dtype=_np.int64), axis=0), dtype=_torch.long, device=self._cfg.device)

    def process_text_embedding_for_buffer(self, goal, tokenizer=None, text_model=None):
        """
        Process text goal embedding for storing in the circular buffer.
        Returns a numpy array of shape (max_block_size, n_embd) without batch dimension.
        """
        import numpy as _np
        if tokenizer is None or text_model is None:
            raise ValueError("tokenizer and text_model must be provided when using T5 encoding")
        
        goal_ = _np.zeros((self._cfg.max_block_size, self._cfg.n_embd), dtype=_np.float32)
        input_ids = tokenizer(goal, return_tensors="pt").input_ids
        goal_t = text_model.encoder(input_ids).last_hidden_state.detach().cpu().numpy()
        goal_[:len(goal_t[0]), :] = goal_t[0][:self._cfg.max_block_size]
        return goal_

    def decode_action(self, action_tensor):
        
        """
        Docstring for decode_action
        
        :param self: Description
        :param action_tensor: Description
        self._decode_action = lambda binN: (binN * action_std) + action_mean  # Undo mapping to [-1, 1]
        """
        import torch as _torch
        ## The action tensor is of shape (batch_size, action_dim * action_stacking) so we need to repeat the mean and std per action stacking
        action_mean = _torch.tensor(np.repeat(self._cfg.dataset.action_mean, self._cfg.policy.action_stacking), dtype=action_tensor.dtype, device=action_tensor.device)
        action_std = _torch.tensor(np.repeat(self._cfg.dataset.action_std, self._cfg.policy.action_stacking), dtype=action_tensor.dtype, device=action_tensor.device)
        return (action_tensor * action_std) + action_mean
    
    def encode_action(self, action_float):
        """
        Docstring for encode_action
        
        :param self: Description
        :param action_float: Description
        self._encode_action = lambda af:   (af - action_mean)/(action_std) # encoder: take a float, output an integer
        """
        import torch as _torch
        action_mean = _torch.tensor(self._cfg.dataset.action_mean, dtype=action_float.dtype, device=action_float.device)
        action_std = _torch.tensor(self._cfg.dataset.action_std, dtype=action_float.dtype, device=action_float.device)
        return (action_float - action_mean) / action_std


@torch.no_grad()
def estimate_loss(model, dataset):
    out = {}
    model.eval()
    for split in ['train', 'val']:
        losses = torch.zeros(model._cfg.eval_iters)
        for k in range(model._cfg.eval_iters):
            X, x_pose, x_goal, x_goal_img, Y = dataset.get_batch_grp(split, model._cfg, model._cfg.batch_size)
            logits, loss = model(X, x_goal, x_goal_img, Y, pose=x_pose)
            losses[k] = loss.item()
        out[split] = losses.mean()
    model.train()
    return out