Spaces:
Sleeping
Sleeping
| import os, sys | |
| import yaml | |
| import torch | |
| from loguru import logger | |
| from configs import constants as _C | |
| from .smpl import SMPL | |
| def build_body_model(device, batch_size=1, gender='neutral', **kwargs): | |
| sys.stdout = open(os.devnull, 'w') | |
| body_model = SMPL( | |
| model_path=_C.BMODEL.FLDR, | |
| gender=gender, | |
| batch_size=batch_size, | |
| create_transl=False).to(device) | |
| sys.stdout = sys.__stdout__ | |
| return body_model | |
| def build_network(cfg, smpl): | |
| from .wham import Network | |
| with open(cfg.MODEL_CONFIG, 'r') as f: | |
| model_config = yaml.safe_load(f) | |
| model_config.update({'d_feat': _C.IMG_FEAT_DIM[cfg.MODEL.BACKBONE]}) | |
| network = Network(smpl, **model_config).to(cfg.DEVICE) | |
| # Load Checkpoint | |
| if os.path.isfile(cfg.TRAIN.CHECKPOINT): | |
| checkpoint = torch.load(cfg.TRAIN.CHECKPOINT) | |
| ignore_keys = ['smpl.body_pose', 'smpl.betas', 'smpl.global_orient', 'smpl.J_regressor_extra', 'smpl.J_regressor_eval'] | |
| model_state_dict = {k: v for k, v in checkpoint['model'].items() if k not in ignore_keys} | |
| network.load_state_dict(model_state_dict, strict=False) | |
| logger.info(f"=> loaded checkpoint '{cfg.TRAIN.CHECKPOINT}' ") | |
| else: | |
| logger.info(f"=> Warning! no checkpoint found at '{cfg.TRAIN.CHECKPOINT}'.") | |
| return network |