MatAnyone / matanyone2 /utils /get_default_model.py
PeiqingYang's picture
integrate MatAnyone 1 & 2
77b7c2e
raw
history blame contribute delete
936 Bytes
"""
A helper function to get a default model for quick testing
"""
from omegaconf import open_dict
from hydra import compose, initialize
import torch
from matanyone2.model.matanyone2 import MatAnyone2
def get_matanyone2_model(ckpt_path, device=None) -> MatAnyone2:
initialize(version_base='1.3.2', config_path="../config", job_name="eval_our_config")
cfg = compose(config_name="eval_matanyone_config")
with open_dict(cfg):
cfg['weights'] = ckpt_path
# Load the network weights
if device is not None:
matanyone2 = MatAnyone2(cfg, single_object=True).to(device).eval()
model_weights = torch.load(cfg.weights, map_location=device)
else: # if device is not specified, `.cuda()` by default
matanyone2 = MatAnyone2(cfg, single_object=True).cuda().eval()
model_weights = torch.load(cfg.weights)
matanyone2.load_weights(model_weights)
return matanyone2