Spaces:
Runtime error
Runtime error
| from torch import nn, Tensor | |
| from einops import rearrange | |
| import torch | |
| from genie.attention import SelfAttention | |
| import numpy as np | |
| from typing import Optional | |
| class Mlp(nn.Module): | |
| def __init__( | |
| self, | |
| d_model: int, | |
| mlp_ratio: float = 4.0, | |
| mlp_bias: bool = True, | |
| mlp_drop: float = 0.0, | |
| ) -> None: | |
| super().__init__() | |
| hidden_dim = int(d_model * mlp_ratio) | |
| self.fc1 = nn.Linear(d_model, hidden_dim, bias=mlp_bias) | |
| self.act = nn.GELU() | |
| self.fc2 = nn.Linear(hidden_dim, d_model, bias=mlp_bias) | |
| self.drop = nn.Dropout(mlp_drop) | |
| def forward(self, x: Tensor) -> Tensor: | |
| x = self.drop(self.act(self.fc1(x))) | |
| x = self.drop(self.fc2(x)) | |
| return x | |
| class STBlock(nn.Module): | |
| # See Figure 4 of https://arxiv.org/pdf/2402.15391.pdf | |
| def __init__( | |
| self, | |
| num_heads: int, | |
| d_model: int, | |
| qkv_bias: bool = False, | |
| proj_bias: bool = True, | |
| qk_norm: bool = True, | |
| use_mup: bool = True, | |
| attn_drop: float = 0.05, # add dropout | |
| mlp_ratio: float = 4.0, | |
| mlp_bias: bool = True, | |
| mlp_drop: float = 0.05, | |
| # action relevant | |
| action_processing: str = "mlp", | |
| jointly_predict_actions: bool = False, | |
| mask_token_id: int = 0 | |
| ) -> None: | |
| super().__init__() | |
| self.norm1 = nn.Identity() if qk_norm else nn.LayerNorm(d_model, eps=1e-05) | |
| # sequence dim is over each frame's 16x16 patch tokens | |
| self.spatial_attn = SelfAttention( | |
| num_heads=num_heads, | |
| d_model=d_model, | |
| qkv_bias=qkv_bias, | |
| proj_bias=proj_bias, | |
| qk_norm=qk_norm, | |
| use_mup=use_mup, | |
| attn_drop=attn_drop, | |
| ) | |
| # sequence dim is over time sequence (16) | |
| self.temporal_attn = SelfAttention( | |
| num_heads=num_heads, | |
| d_model=d_model, | |
| qkv_bias=qkv_bias, | |
| proj_bias=proj_bias, | |
| qk_norm=qk_norm, | |
| use_mup=use_mup, | |
| attn_drop=attn_drop, | |
| ) | |
| self.action_prediction = jointly_predict_actions | |
| self.action_processing = action_processing | |
| self.norm2 = nn.Identity() if qk_norm else nn.LayerNorm(d_model, eps=1e-05) | |
| self.mlp = Mlp(d_model=d_model, mlp_ratio=mlp_ratio, mlp_bias=mlp_bias, mlp_drop=mlp_drop) | |
| self.action_projectors = None # set at run-time | |
| def forward(self, x_TSC: Tensor, action_ids: Tensor = None, domain = None) -> Tensor: | |
| """ | |
| The main forward pass of the STBlock. It does action conditioning (with options), | |
| (bidrectional) spatial attention, (causal) temporal attention, and action masking. | |
| """ | |
| T, S = x_TSC.size(1), x_TSC.size(2) | |
| x_SC = rearrange(x_TSC, 'B T S C -> (B T) S C') | |
| x_SC = x_SC + self.spatial_attn(self.norm1(x_SC)) | |
| # Process attention temporally | |
| x_TC = rearrange(x_SC, '(B T) S C -> (B S) T C', T=T) | |
| if action_ids is not None and domain is not None and self.action_projectors is not None: | |
| # action_ids: [B, T, D]. Only apply to video parts | |
| if "mlp" in self.action_processing: | |
| action_ids = self.action_projectors[domain](action_ids) # does not depend on x_TC | |
| x_TC = rearrange(x_TC, '(B S) T C -> B S T C', S=S) | |
| x_TC = x_TC + action_ids[:, None, :x_TC.shape[2]] # expand across spatial | |
| x_TC = rearrange(x_TC, 'B S T C -> (B S) T C', S=S) | |
| elif "cross_attention" in self.action_processing: | |
| x_TC = x_TC + self.action_projectors[domain](x_TC, action_ids, action_ids) | |
| elif "modulate" in self.action_processing: | |
| try: | |
| x_TC = x_TC + self.action_projectors[domain](x_TC, action_ids) | |
| except: | |
| import IPython; IPython.embed() | |
| # Apply the Causal Transformer | |
| x_TC = x_TC + self.temporal_attn(x_TC, causal=True) # [256, 16, 256] | |
| x_TC = x_TC + self.mlp(self.norm2(x_TC)) | |
| x_TSC = rearrange(x_TC, '(B S) T C -> B T S C', S=S) | |
| return x_TSC | |
| class STTransformerDecoder(nn.Module): | |
| def __init__( | |
| self, | |
| num_layers: int, | |
| num_heads: int, | |
| d_model: int, | |
| qkv_bias: bool = False, | |
| proj_bias: bool = True, | |
| qk_norm: bool = True, | |
| use_mup: bool = True, | |
| attn_drop: float = 0.0, | |
| mlp_ratio: float = 4.0, | |
| mlp_bias: bool = True, | |
| mlp_drop: float = 0.0, | |
| # action relevant | |
| action_processing: str = "mlp", | |
| jointly_predict_actions: bool = False, | |
| random_dummy_action: bool = True, | |
| mask_token_id: int = 0, | |
| ): | |
| super().__init__() | |
| self.layers = nn.ModuleList([STBlock( | |
| num_heads=num_heads, | |
| d_model=d_model, | |
| qkv_bias=qkv_bias, | |
| proj_bias=proj_bias, | |
| qk_norm=qk_norm, | |
| use_mup=use_mup, | |
| attn_drop=attn_drop, | |
| mlp_ratio=mlp_ratio, | |
| mlp_bias=mlp_bias, | |
| mlp_drop=mlp_drop, | |
| action_processing=action_processing, | |
| jointly_predict_actions=jointly_predict_actions, | |
| mask_token_id=mask_token_id | |
| ) for _ in range(num_layers)]) | |
| self.apply(self._init_weights) | |
| def _init_weights(self, m): | |
| """ | |
| Weight initialization for transformer | |
| """ | |
| if isinstance(m, nn.Linear): | |
| torch.nn.init.xavier_uniform_(m.weight, gain=0.1) | |
| if isinstance(m, nn.Linear) and m.bias is not None: | |
| nn.init.constant_(m.bias, 0) | |
| elif isinstance(m, nn.LayerNorm): | |
| nn.init.constant_(m.bias, 0) | |
| nn.init.constant_(m.weight, 1.0) | |
| def forward(self, tgt, action_ids=None, domain=""): | |
| x = tgt | |
| for layer in self.layers: | |
| x = layer(x, action_ids=action_ids, domain=domain) | |
| return x | |