Spaces:
Running
Running
| import os | |
| import torch | |
| from huggingface_hub import snapshot_download | |
| from src.models.hmr2 import HMR2 | |
| from huggingface_hub import login | |
| login(token=os.getenv('hf_token')) | |
| class Namespace(dict): | |
| __getattr__ = dict.get | |
| # def load_hmr2(checkpoint_path="data/model.ckpt"): | |
| def load_hmr2(repo_id="genzoo-org/genzoo", filename="model.ckpt"): | |
| local_dir = snapshot_download(repo_id=repo_id) | |
| checkpoint_path = os.path.join(local_dir, filename) | |
| model_cfg = Namespace( | |
| MODEL=Namespace( | |
| BACKBONE=Namespace( | |
| TYPE="vit", | |
| ), | |
| BBOX_SHAPE=[192, 256], | |
| IMAGE_MEAN=[0.485, 0.456, 0.406], | |
| IMAGE_SIZE=256, | |
| IMAGE_STD=[0.229, 0.224, 0.225], | |
| SMPL_HEAD=Namespace( | |
| IN_CHANNELS=2048, | |
| TRANSFORMER_DECODER=Namespace( | |
| context_dim=1280, | |
| depth=6, | |
| dim_head=64, | |
| dropout=0.0, | |
| emb_dropout=0.0, | |
| heads=8, | |
| mlp_dim=1024, | |
| norm="layer", | |
| ), | |
| TYPE="transformer_decoder", | |
| ), | |
| ), | |
| SMPL=Namespace( | |
| NUM_BETAS=145, | |
| NUM_BODY_JOINTS=34, | |
| ), | |
| ckpt_path=checkpoint_path, | |
| ) | |
| model = HMR2(model_cfg) | |
| model.load_state_dict(torch.load(checkpoint_path)) | |
| return model, model_cfg | |