import pickle import numpy as np from einops import rearrange from inspect import isfunction from typing import Callable, Optional import torch import torch.nn as nn import torch.nn.functional as F import smplx from smplx.lbs import vertices2joints from smplx.utils import MANOOutput, to_tensor from smplx.vertex_ids import vertex_ids from lib.core.config import cfg from lib.utils.human_models import mano # This function is from HaMeR (https://github.com/geopavlakos/hamer). def exists(val): return val is not None # This function is from HaMeR (https://github.com/geopavlakos/hamer). def default(val, d): if exists(val): return val return d() if isfunction(d) else d # This class is from HaMeR (https://github.com/geopavlakos/hamer). class Attention(nn.Module): def __init__(self, dim, heads=8, dim_head=64, dropout=0.0): super().__init__() inner_dim = dim_head * heads project_out = not (heads == 1 and dim_head == dim) self.heads = heads self.scale = dim_head**-0.5 self.attend = nn.Softmax(dim=-1) self.dropout = nn.Dropout(dropout) self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False) self.to_out = ( nn.Sequential(nn.Linear(inner_dim, dim), nn.Dropout(dropout)) if project_out else nn.Identity() ) def forward(self, x): qkv = self.to_qkv(x).chunk(3, dim=-1) q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=self.heads), qkv) dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale attn = self.attend(dots) attn = self.dropout(attn) out = torch.matmul(attn, v) out = rearrange(out, "b h n d -> b n (h d)") return self.to_out(out) # This class is from HaMeR (https://github.com/geopavlakos/hamer). class CrossAttention(nn.Module): def __init__(self, dim, context_dim=None, heads=8, dim_head=64, dropout=0.0): super().__init__() inner_dim = dim_head * heads project_out = not (heads == 1 and dim_head == dim) self.heads = heads self.scale = dim_head**-0.5 self.attend = nn.Softmax(dim=-1) self.dropout = nn.Dropout(dropout) context_dim = default(context_dim, dim) self.to_kv = nn.Linear(context_dim, inner_dim * 2, bias=False) self.to_q = nn.Linear(dim, inner_dim, bias=False) self.to_out = ( nn.Sequential(nn.Linear(inner_dim, dim), nn.Dropout(dropout)) if project_out else nn.Identity() ) def forward(self, x, context=None): context = default(context, x) k, v = self.to_kv(context).chunk(2, dim=-1) q = self.to_q(x) q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=self.heads), [q, k, v]) dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale attn = self.attend(dots) attn = self.dropout(attn) out = torch.matmul(attn, v) out = rearrange(out, "b h n d -> b n (h d)") return self.to_out(out) # This class is from HaMeR (https://github.com/geopavlakos/hamer). class FeedForward(nn.Module): def __init__(self, dim, hidden_dim, dropout=0.0): super().__init__() self.net = nn.Sequential( nn.Linear(dim, hidden_dim), nn.GELU(), nn.Dropout(dropout), nn.Linear(hidden_dim, dim), nn.Dropout(dropout), ) def forward(self, x): return self.net(x) # This class is from HaMeR (https://github.com/geopavlakos/hamer). class Transformer(nn.Module): def __init__( self, dim: int, depth: int, heads: int, dim_head: int, mlp_dim: int, dropout: float = 0.0, norm: str = "layer", norm_cond_dim: int = -1, ): super().__init__() self.layers = nn.ModuleList([]) for _ in range(depth): sa = Attention(dim, heads=heads, dim_head=dim_head, dropout=dropout) ff = FeedForward(dim, mlp_dim, dropout=dropout) self.layers.append( nn.ModuleList( [ PreNorm(dim, sa, norm=norm, norm_cond_dim=norm_cond_dim), PreNorm(dim, ff, norm=norm, norm_cond_dim=norm_cond_dim), ] ) ) def forward(self, x: torch.Tensor, *args): for attn, ff in self.layers: x = attn(x, *args) + x x = ff(x, *args) + x return x class AdaptiveLayerNorm1D(torch.nn.Module): def __init__(self, data_dim: int, norm_cond_dim: int): super().__init__() if data_dim <= 0: raise ValueError(f"data_dim must be positive, but got {data_dim}") if norm_cond_dim <= 0: raise ValueError(f"norm_cond_dim must be positive, but got {norm_cond_dim}") self.norm = torch.nn.LayerNorm( data_dim ) # TODO: Check if elementwise_affine=True is correct self.linear = torch.nn.Linear(norm_cond_dim, 2 * data_dim) torch.nn.init.zeros_(self.linear.weight) torch.nn.init.zeros_(self.linear.bias) def forward(self, x: torch.Tensor, t: torch.Tensor) -> torch.Tensor: # x: (batch, ..., data_dim) # t: (batch, norm_cond_dim) # return: (batch, data_dim) x = self.norm(x) alpha, beta = self.linear(t).chunk(2, dim=-1) # Add singleton dimensions to alpha and beta if x.dim() > 2: alpha = alpha.view(alpha.shape[0], *([1] * (x.dim() - 2)), alpha.shape[1]) beta = beta.view(beta.shape[0], *([1] * (x.dim() - 2)), beta.shape[1]) return x * (1 + alpha) + beta def normalization_layer(norm: Optional[str], dim: int, norm_cond_dim: int = -1): if norm == "batch": return torch.nn.BatchNorm1d(dim) elif norm == "layer": return torch.nn.LayerNorm(dim) elif norm == "ada": assert norm_cond_dim > 0, f"norm_cond_dim must be positive, got {norm_cond_dim}" return AdaptiveLayerNorm1D(dim, norm_cond_dim) elif norm is None: return torch.nn.Identity() else: raise ValueError(f"Unknown norm: {norm}") class PreNorm(nn.Module): def __init__(self, dim: int, fn: Callable, norm: str = "layer", norm_cond_dim: int = -1): super().__init__() self.norm = normalization_layer(norm, dim, norm_cond_dim) self.fn = fn def forward(self, x: torch.Tensor, *args, **kwargs): if isinstance(self.norm, AdaptiveLayerNorm1D): return self.fn(self.norm(x, *args), **kwargs) else: return self.fn(self.norm(x), **kwargs) # This class is from HaMeR (https://github.com/geopavlakos/hamer). class TransformerCrossAttn(nn.Module): def __init__( self, dim: int, depth: int, heads: int, dim_head: int, mlp_dim: int, dropout: float = 0.0, norm: str = "layer", norm_cond_dim: int = -1, context_dim: Optional[int] = None, ): super().__init__() self.layers = nn.ModuleList([]) for _ in range(depth): sa = Attention(dim, heads=heads, dim_head=dim_head, dropout=dropout) ca = CrossAttention( dim, context_dim=context_dim, heads=heads, dim_head=dim_head, dropout=dropout ) ff = FeedForward(dim, mlp_dim, dropout=dropout) self.layers.append( nn.ModuleList( [ PreNorm(dim, sa, norm=norm, norm_cond_dim=norm_cond_dim), PreNorm(dim, ca, norm=norm, norm_cond_dim=norm_cond_dim), PreNorm(dim, ff, norm=norm, norm_cond_dim=norm_cond_dim), ] ) ) def forward(self, x: torch.Tensor, *args, context=None, context_list=None): if context_list is None: context_list = [context] * len(self.layers) if len(context_list) != len(self.layers): raise ValueError(f"len(context_list) != len(self.layers) ({len(context_list)} != {len(self.layers)})") for i, (self_attn, cross_attn, ff) in enumerate(self.layers): x = self_attn(x, *args) + x x = cross_attn(x, *args, context=context_list[i]) + x x = ff(x, *args) + x return x # This class is from HaMeR (https://github.com/geopavlakos/hamer). class DropTokenDropout(nn.Module): def __init__(self, p: float = 0.1): super().__init__() if p < 0 or p > 1: raise ValueError( "dropout probability has to be between 0 and 1, " "but got {}".format(p) ) self.p = p def forward(self, x: torch.Tensor): # x: (batch_size, seq_len, dim) if self.training and self.p > 0: zero_mask = torch.full_like(x[0, :, 0], self.p).bernoulli().bool() # TODO: permutation idx for each batch using torch.argsort if zero_mask.any(): x = x[:, ~zero_mask, :] return x # This class is from HaMeR (https://github.com/geopavlakos/hamer). class ZeroTokenDropout(nn.Module): def __init__(self, p: float = 0.1): super().__init__() if p < 0 or p > 1: raise ValueError( "dropout probability has to be between 0 and 1, " "but got {}".format(p) ) self.p = p def forward(self, x: torch.Tensor): # x: (batch_size, seq_len, dim) if self.training and self.p > 0: zero_mask = torch.full_like(x[:, :, 0], self.p).bernoulli().bool() # Zero-out the masked tokens x[zero_mask, :] = 0 return x # This class is from HaMeR (https://github.com/geopavlakos/hamer). class TransformerDecoder(nn.Module): def __init__( self, num_tokens: int, token_dim: int, dim: int, depth: int, heads: int, mlp_dim: int, dim_head: int = 64, dropout: float = 0.0, emb_dropout: float = 0.0, emb_dropout_type: str = 'drop', norm: str = "layer", norm_cond_dim: int = -1, context_dim: Optional[int] = None, skip_token_embedding: bool = False, ): super().__init__() if not skip_token_embedding: self.to_token_embedding = nn.Linear(token_dim, dim) else: self.to_token_embedding = nn.Identity() if token_dim != dim: raise ValueError( f"token_dim ({token_dim}) != dim ({dim}) when skip_token_embedding is True" ) self.pos_embedding = nn.Parameter(torch.randn(1, num_tokens, dim)) if emb_dropout_type == "drop": self.dropout = DropTokenDropout(emb_dropout) elif emb_dropout_type == "zero": self.dropout = ZeroTokenDropout(emb_dropout) elif emb_dropout_type == "normal": self.dropout = nn.Dropout(emb_dropout) self.transformer = TransformerCrossAttn( dim, depth, heads, dim_head, mlp_dim, dropout, norm=norm, norm_cond_dim=norm_cond_dim, context_dim=context_dim, ) def forward(self, inp: torch.Tensor, *args, context=None, context_list=None): x = self.to_token_embedding(inp) b, n, _ = x.shape x = self.dropout(x) x += self.pos_embedding[:, :n] x = self.transformer(x, *args, context=context, context_list=context_list) return x def rot6d_to_rotmat(x: torch.Tensor) -> torch.Tensor: """ Convert 6D rotation representation to 3x3 rotation matrix. Based on Zhou et al., "On the Continuity of Rotation Representations in Neural Networks", CVPR 2019 Args: x (torch.Tensor): (B,6) Batch of 6-D rotation representations. Returns: torch.Tensor: Batch of corresponding rotation matrices with shape (B,3,3). """ x = x.reshape(-1,2,3).permute(0, 2, 1).contiguous() a1 = x[:, :, 0] a2 = x[:, :, 1] b1 = F.normalize(a1) b2 = F.normalize(a2 - torch.einsum('bi,bi->b', b1, a2).unsqueeze(-1) * b1) b3 = torch.cross(b1, b2) return torch.stack((b1, b2, b3), dim=-1) def aa_to_rotmat(theta: torch.Tensor): """ Convert axis-angle representation to rotation matrix. Works by first converting it to a quaternion. Args: theta (torch.Tensor): Tensor of shape (B, 3) containing axis-angle representations. Returns: torch.Tensor: Corresponding rotation matrices with shape (B, 3, 3). """ norm = torch.norm(theta + 1e-8, p = 2, dim = 1) angle = torch.unsqueeze(norm, -1) normalized = torch.div(theta, angle) angle = angle * 0.5 v_cos = torch.cos(angle) v_sin = torch.sin(angle) quat = torch.cat([v_cos, v_sin * normalized], dim = 1) return quat_to_rotmat(quat) class MANO(smplx.MANOLayer): def __init__(self, *args, joint_regressor_extra: Optional[str] = None, **kwargs): """ Extension of the official MANO implementation to support more joints. Args: Same as MANOLayer. joint_regressor_extra (str): Path to extra joint regressor. """ super(MANO, self).__init__(*args, **kwargs) mano_to_openpose = [0, 13, 14, 15, 16, 1, 2, 3, 17, 4, 5, 6, 18, 10, 11, 12, 19, 7, 8, 9, 20] #2, 3, 5, 4, 1 if joint_regressor_extra is not None: self.register_buffer('joint_regressor_extra', torch.tensor(pickle.load(open(joint_regressor_extra, 'rb'), encoding='latin1'), dtype=torch.float32)) self.register_buffer('extra_joints_idxs', to_tensor(list(vertex_ids['mano'].values()), dtype=torch.long)) self.register_buffer('joint_map', torch.tensor(mano_to_openpose, dtype=torch.long)) def forward(self, *args, **kwargs) -> MANOOutput: """ Run forward pass. Same as MANO and also append an extra set of joints if joint_regressor_extra is specified. """ mano_output = super(MANO, self).forward(*args, **kwargs) extra_joints = torch.index_select(mano_output.vertices, 1, self.extra_joints_idxs) joints = torch.cat([mano_output.joints, extra_joints], dim=1) joints = joints[:, self.joint_map, :] if hasattr(self, 'joint_regressor_extra'): extra_joints = vertices2joints(self.joint_regressor_extra, mano_output.vertices) joints = torch.cat([joints, extra_joints], dim=1) mano_output.joints = joints return mano_output class MANOTransformerDecoderHead(nn.Module): """ Cross-attention based MANO Transformer decoder """ def __init__(self): super().__init__() # self.cfg = cfg self.joint_rep_type = '6d' #cfg.MODEL.MANO_HEAD.get('JOINT_REP', '6d') self.joint_rep_dim = {'6d': 6, 'aa': 3}[self.joint_rep_type] npose = self.joint_rep_dim * (cfg.MODEL.hamer_mano_num_hand_joints + 1) self.npose = npose self.input_is_mean_shape = False #cfg.MODEL.MANO_HEAD.get('TRANSFORMER_INPUT', 'zero') == 'mean_shape' transformer_args = dict( num_tokens=1, token_dim=1, dim=1024, ) if cfg.MODEL.backbone_type in ['resnet-50', 'resnet-101', 'resnet-152', 'hrnet-w32', 'hrnet-w48']: context_dim = 2048 elif cfg.MODEL.backbone_type in ['vit-l-16']: context_dim = 1024 elif cfg.MODEL.backbone_type in ['vit-b-16']: context_dim = 768 elif cfg.MODEL.backbone_type in ['resnet-18', 'resnet-34']: context_dim = 512 elif cfg.MODEL.backbone_type in ['vit-s-16']: context_dim = 384 elif cfg.MODEL.backbone_type in ['handoccnet']: context_dim = 256 else: context_dim = 1280 # transformer_args = (transformer_args | {'context_dim': 1280, 'depth': 6, 'dim_head': 64, 'dropout': 0.0, 'emb_dropout': 0.0, 'heads': 8, 'mlp_dim': 1024, 'norm': 'layer'}) transformer_args = {**transformer_args, 'context_dim': context_dim, 'depth': 6, 'dim_head': 64, 'dropout': 0.0, 'emb_dropout': 0.0, 'heads': 8, 'mlp_dim': 1024, 'norm': 'layer'} self.transformer = TransformerDecoder( **transformer_args ) dim=transformer_args['dim'] self.decpose = nn.Linear(dim, npose) self.decshape = nn.Linear(dim, 10) self.deccam = nn.Linear(dim, 3) mean_params = np.load(cfg.MODEL.hamer_mano_mean_params) init_hand_pose = torch.from_numpy(mean_params['pose'].astype(np.float32)).unsqueeze(0) init_betas = torch.from_numpy(mean_params['shape'].astype('float32')).unsqueeze(0) init_cam = torch.from_numpy(mean_params['cam'].astype(np.float32)).unsqueeze(0) self.register_buffer('init_hand_pose', init_hand_pose) self.register_buffer('init_betas', init_betas) self.register_buffer('init_cam', init_cam) def forward(self, x, **kwargs): batch_size = x.shape[0] # vit pretrained backbone is channel-first. Change to token-first x = rearrange(x, 'b c h w -> b (h w) c') init_hand_pose = self.init_hand_pose.expand(batch_size, -1) init_betas = self.init_betas.expand(batch_size, -1) init_cam = self.init_cam.expand(batch_size, -1) # TODO: Convert init_hand_pose to aa rep if needed if self.joint_rep_type == 'aa': raise NotImplementedError pred_hand_pose = init_hand_pose pred_betas = init_betas pred_cam = init_cam pred_hand_pose_list = [] pred_betas_list = [] pred_cam_list = [] # Input token to transformer is zero token if self.input_is_mean_shape: token = torch.cat([pred_hand_pose, pred_betas, pred_cam], dim=1)[:,None,:] else: token = torch.zeros(batch_size, 1, 1).to(x.device) # Pass through transformer token_out = self.transformer(token, context=x) token_out = token_out.squeeze(1) # (B, C) # Readout from token_out pred_hand_pose = self.decpose(token_out) + pred_hand_pose pred_betas = self.decshape(token_out) + pred_betas pred_cam = self.deccam(token_out) + pred_cam pred_hand_pose_list.append(pred_hand_pose) pred_betas_list.append(pred_betas) pred_cam_list.append(pred_cam) # Convert self.joint_rep_type -> rotmat joint_conversion_fn = { '6d': rot6d_to_rotmat, 'aa': lambda x: aa_to_rotmat(x.view(-1, 3).contiguous()) }[self.joint_rep_type] pred_mano_params_list = {} pred_mano_params_list['hand_pose'] = torch.cat([joint_conversion_fn(pbp).view(batch_size, -1, 3, 3)[:, 1:, :, :] for pbp in pred_hand_pose_list], dim=0) pred_mano_params_list['betas'] = torch.cat(pred_betas_list, dim=0) pred_mano_params_list['cam'] = torch.cat(pred_cam_list, dim=0) pred_hand_pose = joint_conversion_fn(pred_hand_pose).view(batch_size, cfg.MODEL.hamer_mano_num_hand_joints+1, 3, 3) pred_mano_params = {'global_orient': pred_hand_pose[:, [0]], 'hand_pose': pred_hand_pose[:, 1:], 'betas': pred_betas} return pred_mano_params, pred_cam, pred_mano_params_list def perspective_projection(points: torch.Tensor, translation: torch.Tensor, focal_length: torch.Tensor, camera_center: Optional[torch.Tensor] = None, rotation: Optional[torch.Tensor] = None) -> torch.Tensor: """ Computes the perspective projection of a set of 3D points. Args: points (torch.Tensor): Tensor of shape (B, N, 3) containing the input 3D points. translation (torch.Tensor): Tensor of shape (B, 3) containing the 3D camera translation. focal_length (torch.Tensor): Tensor of shape (B, 2) containing the focal length in pixels. camera_center (torch.Tensor): Tensor of shape (B, 2) containing the camera center in pixels. rotation (torch.Tensor): Tensor of shape (B, 3, 3) containing the camera rotation. Returns: torch.Tensor: Tensor of shape (B, N, 2) containing the projection of the input points. """ batch_size = points.shape[0] if rotation is None: rotation = torch.eye(3, device=points.device, dtype=points.dtype).unsqueeze(0).expand(batch_size, -1, -1) if camera_center is None: camera_center = torch.zeros(batch_size, 2, device=points.device, dtype=points.dtype) # Populate intrinsic camera matrix K. K = torch.zeros([batch_size, 3, 3], device=points.device, dtype=points.dtype) K[:,0,0] = focal_length[:,0] K[:,1,1] = focal_length[:,1] K[:,2,2] = 1. K[:,:-1, -1] = camera_center # Transform points points = torch.einsum('bij,bkj->bki', rotation, points) points = points + translation.unsqueeze(1) # Apply perspective distortion projected_points = points / points[:,:,-1].unsqueeze(-1) # Apply camera intrinsics projected_points = torch.einsum('bij,bkj->bki', K, projected_points) return projected_points[:, :, :-1] # This module is modified from MANOTransformerDecoderHead of HaMeR (https://github.com/geopavlakos/hamer). All cfg are directly initialized. class ContactTransformerDecoderHead(nn.Module): """ Cross-attention based MANO Transformer decoder """ def __init__(self): super().__init__() transformer_args = dict( num_tokens=1, token_dim=1, dim=1024, ) if cfg.MODEL.backbone_type in ['resnet-50', 'resnet-101', 'resnet-152', 'hrnet-w32', 'hrnet-w48']: context_dim = 2048 elif cfg.MODEL.backbone_type in ['vit-l-16']: context_dim = 1024 elif cfg.MODEL.backbone_type in ['vit-b-16']: context_dim = 768 elif cfg.MODEL.backbone_type in ['resnet-18', 'resnet-34']: context_dim = 512 elif cfg.MODEL.backbone_type in ['vit-s-16']: context_dim = 384 elif cfg.MODEL.backbone_type in ['handoccnet']: context_dim = 256 else: context_dim = 1280 MANO_HEAD_TRANSFORMER_DECODER_CONFIG = {'depth': 6, 'heads': 8, 'mlp_dim': 1024, 'dim_head': 64, 'dropout': 0.0, 'emb_dropout': 0.0, 'norm': 'layer', 'context_dim': context_dim} transformer_args.update(dict(MANO_HEAD_TRANSFORMER_DECODER_CONFIG)) self.transformer = TransformerDecoder( **transformer_args ) self.deccontact = nn.Linear(1024, 778) CONTACT_MEAN_DIR = cfg.MODEL.contact_means_path # TODO: REPLACE THIS WITH CONTACT MEAN OF ENTIRE DATASETS init_contact = nn.Parameter(torch.randn(1, 778, requires_grad=True)) self.register_buffer('init_contact', init_contact) def forward(self, x, **kwargs): # x: [b, 1280, 16, 12] (if resnet-50, x: [b, 2048, 8, 8], resnet-34: [b, 512, 8, 8], hrnet-w32: [b, 2048, 8, 8]) batch_size = x.shape[0] device = x.device # vit pretrained backbone is channel-first. Change to token-first x = rearrange(x, 'b c h w -> b (h w) c') init_contact = self.init_contact.expand(batch_size, -1) pred_contact = init_contact token = torch.zeros(batch_size, 1, 1).to(x.device) # Pass through transformer token_out = self.transformer(token, context=x) # x: [b, 192, 1280] token_out = token_out[:, 0] # (B, C) # Readout from token_out pred_contact = self.deccontact(token_out) + pred_contact # pred_contact = pred_contact.sigmoid() return pred_contact