| """ | |
| A helper function to get a default model for quick testing | |
| """ | |
| from omegaconf import open_dict | |
| from hydra import compose, initialize | |
| import torch | |
| from matanyone2.model.matanyone2 import MatAnyone2 | |
| def get_matanyone2_model(ckpt_path, device=None) -> MatAnyone2: | |
| initialize(version_base='1.3.2', config_path="../config", job_name="eval_our_config") | |
| cfg = compose(config_name="eval_matanyone_config") | |
| with open_dict(cfg): | |
| cfg['weights'] = ckpt_path | |
| # Load the network weights | |
| if device is not None: | |
| matanyone2 = MatAnyone2(cfg, single_object=True).to(device).eval() | |
| model_weights = torch.load(cfg.weights, map_location=device) | |
| else: # if device is not specified, `.cuda()` by default | |
| matanyone2 = MatAnyone2(cfg, single_object=True).cuda().eval() | |
| model_weights = torch.load(cfg.weights) | |
| matanyone2.load_weights(model_weights) | |
| return matanyone2 | |