Spaces:
Runtime error
Runtime error
| # -*- coding: utf-8 -*- | |
| # Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is | |
| # holder of all proprietary rights on this computer program. | |
| # You can only use this computer program if you have closed | |
| # a license agreement with MPG or you get the right to use the computer | |
| # program from someone who is authorized to grant you that right. | |
| # Any use of the computer program without a valid license is prohibited and | |
| # liable to prosecution. | |
| # | |
| # Copyright©2020 Max-Planck-Gesellschaft zur Förderung | |
| # der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute | |
| # for Intelligent Systems. All rights reserved. | |
| # | |
| # Contact: Vassilis Choutas, vassilis.choutas@tuebingen.mpg.de | |
| from typing import Optional, Dict, Callable | |
| import sys | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| from tqdm import tqdm | |
| from loguru import logger | |
| from SMPLX.transfer_model.utils import get_vertices_per_edge | |
| from SMPLX.transfer_model.optimizers import build_optimizer, minimize | |
| from SMPLX.transfer_model.utils import ( | |
| Tensor, batch_rodrigues, apply_deformation_transfer) | |
| from SMPLX.transfer_model.losses import build_loss | |
| def summary_closure(gt_vertices, var_dict, body_model, mask_ids=None): | |
| param_dict = {} | |
| for key, var in var_dict.items(): | |
| # Decode the axis-angles | |
| if 'pose' in key or 'orient' in key: | |
| param_dict[key] = batch_rodrigues( | |
| var.reshape(-1, 3)).reshape(len(var), -1, 3, 3) | |
| else: | |
| # Simply pass the variable | |
| param_dict[key] = var | |
| body_model_output = body_model( | |
| return_full_pose=True, get_skin=True, **param_dict) | |
| est_vertices = body_model_output.vertices | |
| if mask_ids is not None: | |
| est_vertices = est_vertices[:, mask_ids] | |
| gt_vertices = gt_vertices[:, mask_ids] | |
| v2v = (est_vertices - gt_vertices).pow(2).sum(dim=-1).sqrt().mean() | |
| return { | |
| 'Vertex-to-Vertex': v2v * 1000} | |
| def build_model_forward_closure( | |
| body_model: nn.Module, | |
| var_dict: Dict[str, Tensor], | |
| per_part: bool = True, | |
| part_key: Optional[str] = None, | |
| jidx: Optional[int] = None, | |
| part: Optional[Tensor] = None | |
| ) -> Callable: | |
| if per_part: | |
| cond = part is not None and part_key is not None and jidx is not None | |
| assert cond, ( | |
| 'When per-part is True, "part", "part_key", "jidx" must not be' | |
| ' None.' | |
| ) | |
| def model_forward(): | |
| param_dict = {} | |
| for key, var in var_dict.items(): | |
| if part_key == key: | |
| param_dict[key] = batch_rodrigues( | |
| var.reshape(-1, 3)).reshape(len(var), -1, 3, 3) | |
| param_dict[key][:, jidx] = batch_rodrigues( | |
| part.reshape(-1, 3)).reshape(-1, 3, 3) | |
| else: | |
| # Decode the axis-angles | |
| if 'pose' in key or 'orient' in key: | |
| param_dict[key] = batch_rodrigues( | |
| var.reshape(-1, 3)).reshape(len(var), -1, 3, 3) | |
| else: | |
| # Simply pass the variable | |
| param_dict[key] = var | |
| return body_model( | |
| return_full_pose=True, get_skin=True, **param_dict) | |
| else: | |
| def model_forward(): | |
| param_dict = {} | |
| for key, var in var_dict.items(): | |
| # Decode the axis-angles | |
| if 'pose' in key or 'orient' in key: | |
| param_dict[key] = batch_rodrigues( | |
| var.reshape(-1, 3)).reshape(len(var), -1, 3, 3) | |
| else: | |
| # Simply pass the variable | |
| param_dict[key] = var | |
| return body_model(return_full_pose=True, get_skin=True, | |
| **param_dict) | |
| return model_forward | |
| def build_edge_closure( | |
| body_model: nn.Module, | |
| var_dict: Dict[str, Tensor], | |
| edge_loss: nn.Module, | |
| optimizer_dict, | |
| gt_vertices: Tensor, | |
| per_part: bool = True, | |
| part_key: Optional[str] = None, | |
| jidx: Optional[int] = None, | |
| part: Optional[Tensor] = None | |
| ) -> Callable: | |
| ''' Builds the closure for the edge objective | |
| ''' | |
| optimizer = optimizer_dict['optimizer'] | |
| create_graph = optimizer_dict['create_graph'] | |
| if per_part: | |
| params_to_opt = [part] | |
| else: | |
| params_to_opt = [p for key, p in var_dict.items() if 'pose' in key] | |
| model_forward = build_model_forward_closure( | |
| body_model, var_dict, per_part=per_part, part_key=part_key, | |
| jidx=jidx, part=part) | |
| def closure(backward=True): | |
| if backward: | |
| optimizer.zero_grad() | |
| body_model_output = model_forward() | |
| est_vertices = body_model_output.vertices | |
| loss = edge_loss(est_vertices, gt_vertices) | |
| if backward: | |
| if create_graph: | |
| # Use this instead of .backward to avoid GPU memory leaks | |
| grads = torch.autograd.grad( | |
| loss, params_to_opt, create_graph=True) | |
| torch.autograd.backward( | |
| params_to_opt, grads, create_graph=True) | |
| else: | |
| loss.backward() | |
| return loss | |
| return closure | |
| def build_vertex_closure( | |
| body_model: nn.Module, | |
| var_dict: Dict[str, Tensor], | |
| optimizer_dict, | |
| gt_vertices: Tensor, | |
| vertex_loss: nn.Module, | |
| mask_ids=None, | |
| per_part: bool = True, | |
| part_key: Optional[str] = None, | |
| jidx: Optional[int] = None, | |
| part: Optional[Tensor] = None, | |
| params_to_opt: Optional[Tensor] = None, | |
| ) -> Callable: | |
| ''' Builds the closure for the vertex objective | |
| ''' | |
| optimizer = optimizer_dict['optimizer'] | |
| create_graph = optimizer_dict['create_graph'] | |
| model_forward = build_model_forward_closure( | |
| body_model, var_dict, per_part=per_part, part_key=part_key, | |
| jidx=jidx, part=part) | |
| if params_to_opt is None: | |
| params_to_opt = [p for key, p in var_dict.items()] | |
| def closure(backward=True): | |
| if backward: | |
| optimizer.zero_grad() | |
| body_model_output = model_forward() | |
| est_vertices = body_model_output.vertices | |
| loss = vertex_loss( | |
| est_vertices[:, mask_ids] if mask_ids is not None else | |
| est_vertices, | |
| gt_vertices[:, mask_ids] if mask_ids is not None else gt_vertices) | |
| if backward: | |
| if create_graph: | |
| # Use this instead of .backward to avoid GPU memory leaks | |
| grads = torch.autograd.grad( | |
| loss, params_to_opt, create_graph=True) | |
| torch.autograd.backward( | |
| params_to_opt, grads, create_graph=True) | |
| else: | |
| loss.backward() | |
| return loss | |
| return closure | |
| def get_variables( | |
| batch_size: int, | |
| body_model: nn.Module, | |
| dtype: torch.dtype = torch.float32 | |
| ) -> Dict[str, Tensor]: | |
| var_dict = {} | |
| device = next(body_model.buffers()).device | |
| if (body_model.name() == 'SMPL' or body_model.name() == 'SMPL+H' or | |
| body_model.name() == 'SMPL-X'): | |
| var_dict.update({ | |
| 'transl': torch.zeros( | |
| [batch_size, 3], device=device, dtype=dtype), | |
| 'global_orient': torch.zeros( | |
| [batch_size, 1, 3], device=device, dtype=dtype), | |
| 'body_pose': torch.zeros( | |
| [batch_size, body_model.NUM_BODY_JOINTS, 3], | |
| device=device, dtype=dtype), | |
| 'betas': torch.zeros([batch_size, body_model.num_betas], | |
| dtype=dtype, device=device), | |
| }) | |
| if body_model.name() == 'SMPL+H' or body_model.name() == 'SMPL-X': | |
| var_dict.update( | |
| left_hand_pose=torch.zeros( | |
| [batch_size, body_model.NUM_HAND_JOINTS, 3], device=device, | |
| dtype=dtype), | |
| right_hand_pose=torch.zeros( | |
| [batch_size, body_model.NUM_HAND_JOINTS, 3], device=device, | |
| dtype=dtype), | |
| ) | |
| if body_model.name() == 'SMPL-X': | |
| var_dict.update( | |
| jaw_pose=torch.zeros([batch_size, 1, 3], | |
| device=device, dtype=dtype), | |
| leye_pose=torch.zeros([batch_size, 1, 3], | |
| device=device, dtype=dtype), | |
| reye_pose=torch.zeros([batch_size, 1, 3], | |
| device=device, dtype=dtype), | |
| expression=torch.zeros( | |
| [batch_size, body_model.num_expression_coeffs], | |
| device=device, dtype=dtype), | |
| ) | |
| # Toggle gradients to True | |
| for key, val in var_dict.items(): | |
| val.requires_grad_(True) | |
| return var_dict | |
| def run_fitting( | |
| # exp_cfg, | |
| batch: Dict[str, Tensor], | |
| body_model: nn.Module, | |
| def_matrix: Tensor, | |
| mask_ids | |
| ) -> Dict[str, Tensor]: | |
| ''' Runs fitting | |
| ''' | |
| vertices = batch['vertices'] | |
| faces = batch['faces'] | |
| batch_size = len(vertices) | |
| dtype, device = vertices.dtype, vertices.device | |
| # summary_steps = exp_cfg.get('summary_steps') | |
| # interactive = exp_cfg.get('interactive') | |
| summary_steps = 100 | |
| interactive = True | |
| # Get the parameters from the model | |
| var_dict = get_variables(batch_size, body_model) | |
| # Build the optimizer object for the current batch | |
| # optim_cfg = exp_cfg.get('optim', {}) | |
| optim_cfg = {'type': 'trust-ncg', 'lr': 1.0, 'gtol': 1e-06, 'ftol': -1.0, 'maxiters': 100, 'lbfgs': {'line_search_fn': 'strong_wolfe', 'max_iter': 50}, 'sgd': {'momentum': 0.9, 'nesterov': True}, 'adam': {'betas': [0.9, 0.999], 'eps': 1e-08, 'amsgrad': False}, 'trust_ncg': {'max_trust_radius': 1000.0, 'initial_trust_radius': 0.05, 'eta': 0.15, 'gtol': 1e-05}} | |
| def_vertices = apply_deformation_transfer(def_matrix, vertices, faces) | |
| if mask_ids is None: | |
| f_sel = np.ones_like(body_model.faces[:, 0], dtype=np.bool_) | |
| else: | |
| f_per_v = [[] for _ in range(body_model.get_num_verts())] | |
| [f_per_v[vv].append(iff) for iff, ff in enumerate(body_model.faces) | |
| for vv in ff] | |
| f_sel = list(set(tuple(sum([f_per_v[vv] for vv in mask_ids], [])))) | |
| vpe = get_vertices_per_edge( | |
| body_model.v_template.detach().cpu().numpy(), body_model.faces[f_sel]) | |
| def log_closure(): | |
| return summary_closure(def_vertices, var_dict, body_model, | |
| mask_ids=mask_ids) | |
| # edge_fitting_cfg = exp_cfg.get('edge_fitting', {}) | |
| edge_fitting_cfg = {'per_part': False, 'reduction': 'mean'} | |
| edge_loss = build_loss(type='vertex-edge', gt_edges=vpe, est_edges=vpe, | |
| **edge_fitting_cfg) | |
| edge_loss = edge_loss.to(device=device) | |
| # vertex_fitting_cfg = exp_cfg.get('vertex_fitting', {}) | |
| vertex_fitting_cfg = {} | |
| vertex_loss = build_loss(**vertex_fitting_cfg) | |
| vertex_loss = vertex_loss.to(device=device) | |
| per_part = edge_fitting_cfg.get('per_part', True) | |
| logger.info(f'Per-part: {per_part}') | |
| # Optimize edge-based loss to initialize pose | |
| if per_part: | |
| for key, var in tqdm(var_dict.items(), desc='Parts'): | |
| if 'pose' not in key: | |
| continue | |
| for jidx in tqdm(range(var.shape[1]), desc='Joints'): | |
| part = torch.zeros( | |
| [batch_size, 3], dtype=dtype, device=device, | |
| requires_grad=True) | |
| # Build the optimizer for the current part | |
| optimizer_dict = build_optimizer([part], optim_cfg) | |
| closure = build_edge_closure( | |
| body_model, var_dict, edge_loss, optimizer_dict, | |
| def_vertices, per_part=per_part, part_key=key, jidx=jidx, | |
| part=part) | |
| minimize(optimizer_dict['optimizer'], closure, | |
| params=[part], | |
| summary_closure=log_closure, | |
| summary_steps=summary_steps, | |
| interactive=interactive, | |
| **optim_cfg) | |
| with torch.no_grad(): | |
| var[:, jidx] = part | |
| else: | |
| optimizer_dict = build_optimizer(list(var_dict.values()), optim_cfg) | |
| closure = build_edge_closure( | |
| body_model, var_dict, edge_loss, optimizer_dict, | |
| def_vertices, per_part=per_part) | |
| minimize(optimizer_dict['optimizer'], closure, | |
| params=var_dict.values(), | |
| summary_closure=log_closure, | |
| summary_steps=summary_steps, | |
| interactive=interactive, | |
| **optim_cfg) | |
| if 'translation' in var_dict: | |
| optimizer_dict = build_optimizer([var_dict['translation']], optim_cfg) | |
| closure = build_vertex_closure( | |
| body_model, var_dict, | |
| optimizer_dict, | |
| def_vertices, | |
| vertex_loss=vertex_loss, | |
| mask_ids=mask_ids, | |
| per_part=False, | |
| params_to_opt=[var_dict['translation']], | |
| ) | |
| # Optimize translation | |
| minimize(optimizer_dict['optimizer'], | |
| closure, | |
| params=[var_dict['translation']], | |
| summary_closure=log_closure, | |
| summary_steps=summary_steps, | |
| interactive=interactive, | |
| **optim_cfg) | |
| # Optimize all model parameters with vertex-based loss | |
| optimizer_dict = build_optimizer(list(var_dict.values()), optim_cfg) | |
| closure = build_vertex_closure( | |
| body_model, var_dict, | |
| optimizer_dict, | |
| def_vertices, | |
| vertex_loss=vertex_loss, | |
| per_part=False, | |
| mask_ids=mask_ids) | |
| minimize(optimizer_dict['optimizer'], closure, | |
| params=list(var_dict.values()), | |
| summary_closure=log_closure, | |
| summary_steps=summary_steps, | |
| interactive=interactive, | |
| **optim_cfg) | |
| param_dict = {} | |
| for key, var in var_dict.items(): | |
| # Decode the axis-angles | |
| if 'pose' in key or 'orient' in key: | |
| param_dict[key] = batch_rodrigues( | |
| var.reshape(-1, 3)).reshape(len(var), -1, 3, 3) | |
| else: | |
| # Simply pass the variable | |
| param_dict[key] = var | |
| body_model_output = body_model( | |
| return_full_pose=True, get_skin=True, **param_dict) | |
| keys = ["vertices", "joints", "betas", "global_orient", "body_pose", "left_hand_pose", "right_hand_pose", "full_pose"] | |
| for key in keys: | |
| var_dict[key] = getattr(body_model_output, key) | |
| var_dict['faces'] = body_model.faces | |
| for key in var_dict.keys(): | |
| try: | |
| var_dict[key] = var_dict[key].detach().cpu().numpy() | |
| except: | |
| pass | |
| return var_dict | |