| | """ |
| | A helper function to get a default model for quick testing |
| | """ |
| | from omegaconf import open_dict |
| | from hydra import compose, initialize |
| |
|
| | import torch |
| | from matanyone.model.matanyone import MatAnyone |
| |
|
| | def get_matanyone_model(ckpt_path, device=None) -> MatAnyone: |
| | initialize(version_base='1.3.2', config_path="../config", job_name="eval_our_config") |
| | cfg = compose(config_name="eval_matanyone_config") |
| | |
| | with open_dict(cfg): |
| | cfg['weights'] = ckpt_path |
| |
|
| | |
| | if device is not None: |
| | matanyone = MatAnyone(cfg, single_object=True).to(device).eval() |
| | model_weights = torch.load(cfg.weights, map_location=device) |
| | else: |
| | matanyone = MatAnyone(cfg, single_object=True).cuda().eval() |
| | model_weights = torch.load(cfg.weights) |
| | |
| | matanyone.load_weights(model_weights) |
| |
|
| | return matanyone |
| |
|