Spaces:
Runtime error
Runtime error
| import os | |
| import torch.distributed as dist | |
| import torch | |
| import sys | |
| def setup(rank, world_size): | |
| os.environ['MASTER_ADDR'] = 'localhost' | |
| os.environ['MASTER_PORT'] = '1253' | |
| dist.init_process_group("nccl", rank=rank, world_size=world_size) | |
| def create_smplx_model(fast_smplx_path, | |
| model_path, | |
| model_type, | |
| gender, | |
| ext, | |
| batch_size, | |
| device): | |
| sys.path.insert(0, fast_smplx_path) | |
| import smplx | |
| smpl_model = smplx.create(model_path=model_path, | |
| model_type=model_type, | |
| gender=gender, | |
| ext=ext, | |
| batch_size=batch_size).to(device) | |
| smpl_model.eval() | |
| return smpl_model | |
| def load_checkpoint(model, optimizer, checkpoint_path): | |
| checkpoint = torch.load(checkpoint_path) | |
| model.load_state_dict(checkpoint['model_state_dict']) | |
| optimizer.load_state_dict(checkpoint['optimizer_state_dict']) | |