Spaces:
Sleeping
Sleeping
File size: 900 Bytes
326e019 316cf04 326e019 19a4dfc 326e019 bb73124 326e019 bb73124 326e019 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 |
import sys
# sys.path.insert(0, "/data/yzhouc01/MassSpecGym")
# sys.path.insert(0, "/data/yzhouc01/FILIP-MS")
from rdkit import RDLogger
from flare.utils.data import get_spec_featurizer, get_mol_featurizer, get_ms_dataset
from flare.utils.models import get_model
import yaml
# Suppress RDKit warnings and errors
lg = RDLogger.logger()
lg.setLevel(RDLogger.CRITICAL)
# Load model and data
def load_model_components():
param_pth = 'hparams.yaml'
with open(param_pth) as f:
params = yaml.load(f, Loader=yaml.FullLoader)
spec_featurizer = get_spec_featurizer(params['spectra_view'], params)
mol_featurizer = get_mol_featurizer(params['molecule_view'], params)
# load model
checkpoint_pth = "pretrained_models/flare.ckpt"
params['checkpoint_pth'] = checkpoint_pth
model = get_model(params['model'], params)
return spec_featurizer, mol_featurizer, model |