File size: 900 Bytes
326e019
316cf04
 
326e019
 
19a4dfc
 
326e019
 
 
 
 
 
 
 
 
 
bb73124
326e019
 
 
 
 
 
 
 
bb73124
326e019
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
import sys
# sys.path.insert(0, "/data/yzhouc01/MassSpecGym")
# sys.path.insert(0, "/data/yzhouc01/FILIP-MS")

from rdkit import RDLogger
from flare.utils.data import get_spec_featurizer, get_mol_featurizer, get_ms_dataset
from flare.utils.models import get_model

import yaml

# Suppress RDKit warnings and errors
lg = RDLogger.logger()
lg.setLevel(RDLogger.CRITICAL)

# Load model and data

def load_model_components():
    param_pth = 'hparams.yaml'
    with open(param_pth) as f:
        params = yaml.load(f, Loader=yaml.FullLoader)

    spec_featurizer = get_spec_featurizer(params['spectra_view'], params)
    mol_featurizer = get_mol_featurizer(params['molecule_view'], params)

    # load model

    checkpoint_pth = "pretrained_models/flare.ckpt"
    params['checkpoint_pth'] = checkpoint_pth
    model = get_model(params['model'], params)

    return spec_featurizer, mol_featurizer, model