| | import torch |
| |
|
| |
|
| | def to_numpy(tensor): |
| | if torch.is_tensor(tensor): |
| | return tensor.cpu().numpy() |
| | elif type(tensor).__module__ != 'numpy': |
| | raise ValueError("Cannot convert {} to numpy array".format( |
| | type(tensor))) |
| | return tensor |
| |
|
| |
|
| | def to_torch(ndarray): |
| | if type(ndarray).__module__ == 'numpy': |
| | return torch.from_numpy(ndarray) |
| | elif not torch.is_tensor(ndarray): |
| | raise ValueError("Cannot convert {} to torch tensor".format( |
| | type(ndarray))) |
| | return ndarray |
| |
|
| |
|
| | def cleanexit(): |
| | import sys |
| | import os |
| | try: |
| | sys.exit(0) |
| | except SystemExit: |
| | os._exit(0) |
| |
|
| | def load_model_wo_clip(model, state_dict): |
| | missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False) |
| | assert len(unexpected_keys) == 0 |
| | assert all([k.startswith('clip_model.') for k in missing_keys]) |
| |
|
| | def freeze_joints(x, joints_to_freeze): |
| | |
| | |
| | frozen = x.detach().clone() |
| | frozen[:, joints_to_freeze, :, :] = frozen[:, joints_to_freeze, :, :1] |
| | return frozen |
| |
|