File size: 4,451 Bytes
42f26af
 
 
 
 
 
 
 
 
 
 
 
 
2c0063e
42f26af
2c0063e
42f26af
2c0063e
42f26af
 
2c0063e
 
42f26af
 
 
 
 
 
 
 
 
 
 
 
 
0b51da1
42f26af
 
 
 
 
 
 
 
 
2c0063e
42f26af
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0b51da1
42f26af
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
import argparse
import datetime

import os
import sys
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))

from rdkit import RDLogger
import pytorch_lightning as pl
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks.early_stopping import EarlyStopping


from flare.data.data_module import ContrastiveDataModule

from flare.definitions import TEST_RESULTS_DIR
import yaml
from flare.data.datasets import ContrastiveDataset
from functools import partial

from flare.utils.data import get_ms_dataset, get_spec_featurizer, get_mol_featurizer
from flare.utils.models import get_model
# Suppress RDKit warnings and errors
lg = RDLogger.logger()
lg.setLevel(RDLogger.CRITICAL)

parser = argparse.ArgumentParser()
parser.add_argument("--param_pth", type=str, default="params_formSpec.yaml")

def main(params):
    # Seed everything
    pl.seed_everything(params['seed'])
        
    # Init paths to data files
    if params['debug']:
        params['dataset_pth'] = "/data/yzhouc01/MVP/data/sample/data.tsv"
        params['candidates_pth'] =None
        params['split_pth']=None

    # Load dataset
    spec_featurizer = get_spec_featurizer(params['spectra_view'], params)
    mol_featurizer = get_mol_featurizer(params['molecule_view'], params)
    dataset = get_ms_dataset(params['spectra_view'], params['molecule_view'], spec_featurizer, mol_featurizer, params)
    
    # Init data module
    collate_fn = partial(ContrastiveDataset.collate_fn, spec_enc=params['spec_enc'], spectra_view=params['spectra_view'])
    data_module = ContrastiveDataModule(
        dataset=dataset,
        collate_fn=collate_fn,
        split_pth=params['split_pth'],
        batch_size=params['batch_size'],
        num_workers=params['num_workers'],
    )

    model = get_model(params['model'], params)

    # Init logger
    if params['no_wandb']:
        logger = None
    else:
        logger = pl.loggers.WandbLogger(
            save_dir=params['experiment_dir'],
            dir=params['experiment_dir'],
            log_dir=params['experiment_dir'],
            name=params['run_name'],
            project=params['project_name'],
            log_model=False,
            config=model.hparams
        )

    # Init callbacks for checkpointing and early stopping
    callbacks = [pl.callbacks.ModelCheckpoint(save_last=False) ]
    for i, monitor in enumerate(model.get_checkpoint_monitors()):
        monitor_name = monitor['monitor']
        checkpoint = pl.callbacks.ModelCheckpoint(
            monitor=monitor_name,
            save_top_k=1,
            mode=monitor['mode'],
            dirpath=params['experiment_dir'],
            filename=f'{{epoch}}-{{{monitor_name}:.2f}}',
            # filename='{epoch}-{val_loss:.2f}-{train_loss:.2f}',
            auto_insert_metric_name=True,
            # save_last=(i == 0)
        )
        callbacks.append(checkpoint)
        if monitor.get('early_stopping', False):
            early_stopping = EarlyStopping(
                monitor=monitor_name,
                mode=monitor['mode'],
                verbose=True,
                patience=params['early_stopping_patience'],
            )
            callbacks.append(early_stopping)

    # Init trainer
    trainer = Trainer(
        accelerator=params['accelerator'],
        devices=params['devices'],
        max_epochs=params['max_epochs'],
        logger=logger,
        log_every_n_steps=params['log_every_n_steps'],
        val_check_interval=params['val_check_interval'],
        callbacks=callbacks,
        default_root_dir=params['experiment_dir'],
    )

    # Prepare data module to validate or test before training
    data_module.prepare_data()
    data_module.setup()


    # Validate before training
    trainer.validate(model, datamodule=data_module)

    # Train
    trainer.fit(model, datamodule=data_module)
        


if __name__ == "__main__":
    args = parser.parse_args([] if "__file__" not in globals() else None)

    # Get current time
    now = datetime.datetime.now()
    now_formatted = now.strftime("%Y%m%d")

    # Load
    with open(args.param_pth) as f:
        params = yaml.load(f, Loader=yaml.FullLoader)

    experiment_dir = str(TEST_RESULTS_DIR / f"{now_formatted}_{params['run_name']}")
    params['experiment_dir'] = experiment_dir

    if not params['df_test_path']:
        params['df_test_path'] = os.path.join(experiment_dir, "result.pkl")

    main(params)