import os import torch from huggingface_hub import snapshot_download from src.models.hmr2 import HMR2 from huggingface_hub import login login(token=os.getenv('hf_token')) class Namespace(dict): __getattr__ = dict.get # def load_hmr2(checkpoint_path="data/model.ckpt"): def load_hmr2(repo_id="genzoo-org/genzoo", filename="model.ckpt"): local_dir = snapshot_download(repo_id=repo_id) checkpoint_path = os.path.join(local_dir, filename) model_cfg = Namespace( MODEL=Namespace( BACKBONE=Namespace( TYPE="vit", ), BBOX_SHAPE=[192, 256], IMAGE_MEAN=[0.485, 0.456, 0.406], IMAGE_SIZE=256, IMAGE_STD=[0.229, 0.224, 0.225], SMPL_HEAD=Namespace( IN_CHANNELS=2048, TRANSFORMER_DECODER=Namespace( context_dim=1280, depth=6, dim_head=64, dropout=0.0, emb_dropout=0.0, heads=8, mlp_dim=1024, norm="layer", ), TYPE="transformer_decoder", ), ), SMPL=Namespace( NUM_BETAS=145, NUM_BODY_JOINTS=34, ), ckpt_path=checkpoint_path, ) model = HMR2(model_cfg) model.load_state_dict(torch.load(checkpoint_path)) return model, model_cfg