q1_continuous / grp_model.py
NMThuan032k's picture
Upload folder using huggingface_hub
b7b35a5 verified
import numpy as np
import torch
from torch import nn
from torch.nn import functional as F
def get_patches_fast(images, cfg):
from einops import rearrange
batch_size, height, width, channels = images.shape
patch_size = cfg.patch_size ## n_patches = 8
patches = rearrange(images[:,:,:,:3], 'b (h p1) (w p2) c -> b (h w) (p1 p2 c)', p1 = patch_size, p2 = patch_size)
if channels > 3:
## History stacking in the channel dimension for observations only, not goal images.
patches = rearrange(images, 'b (h p1) (w p2) (c hs) -> b (h w hs) (p1 p2 c)', p1 = patch_size, p2 = patch_size, hs=cfg.policy.obs_stacking) ## Stack the history in the channel dimension
return patches
def calc_positional_embeddings(sequence_length, d):
result = torch.ones(sequence_length, d)
for i in range(sequence_length):
for j in range(d):
result[i][j] = np.sin(i / (10000 ** (j / d))) if j % 2 == 0 else np.cos(i / (10000 ** ((j - 1) / d)))
return result
class Head(nn.Module):
""" one head of self-attention """
def __init__(self, head_size, n_embd, dropout):
super().__init__()
self.key = nn.Linear(n_embd, head_size, bias=False)
self.query = nn.Linear(n_embd, head_size, bias=False)
self.value = nn.Linear(n_embd, head_size, bias=False)
self.dropout = nn.Dropout(dropout)
def forward(self, x, mask=None):
B,T,C = x.shape
# TODO:
## Provide the block masking logic for the attention head
k = self.key(x)
q = self.query(x)
wei = q @ k.transpose(-2,-1) * C**-0.5
wei = wei.masked_fill(mask == 0, float('-inf'))
wei = F.softmax(wei, dim=-1)
wei = self.dropout(wei)
v = self.value(x)
out = wei @ v
return out
class MultiHeadAttention(nn.Module):
def __init__(self, num_heads, head_size, n_embd, dropout):
super().__init__()
self.heads = nn.ModuleList([Head(head_size, n_embd=n_embd, dropout=dropout) for _ in range(num_heads)])
self.proj = nn.Linear(n_embd, n_embd)
self.dropout = nn.Dropout(dropout)
def forward(self, x, mask=None):
with torch.profiler.record_function("Self-Attention"):
out = torch.cat([h(x, mask) for h in self.heads], dim=-1)
out = self.dropout(self.proj(out))
return out
class FeedFoward(nn.Module):
def __init__(self, n_embd, dropout):
super().__init__()
self.net = nn.Sequential(
nn.Linear(n_embd, 4 * n_embd),
nn.ReLU(),
nn.Linear(4 * n_embd, n_embd),
nn.Dropout(dropout),
)
def forward(self, x):
return self.net(x)
class Block(nn.Module):
def __init__(self, n_embd, n_head, dropout):
super().__init__()
head_size = n_embd // n_head
self.sa = MultiHeadAttention(n_head, head_size, n_embd=n_embd, dropout=dropout)
self.ffwd = FeedFoward(n_embd, dropout)
self.ln1 = nn.LayerNorm(n_embd)
self.ln2 = nn.LayerNorm(n_embd)
def forward(self, x, mask=None):
x = x + self.sa(self.ln1(x), mask)
x = x + self.ffwd(self.ln2(x))
return x
class GRP(nn.Module):
def __init__(self, cfg, mlp_ratio=4):
super(GRP, self).__init__()
self._cfg = cfg
chars = cfg.dataset.chars_list
cfg.vocab_size = len(chars)
# TODO:
## Provide the logic for the GRP network
# 1) Patch embedding layer
self.patch_embedding = nn.Linear(cfg.patch_size * cfg.patch_size * 3, cfg.n_embd)
# 2) Learnable token embeddings for classification and goal image tokens
self.class_token = nn.Parameter(torch.zeros(1, 1, cfg.n_embd))
self.goal_token = nn.Parameter(torch.zeros(1, 1, cfg.n_embd))
# 3) Token embedding table for text goals (if not using T5)
self.token_embedding_table = nn.Embedding(cfg.vocab_size, cfg.n_embd)
self.dropout = nn.Dropout(cfg.dropout)
# 4) Transformer encoder blocks
self.blocks = nn.ModuleList([Block(cfg.n_embd, cfg.n_head, cfg.dropout) for _ in range(cfg.n_blocks)])
self.ln_f = nn.LayerNorm(cfg.n_embd)
# 5) Classification MLP head
if cfg.action_space == "continuous":
self.action_head = nn.Sequential(
nn.Linear(cfg.n_embd, cfg.n_embd * mlp_ratio),
nn.ReLU(),
nn.Linear(cfg.n_embd * mlp_ratio, cfg.action_dim * cfg.policy.action_stacking)
)
elif cfg.action_space == "discrete":
self.action_head = nn.Sequential(
nn.Linear(cfg.n_embd, cfg.n_embd * mlp_ratio),
nn.ReLU(),
nn.Linear(cfg.n_embd * mlp_ratio, cfg.action_dim * cfg.policy.action_stacking * 14)
)
# Weight initialization
self.apply(self._init_weights)
# Initialize learnable tokens with small random values instead of zeros
def _init_weights(self, module):
if isinstance(module, nn.Linear):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
if module.bias is not None:
torch.nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
def forward(self, images, goals_txt, goal_imgs, targets=None, pose=None, mask_=True):
n, c, h, w = images.shape
obs_patches = get_patches_fast(images, self._cfg)
patches_g = get_patches_fast(goal_imgs, self._cfg)
if self._cfg.dataset.encode_with_t5:
goals_e = goals_txt
B, T, E = goals_txt.shape
else:
goals_e = self.token_embedding_table(goals_txt)
B, E = goals_txt.shape
T = self._cfg.max_block_size
# TODO:
## Provide the logic to produce the output and loss for the GRP
# Map the vector corresponding to each patch to the hidden size dimension
obs_tokens = self.patch_embedding(obs_patches) # (n, n_patches, n_embd)
goal_img_tokens = self.patch_embedding(patches_g) # (n, n_patches, n_embd)
# Adding classification and goal_img tokens to the tokens
cls_token = self.class_token.expand(B, -1, -1) # (batch, 1, n_embd)
goal_token = self.goal_token.expand(B, -1, -1) # (batch, 1, n_embd)
x = torch.cat([cls_token, obs_tokens, goal_token, goal_img_tokens, goals_e], dim=1) # (batch, total_tokens, n_embd)
# Adding positional embedding
pos_emb = calc_positional_embeddings(x.shape[1], self._cfg.n_embd).to(x.device)
x = x + pos_emb.unsqueeze(0)[:, :x.shape[1], :]
x = self.dropout(x)
# Compute blocked masks
att_mask = torch.ones((B, x.shape[1]), device=x.device)
if mask_:
total_patches = (self._cfg.image_shape[0] // self._cfg.patch_size) * (self._cfg.image_shape[1] // self._cfg.patch_size)
obs_start = 1
obs_end = obs_start + total_patches * self._cfg.policy.obs_stacking
goal_img_start = obs_end + 1
goal_img_end = goal_img_start + total_patches
goal_text_start = goal_img_end
goal_text_end = goal_text_start + T
assert x.shape[1] == goal_text_end
# Create attention mask
# Randomly mask Text or Image goal
rand_val = torch.rand(n, device=x.device)
mask_text = (rand_val < 0.33).unsqueeze(1) # (B, 1)
mask_image = (rand_val > 0.66).unsqueeze(1) # (B, 1)
# Apply masking (0 = ignore, 1 = attend)
att_mask[:, goal_text_start:goal_text_end].masked_fill_(mask_text, 0)
att_mask[:, goal_img_start:goal_img_end].masked_fill_(mask_image, 0)
block_mask = att_mask.unsqueeze(1) # (B, 1, T)
# New masking logic
# att_mask = torch.ones((B, x.shape[1], self._cfg.n_embd), device=x.device)
# if mask_:
# total_patches = (self._cfg.image_shape[0] // self._cfg.patch_size) * (self._cfg.image_shape[1] // self._cfg.patch_size)
# obs_start = 1
# obs_end = obs_start + total_patches * self._cfg.policy.obs_stacking
# goal_img_start = obs_end + 1
# goal_img_end = goal_img_start + total_patches
# goal_text_start = goal_img_end
# goal_text_end = goal_text_start + T
# assert x.shape[1] == goal_text_end
# # Create attention mask
# # Randomly mask Text or Image goal
# rand_val = torch.rand(n, device=x.device)
# mask_text = (rand_val < 0.33).unsqueeze(1).unsqueeze(2) # (B, 1, 1)
# mask_image = (rand_val > 0.66).unsqueeze(1).unsqueeze(2) # (B, 1, 1)
# # Apply masking (0 = ignore, 1 = attend)
# att_mask[:, goal_text_start:goal_text_end, :].masked_fill_(mask_text, 0)
# att_mask[:, goal_img_start:goal_img_end, :].masked_fill_(mask_image, 0)
# # Apply mask to the embedding dimension as well
# x = x * att_mask
# Pass the mask to transformer blocks
for block in self.blocks:
x = block(x, mask=block_mask)
x = self.ln_f(x)
if targets is not None:
if self._cfg.action_space == "continuous":
out = self.action_head(x[:, 0, :]) # (batch, action_dim * action_stacking)
loss = F.mse_loss(out, targets)
elif self._cfg.action_space == "discrete":
logits = self.action_head(x[:, 0, :]).view(B, -1, 14) # (B, action_dim * action_stacking, 14)
targets_clamped = torch.clamp(targets, -1, 1) # Ensure targets are within valid range
targets_bins = ((targets_clamped + 1) / 2 * 13).long() # Map targets from [-1, 1] to [0, 13]
loss = F.cross_entropy(logits.permute(0, 2, 1), targets_bins)
out = logits.argmax(dim=-1)
else:
if self._cfg.action_space == "continuous":
out = self.action_head(x[:, 0, :]) # (batch, action_dim * action_stacking)
loss = torch.tensor(0.0, device=out.device)
elif self._cfg.action_space == "discrete":
logits = self.action_head(x[:, 0, :]).view(B, -1, 14) # (B, action_dim * action_stacking, 14)
bin_idxs = logits.argmax(dim=-1).float()
# Convert bin indices back to continuous values in [-1, 1]
out = (bin_idxs / 13.0) * 2.0 - 1.0
loss = torch.tensor(0.0, device=out.device)
# print("No targets provided, loss set to 0.0")
# Compute output and loss
return (out, loss)
def resize_image(self, image):
"""
Docstring for resize_image
:param self: Description
:param image: Description
self._resize_state = lambda sf: cv2.resize(np.array(sf, dtype=np.float32), (cfg.image_shape[0], cfg.image_shape[1])) # resize state
"""
import cv2
import numpy as _np
img = _np.array(image, dtype=_np.float32)
img = cv2.resize(img, (self._cfg.image_shape[0], self._cfg.image_shape[1]))
return img
def normalize_state(self, image):
"""
Docstring for preprocess_state
:param self: Description
:param image: Description
self._encode_state = lambda af: ((af/(255.0)*2.0)-1.0) # encoder: take a float, output an integer
self._resize_state = lambda sf: cv2.resize(np.array(sf, dtype=np.float32), (cfg.image_shape[0], cfg.image_shape[1])) # resize state
"""
# img = _np.array(image, dtype=_np.float32)
# img = cv2.resize(img, (self._cfg.image_shape[0], self._cfg.image_shape[1]))
enc = ((image / 255.0) * 2.0) - 1.0
# t = _torch.tensor(enc, dtype=_torch.float32, device=self._cfg.device)
return enc
def preprocess_state(self, image):
img = self.resize_image(image)
img = self.normalize_state(img)
return img
def preprocess_goal_image(self, image):
return self.preprocess_state(image)
def encode_text_goal(self, goal, tokenizer=None, text_model=None):
import numpy as _np
import torch as _torch
if self._cfg.dataset.encode_with_t5:
if tokenizer is None or text_model is None:
raise ValueError("tokenizer and text_model must be provided when using T5 encoding")
# TODO:
## Provide the logic converting text goal to T5 embedding tensor
device = text_model.device
tokens = tokenizer(goal, return_tensors="pt").input_ids.to(device)
with _torch.no_grad():
embedding = text_model.encoder(tokens).last_hidden_state
return embedding.to(self._cfg.device)
else:
pad = " " * self._cfg.max_block_size
goal_ = goal[:self._cfg.max_block_size] + pad[len(goal):self._cfg.max_block_size]
try:
stoi = {c: i for i, c in enumerate(self._cfg.dataset.chars_list)}
ids = [stoi.get(c, 0) for c in goal_]
except Exception:
ids = [0] * self._cfg.max_block_size
return _torch.tensor(_np.expand_dims(_np.array(ids, dtype=_np.int64), axis=0), dtype=_torch.long, device=self._cfg.device)
def process_text_embedding_for_buffer(self, goal, tokenizer=None, text_model=None):
"""
Process text goal embedding for storing in the circular buffer.
Returns a numpy array of shape (max_block_size, n_embd) without batch dimension.
"""
import numpy as _np
if tokenizer is None or text_model is None:
raise ValueError("tokenizer and text_model must be provided when using T5 encoding")
goal_ = _np.zeros((self._cfg.max_block_size, self._cfg.n_embd), dtype=_np.float32)
input_ids = tokenizer(goal, return_tensors="pt").input_ids
goal_t = text_model.encoder(input_ids).last_hidden_state.detach().cpu().numpy()
goal_[:len(goal_t[0]), :] = goal_t[0][:self._cfg.max_block_size]
return goal_
def decode_action(self, action_tensor):
"""
Docstring for decode_action
:param self: Description
:param action_tensor: Description
self._decode_action = lambda binN: (binN * action_std) + action_mean # Undo mapping to [-1, 1]
"""
import torch as _torch
## The action tensor is of shape (batch_size, action_dim * action_stacking) so we need to repeat the mean and std per action stacking
action_mean = _torch.tensor(np.repeat(self._cfg.dataset.action_mean, self._cfg.policy.action_stacking), dtype=action_tensor.dtype, device=action_tensor.device)
action_std = _torch.tensor(np.repeat(self._cfg.dataset.action_std, self._cfg.policy.action_stacking), dtype=action_tensor.dtype, device=action_tensor.device)
return (action_tensor * action_std) + action_mean
def encode_action(self, action_float):
"""
Docstring for encode_action
:param self: Description
:param action_float: Description
self._encode_action = lambda af: (af - action_mean)/(action_std) # encoder: take a float, output an integer
"""
import torch as _torch
action_mean = _torch.tensor(self._cfg.dataset.action_mean, dtype=action_float.dtype, device=action_float.device)
action_std = _torch.tensor(self._cfg.dataset.action_std, dtype=action_float.dtype, device=action_float.device)
return (action_float - action_mean) / action_std
@torch.no_grad()
def estimate_loss(model, dataset):
out = {}
model.eval()
for split in ['train', 'val']:
losses = torch.zeros(model._cfg.eval_iters)
for k in range(model._cfg.eval_iters):
X, x_pose, x_goal, x_goal_img, Y = dataset.get_batch_grp(split, model._cfg, model._cfg.batch_size)
logits, loss = model(X, x_goal, x_goal_img, Y, pose=x_pose)
losses[k] = loss.item()
out[split] = losses.mean()
model.train()
return out