Spaces:
Runtime error
Runtime error
| # -------------------------------------------------------- | |
| # Based on the 4DHumans code base | |
| # https://github.com/shubham-goel/4D-Humans | |
| # -------------------------------------------------------- | |
| import torch | |
| from typing import Any, Dict, Mapping, Tuple | |
| from yacs.config import CfgNode | |
| from ..utils import SkeletonRenderer, MeshRenderer | |
| from ..utils.geometry import perspective_projection | |
| from .backbones import create_backbone | |
| from .heads import build_smpl_head | |
| from . import SMPL | |
| class HMR2(torch.nn.Module): | |
| def __init__(self, cfg: CfgNode, init_renderer: bool = True): | |
| """ | |
| Setup HMR2 model | |
| Args: | |
| cfg (CfgNode): Config file as a yacs CfgNode | |
| """ | |
| super().__init__() | |
| # Save hyperparameters | |
| self.save_hyperparameters(logger=False, ignore=['init_renderer']) | |
| self.cfg = cfg | |
| # Create backbone feature extractor | |
| self.backbone = create_backbone(cfg) | |
| if cfg.MODEL.BACKBONE.get('PRETRAINED_WEIGHTS', None): | |
| self.backbone.load_state_dict(torch.load(cfg.MODEL.BACKBONE.PRETRAINED_WEIGHTS, map_location='cpu')['state_dict']) | |
| # Create SMPL head | |
| self.smpl_head = build_smpl_head(cfg) | |
| # Instantiate SMPL model | |
| smpl_cfg = {k.lower(): v for k,v in dict(cfg.SMPL).items()} | |
| self.smpl = SMPL(**smpl_cfg) | |
| # Buffer that shows whetheer we need to initialize ActNorm layers | |
| self.register_buffer('initialized', torch.tensor(False)) | |
| # Setup renderer for visualization | |
| if init_renderer: | |
| self.renderer = SkeletonRenderer(self.cfg) | |
| self.mesh_renderer = MeshRenderer(self.cfg, faces=self.smpl.faces) | |
| else: | |
| self.renderer = None | |
| self.mesh_renderer = None | |
| # Disable automatic optimization since we use adversarial training | |
| self.automatic_optimization = False | |
| def forward_step(self, batch: Dict, train: bool = False) -> Dict: | |
| """ | |
| Run a forward step of the network | |
| Args: | |
| batch (Dict): Dictionary containing batch data | |
| train (bool): Flag indicating whether it is training or validation mode | |
| Returns: | |
| Dict: Dictionary containing the regression output | |
| """ | |
| # Use RGB image as input | |
| x = batch['img'] | |
| batch_size = x.shape[0] | |
| # Compute conditioning features using the backbone | |
| # if using ViT backbone, we need to use a different aspect ratio | |
| conditioning_feats = self.backbone(x[:,:,:,32:-32]) | |
| pred_smpl_params, pred_cam, _ = self.smpl_head(conditioning_feats) | |
| # Store useful regression outputs to the output dict | |
| output = {} | |
| output['pred_cam'] = pred_cam | |
| output['pred_smpl_params'] = {k: v.clone() for k,v in pred_smpl_params.items()} | |
| # Compute camera translation | |
| device = pred_smpl_params['body_pose'].device | |
| dtype = pred_smpl_params['body_pose'].dtype | |
| focal_length = self.cfg.EXTRA.FOCAL_LENGTH * torch.ones(batch_size, 2, device=device, dtype=dtype) | |
| pred_cam_t = torch.stack([pred_cam[:, 1], | |
| pred_cam[:, 2], | |
| 2*focal_length[:, 0]/(self.cfg.MODEL.IMAGE_SIZE * pred_cam[:, 0] +1e-9)],dim=-1) | |
| output['pred_cam_t'] = pred_cam_t | |
| output['focal_length'] = focal_length | |
| # Compute model vertices, joints and the projected joints | |
| pred_smpl_params['global_orient'] = pred_smpl_params['global_orient'].reshape(batch_size, -1, 3, 3) | |
| pred_smpl_params['body_pose'] = pred_smpl_params['body_pose'].reshape(batch_size, -1, 3, 3) | |
| pred_smpl_params['betas'] = pred_smpl_params['betas'].reshape(batch_size, -1) | |
| smpl_output = self.smpl(**{k: v.float() for k,v in pred_smpl_params.items()}, pose2rot=False) | |
| pred_keypoints_3d = smpl_output.joints | |
| pred_vertices = smpl_output.vertices | |
| output['pred_keypoints_3d'] = pred_keypoints_3d.reshape(batch_size, -1, 3) | |
| output['pred_vertices'] = pred_vertices.reshape(batch_size, -1, 3) | |
| pred_cam_t = pred_cam_t.reshape(-1, 3) | |
| focal_length = focal_length.reshape(-1, 2) | |
| pred_keypoints_2d = perspective_projection(pred_keypoints_3d, | |
| translation=pred_cam_t, | |
| focal_length=focal_length / self.cfg.MODEL.IMAGE_SIZE) | |
| output['pred_keypoints_2d'] = pred_keypoints_2d.reshape(batch_size, -1, 2) | |
| return output | |
| def forward(self, batch: Dict) -> Dict: | |
| """ | |
| Run a forward step of the network in val mode | |
| Args: | |
| batch (Dict): Dictionary containing batch data | |
| Returns: | |
| Dict: Dictionary containing the regression output | |
| """ | |
| return self.forward_step(batch, train=False) | |