HACO / lib /models /decoder /decoder_hamer_style.py
dqj5182's picture
update
5dff371
import pickle
import numpy as np
from einops import rearrange
from inspect import isfunction
from typing import Callable, Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
import smplx
from smplx.lbs import vertices2joints
from smplx.utils import MANOOutput, to_tensor
from smplx.vertex_ids import vertex_ids
from lib.core.config import cfg
from lib.utils.human_models import mano
# This function is from HaMeR (https://github.com/geopavlakos/hamer).
def exists(val):
return val is not None
# This function is from HaMeR (https://github.com/geopavlakos/hamer).
def default(val, d):
if exists(val):
return val
return d() if isfunction(d) else d
# This class is from HaMeR (https://github.com/geopavlakos/hamer).
class Attention(nn.Module):
def __init__(self, dim, heads=8, dim_head=64, dropout=0.0):
super().__init__()
inner_dim = dim_head * heads
project_out = not (heads == 1 and dim_head == dim)
self.heads = heads
self.scale = dim_head**-0.5
self.attend = nn.Softmax(dim=-1)
self.dropout = nn.Dropout(dropout)
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False)
self.to_out = (
nn.Sequential(nn.Linear(inner_dim, dim), nn.Dropout(dropout))
if project_out
else nn.Identity()
)
def forward(self, x):
qkv = self.to_qkv(x).chunk(3, dim=-1)
q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=self.heads), qkv)
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
attn = self.attend(dots)
attn = self.dropout(attn)
out = torch.matmul(attn, v)
out = rearrange(out, "b h n d -> b n (h d)")
return self.to_out(out)
# This class is from HaMeR (https://github.com/geopavlakos/hamer).
class CrossAttention(nn.Module):
def __init__(self, dim, context_dim=None, heads=8, dim_head=64, dropout=0.0):
super().__init__()
inner_dim = dim_head * heads
project_out = not (heads == 1 and dim_head == dim)
self.heads = heads
self.scale = dim_head**-0.5
self.attend = nn.Softmax(dim=-1)
self.dropout = nn.Dropout(dropout)
context_dim = default(context_dim, dim)
self.to_kv = nn.Linear(context_dim, inner_dim * 2, bias=False)
self.to_q = nn.Linear(dim, inner_dim, bias=False)
self.to_out = (
nn.Sequential(nn.Linear(inner_dim, dim), nn.Dropout(dropout))
if project_out
else nn.Identity()
)
def forward(self, x, context=None):
context = default(context, x)
k, v = self.to_kv(context).chunk(2, dim=-1)
q = self.to_q(x)
q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=self.heads), [q, k, v])
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
attn = self.attend(dots)
attn = self.dropout(attn)
out = torch.matmul(attn, v)
out = rearrange(out, "b h n d -> b n (h d)")
return self.to_out(out)
# This class is from HaMeR (https://github.com/geopavlakos/hamer).
class FeedForward(nn.Module):
def __init__(self, dim, hidden_dim, dropout=0.0):
super().__init__()
self.net = nn.Sequential(
nn.Linear(dim, hidden_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, dim),
nn.Dropout(dropout),
)
def forward(self, x):
return self.net(x)
# This class is from HaMeR (https://github.com/geopavlakos/hamer).
class Transformer(nn.Module):
def __init__(
self,
dim: int,
depth: int,
heads: int,
dim_head: int,
mlp_dim: int,
dropout: float = 0.0,
norm: str = "layer",
norm_cond_dim: int = -1,
):
super().__init__()
self.layers = nn.ModuleList([])
for _ in range(depth):
sa = Attention(dim, heads=heads, dim_head=dim_head, dropout=dropout)
ff = FeedForward(dim, mlp_dim, dropout=dropout)
self.layers.append(
nn.ModuleList(
[
PreNorm(dim, sa, norm=norm, norm_cond_dim=norm_cond_dim),
PreNorm(dim, ff, norm=norm, norm_cond_dim=norm_cond_dim),
]
)
)
def forward(self, x: torch.Tensor, *args):
for attn, ff in self.layers:
x = attn(x, *args) + x
x = ff(x, *args) + x
return x
class AdaptiveLayerNorm1D(torch.nn.Module):
def __init__(self, data_dim: int, norm_cond_dim: int):
super().__init__()
if data_dim <= 0:
raise ValueError(f"data_dim must be positive, but got {data_dim}")
if norm_cond_dim <= 0:
raise ValueError(f"norm_cond_dim must be positive, but got {norm_cond_dim}")
self.norm = torch.nn.LayerNorm(
data_dim
) # TODO: Check if elementwise_affine=True is correct
self.linear = torch.nn.Linear(norm_cond_dim, 2 * data_dim)
torch.nn.init.zeros_(self.linear.weight)
torch.nn.init.zeros_(self.linear.bias)
def forward(self, x: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
# x: (batch, ..., data_dim)
# t: (batch, norm_cond_dim)
# return: (batch, data_dim)
x = self.norm(x)
alpha, beta = self.linear(t).chunk(2, dim=-1)
# Add singleton dimensions to alpha and beta
if x.dim() > 2:
alpha = alpha.view(alpha.shape[0], *([1] * (x.dim() - 2)), alpha.shape[1])
beta = beta.view(beta.shape[0], *([1] * (x.dim() - 2)), beta.shape[1])
return x * (1 + alpha) + beta
def normalization_layer(norm: Optional[str], dim: int, norm_cond_dim: int = -1):
if norm == "batch":
return torch.nn.BatchNorm1d(dim)
elif norm == "layer":
return torch.nn.LayerNorm(dim)
elif norm == "ada":
assert norm_cond_dim > 0, f"norm_cond_dim must be positive, got {norm_cond_dim}"
return AdaptiveLayerNorm1D(dim, norm_cond_dim)
elif norm is None:
return torch.nn.Identity()
else:
raise ValueError(f"Unknown norm: {norm}")
class PreNorm(nn.Module):
def __init__(self, dim: int, fn: Callable, norm: str = "layer", norm_cond_dim: int = -1):
super().__init__()
self.norm = normalization_layer(norm, dim, norm_cond_dim)
self.fn = fn
def forward(self, x: torch.Tensor, *args, **kwargs):
if isinstance(self.norm, AdaptiveLayerNorm1D):
return self.fn(self.norm(x, *args), **kwargs)
else:
return self.fn(self.norm(x), **kwargs)
# This class is from HaMeR (https://github.com/geopavlakos/hamer).
class TransformerCrossAttn(nn.Module):
def __init__(
self,
dim: int,
depth: int,
heads: int,
dim_head: int,
mlp_dim: int,
dropout: float = 0.0,
norm: str = "layer",
norm_cond_dim: int = -1,
context_dim: Optional[int] = None,
):
super().__init__()
self.layers = nn.ModuleList([])
for _ in range(depth):
sa = Attention(dim, heads=heads, dim_head=dim_head, dropout=dropout)
ca = CrossAttention(
dim, context_dim=context_dim, heads=heads, dim_head=dim_head, dropout=dropout
)
ff = FeedForward(dim, mlp_dim, dropout=dropout)
self.layers.append(
nn.ModuleList(
[
PreNorm(dim, sa, norm=norm, norm_cond_dim=norm_cond_dim),
PreNorm(dim, ca, norm=norm, norm_cond_dim=norm_cond_dim),
PreNorm(dim, ff, norm=norm, norm_cond_dim=norm_cond_dim),
]
)
)
def forward(self, x: torch.Tensor, *args, context=None, context_list=None):
if context_list is None:
context_list = [context] * len(self.layers)
if len(context_list) != len(self.layers):
raise ValueError(f"len(context_list) != len(self.layers) ({len(context_list)} != {len(self.layers)})")
for i, (self_attn, cross_attn, ff) in enumerate(self.layers):
x = self_attn(x, *args) + x
x = cross_attn(x, *args, context=context_list[i]) + x
x = ff(x, *args) + x
return x
# This class is from HaMeR (https://github.com/geopavlakos/hamer).
class DropTokenDropout(nn.Module):
def __init__(self, p: float = 0.1):
super().__init__()
if p < 0 or p > 1:
raise ValueError(
"dropout probability has to be between 0 and 1, " "but got {}".format(p)
)
self.p = p
def forward(self, x: torch.Tensor):
# x: (batch_size, seq_len, dim)
if self.training and self.p > 0:
zero_mask = torch.full_like(x[0, :, 0], self.p).bernoulli().bool()
# TODO: permutation idx for each batch using torch.argsort
if zero_mask.any():
x = x[:, ~zero_mask, :]
return x
# This class is from HaMeR (https://github.com/geopavlakos/hamer).
class ZeroTokenDropout(nn.Module):
def __init__(self, p: float = 0.1):
super().__init__()
if p < 0 or p > 1:
raise ValueError(
"dropout probability has to be between 0 and 1, " "but got {}".format(p)
)
self.p = p
def forward(self, x: torch.Tensor):
# x: (batch_size, seq_len, dim)
if self.training and self.p > 0:
zero_mask = torch.full_like(x[:, :, 0], self.p).bernoulli().bool()
# Zero-out the masked tokens
x[zero_mask, :] = 0
return x
# This class is from HaMeR (https://github.com/geopavlakos/hamer).
class TransformerDecoder(nn.Module):
def __init__(
self,
num_tokens: int,
token_dim: int,
dim: int,
depth: int,
heads: int,
mlp_dim: int,
dim_head: int = 64,
dropout: float = 0.0,
emb_dropout: float = 0.0,
emb_dropout_type: str = 'drop',
norm: str = "layer",
norm_cond_dim: int = -1,
context_dim: Optional[int] = None,
skip_token_embedding: bool = False,
):
super().__init__()
if not skip_token_embedding:
self.to_token_embedding = nn.Linear(token_dim, dim)
else:
self.to_token_embedding = nn.Identity()
if token_dim != dim:
raise ValueError(
f"token_dim ({token_dim}) != dim ({dim}) when skip_token_embedding is True"
)
self.pos_embedding = nn.Parameter(torch.randn(1, num_tokens, dim))
if emb_dropout_type == "drop":
self.dropout = DropTokenDropout(emb_dropout)
elif emb_dropout_type == "zero":
self.dropout = ZeroTokenDropout(emb_dropout)
elif emb_dropout_type == "normal":
self.dropout = nn.Dropout(emb_dropout)
self.transformer = TransformerCrossAttn(
dim,
depth,
heads,
dim_head,
mlp_dim,
dropout,
norm=norm,
norm_cond_dim=norm_cond_dim,
context_dim=context_dim,
)
def forward(self, inp: torch.Tensor, *args, context=None, context_list=None):
x = self.to_token_embedding(inp)
b, n, _ = x.shape
x = self.dropout(x)
x += self.pos_embedding[:, :n]
x = self.transformer(x, *args, context=context, context_list=context_list)
return x
def rot6d_to_rotmat(x: torch.Tensor) -> torch.Tensor:
"""
Convert 6D rotation representation to 3x3 rotation matrix.
Based on Zhou et al., "On the Continuity of Rotation Representations in Neural Networks", CVPR 2019
Args:
x (torch.Tensor): (B,6) Batch of 6-D rotation representations.
Returns:
torch.Tensor: Batch of corresponding rotation matrices with shape (B,3,3).
"""
x = x.reshape(-1,2,3).permute(0, 2, 1).contiguous()
a1 = x[:, :, 0]
a2 = x[:, :, 1]
b1 = F.normalize(a1)
b2 = F.normalize(a2 - torch.einsum('bi,bi->b', b1, a2).unsqueeze(-1) * b1)
b3 = torch.cross(b1, b2)
return torch.stack((b1, b2, b3), dim=-1)
def aa_to_rotmat(theta: torch.Tensor):
"""
Convert axis-angle representation to rotation matrix.
Works by first converting it to a quaternion.
Args:
theta (torch.Tensor): Tensor of shape (B, 3) containing axis-angle representations.
Returns:
torch.Tensor: Corresponding rotation matrices with shape (B, 3, 3).
"""
norm = torch.norm(theta + 1e-8, p = 2, dim = 1)
angle = torch.unsqueeze(norm, -1)
normalized = torch.div(theta, angle)
angle = angle * 0.5
v_cos = torch.cos(angle)
v_sin = torch.sin(angle)
quat = torch.cat([v_cos, v_sin * normalized], dim = 1)
return quat_to_rotmat(quat)
class MANO(smplx.MANOLayer):
def __init__(self, *args, joint_regressor_extra: Optional[str] = None, **kwargs):
"""
Extension of the official MANO implementation to support more joints.
Args:
Same as MANOLayer.
joint_regressor_extra (str): Path to extra joint regressor.
"""
super(MANO, self).__init__(*args, **kwargs)
mano_to_openpose = [0, 13, 14, 15, 16, 1, 2, 3, 17, 4, 5, 6, 18, 10, 11, 12, 19, 7, 8, 9, 20]
#2, 3, 5, 4, 1
if joint_regressor_extra is not None:
self.register_buffer('joint_regressor_extra', torch.tensor(pickle.load(open(joint_regressor_extra, 'rb'), encoding='latin1'), dtype=torch.float32))
self.register_buffer('extra_joints_idxs', to_tensor(list(vertex_ids['mano'].values()), dtype=torch.long))
self.register_buffer('joint_map', torch.tensor(mano_to_openpose, dtype=torch.long))
def forward(self, *args, **kwargs) -> MANOOutput:
"""
Run forward pass. Same as MANO and also append an extra set of joints if joint_regressor_extra is specified.
"""
mano_output = super(MANO, self).forward(*args, **kwargs)
extra_joints = torch.index_select(mano_output.vertices, 1, self.extra_joints_idxs)
joints = torch.cat([mano_output.joints, extra_joints], dim=1)
joints = joints[:, self.joint_map, :]
if hasattr(self, 'joint_regressor_extra'):
extra_joints = vertices2joints(self.joint_regressor_extra, mano_output.vertices)
joints = torch.cat([joints, extra_joints], dim=1)
mano_output.joints = joints
return mano_output
class MANOTransformerDecoderHead(nn.Module):
""" Cross-attention based MANO Transformer decoder
"""
def __init__(self):
super().__init__()
# self.cfg = cfg
self.joint_rep_type = '6d' #cfg.MODEL.MANO_HEAD.get('JOINT_REP', '6d')
self.joint_rep_dim = {'6d': 6, 'aa': 3}[self.joint_rep_type]
npose = self.joint_rep_dim * (cfg.MODEL.hamer_mano_num_hand_joints + 1)
self.npose = npose
self.input_is_mean_shape = False #cfg.MODEL.MANO_HEAD.get('TRANSFORMER_INPUT', 'zero') == 'mean_shape'
transformer_args = dict(
num_tokens=1,
token_dim=1,
dim=1024,
)
if cfg.MODEL.backbone_type in ['resnet-50', 'resnet-101', 'resnet-152', 'hrnet-w32', 'hrnet-w48']:
context_dim = 2048
elif cfg.MODEL.backbone_type in ['vit-l-16']:
context_dim = 1024
elif cfg.MODEL.backbone_type in ['vit-b-16']:
context_dim = 768
elif cfg.MODEL.backbone_type in ['resnet-18', 'resnet-34']:
context_dim = 512
elif cfg.MODEL.backbone_type in ['vit-s-16']:
context_dim = 384
elif cfg.MODEL.backbone_type in ['handoccnet']:
context_dim = 256
else:
context_dim = 1280
# transformer_args = (transformer_args | {'context_dim': 1280, 'depth': 6, 'dim_head': 64, 'dropout': 0.0, 'emb_dropout': 0.0, 'heads': 8, 'mlp_dim': 1024, 'norm': 'layer'})
transformer_args = {**transformer_args, 'context_dim': context_dim, 'depth': 6, 'dim_head': 64, 'dropout': 0.0, 'emb_dropout': 0.0, 'heads': 8, 'mlp_dim': 1024, 'norm': 'layer'}
self.transformer = TransformerDecoder(
**transformer_args
)
dim=transformer_args['dim']
self.decpose = nn.Linear(dim, npose)
self.decshape = nn.Linear(dim, 10)
self.deccam = nn.Linear(dim, 3)
mean_params = np.load(cfg.MODEL.hamer_mano_mean_params)
init_hand_pose = torch.from_numpy(mean_params['pose'].astype(np.float32)).unsqueeze(0)
init_betas = torch.from_numpy(mean_params['shape'].astype('float32')).unsqueeze(0)
init_cam = torch.from_numpy(mean_params['cam'].astype(np.float32)).unsqueeze(0)
self.register_buffer('init_hand_pose', init_hand_pose)
self.register_buffer('init_betas', init_betas)
self.register_buffer('init_cam', init_cam)
def forward(self, x, **kwargs):
batch_size = x.shape[0]
# vit pretrained backbone is channel-first. Change to token-first
x = rearrange(x, 'b c h w -> b (h w) c')
init_hand_pose = self.init_hand_pose.expand(batch_size, -1)
init_betas = self.init_betas.expand(batch_size, -1)
init_cam = self.init_cam.expand(batch_size, -1)
# TODO: Convert init_hand_pose to aa rep if needed
if self.joint_rep_type == 'aa':
raise NotImplementedError
pred_hand_pose = init_hand_pose
pred_betas = init_betas
pred_cam = init_cam
pred_hand_pose_list = []
pred_betas_list = []
pred_cam_list = []
# Input token to transformer is zero token
if self.input_is_mean_shape:
token = torch.cat([pred_hand_pose, pred_betas, pred_cam], dim=1)[:,None,:]
else:
token = torch.zeros(batch_size, 1, 1).to(x.device)
# Pass through transformer
token_out = self.transformer(token, context=x)
token_out = token_out.squeeze(1) # (B, C)
# Readout from token_out
pred_hand_pose = self.decpose(token_out) + pred_hand_pose
pred_betas = self.decshape(token_out) + pred_betas
pred_cam = self.deccam(token_out) + pred_cam
pred_hand_pose_list.append(pred_hand_pose)
pred_betas_list.append(pred_betas)
pred_cam_list.append(pred_cam)
# Convert self.joint_rep_type -> rotmat
joint_conversion_fn = {
'6d': rot6d_to_rotmat,
'aa': lambda x: aa_to_rotmat(x.view(-1, 3).contiguous())
}[self.joint_rep_type]
pred_mano_params_list = {}
pred_mano_params_list['hand_pose'] = torch.cat([joint_conversion_fn(pbp).view(batch_size, -1, 3, 3)[:, 1:, :, :] for pbp in pred_hand_pose_list], dim=0)
pred_mano_params_list['betas'] = torch.cat(pred_betas_list, dim=0)
pred_mano_params_list['cam'] = torch.cat(pred_cam_list, dim=0)
pred_hand_pose = joint_conversion_fn(pred_hand_pose).view(batch_size, cfg.MODEL.hamer_mano_num_hand_joints+1, 3, 3)
pred_mano_params = {'global_orient': pred_hand_pose[:, [0]],
'hand_pose': pred_hand_pose[:, 1:],
'betas': pred_betas}
return pred_mano_params, pred_cam, pred_mano_params_list
def perspective_projection(points: torch.Tensor,
translation: torch.Tensor,
focal_length: torch.Tensor,
camera_center: Optional[torch.Tensor] = None,
rotation: Optional[torch.Tensor] = None) -> torch.Tensor:
"""
Computes the perspective projection of a set of 3D points.
Args:
points (torch.Tensor): Tensor of shape (B, N, 3) containing the input 3D points.
translation (torch.Tensor): Tensor of shape (B, 3) containing the 3D camera translation.
focal_length (torch.Tensor): Tensor of shape (B, 2) containing the focal length in pixels.
camera_center (torch.Tensor): Tensor of shape (B, 2) containing the camera center in pixels.
rotation (torch.Tensor): Tensor of shape (B, 3, 3) containing the camera rotation.
Returns:
torch.Tensor: Tensor of shape (B, N, 2) containing the projection of the input points.
"""
batch_size = points.shape[0]
if rotation is None:
rotation = torch.eye(3, device=points.device, dtype=points.dtype).unsqueeze(0).expand(batch_size, -1, -1)
if camera_center is None:
camera_center = torch.zeros(batch_size, 2, device=points.device, dtype=points.dtype)
# Populate intrinsic camera matrix K.
K = torch.zeros([batch_size, 3, 3], device=points.device, dtype=points.dtype)
K[:,0,0] = focal_length[:,0]
K[:,1,1] = focal_length[:,1]
K[:,2,2] = 1.
K[:,:-1, -1] = camera_center
# Transform points
points = torch.einsum('bij,bkj->bki', rotation, points)
points = points + translation.unsqueeze(1)
# Apply perspective distortion
projected_points = points / points[:,:,-1].unsqueeze(-1)
# Apply camera intrinsics
projected_points = torch.einsum('bij,bkj->bki', K, projected_points)
return projected_points[:, :, :-1]
# This module is modified from MANOTransformerDecoderHead of HaMeR (https://github.com/geopavlakos/hamer). All cfg are directly initialized.
class ContactTransformerDecoderHead(nn.Module):
""" Cross-attention based MANO Transformer decoder
"""
def __init__(self):
super().__init__()
transformer_args = dict(
num_tokens=1,
token_dim=1,
dim=1024,
)
if cfg.MODEL.backbone_type in ['resnet-50', 'resnet-101', 'resnet-152', 'hrnet-w32', 'hrnet-w48']:
context_dim = 2048
elif cfg.MODEL.backbone_type in ['vit-l-16']:
context_dim = 1024
elif cfg.MODEL.backbone_type in ['vit-b-16']:
context_dim = 768
elif cfg.MODEL.backbone_type in ['resnet-18', 'resnet-34']:
context_dim = 512
elif cfg.MODEL.backbone_type in ['vit-s-16']:
context_dim = 384
elif cfg.MODEL.backbone_type in ['handoccnet']:
context_dim = 256
else:
context_dim = 1280
MANO_HEAD_TRANSFORMER_DECODER_CONFIG = {'depth': 6, 'heads': 8, 'mlp_dim': 1024, 'dim_head': 64, 'dropout': 0.0, 'emb_dropout': 0.0, 'norm': 'layer', 'context_dim': context_dim}
transformer_args.update(dict(MANO_HEAD_TRANSFORMER_DECODER_CONFIG))
self.transformer = TransformerDecoder(
**transformer_args
)
self.deccontact = nn.Linear(1024, 778)
CONTACT_MEAN_DIR = cfg.MODEL.contact_means_path # TODO: REPLACE THIS WITH CONTACT MEAN OF ENTIRE DATASETS
init_contact = nn.Parameter(torch.randn(1, 778, requires_grad=True))
self.register_buffer('init_contact', init_contact)
def forward(self, x, **kwargs): # x: [b, 1280, 16, 12] (if resnet-50, x: [b, 2048, 8, 8], resnet-34: [b, 512, 8, 8], hrnet-w32: [b, 2048, 8, 8])
batch_size = x.shape[0]
device = x.device
# vit pretrained backbone is channel-first. Change to token-first
x = rearrange(x, 'b c h w -> b (h w) c')
init_contact = self.init_contact.expand(batch_size, -1)
pred_contact = init_contact
token = torch.zeros(batch_size, 1, 1).to(x.device)
# Pass through transformer
token_out = self.transformer(token, context=x) # x: [b, 192, 1280]
token_out = token_out[:, 0] # (B, C)
# Readout from token_out
pred_contact = self.deccontact(token_out) + pred_contact
# pred_contact = pred_contact.sigmoid()
return pred_contact