File size: 4,642 Bytes
42f26af
 
 
 
0b51da1
42f26af
 
 
 
 
 
 
2c0063e
 
 
 
42f26af
2c0063e
42f26af
 
 
 
 
 
 
 
 
 
 
 
 
19a4dfc
 
42f26af
 
 
 
 
 
b1aa639
 
42f26af
 
 
 
 
b1aa639
42f26af
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19a4dfc
42f26af
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19a4dfc
 
 
 
 
 
42f26af
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
import argparse
import datetime
import sys
sys.path.insert(0, "/data/yzhouc01/MassSpecGym")
sys.path.insert(0, "/data/yzhouc01/FILIP-MS")

from rdkit import RDLogger
import pytorch_lightning as pl
from pytorch_lightning import Trainer
from massspecgym.models.base import Stage
import os

from flare.data.data_module import TestDataModule
from flare.data.datasets import ContrastiveDataset
from flare.utils.data import get_spec_featurizer, get_mol_featurizer, get_test_ms_dataset
from flare.utils.models import get_model

from flare.definitions import TEST_RESULTS_DIR
import yaml
from functools import partial
# Suppress RDKit warnings and errors
lg = RDLogger.logger()
lg.setLevel(RDLogger.CRITICAL)

parser = argparse.ArgumentParser()
parser.add_argument("--param_pth", type=str, default="params_formSpec.yaml")
parser.add_argument('--checkpoint_pth', type=str, default='')
parser.add_argument('--checkpoint_choice', type=str, default='train', choices=['train', 'val'])
parser.add_argument('--df_test_pth', type=str, help='result file name')
parser.add_argument('--exp_dir', type=str)
parser.add_argument('--candidates_pth', type=str)
parser.add_argument('--external_test', action='store_true', help='whether the test set is external data without labels')

def main(params):
    # Seed everything
    pl.seed_everything(params['seed'])
        
    # Init paths to data files
    if params['debug']:

        params['dataset_pth'] = "/data/yzhouc01/MVP/data/sample/data.tsv"
        params['split_pth']=None
        params['df_test_path'] = os.path.join(params['experiment_dir'], 'debug_result.pkl')

    # Load dataset
    spec_featurizer = get_spec_featurizer(params['spectra_view'], params)

    mol_featurizer = get_mol_featurizer(params['molecule_view'], params)
    dataset = get_test_ms_dataset(params['spectra_view'], params['molecule_view'], spec_featurizer, mol_featurizer, params)

    # Init data module
    collate_fn = partial(ContrastiveDataset.collate_fn, spec_enc=params['spec_enc'], spectra_view=params['spectra_view'], stage=Stage.TEST)
    data_module = TestDataModule(
        dataset=dataset,
        collate_fn=collate_fn,
        split_pth=params['split_pth'],
        batch_size=params['batch_size'],
        num_workers=params['num_workers']
    )

    model = get_model(params['model'], params)
    model.df_test_path = params['df_test_path']
    model.external_test = params['external_test']
    
    # Init trainer
    trainer = Trainer(
        accelerator=params['accelerator'],
        devices=params['devices'],
        default_root_dir=params['experiment_dir']
    )

    # Prepare data module to test
    data_module.prepare_data()
    data_module.setup(stage="test")
        
    # Test
    trainer.test(model, datamodule=data_module)


if __name__ == "__main__":
    args = parser.parse_args([] if "__file__" not in globals() else None)

    # Load
    with open(args.param_pth) as f:
        params = yaml.load(f, Loader=yaml.FullLoader)
    
    # Experiment directory
    if args.exp_dir:
        exp_dir = args.exp_dir
    else:
        run_name = params['run_name']
        for exp in os.listdir(TEST_RESULTS_DIR): # find exp dir with matching run_name
            if exp.endswith("_"+run_name):
                exp_dir = str(TEST_RESULTS_DIR / exp)
                break
    if not exp_dir:
        now = datetime.datetime.now().strftime("%Y%m%d")
        exp_dir = str(TEST_RESULTS_DIR / f"{now}_{params['run_name']}")
        os.makedirs(exp_dir, exist_ok=True)
    print("EXPERIMENT directory: ",exp_dir)
    params['experiment_dir'] = exp_dir
    
    # Checkpoint path
    if args.checkpoint_pth:
        params['checkpoint_pth'] = args.checkpoint_pth
    
    if not params['checkpoint_pth']:
        print("No checkpoint provided. Using the checkpoint in the experiment directory")
        for f in os.listdir(exp_dir):
            if f.endswith("ckpt") and f.startswith("epoch") and args.checkpoint_choice in f:
                checkpoint_path = os.path.join(exp_dir, f)
                params['checkpoint_pth'] = checkpoint_path
                break
    assert(params['checkpoint_pth'] != '')

    if args.external_test:
        params['external_test'] = True
    else:
        params['external_test'] = False
        
    if args.candidates_pth:
        params['candidates_pth'] = args.candidates_pth
    if args.df_test_pth:
        params['df_test_path'] = os.path.join(exp_dir, args.df_test_pth)
    if not params['df_test_path']:
        params['df_test_path'] = os.path.join(exp_dir, f"result_{params['candidates_pth'].split('/')[-1].split('.')[0]}.pkl")
        
    main(params)