File size: 936 Bytes
dcc8c59
 
 
 
 
 
 
77b7c2e
dcc8c59
77b7c2e
dcc8c59
 
 
 
 
 
 
5e2bf3b
77b7c2e
5e2bf3b
 
77b7c2e
5e2bf3b
 
77b7c2e
dcc8c59
77b7c2e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
"""
A helper function to get a default model for quick testing
"""
from omegaconf import open_dict
from hydra import compose, initialize

import torch
from matanyone2.model.matanyone2 import MatAnyone2

def get_matanyone2_model(ckpt_path, device=None) -> MatAnyone2:
    initialize(version_base='1.3.2', config_path="../config", job_name="eval_our_config")
    cfg = compose(config_name="eval_matanyone_config")
    
    with open_dict(cfg):
        cfg['weights'] = ckpt_path

    # Load the network weights
    if device is not None:
        matanyone2 = MatAnyone2(cfg, single_object=True).to(device).eval()
        model_weights = torch.load(cfg.weights, map_location=device)
    else:  # if device is not specified, `.cuda()` by default
        matanyone2 = MatAnyone2(cfg, single_object=True).cuda().eval()
        model_weights = torch.load(cfg.weights)
        
    matanyone2.load_weights(model_weights)

    return matanyone2