|
|
import pickle |
|
|
import numpy as np |
|
|
from einops import rearrange |
|
|
from inspect import isfunction |
|
|
from typing import Callable, Optional |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
|
|
|
import smplx |
|
|
from smplx.lbs import vertices2joints |
|
|
from smplx.utils import MANOOutput, to_tensor |
|
|
from smplx.vertex_ids import vertex_ids |
|
|
|
|
|
from lib.core.config import cfg |
|
|
from lib.utils.human_models import mano |
|
|
|
|
|
|
|
|
|
|
|
def exists(val): |
|
|
return val is not None |
|
|
|
|
|
|
|
|
|
|
|
def default(val, d): |
|
|
if exists(val): |
|
|
return val |
|
|
return d() if isfunction(d) else d |
|
|
|
|
|
|
|
|
|
|
|
class Attention(nn.Module): |
|
|
def __init__(self, dim, heads=8, dim_head=64, dropout=0.0): |
|
|
super().__init__() |
|
|
inner_dim = dim_head * heads |
|
|
project_out = not (heads == 1 and dim_head == dim) |
|
|
|
|
|
self.heads = heads |
|
|
self.scale = dim_head**-0.5 |
|
|
|
|
|
self.attend = nn.Softmax(dim=-1) |
|
|
self.dropout = nn.Dropout(dropout) |
|
|
|
|
|
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False) |
|
|
|
|
|
self.to_out = ( |
|
|
nn.Sequential(nn.Linear(inner_dim, dim), nn.Dropout(dropout)) |
|
|
if project_out |
|
|
else nn.Identity() |
|
|
) |
|
|
|
|
|
def forward(self, x): |
|
|
qkv = self.to_qkv(x).chunk(3, dim=-1) |
|
|
q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=self.heads), qkv) |
|
|
|
|
|
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale |
|
|
|
|
|
attn = self.attend(dots) |
|
|
attn = self.dropout(attn) |
|
|
|
|
|
out = torch.matmul(attn, v) |
|
|
out = rearrange(out, "b h n d -> b n (h d)") |
|
|
return self.to_out(out) |
|
|
|
|
|
|
|
|
|
|
|
class CrossAttention(nn.Module): |
|
|
def __init__(self, dim, context_dim=None, heads=8, dim_head=64, dropout=0.0): |
|
|
super().__init__() |
|
|
inner_dim = dim_head * heads |
|
|
project_out = not (heads == 1 and dim_head == dim) |
|
|
|
|
|
self.heads = heads |
|
|
self.scale = dim_head**-0.5 |
|
|
|
|
|
self.attend = nn.Softmax(dim=-1) |
|
|
self.dropout = nn.Dropout(dropout) |
|
|
|
|
|
context_dim = default(context_dim, dim) |
|
|
self.to_kv = nn.Linear(context_dim, inner_dim * 2, bias=False) |
|
|
self.to_q = nn.Linear(dim, inner_dim, bias=False) |
|
|
|
|
|
self.to_out = ( |
|
|
nn.Sequential(nn.Linear(inner_dim, dim), nn.Dropout(dropout)) |
|
|
if project_out |
|
|
else nn.Identity() |
|
|
) |
|
|
|
|
|
def forward(self, x, context=None): |
|
|
context = default(context, x) |
|
|
k, v = self.to_kv(context).chunk(2, dim=-1) |
|
|
q = self.to_q(x) |
|
|
q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=self.heads), [q, k, v]) |
|
|
|
|
|
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale |
|
|
|
|
|
attn = self.attend(dots) |
|
|
attn = self.dropout(attn) |
|
|
|
|
|
out = torch.matmul(attn, v) |
|
|
out = rearrange(out, "b h n d -> b n (h d)") |
|
|
return self.to_out(out) |
|
|
|
|
|
|
|
|
|
|
|
class FeedForward(nn.Module): |
|
|
def __init__(self, dim, hidden_dim, dropout=0.0): |
|
|
super().__init__() |
|
|
self.net = nn.Sequential( |
|
|
nn.Linear(dim, hidden_dim), |
|
|
nn.GELU(), |
|
|
nn.Dropout(dropout), |
|
|
nn.Linear(hidden_dim, dim), |
|
|
nn.Dropout(dropout), |
|
|
) |
|
|
|
|
|
def forward(self, x): |
|
|
return self.net(x) |
|
|
|
|
|
|
|
|
|
|
|
class Transformer(nn.Module): |
|
|
def __init__( |
|
|
self, |
|
|
dim: int, |
|
|
depth: int, |
|
|
heads: int, |
|
|
dim_head: int, |
|
|
mlp_dim: int, |
|
|
dropout: float = 0.0, |
|
|
norm: str = "layer", |
|
|
norm_cond_dim: int = -1, |
|
|
): |
|
|
super().__init__() |
|
|
self.layers = nn.ModuleList([]) |
|
|
for _ in range(depth): |
|
|
sa = Attention(dim, heads=heads, dim_head=dim_head, dropout=dropout) |
|
|
ff = FeedForward(dim, mlp_dim, dropout=dropout) |
|
|
self.layers.append( |
|
|
nn.ModuleList( |
|
|
[ |
|
|
PreNorm(dim, sa, norm=norm, norm_cond_dim=norm_cond_dim), |
|
|
PreNorm(dim, ff, norm=norm, norm_cond_dim=norm_cond_dim), |
|
|
] |
|
|
) |
|
|
) |
|
|
|
|
|
def forward(self, x: torch.Tensor, *args): |
|
|
for attn, ff in self.layers: |
|
|
x = attn(x, *args) + x |
|
|
x = ff(x, *args) + x |
|
|
return x |
|
|
|
|
|
|
|
|
class AdaptiveLayerNorm1D(torch.nn.Module): |
|
|
def __init__(self, data_dim: int, norm_cond_dim: int): |
|
|
super().__init__() |
|
|
if data_dim <= 0: |
|
|
raise ValueError(f"data_dim must be positive, but got {data_dim}") |
|
|
if norm_cond_dim <= 0: |
|
|
raise ValueError(f"norm_cond_dim must be positive, but got {norm_cond_dim}") |
|
|
self.norm = torch.nn.LayerNorm( |
|
|
data_dim |
|
|
) |
|
|
self.linear = torch.nn.Linear(norm_cond_dim, 2 * data_dim) |
|
|
torch.nn.init.zeros_(self.linear.weight) |
|
|
torch.nn.init.zeros_(self.linear.bias) |
|
|
|
|
|
def forward(self, x: torch.Tensor, t: torch.Tensor) -> torch.Tensor: |
|
|
|
|
|
|
|
|
|
|
|
x = self.norm(x) |
|
|
alpha, beta = self.linear(t).chunk(2, dim=-1) |
|
|
|
|
|
|
|
|
if x.dim() > 2: |
|
|
alpha = alpha.view(alpha.shape[0], *([1] * (x.dim() - 2)), alpha.shape[1]) |
|
|
beta = beta.view(beta.shape[0], *([1] * (x.dim() - 2)), beta.shape[1]) |
|
|
|
|
|
return x * (1 + alpha) + beta |
|
|
|
|
|
|
|
|
def normalization_layer(norm: Optional[str], dim: int, norm_cond_dim: int = -1): |
|
|
if norm == "batch": |
|
|
return torch.nn.BatchNorm1d(dim) |
|
|
elif norm == "layer": |
|
|
return torch.nn.LayerNorm(dim) |
|
|
elif norm == "ada": |
|
|
assert norm_cond_dim > 0, f"norm_cond_dim must be positive, got {norm_cond_dim}" |
|
|
return AdaptiveLayerNorm1D(dim, norm_cond_dim) |
|
|
elif norm is None: |
|
|
return torch.nn.Identity() |
|
|
else: |
|
|
raise ValueError(f"Unknown norm: {norm}") |
|
|
|
|
|
|
|
|
class PreNorm(nn.Module): |
|
|
def __init__(self, dim: int, fn: Callable, norm: str = "layer", norm_cond_dim: int = -1): |
|
|
super().__init__() |
|
|
self.norm = normalization_layer(norm, dim, norm_cond_dim) |
|
|
self.fn = fn |
|
|
|
|
|
def forward(self, x: torch.Tensor, *args, **kwargs): |
|
|
if isinstance(self.norm, AdaptiveLayerNorm1D): |
|
|
return self.fn(self.norm(x, *args), **kwargs) |
|
|
else: |
|
|
return self.fn(self.norm(x), **kwargs) |
|
|
|
|
|
|
|
|
|
|
|
class TransformerCrossAttn(nn.Module): |
|
|
def __init__( |
|
|
self, |
|
|
dim: int, |
|
|
depth: int, |
|
|
heads: int, |
|
|
dim_head: int, |
|
|
mlp_dim: int, |
|
|
dropout: float = 0.0, |
|
|
norm: str = "layer", |
|
|
norm_cond_dim: int = -1, |
|
|
context_dim: Optional[int] = None, |
|
|
): |
|
|
super().__init__() |
|
|
self.layers = nn.ModuleList([]) |
|
|
for _ in range(depth): |
|
|
sa = Attention(dim, heads=heads, dim_head=dim_head, dropout=dropout) |
|
|
ca = CrossAttention( |
|
|
dim, context_dim=context_dim, heads=heads, dim_head=dim_head, dropout=dropout |
|
|
) |
|
|
ff = FeedForward(dim, mlp_dim, dropout=dropout) |
|
|
self.layers.append( |
|
|
nn.ModuleList( |
|
|
[ |
|
|
PreNorm(dim, sa, norm=norm, norm_cond_dim=norm_cond_dim), |
|
|
PreNorm(dim, ca, norm=norm, norm_cond_dim=norm_cond_dim), |
|
|
PreNorm(dim, ff, norm=norm, norm_cond_dim=norm_cond_dim), |
|
|
] |
|
|
) |
|
|
) |
|
|
|
|
|
def forward(self, x: torch.Tensor, *args, context=None, context_list=None): |
|
|
if context_list is None: |
|
|
context_list = [context] * len(self.layers) |
|
|
if len(context_list) != len(self.layers): |
|
|
raise ValueError(f"len(context_list) != len(self.layers) ({len(context_list)} != {len(self.layers)})") |
|
|
|
|
|
for i, (self_attn, cross_attn, ff) in enumerate(self.layers): |
|
|
x = self_attn(x, *args) + x |
|
|
x = cross_attn(x, *args, context=context_list[i]) + x |
|
|
x = ff(x, *args) + x |
|
|
return x |
|
|
|
|
|
|
|
|
|
|
|
class DropTokenDropout(nn.Module): |
|
|
def __init__(self, p: float = 0.1): |
|
|
super().__init__() |
|
|
if p < 0 or p > 1: |
|
|
raise ValueError( |
|
|
"dropout probability has to be between 0 and 1, " "but got {}".format(p) |
|
|
) |
|
|
self.p = p |
|
|
|
|
|
def forward(self, x: torch.Tensor): |
|
|
|
|
|
if self.training and self.p > 0: |
|
|
zero_mask = torch.full_like(x[0, :, 0], self.p).bernoulli().bool() |
|
|
|
|
|
if zero_mask.any(): |
|
|
x = x[:, ~zero_mask, :] |
|
|
return x |
|
|
|
|
|
|
|
|
|
|
|
class ZeroTokenDropout(nn.Module): |
|
|
def __init__(self, p: float = 0.1): |
|
|
super().__init__() |
|
|
if p < 0 or p > 1: |
|
|
raise ValueError( |
|
|
"dropout probability has to be between 0 and 1, " "but got {}".format(p) |
|
|
) |
|
|
self.p = p |
|
|
|
|
|
def forward(self, x: torch.Tensor): |
|
|
|
|
|
if self.training and self.p > 0: |
|
|
zero_mask = torch.full_like(x[:, :, 0], self.p).bernoulli().bool() |
|
|
|
|
|
x[zero_mask, :] = 0 |
|
|
return x |
|
|
|
|
|
|
|
|
|
|
|
class TransformerDecoder(nn.Module): |
|
|
def __init__( |
|
|
self, |
|
|
num_tokens: int, |
|
|
token_dim: int, |
|
|
dim: int, |
|
|
depth: int, |
|
|
heads: int, |
|
|
mlp_dim: int, |
|
|
dim_head: int = 64, |
|
|
dropout: float = 0.0, |
|
|
emb_dropout: float = 0.0, |
|
|
emb_dropout_type: str = 'drop', |
|
|
norm: str = "layer", |
|
|
norm_cond_dim: int = -1, |
|
|
context_dim: Optional[int] = None, |
|
|
skip_token_embedding: bool = False, |
|
|
): |
|
|
super().__init__() |
|
|
if not skip_token_embedding: |
|
|
self.to_token_embedding = nn.Linear(token_dim, dim) |
|
|
else: |
|
|
self.to_token_embedding = nn.Identity() |
|
|
if token_dim != dim: |
|
|
raise ValueError( |
|
|
f"token_dim ({token_dim}) != dim ({dim}) when skip_token_embedding is True" |
|
|
) |
|
|
|
|
|
self.pos_embedding = nn.Parameter(torch.randn(1, num_tokens, dim)) |
|
|
if emb_dropout_type == "drop": |
|
|
self.dropout = DropTokenDropout(emb_dropout) |
|
|
elif emb_dropout_type == "zero": |
|
|
self.dropout = ZeroTokenDropout(emb_dropout) |
|
|
elif emb_dropout_type == "normal": |
|
|
self.dropout = nn.Dropout(emb_dropout) |
|
|
|
|
|
self.transformer = TransformerCrossAttn( |
|
|
dim, |
|
|
depth, |
|
|
heads, |
|
|
dim_head, |
|
|
mlp_dim, |
|
|
dropout, |
|
|
norm=norm, |
|
|
norm_cond_dim=norm_cond_dim, |
|
|
context_dim=context_dim, |
|
|
) |
|
|
|
|
|
def forward(self, inp: torch.Tensor, *args, context=None, context_list=None): |
|
|
x = self.to_token_embedding(inp) |
|
|
b, n, _ = x.shape |
|
|
|
|
|
x = self.dropout(x) |
|
|
x += self.pos_embedding[:, :n] |
|
|
|
|
|
x = self.transformer(x, *args, context=context, context_list=context_list) |
|
|
return x |
|
|
|
|
|
|
|
|
def rot6d_to_rotmat(x: torch.Tensor) -> torch.Tensor: |
|
|
""" |
|
|
Convert 6D rotation representation to 3x3 rotation matrix. |
|
|
Based on Zhou et al., "On the Continuity of Rotation Representations in Neural Networks", CVPR 2019 |
|
|
Args: |
|
|
x (torch.Tensor): (B,6) Batch of 6-D rotation representations. |
|
|
Returns: |
|
|
torch.Tensor: Batch of corresponding rotation matrices with shape (B,3,3). |
|
|
""" |
|
|
x = x.reshape(-1,2,3).permute(0, 2, 1).contiguous() |
|
|
a1 = x[:, :, 0] |
|
|
a2 = x[:, :, 1] |
|
|
b1 = F.normalize(a1) |
|
|
b2 = F.normalize(a2 - torch.einsum('bi,bi->b', b1, a2).unsqueeze(-1) * b1) |
|
|
b3 = torch.cross(b1, b2) |
|
|
return torch.stack((b1, b2, b3), dim=-1) |
|
|
|
|
|
|
|
|
def aa_to_rotmat(theta: torch.Tensor): |
|
|
""" |
|
|
Convert axis-angle representation to rotation matrix. |
|
|
Works by first converting it to a quaternion. |
|
|
Args: |
|
|
theta (torch.Tensor): Tensor of shape (B, 3) containing axis-angle representations. |
|
|
Returns: |
|
|
torch.Tensor: Corresponding rotation matrices with shape (B, 3, 3). |
|
|
""" |
|
|
norm = torch.norm(theta + 1e-8, p = 2, dim = 1) |
|
|
angle = torch.unsqueeze(norm, -1) |
|
|
normalized = torch.div(theta, angle) |
|
|
angle = angle * 0.5 |
|
|
v_cos = torch.cos(angle) |
|
|
v_sin = torch.sin(angle) |
|
|
quat = torch.cat([v_cos, v_sin * normalized], dim = 1) |
|
|
return quat_to_rotmat(quat) |
|
|
|
|
|
|
|
|
class MANO(smplx.MANOLayer): |
|
|
def __init__(self, *args, joint_regressor_extra: Optional[str] = None, **kwargs): |
|
|
""" |
|
|
Extension of the official MANO implementation to support more joints. |
|
|
Args: |
|
|
Same as MANOLayer. |
|
|
joint_regressor_extra (str): Path to extra joint regressor. |
|
|
""" |
|
|
super(MANO, self).__init__(*args, **kwargs) |
|
|
mano_to_openpose = [0, 13, 14, 15, 16, 1, 2, 3, 17, 4, 5, 6, 18, 10, 11, 12, 19, 7, 8, 9, 20] |
|
|
|
|
|
|
|
|
if joint_regressor_extra is not None: |
|
|
self.register_buffer('joint_regressor_extra', torch.tensor(pickle.load(open(joint_regressor_extra, 'rb'), encoding='latin1'), dtype=torch.float32)) |
|
|
self.register_buffer('extra_joints_idxs', to_tensor(list(vertex_ids['mano'].values()), dtype=torch.long)) |
|
|
self.register_buffer('joint_map', torch.tensor(mano_to_openpose, dtype=torch.long)) |
|
|
|
|
|
def forward(self, *args, **kwargs) -> MANOOutput: |
|
|
""" |
|
|
Run forward pass. Same as MANO and also append an extra set of joints if joint_regressor_extra is specified. |
|
|
""" |
|
|
mano_output = super(MANO, self).forward(*args, **kwargs) |
|
|
extra_joints = torch.index_select(mano_output.vertices, 1, self.extra_joints_idxs) |
|
|
joints = torch.cat([mano_output.joints, extra_joints], dim=1) |
|
|
joints = joints[:, self.joint_map, :] |
|
|
if hasattr(self, 'joint_regressor_extra'): |
|
|
extra_joints = vertices2joints(self.joint_regressor_extra, mano_output.vertices) |
|
|
joints = torch.cat([joints, extra_joints], dim=1) |
|
|
mano_output.joints = joints |
|
|
return mano_output |
|
|
|
|
|
|
|
|
class MANOTransformerDecoderHead(nn.Module): |
|
|
""" Cross-attention based MANO Transformer decoder |
|
|
""" |
|
|
|
|
|
def __init__(self): |
|
|
super().__init__() |
|
|
|
|
|
self.joint_rep_type = '6d' |
|
|
self.joint_rep_dim = {'6d': 6, 'aa': 3}[self.joint_rep_type] |
|
|
npose = self.joint_rep_dim * (cfg.MODEL.hamer_mano_num_hand_joints + 1) |
|
|
self.npose = npose |
|
|
self.input_is_mean_shape = False |
|
|
transformer_args = dict( |
|
|
num_tokens=1, |
|
|
token_dim=1, |
|
|
dim=1024, |
|
|
) |
|
|
if cfg.MODEL.backbone_type in ['resnet-50', 'resnet-101', 'resnet-152', 'hrnet-w32', 'hrnet-w48']: |
|
|
context_dim = 2048 |
|
|
elif cfg.MODEL.backbone_type in ['vit-l-16']: |
|
|
context_dim = 1024 |
|
|
elif cfg.MODEL.backbone_type in ['vit-b-16']: |
|
|
context_dim = 768 |
|
|
elif cfg.MODEL.backbone_type in ['resnet-18', 'resnet-34']: |
|
|
context_dim = 512 |
|
|
elif cfg.MODEL.backbone_type in ['vit-s-16']: |
|
|
context_dim = 384 |
|
|
elif cfg.MODEL.backbone_type in ['handoccnet']: |
|
|
context_dim = 256 |
|
|
else: |
|
|
context_dim = 1280 |
|
|
|
|
|
|
|
|
transformer_args = {**transformer_args, 'context_dim': context_dim, 'depth': 6, 'dim_head': 64, 'dropout': 0.0, 'emb_dropout': 0.0, 'heads': 8, 'mlp_dim': 1024, 'norm': 'layer'} |
|
|
self.transformer = TransformerDecoder( |
|
|
**transformer_args |
|
|
) |
|
|
dim=transformer_args['dim'] |
|
|
self.decpose = nn.Linear(dim, npose) |
|
|
self.decshape = nn.Linear(dim, 10) |
|
|
self.deccam = nn.Linear(dim, 3) |
|
|
|
|
|
mean_params = np.load(cfg.MODEL.hamer_mano_mean_params) |
|
|
init_hand_pose = torch.from_numpy(mean_params['pose'].astype(np.float32)).unsqueeze(0) |
|
|
init_betas = torch.from_numpy(mean_params['shape'].astype('float32')).unsqueeze(0) |
|
|
init_cam = torch.from_numpy(mean_params['cam'].astype(np.float32)).unsqueeze(0) |
|
|
self.register_buffer('init_hand_pose', init_hand_pose) |
|
|
self.register_buffer('init_betas', init_betas) |
|
|
self.register_buffer('init_cam', init_cam) |
|
|
|
|
|
def forward(self, x, **kwargs): |
|
|
batch_size = x.shape[0] |
|
|
|
|
|
x = rearrange(x, 'b c h w -> b (h w) c') |
|
|
|
|
|
init_hand_pose = self.init_hand_pose.expand(batch_size, -1) |
|
|
init_betas = self.init_betas.expand(batch_size, -1) |
|
|
init_cam = self.init_cam.expand(batch_size, -1) |
|
|
|
|
|
|
|
|
if self.joint_rep_type == 'aa': |
|
|
raise NotImplementedError |
|
|
|
|
|
pred_hand_pose = init_hand_pose |
|
|
pred_betas = init_betas |
|
|
pred_cam = init_cam |
|
|
pred_hand_pose_list = [] |
|
|
pred_betas_list = [] |
|
|
pred_cam_list = [] |
|
|
|
|
|
|
|
|
if self.input_is_mean_shape: |
|
|
token = torch.cat([pred_hand_pose, pred_betas, pred_cam], dim=1)[:,None,:] |
|
|
else: |
|
|
token = torch.zeros(batch_size, 1, 1).to(x.device) |
|
|
|
|
|
|
|
|
token_out = self.transformer(token, context=x) |
|
|
token_out = token_out.squeeze(1) |
|
|
|
|
|
|
|
|
pred_hand_pose = self.decpose(token_out) + pred_hand_pose |
|
|
pred_betas = self.decshape(token_out) + pred_betas |
|
|
pred_cam = self.deccam(token_out) + pred_cam |
|
|
pred_hand_pose_list.append(pred_hand_pose) |
|
|
pred_betas_list.append(pred_betas) |
|
|
pred_cam_list.append(pred_cam) |
|
|
|
|
|
|
|
|
joint_conversion_fn = { |
|
|
'6d': rot6d_to_rotmat, |
|
|
'aa': lambda x: aa_to_rotmat(x.view(-1, 3).contiguous()) |
|
|
}[self.joint_rep_type] |
|
|
|
|
|
pred_mano_params_list = {} |
|
|
pred_mano_params_list['hand_pose'] = torch.cat([joint_conversion_fn(pbp).view(batch_size, -1, 3, 3)[:, 1:, :, :] for pbp in pred_hand_pose_list], dim=0) |
|
|
pred_mano_params_list['betas'] = torch.cat(pred_betas_list, dim=0) |
|
|
pred_mano_params_list['cam'] = torch.cat(pred_cam_list, dim=0) |
|
|
pred_hand_pose = joint_conversion_fn(pred_hand_pose).view(batch_size, cfg.MODEL.hamer_mano_num_hand_joints+1, 3, 3) |
|
|
|
|
|
pred_mano_params = {'global_orient': pred_hand_pose[:, [0]], |
|
|
'hand_pose': pred_hand_pose[:, 1:], |
|
|
'betas': pred_betas} |
|
|
return pred_mano_params, pred_cam, pred_mano_params_list |
|
|
|
|
|
|
|
|
def perspective_projection(points: torch.Tensor, |
|
|
translation: torch.Tensor, |
|
|
focal_length: torch.Tensor, |
|
|
camera_center: Optional[torch.Tensor] = None, |
|
|
rotation: Optional[torch.Tensor] = None) -> torch.Tensor: |
|
|
""" |
|
|
Computes the perspective projection of a set of 3D points. |
|
|
Args: |
|
|
points (torch.Tensor): Tensor of shape (B, N, 3) containing the input 3D points. |
|
|
translation (torch.Tensor): Tensor of shape (B, 3) containing the 3D camera translation. |
|
|
focal_length (torch.Tensor): Tensor of shape (B, 2) containing the focal length in pixels. |
|
|
camera_center (torch.Tensor): Tensor of shape (B, 2) containing the camera center in pixels. |
|
|
rotation (torch.Tensor): Tensor of shape (B, 3, 3) containing the camera rotation. |
|
|
Returns: |
|
|
torch.Tensor: Tensor of shape (B, N, 2) containing the projection of the input points. |
|
|
""" |
|
|
batch_size = points.shape[0] |
|
|
if rotation is None: |
|
|
rotation = torch.eye(3, device=points.device, dtype=points.dtype).unsqueeze(0).expand(batch_size, -1, -1) |
|
|
if camera_center is None: |
|
|
camera_center = torch.zeros(batch_size, 2, device=points.device, dtype=points.dtype) |
|
|
|
|
|
K = torch.zeros([batch_size, 3, 3], device=points.device, dtype=points.dtype) |
|
|
K[:,0,0] = focal_length[:,0] |
|
|
K[:,1,1] = focal_length[:,1] |
|
|
K[:,2,2] = 1. |
|
|
K[:,:-1, -1] = camera_center |
|
|
|
|
|
|
|
|
points = torch.einsum('bij,bkj->bki', rotation, points) |
|
|
points = points + translation.unsqueeze(1) |
|
|
|
|
|
|
|
|
projected_points = points / points[:,:,-1].unsqueeze(-1) |
|
|
|
|
|
|
|
|
projected_points = torch.einsum('bij,bkj->bki', K, projected_points) |
|
|
|
|
|
return projected_points[:, :, :-1] |
|
|
|
|
|
|
|
|
|
|
|
class ContactTransformerDecoderHead(nn.Module): |
|
|
""" Cross-attention based MANO Transformer decoder |
|
|
""" |
|
|
def __init__(self): |
|
|
super().__init__() |
|
|
transformer_args = dict( |
|
|
num_tokens=1, |
|
|
token_dim=1, |
|
|
dim=1024, |
|
|
) |
|
|
if cfg.MODEL.backbone_type in ['resnet-50', 'resnet-101', 'resnet-152', 'hrnet-w32', 'hrnet-w48']: |
|
|
context_dim = 2048 |
|
|
elif cfg.MODEL.backbone_type in ['vit-l-16']: |
|
|
context_dim = 1024 |
|
|
elif cfg.MODEL.backbone_type in ['vit-b-16']: |
|
|
context_dim = 768 |
|
|
elif cfg.MODEL.backbone_type in ['resnet-18', 'resnet-34']: |
|
|
context_dim = 512 |
|
|
elif cfg.MODEL.backbone_type in ['vit-s-16']: |
|
|
context_dim = 384 |
|
|
elif cfg.MODEL.backbone_type in ['handoccnet']: |
|
|
context_dim = 256 |
|
|
else: |
|
|
context_dim = 1280 |
|
|
MANO_HEAD_TRANSFORMER_DECODER_CONFIG = {'depth': 6, 'heads': 8, 'mlp_dim': 1024, 'dim_head': 64, 'dropout': 0.0, 'emb_dropout': 0.0, 'norm': 'layer', 'context_dim': context_dim} |
|
|
transformer_args.update(dict(MANO_HEAD_TRANSFORMER_DECODER_CONFIG)) |
|
|
self.transformer = TransformerDecoder( |
|
|
**transformer_args |
|
|
) |
|
|
self.deccontact = nn.Linear(1024, 778) |
|
|
|
|
|
CONTACT_MEAN_DIR = cfg.MODEL.contact_means_path |
|
|
init_contact = nn.Parameter(torch.randn(1, 778, requires_grad=True)) |
|
|
self.register_buffer('init_contact', init_contact) |
|
|
|
|
|
def forward(self, x, **kwargs): |
|
|
batch_size = x.shape[0] |
|
|
device = x.device |
|
|
|
|
|
|
|
|
x = rearrange(x, 'b c h w -> b (h w) c') |
|
|
|
|
|
init_contact = self.init_contact.expand(batch_size, -1) |
|
|
pred_contact = init_contact |
|
|
|
|
|
token = torch.zeros(batch_size, 1, 1).to(x.device) |
|
|
|
|
|
|
|
|
token_out = self.transformer(token, context=x) |
|
|
token_out = token_out[:, 0] |
|
|
|
|
|
|
|
|
pred_contact = self.deccontact(token_out) + pred_contact |
|
|
|
|
|
|
|
|
|
|
|
return pred_contact |